In [None]:
!pip install datasets

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from datasets import load_dataset
from transformers import BertTokenizerFast, BertModel
from random import randint
from torch.utils.data import Dataset

import transformers

In [None]:
dataset = load_dataset("roneneldan/TinyStories")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [None]:
data = [para for para in dataset['train']["text"] if len(para) > 0]

In [None]:
class configs:
    chunk_size = 100
    batch_size = 32
    block_size = 50
    epochs = 100
    eval_interval = 1000
    learning_rate = 3e-4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    eval_iters = 500
    vocab_size = 30522
    n_embd = 768
    n_head = 12
    n_layer = 2
    dropout = 0.3

In [None]:
class StoriesDataset(Dataset):
    def __init__(self, dataset, tokenizer, chunk_size):
        self.stories = dataset
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size

    def __len__(self):
        return len(self.stories)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        story = self.stories[idx]
        tokens = self.tokenizer.encode_plus(
            story, add_special_tokens=False, return_attention_mask=True
        )

        input_ids = tokens['input_ids']
        attention_mask = tokens['attention_mask']

        if len(input_ids) < self.chunk_size:
            pad_length = self.chunk_size - len(input_ids)

            input_tokens = [0] * pad_length + [101] + input_ids[:-1]
            output_tokens = [0] * pad_length + input_ids
            attention_mask = [0] * pad_length + [1] + attention_mask[:-1]
            assert len(input_tokens) == len(output_tokens), f"{len(input_tokens)} {len(output_tokens)} {len(input_ids)}"

        else:
            start_idx = randint(0, max(0, len(input_ids) - self.chunk_size))

            input_tokens = [101] + input_ids[start_idx : start_idx + self.chunk_size - 1]
            output_tokens = input_ids[start_idx: start_idx + self.chunk_size]
            attention_mask = [1] + attention_mask[start_idx : start_idx + self.chunk_size-1]
            assert len(input_tokens) == len(output_tokens), f"{len(input_tokens)} {len(output_tokens)} {len(input_ids)} {start_idx}"


        return (
            torch.tensor(input_tokens, dtype=torch.long),
            torch.tensor(output_tokens, dtype=torch.long),
            torch.tensor(attention_mask, dtype=torch.long)
        )

In [None]:
class BERT_LSTM_GRU(nn.Module):
    def __init__(self, bert_model, hidden_dim, embedding_dim):
        super(BERT_LSTM_GRU, self).__init__()
        self.bert = bert_model
        self.bert.requires_grad_(False)
        self.lstm = nn.LSTM(768, hidden_dim, batch_first=True).to(torch.float32)
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True).to(torch.float32)
        self.lm_head = nn.Linear(hidden_dim, embedding_dim).to(torch.float32)

    def forward(self, x):
        with torch.no_grad():
            embedding = self.bert(x).last_hidden_state
        x, _ = self.lstm(embedding)
        x, _ = self.gru(x)
        x = self.lm_head(x)
        return x

In [None]:
bert_model = BertModel.from_pretrained('bert-base-uncased')
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = BERT_LSTM_GRU(bert_model, 512, tokenizer.vocab_size).to(device)

In [None]:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=configs.learning_rate)

In [None]:
scaler = torch.amp.GradScaler(device='cuda')

In [None]:
checkpoint_dir = "./checkpoints/"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
train_dataset = StoriesDataset(data, tokenizer, configs.chunk_size)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size)

In [None]:
checkpoint = torch.load("model_epoch2_step10000.pt", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scaler.load_state_dict(checkpoint["scaler_state_dict"])

In [None]:
print(checkpoint["loss"])

In [None]:
for epoch in range(configs.epochs):
    model.train()
    epoch_loss = 0

    for step, (input_tokens, output_tokens, attention_mask) in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}")):
        input_tokens, output_tokens, attention_mask = input_tokens.to(device), output_tokens.to(device), attention_mask.to(device)

        with autocast():
            logits = model(input_tokens)
            loss = loss_fn(logits.view(-1, logits.size(-1)), output_tokens.view(-1))

        scaler.scale(loss).backward()

        if (step + 1) % 4 == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss += loss.item()

        if (step + 1) % 10000 == 0:
            print(f"Loss: {loss.item()}")
            checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch+1}_step{step+1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'loss': loss.item()
            }, checkpoint_path)

    print(f"Epoch {epoch+1}: Average Loss = {epoch_loss / len(train_dataloader):.4f}")

In [None]:
def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'):
    model.eval()

    tokens = tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt').to(device)

    generated_tokens = tokens.clone()

    with torch.no_grad():
        for _ in range(max_length):
            logits = model(generated_tokens)

            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            generated_tokens = torch.cat([generated_tokens, next_token], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    generated_text = tokenizer.decode(generated_tokens.squeeze(0), skip_special_tokens=True)

    return generated_text

In [None]:
generate_text(model, tokenizer, "The moral of the story ")

In [None]:
from google.colab import files

In [None]:
files.download("./checkpoints/model_epoch2_step40000.pt")