In [1]:
# !wget https://raw.github/zusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [2]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

tokens = tokenizer.encode(text, add_special_tokens=False)

Token indices sequence length is longer than the specified maximum sequence length for this model (338025 > 1024). Running this sequence through the model will result in indexing errors


In [4]:
context_length = 512
batch_size = 2

In [5]:
num_batches = len(tokens) // (batch_size * context_length)
tokens = tokens[:num_batches * batch_size * context_length]

In [6]:
import torch

input_ids = torch.tensor(tokens).view(-1, context_length)

In [7]:
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from torch.utils.data import random_split

dataset = TensorDataset(input_ids)

train_ratio = 0.8
test_ratio = 0.2

train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
from labml_nn.transformers.LoRA.GPT2 import GPTModel

model = GPTModel()
state_dict = torch.load('transformed.pth', weights_only=True)

_ = model.load_state_dict(state_dict, strict=False)

In [9]:
device = "cuda"
model = model.to(device="cuda")

In [10]:
from labml import tracker, experiment

optimizer = Adam(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

model.train()
epochs = 3
step = 0

with experiment.record(name='LoRA.GPT2', app_url='http://localhost:5005/api/v1/track'):
    for epoch in range(epochs):
        for batch in train_dataloader:
            inputs = batch[0]
            inputs = inputs.to(device)
            labels = inputs.clone()
            
            outputs = model(inputs)
            
            shift_logits = outputs[..., :-1, :]
            shift_labels = labels[..., 1:]
            
            loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            tracker.save(step, {'loss': loss})
            step += 1
        print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')
        
        test_loss = 0
        for batch in test_dataloader:
            inputs = batch[0]
            inputs = inputs.to(device)
            labels = inputs.clone()
            
            outputs = model(inputs)
            
            shift_logits = outputs[..., :-1, :]
            shift_labels = labels[..., 1:]
            
            loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
            
            test_loss += loss.item()
        test_loss /= len(test_dataloader)
        tracker.save(step, {'test_loss': test_loss})
        

print("Training complete.")

KeyboardInterrupt: 