In [None]:
!pip install torch
!pip install pytorch-lightning
!pip install transformers
!pip install tensorboard

In [20]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from pytorch_lightning.profilers import PyTorchProfiler

In [14]:
class NietzscheDataModule(pl.LightningDataModule):
    def __init__(self, file_path, tokenizer, batch_size=4):
        super(NietzscheDataModule, self).__init__()
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def setup(self, stage=None):
        with open(self.file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        self.input_ids = self.tokenizer.encode(text, truncation=True, add_special_tokens=True, return_tensors='pt')

    def train_dataloader(self):
        dataset = torch.utils.data.TensorDataset(self.input_ids)
        return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

In [15]:
class NietzscheTextGenerator(pl.LightningModule):
    def __init__(self):
        super(NietzscheTextGenerator, self).__init__()
        self.model = GPT2LMHeadModel.from_pretrained('gpt2')

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def training_step(self, batch, batch_idx):
        input_ids = torch.stack(batch)
        outputs = self.model(input_ids=input_ids, labels=input_ids)
        loss = outputs.loss
        self.log('train_loss', loss, on_epoch=True)
        perplexity = torch.exp(loss)
        self.log('perplexity', perplexity, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=1e-5)

tensorboard --logdir=./logs

In [23]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
profiler = PyTorchProfiler(profile_memory=True, schedule=torch.profiler.schedule(wait=1,warmup=1,active=2))
logger = TensorBoardLogger(save_dir='./logs', name='GPT2-NIETZSCHE')

In [24]:
data_module = NietzscheDataModule(file_path='nietzsche.txt', tokenizer=tokenizer, batch_size=4)
nietzsche_generator = NietzscheTextGenerator()
trainer = pl.Trainer(accelerator="auto", max_epochs=10, logger = logger, profiler = profiler)
trainer.fit(nietzsche_generator, data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type            | Params
------------------------------------------
0 | model | GPT2LMHeadModel | 124 M 
------------------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
497.759   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
FIT Profiler Report
Profile stats for: records
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        11.13%        4.093s        76.15%       28.019s       14.009s       3.42 Gb       1.10 Gb             2  
                        [pl][profile]run_training_epoch         0.01%       4.492ms        34.80%       12.803s        6.402s       2.95 Gb      -8.01 Kb             2  
                        [pl][profile]run_training_batch

In [6]:
def generate_text(prompt, max_length=100):
    input_ids_prompt = tokenizer.encode(prompt, return_tensors='pt')
    attention_mask = torch.ones_like(input_ids_prompt)
    nietzsche_generator.to(input_ids_prompt.device)
    output = nietzsche_generator.model.generate(
        input_ids_prompt,
        attention_mask=attention_mask,
        max_length=max_length,
        do_sample=True,
        top_k=50,
        top_p=0.90,
        pad_token_id=tokenizer.eos_token_id
    )
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

In [None]:
prompt = "In the abyss"
generated_text = generate_text(prompt)
print(generated_text)