In [1]:
import time
import torch
import torch.cuda.amp as amp
import os
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import math
import numpy as np
from torch.nn.utils import clip_grad_norm_
from sklearn.decomposition import PCA
from collections import Counter
from math import log2
from typing import List, Dict

## Load the data

In [2]:
with open('/kaggle/input/tinyshake/tiny_shake.txt', "r", encoding="utf-8") as my_file:
    text = my_file.read()

chars = sorted(list(set(text)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train = data[:n]
val = data[n:]

## Config class and helper functions

In [3]:
class Config:
    batch_size = 64
    block_size = 128
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embedding_dim = 256
    hidden_dim = 512
    num_layers = 2
    dropout = 0.5
    learning_rate = 1e-3
    grad_clip = 1.0
    max_epochs = 20
    weight_init_range = 0.1
    num_batches_per_epoch = 100
    num_eval_batches = 10
    vocab_size = len(chars)

# Code to get the training and validation batches
def get_batch(split, block_size=Config.block_size, batch_size=Config.batch_size):
    """
    Get a batch of data for training or testing.
    """
    data = train if split == 'train' else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x_batch = []
    y_batch = []

    for i in ix:
        x_batch.append(data[i:i+block_size])
        y_batch.append(data[i+1:i+1+block_size])

    x = torch.stack(x_batch)
    y = torch.stack(y_batch)
    x, y = x.to(Config.device), y.to(Config.device)

    return x, y

# Generate text
def generate(model, start_text, max_length, temperature):
    model.eval()
    chars = torch.tensor(encode(start_text)).unsqueeze(0).to(Config.device)
    
    with torch.no_grad():
        for _ in range(max_length):
            logits = model(chars)
            next_char_logits = logits[0, -1, :] / temperature
            probs = F.softmax(next_char_logits, dim=0)
            next_char = torch.multinomial(probs, 1)
            chars = torch.cat([chars, next_char.unsqueeze(0)], dim=1)
    
    return decode(chars[0].tolist())

## Entropy of the dataset

In [None]:
def calculate_ngram_entropies(text, max_n = 4):
    """
    Calculate entropy for different n-gram sizes up to max_n.
    Returns both bits per character and bits per byte.
    """
    results = {}
    
    for n in range(1, max_n + 1):
        # Create overlapping n-grams
        ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
        
        # Count frequencies
        counts = Counter(ngrams)
        total = len(ngrams)
        
        # Calculate probabilities
        probs = [count/total for count in counts.values()]
        
        # Calculate entropy
        entropy = -sum(p * log2(p) for p in probs)
        
        # Normalize by n to get per-character entropy
        normalized_entropy = entropy / n
        # Calculate bits per byte
        bits_per_byte = normalized_entropy / 8
        
        results[n] = (normalized_entropy, bits_per_byte)
        
    return results

def calculate_conditional_entropy(text, context_size):
    """
    Calculate conditional entropy H(X|Y) where Y is the context.
    Returns both bits per character and bits per byte.
    """
    # Get all possible contexts and their following characters
    contexts = {}
    for i in range(len(text) - context_size):
        context = text[i:i+context_size]
        next_char = text[i+context_size]
        if context not in contexts:
            contexts[context] = []
        contexts[context].append(next_char)
    
    # Calculate conditional entropy
    total_contexts = sum(len(chars) for chars in contexts.values())
    entropy = 0
    
    for context, next_chars in contexts.items():
        # Probability of this context
        context_prob = len(next_chars) / total_contexts
        
        # Calculate entropy for characters following this context
        char_counts = Counter(next_chars)
        char_probs = [count/len(next_chars) for count in char_counts.values()]
        context_entropy = -sum(p * log2(p) for p in char_probs)
        
        entropy += context_prob * context_entropy
    
    bits_per_byte = entropy / 8
    return entropy, bits_per_byte

# Calculate n-gram entropies
ngram_entropies = calculate_ngram_entropies(text)
for n, (bpc, bpb) in ngram_entropies.items():
    print(f"{n}-gram entropy: {bpc:.4f} bits per character ({bpb:.4f} bits per byte)")

# Calculate conditional entropies for different context sizes
for context_size in range(1, 5):
    bpc, bpb = calculate_conditional_entropy(text, context_size)
    print(f"Conditional entropy with context size {context_size}: "
          f"{bpc:.4f} bits per character ({bpb:.4f} bits per byte)")

1-gram entropy: 4.7794 bits per character (0.5974 bits per byte)
2-gram entropy: 4.1588 bits per character (0.5199 bits per byte)
3-gram entropy: 3.6899 bits per character (0.4612 bits per byte)
4-gram entropy: 3.3077 bits per character (0.4135 bits per byte)
Conditional entropy with context size 1: 3.5383 bits per character (0.4423 bits per byte)
Conditional entropy with context size 2: 2.7520 bits per character (0.3440 bits per byte)
Conditional entropy with context size 3: 2.1610 bits per character (0.2701 bits per byte)
Conditional entropy with context size 4: 1.7672 bits per character (0.2209 bits per byte)


## Model

To improve training we are going to use **torch.autocast**. It automatically converts certain operations to run in FP16 precision instead of FP32 where it's safe to do so. This includes matrix multiplications in the LSTM layers, the embedding layer operations, linear layer transformations while It keeps other operations that need higher precision (like loss calculations and certain accumulations) in FP32 for numerical stability. 

The GradScaler works together with autocast to scale up the loss to prevent underflow in FP16 gradients, unscale gradients before gradient clipping and handle the optimizer step safely with mixed precision. Underflow occurs when a calculation produces a number that is too small to be represented in the available precision format, causing it to be rounded to zero. FP16 can only represent numbers between approximately 6e-5 and 65504 when gradients become very small during backpropagation they might fall below 6e-5, 
if this happens, they get rounded to 0, effectively stopping any learning in those parameters. Example:

* Without scaling: 1e-8 * 0.5 = 0 (underflow in FP16)
* With scaling: (1e-8 * 1000) * 0.5 = 5e-6, then divide by 1000 = 5e-9 (preserves the small value)

In [None]:
class ShakeLSTM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, Config.embedding_dim)
        self.layer_norm = nn.LayerNorm(Config.embedding_dim)
        
        self.lstm = nn.LSTM(
            input_size=Config.embedding_dim,
            hidden_size=Config.hidden_dim,
            num_layers=Config.num_layers,
            dropout=Config.dropout,
            batch_first=True # (batch_size, seq_length, input_size)
        )
        
        self.dropout = nn.Dropout(Config.dropout)
        self.fc = nn.Linear(Config.hidden_dim, vocab_size)
        
        self.init_weights()

    def init_weights(self):
        # Initialize embeddings and linear layer
        self.embedding.weight.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
        self.fc.weight.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
        self.fc.bias.data.zero_()

        # Initialize LSTM weights and biases
        for name, param in self.lstm.named_parameters():
            if 'weight' in name:
                param.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
            elif 'bias' in name:
                # Initialize all biases to zero
                param.data.zero_()
                # Set forget gate bias to 1
                if 'bias_ih' in name or 'bias_hh' in name:
                    n = param.size(0)
                    start, end = n//4, n//2
                    param.data[start:end].fill_(1.)

    def forward(self, x):
        embedded = self.embedding(x) # 64, 128 --> 64, 128, 256
        normalized = self.layer_norm(embedded) 
        lstm_out, _ = self.lstm(normalized) #64, 128, 512
        dropped = self.dropout(lstm_out)
        logits = self.fc(dropped) # 64, 128, 65
        return logits

scaler = amp.GradScaler()

def train_epoch(model, optimizer, criterion):
    model.train()
    total_loss = 0
    total_perplexity = 0

    for i in range(Config.num_batches_per_epoch):
        x, y = get_batch('train')

        optimizer.zero_grad()

        # Enable automatic mixed precision and converts certain operations to run in FP16 where beneficial
        # while keeping other operations in FP32 where needed for numerical stability.
        with torch.amp.autocast(device_type='cuda'):
            logits = model(x)
            B, T, C = logits.shape
            loss = criterion(logits.view(-1, C), y.view(-1))
            perplexity = torch.exp(loss)

        # Scale to prevent underflow
        scaler.scale(loss).backward()

        # Gradient clipping
        scaler.unscale_(optimizer)  # Unscale before clipping
        clip_grad_norm_(model.parameters(), max_norm=Config.grad_clip)

        # Optimizer step
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        total_perplexity += perplexity.item()

    avg_loss = total_loss / Config.num_batches_per_epoch
    bits_per_char = avg_loss / math.log(2)  # Convert nats to bits
    bits_per_byte = bits_per_char / 8  # Convert bits per char to bits per byte

    return avg_loss, bits_per_byte

# Modified evaluate function with memory optimizations
@torch.no_grad()
def evaluate(model, criterion):
    model.eval()
    total_loss = 0
    
    for i in range(Config.num_eval_batches):
        x, y = get_batch('val')
        with torch.amp.autocast(device_type='cuda'):
            logits = model(x)
            B, T, C = logits.shape
            loss = criterion(logits.view(-1, C), y.view(-1))
        
        total_loss += loss.item()
        
        # Explicitly clear cache every few batches
        if i % 10 == 0:
            torch.cuda.empty_cache()
        
    avg_loss = total_loss / Config.num_eval_batches
    bits_per_char = avg_loss / math.log(2)
    bits_per_byte = bits_per_char / 8
    
    return avg_loss, bits_per_byte

  scaler = amp.GradScaler()


## Training

In [10]:
model = ShakeLSTM(vocab_size=Config.vocab_size).to(Config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.learning_rate)
criterion = nn.CrossEntropyLoss()

best_val_loss = float('inf')
prompt = "To be or not to be that is the question"

for epoch in range(Config.max_epochs):
    t0 = time.time()
    train_loss, train_bpb = train_epoch(model, optimizer, criterion)
    val_loss, val_bpb = evaluate(model, criterion)  
    t1 = time.time()
    dt = t1 - t0  # time in seconds
    
    # Calculate tokens processed
    tokens_processed = Config.batch_size * Config.block_size * Config.num_batches_per_epoch
    tokens_per_sec = tokens_processed / dt
    
    print(f'Epoch {epoch+1} took {dt*1000:.2f} ms: '
          f'train loss {train_loss:.4f}, '
          f'train bits/byte {train_bpb:.4f}, '
          f'val loss {val_loss:.4f}, '
          f'val bits/byte {val_bpb:.4f}, '
          f'tokens/sec: {tokens_per_sec:.2f}')
    
generated_text = generate(model, prompt, max_length=500, temperature=0.5)
print(f"\nGenerated Text:\n{generated_text}\n")
print("-"*50)

print("Training finished!")

Epoch 1 took 2352.05 ms: train loss 2.5979, train bits/byte 0.4685, val loss 2.0597, val bits/byte 0.3714, tokens/sec: 348291.45
Epoch 2 took 2278.77 ms: train loss 2.0094, train bits/byte 0.3624, val loss 1.8767, val bits/byte 0.3384, tokens/sec: 359491.77
Epoch 3 took 2298.70 ms: train loss 1.8339, train bits/byte 0.3307, val loss 1.7796, val bits/byte 0.3209, tokens/sec: 356374.63
Epoch 4 took 2305.48 ms: train loss 1.7333, train bits/byte 0.3126, val loss 1.7289, val bits/byte 0.3118, tokens/sec: 355327.34
Epoch 5 took 2305.48 ms: train loss 1.6712, train bits/byte 0.3014, val loss 1.6955, val bits/byte 0.3058, tokens/sec: 355327.20
Epoch 6 took 2312.34 ms: train loss 1.6190, train bits/byte 0.2920, val loss 1.6499, val bits/byte 0.2975, tokens/sec: 354272.65
Epoch 7 took 2313.96 ms: train loss 1.5813, train bits/byte 0.2852, val loss 1.6211, val bits/byte 0.2923, tokens/sec: 354025.74
Epoch 8 took 2330.04 ms: train loss 1.5517, train bits/byte 0.2798, val loss 1.6015, val bits/byt

By doing this technique we were able to reduce the training time, in each epoch, from 7315.04 ms to 2394.50, in order words, a 67% improvement, not bad for a 7 years old GPU. However, if we use a modern GPU, and im not even talking about A100, I'm talking about RTX4090, we would smash these numbers. 