In [None]:
!pip install datasets transformers -Uq # torch tqdm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m64.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolve

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tqdm.auto import tqdm

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        out = wei @ v
        return out


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList([
            AttentionHead(head_size, n_embd, block_size, dropout) for _ in range(num_heads)
        ])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [None]:
class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


In [None]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.attn = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
        self.ff = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x



In [None]:
class SimplifiedGPT2(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[
            Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # Ensure T does not exceed block_size
        assert T <= self.block_size, "Input sequence length exceeds block size."
        tok_emb = self.token_embedding_table(idx)
        pos_indices = torch.arange(T, device=idx.device).unsqueeze(0)  # Shape: (1, T)
        pos_emb = self.position_embedding_table(pos_indices)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [None]:
class TextDataset(Dataset):
    def __init__(self, data, block_size, tokenizer):
        self.block_size = block_size
        self.tokenizer = tokenizer
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        chunk = self.data[idx]
        encoded = self.tokenizer.encode(chunk).ids

        # Pad or truncate to block_size + 1 (we need one extra token for the target)
        if len(encoded) < self.block_size + 1:
            encoded = encoded + [self.tokenizer.token_to_id("[PAD]")] * (self.block_size + 1 - len(encoded))
        else:
            encoded = encoded[:self.block_size + 1]

        x = torch.tensor(encoded[:-1], dtype=torch.long)
        y = torch.tensor(encoded[1:], dtype=torch.long)
        return x, y


In [None]:
def create_chunks(text, block_size, overlap=0):
    chunks = []
    for i in range(0, len(text) - block_size + 1, block_size - overlap):
        chunk = text[i:i + block_size + 1]  # +1 for the target token
        if len(chunk) >= block_size // 2:  # Only keep chunks that are at least half the block size
            chunks.append(chunk)
    return chunks


In [None]:
def train_model():
    # Hyperparameters
    batch_size = 256
    block_size = 64
    max_iters = 5000
    learning_rate = 3e-4
    n_embd = 128
    n_head = 4
    n_layer = 4
    dropout = 0.2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load dataset
    dataset = load_dataset("tiny_shakespeare", trust_remote_code=True)
    train_text = dataset["train"]["text"][0]

    # Train tokenizer
    tokenizer = Tokenizer(BPE())
    trainer = BpeTrainer(
        special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"],
        vocab_size=4096,
        min_frequency=2
    )
    tokenizer.pre_tokenizer = Whitespace()
    tokenizer.train_from_iterator([train_text], trainer=trainer)

    # Create chunks of text with overlap
    chunks = create_chunks(train_text, block_size, overlap=block_size//4)
    train_dataset = TextDataset(chunks, block_size, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Initialize model
    vocab_size = tokenizer.get_vocab_size()
    print(f"Vocab size: {vocab_size}")

    model = SimplifiedGPT2(vocab_size, n_embd, n_head, n_layer, block_size, dropout)
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    # Training loop
    model.train()
    for iter in tqdm(range(max_iters)):
        total_loss = 0
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)

            # Create attention mask for padding
            pad_token_id = tokenizer.token_to_id("[PAD]")
            attention_mask = (x != pad_token_id).float()

            logits, loss = model(x, y)
            # Mask out the loss for padding tokens
            loss = loss * attention_mask.view(-1)
            loss = loss.sum() / attention_mask.sum()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        if iter % 100 == 0:
            avg_loss = total_loss / len(train_loader)
            print(f"Iteration {iter}: Average Loss {avg_loss:.4f}")
        if iter % 10 == 0:  # Less frequent cache clearing
            torch.cuda.empty_cache()

    return model, tokenizer

def generate_text(model, tokenizer, prompt, max_new_tokens=100):
    model.eval()
    encoded = tokenizer.encode(prompt).ids
    context = torch.tensor([encoded], dtype=torch.long).to(next(model.parameters()).device)
    with torch.no_grad():
        output_ids = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()

    # Remove padding tokens if any
    pad_token_id = tokenizer.token_to_id("[PAD]")
    output_ids = [id for id in output_ids if id != pad_token_id]

    return tokenizer.decode(output_ids)

In [None]:
model, tokenizer = train_model()

# Generate sample text
sample_text = generate_text(model, tokenizer, "Once upon a")
print("Generated text:")
print(sample_text)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

tiny_shakespeare.py:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1 [00:00<?, ? examples/s]

Vocab size: 4096


  0%|          | 0/100 [00:00<?, ?it/s]

Iteration 0: Average Loss 2.9922
Generated text:
Once upon a life as this seems con ap art Mis . When , now I am not
