In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from datasets import load_dataset
from transformers import BertTokenizerFast, BertModel
from random import randint
from torch.utils.data import Dataset

import transformers

In [None]:
dataset = load_dataset("roneneldan/TinyStories")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
data = [para for para in dataset['train']["text"] if len(para) > 0]

In [None]:
class configs:
    chunk_size = 100
    batch_size = 32
    block_size = 50
    epochs = 100
    eval_interval = 1000
    learning_rate = 3e-4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    eval_iters = 500
    vocab_size = 30522
    n_embd = 768
    n_head = 12
    n_layer = 2
    dropout = 0.3

In [None]:
class StoriesDataset(Dataset):
    def __init__(self, dataset, tokenizer, chunk_size):
        self.stories = dataset
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size

    def __len__(self):
        return len(self.stories)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        story = self.stories[idx]
        tokens = self.tokenizer.encode_plus(
            story, add_special_tokens=False, return_attention_mask=True
        )

        input_ids = tokens['input_ids']
        attention_mask = tokens['attention_mask']

        if len(input_ids) < self.chunk_size:
            pad_length = self.chunk_size - len(input_ids)

            input_tokens = [0] * pad_length + [101] + input_ids[:-1]
            output_tokens = [0] * pad_length + input_ids
            attention_mask = [0] * pad_length + [1] + attention_mask[:-1]
            assert len(input_tokens) == len(output_tokens), f"{len(input_tokens)} {len(output_tokens)} {len(input_ids)}"

        else:
            start_idx = randint(0, max(0, len(input_ids) - self.chunk_size))

            input_tokens = [101] + input_ids[start_idx : start_idx + self.chunk_size - 1]
            output_tokens = input_ids[start_idx: start_idx + self.chunk_size]
            attention_mask = [1] + attention_mask[start_idx : start_idx + self.chunk_size-1]
            assert len(input_tokens) == len(output_tokens), f"{len(input_tokens)} {len(output_tokens)} {len(input_ids)} {start_idx}"


        return (
            torch.tensor(input_tokens, dtype=torch.long),
            torch.tensor(output_tokens, dtype=torch.long),
            torch.tensor(attention_mask, dtype=torch.long)
        )

In [None]:
class BERT_LSTM_GRU(nn.Module):
    def __init__(self, bert_model, hidden_dim, embedding_dim):
        super(BERT_LSTM_GRU, self).__init__()
        self.bert = bert_model
        self.bert.requires_grad_(False)
        self.lstm = nn.LSTM(768, hidden_dim, batch_first=True).to(torch.float32)
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True).to(torch.float32)
        self.lm_head = nn.Linear(hidden_dim, embedding_dim).to(torch.float32)

    def forward(self, x):
        with torch.no_grad():
            embedding = self.bert(x).last_hidden_state
        x, _ = self.lstm(embedding)
        x, _ = self.gru(x)
        x = self.lm_head(x)
        return x

In [None]:
bert_model = BertModel.from_pretrained('bert-base-uncased')
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = BERT_LSTM_GRU(bert_model, 512, tokenizer.vocab_size).to(device)

In [None]:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=configs.learning_rate)

In [None]:
scaler = torch.amp.GradScaler(device='cuda')

In [None]:
checkpoint_dir = "./checkpoints/"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
train_dataset = StoriesDataset(data, tokenizer, configs.chunk_size)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size)

In [None]:
checkpoint = torch.load("model_epoch2_step10000.pt", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scaler.load_state_dict(checkpoint["scaler_state_dict"])

  checkpoint = torch.load("model_epoch2_step10000.pt", map_location=device)


In [None]:
print(checkpoint["loss"])

0.5636722445487976


In [None]:
for epoch in range(configs.epochs):
    model.train()
    epoch_loss = 0

    for step, (input_tokens, output_tokens, attention_mask) in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}")):
        input_tokens, output_tokens, attention_mask = input_tokens.to(device), output_tokens.to(device), attention_mask.to(device)

        with autocast():
            logits = model(input_tokens)
            loss = loss_fn(logits.view(-1, logits.size(-1)), output_tokens.view(-1))

        scaler.scale(loss).backward()

        if (step + 1) % 4 == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss += loss.item()

        if (step + 1) % 10000 == 0:
            print(f"Loss: {loss.item()}")
            checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch+1}_step{step+1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'loss': loss.item()
            }, checkpoint_path)

    print(f"Epoch {epoch+1}: Average Loss = {epoch_loss / len(train_dataloader):.4f}")

  with autocast():
Epoch 1:  15%|█▌        | 9998/66235 [18:02<1:34:59,  9.87it/s]

Loss: 1.6931893825531006


Epoch 1:  30%|███       | 19999/66235 [36:04<1:38:15,  7.84it/s]

Loss: 0.7940985560417175


Epoch 1:  45%|████▌     | 29999/66235 [54:05<1:02:34,  9.65it/s]

Loss: 0.693553626537323


Epoch 1:  60%|██████    | 39999/66235 [1:12:16<48:32,  9.01it/s]

Loss: 0.5339387655258179


Epoch 1:  75%|███████▌  | 49999/66235 [1:30:21<31:33,  8.58it/s]

Loss: 0.40305325388908386


Epoch 1:  91%|█████████ | 59999/66235 [1:48:31<10:35,  9.81it/s]

Loss: 0.35816720128059387


Epoch 1: 100%|██████████| 66235/66235 [2:00:00<00:00,  9.20it/s]


Epoch 1: Average Loss = 1.1477


Epoch 2:  15%|█▌        | 9999/66235 [18:03<1:36:54,  9.67it/s]

Loss: 0.3849264085292816


Epoch 2:  30%|███       | 19999/66235 [36:24<1:42:10,  7.54it/s]

Loss: 0.3168342709541321


Epoch 2:  45%|████▌     | 29998/66235 [54:34<1:02:46,  9.62it/s]

Loss: 0.28193771839141846


Epoch 2:  60%|██████    | 39999/66235 [1:12:38<57:11,  7.65it/s]

Loss: 0.29828816652297974


Epoch 2:  61%|██████    | 40216/66235 [1:13:13<47:22,  9.15it/s]


KeyboardInterrupt: 

In [None]:
def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'):
    model.eval()

    tokens = tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt').to(device)

    generated_tokens = tokens.clone()

    with torch.no_grad():
        for _ in range(max_length):
            logits = model(generated_tokens)

            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            generated_tokens = torch.cat([generated_tokens, next_token], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    generated_text = tokenizer.decode(generated_tokens.squeeze(0), skip_special_tokens=True)

    return generated_text

In [None]:
generate_text(model, tokenizer, "The moral of the story ")

'the moral of the story a you look at the story : there, you have to see what is a€œletletalalalinginginginginginginginginginginginginginginginginginginginginginginginginginging'

In [None]:
from google.colab import files

In [None]:
files.download("./checkpoints/model_epoch2_step40000.pt")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>