In [17]:
import time
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import math
import numpy as np
from torch.nn.utils import clip_grad_norm_
from collections import Counter
from math import log2
from typing import List, Dict

## Load the data

In [13]:
with open('/kaggle/input/tiny-shake/input.txt', "r", encoding="utf-8") as my_file:
    text = my_file.read()

chars = sorted(list(set(text)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train = data[:n]
val = data[n:]

## Config class and helper functions

In [14]:
class Config:
    batch_size = 64
    block_size = 128
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embedding_dim = 256
    hidden_dim = 512
    num_layers = 2
    dropout = 0.5
    learning_rate = 1e-3
    grad_clip = 1.0
    max_epochs = 20
    weight_init_range = 0.1
    num_batches_per_epoch = 100
    num_eval_batches = 10
    vocab_size = len(chars)

# Code to get the training and validation batches
def get_batch(split, block_size=Config.block_size, batch_size=Config.batch_size):
    """
    Get a batch of data for training or testing.
    """
    data = train if split == 'train' else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x_batch = []
    y_batch = []

    for i in ix:
        x_batch.append(data[i:i+block_size])
        y_batch.append(data[i+1:i+1+block_size])

    x = torch.stack(x_batch)
    y = torch.stack(y_batch)
    x, y = x.to(Config.device), y.to(Config.device)

    return x, y

# Generate text
def generate(model, start_text, max_length, temperature):
    model.eval()
    chars = torch.tensor(encode(start_text)).unsqueeze(0).to(Config.device)
    
    with torch.no_grad():
        for _ in range(max_length):
            logits = model(chars)
            next_char_logits = logits[0, -1, :] / temperature
            probs = F.softmax(next_char_logits, dim=0)
            next_char = torch.multinomial(probs, 1)
            chars = torch.cat([chars, next_char.unsqueeze(0)], dim=1)
    
    return decode(chars[0].tolist())


## Entropy of the dataset

Entropy, in the context of llms, refers to the amount of information a token carries. The bigger the entropy the harder it is for the model to predict the next word, or in our case, the next character.

We are also measuring the bits pert byte, the number of bits a language model needs to represent one byte of the original training data.

In [27]:
def calculate_ngram_entropies(text, max_n = 4):
    """
    Calculate entropy for different n-gram sizes up to max_n.
    Returns both bits per character and bits per byte.
    """
    results = {}
    
    for n in range(1, max_n + 1):
        # Create overlapping n-grams
        ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
        
        # Count frequencies
        counts = Counter(ngrams)
        total = len(ngrams)
        
        # Calculate probabilities
        probs = [count/total for count in counts.values()]
        
        # Calculate entropy
        entropy = -sum(p * log2(p) for p in probs)
        
        # Normalize by n to get per-character entropy
        normalized_entropy = entropy / n
        # Calculate bits per byte
        bits_per_byte = normalized_entropy / 8
        
        results[n] = (normalized_entropy, bits_per_byte)
        
    return results

def calculate_conditional_entropy(text, context_size):
    """
    Calculate conditional entropy H(X|Y) where Y is the context.
    Returns both bits per character and bits per byte.
    """
    # Get all possible contexts and their following characters
    contexts = {}
    for i in range(len(text) - context_size):
        context = text[i:i+context_size]
        next_char = text[i+context_size]
        if context not in contexts:
            contexts[context] = []
        contexts[context].append(next_char)
    
    # Calculate conditional entropy
    total_contexts = sum(len(chars) for chars in contexts.values())
    entropy = 0
    
    for context, next_chars in contexts.items():
        # Probability of this context
        context_prob = len(next_chars) / total_contexts
        
        # Calculate entropy for characters following this context
        char_counts = Counter(next_chars)
        char_probs = [count/len(next_chars) for count in char_counts.values()]
        context_entropy = -sum(p * log2(p) for p in char_probs)
        
        entropy += context_prob * context_entropy
    
    bits_per_byte = entropy / 8
    return entropy, bits_per_byte

# Calculate n-gram entropies
ngram_entropies = calculate_ngram_entropies(text)
for n, (bpc, bpb) in ngram_entropies.items():
    print(f"{n}-gram entropy: {bpc:.4f} bits per character ({bpb:.4f} bits per byte)")

# Calculate conditional entropies for different context sizes
for context_size in range(1, 5):
    bpc, bpb = calculate_conditional_entropy(text, context_size)
    print(f"Conditional entropy with context size {context_size}: "
          f"{bpc:.4f} bits per character ({bpb:.4f} bits per byte)")

1-gram entropy: 4.7794 bits per character (0.5974 bits per byte)
2-gram entropy: 4.1588 bits per character (0.5199 bits per byte)
3-gram entropy: 3.6899 bits per character (0.4612 bits per byte)
4-gram entropy: 3.3077 bits per character (0.4135 bits per byte)
Conditional entropy with context size 1: 3.5383 bits per character (0.4423 bits per byte)
Conditional entropy with context size 2: 2.7520 bits per character (0.3440 bits per byte)
Conditional entropy with context size 3: 2.1610 bits per character (0.2701 bits per byte)
Conditional entropy with context size 4: 1.7672 bits per character (0.2209 bits per byte)


By looking at the values of the consitional entropy, we can conclude that the words depend heavily on the previous context and not by chance. A language model is trained to minimize its cross entropy with respect to the training data. If the language model learns perfectly from its training data, the model’s cross entropy will be exactly the same as the entropy of the training data.We can think of a model’s cross entropy as its approximation of the entropy of its training data.

## Model

In [19]:
class ShakeLSTM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, Config.embedding_dim)
        self.layer_norm = nn.LayerNorm(Config.embedding_dim)
        
        self.lstm = nn.LSTM(
            input_size=Config.embedding_dim,
            hidden_size=Config.hidden_dim,
            num_layers=Config.num_layers,
            dropout=Config.dropout,
            batch_first=True # (batch_size, seq_length, input_size)
        )
        
        self.dropout = nn.Dropout(Config.dropout)
        self.fc = nn.Linear(Config.hidden_dim, vocab_size)
        
        self.init_weights()

    def init_weights(self):
        # Initialize embeddings and linear layer
        self.embedding.weight.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
        self.fc.weight.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
        self.fc.bias.data.zero_()

        # Initialize LSTM weights and biases
        for name, param in self.lstm.named_parameters():
            if 'weight' in name:
                param.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
            elif 'bias' in name:
                # Initialize all biases to zero
                param.data.zero_()
                # Set forget gate bias to 1
                if 'bias_ih' in name or 'bias_hh' in name:
                    n = param.size(0)
                    start, end = n//4, n//2
                    param.data[start:end].fill_(1.)

    def forward(self, x):
        embedded = self.embedding(x) # 64, 128 --> 64, 128, 256
        normalized = self.layer_norm(embedded) 
        lstm_out, _ = self.lstm(normalized) #64, 128, 512
        dropped = self.dropout(lstm_out)
        logits = self.fc(dropped) # 64, 128, 65
        return logits

def train_epoch(model, optimizer, criterion):
    model.train()
    total_loss = 0
    total_perplexity = 0
    
    for i in range(Config.num_batches_per_epoch):
        x, y = get_batch('train')
        
        optimizer.zero_grad()
        logits = model(x)
        
        B, T, C = logits.shape
        loss = criterion(logits.view(-1, C), y.view(-1))
        perplexity = torch.exp(loss)
        
        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm=Config.grad_clip)
        optimizer.step()
        
        total_loss += loss.item()
        total_perplexity += perplexity.item()
        
    avg_loss = total_loss / Config.num_batches_per_epoch
    # avg_perplexity = total_perplexity / Config.num_batches_per_epoch
    bits_per_char = avg_loss / math.log(2)  # Convert nats to bits
    bits_per_byte = bits_per_char / 8  # Convert bits per char to bits per byte
    
    return avg_loss, bits_per_byte

@torch.no_grad()
def evaluate(model, criterion):
    model.eval()
    total_loss = 0
    total_perplexity = 0
    
    for i in range(Config.num_eval_batches):
        x, y = get_batch('val')
        logits = model(x)
        B, T, C = logits.shape
        loss = criterion(logits.view(-1, C), y.view(-1))
        perplexity = torch.exp(loss)
        
        total_loss += loss.item()
        total_perplexity += perplexity.item()
        
    avg_loss = total_loss / Config.num_eval_batches
    # avg_perplexity = total_perplexity / Config.num_eval_batches
    bits_per_char = avg_loss / math.log(2)  # Convert nats to bits
    bits_per_byte = bits_per_char / 8  # Convert bits per char to bits per byte
    
    return avg_loss, bits_per_byte

## Training

In [22]:
model = ShakeLSTM(vocab_size=Config.vocab_size).to(Config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.learning_rate)
criterion = nn.CrossEntropyLoss()

best_val_loss = float('inf')
prompt = "To be or not to be that is the question"


for epoch in range(Config.max_epochs):
    t0 = time.time()
    train_loss, train_bpb = train_epoch(model, optimizer, criterion)
    val_loss, val_bpb = evaluate(model, criterion)  
    t1 = time.time()
    dt = (t1 - t0) * 1000
    
    print(f'Epoch {epoch+1} took {dt:.2f} miliseconds: '
          f'train loss {train_loss:.4f}, '
          f'train bits/byte {train_bpb:.4f}, '
          f'val loss {val_loss:.4f}, '
          f'val bits/byte {val_bpb:.4f}, ')
    
generated_text = generate(model, prompt, max_length=500, temperature=0.5)
print(f"\nGenerated Text:\n{generated_text}\n")
print("-"*50)

print("Training finished!")
print(f'Training time excluding generation: {(t1 - t0) * 1000} miliseconds')

Epoch 1 took 6551.30 miliseconds: train loss 2.6143, train bits/byte 0.4715, val loss 2.0554, val bits/byte 0.3707, 
Epoch 2 took 6601.88 miliseconds: train loss 2.0109, train bits/byte 0.3626, val loss 1.8814, val bits/byte 0.3393, 
Epoch 3 took 6741.45 miliseconds: train loss 1.8310, train bits/byte 0.3302, val loss 1.7893, val bits/byte 0.3227, 
Epoch 4 took 6882.44 miliseconds: train loss 1.7267, train bits/byte 0.3114, val loss 1.7121, val bits/byte 0.3088, 
Epoch 5 took 7045.88 miliseconds: train loss 1.6684, train bits/byte 0.3009, val loss 1.6824, val bits/byte 0.3034, 
Epoch 6 took 7204.28 miliseconds: train loss 1.6146, train bits/byte 0.2912, val loss 1.6510, val bits/byte 0.2977, 
Epoch 7 took 7383.98 miliseconds: train loss 1.5806, train bits/byte 0.2850, val loss 1.6355, val bits/byte 0.2949, 
Epoch 8 took 7533.01 miliseconds: train loss 1.5463, train bits/byte 0.2789, val loss 1.5964, val bits/byte 0.2879, 
Epoch 9 took 7572.81 miliseconds: train loss 1.5213, train bits/

The model achieved very good metrics, with training loss decreasing from 2.61 to 1.38 and validation loss improving from 2.06 to 1.50 over 20 epochs. However, the generated text shows a fundamental limitation of LSTM architectures: their milited ability to maintain long-term dependencies through hidden states. While the model successfully learned to generate descent English words and even some Shakespearean-style phrases (terms like "COMINIUS" and "KING RICHARD III"), it fails to maintain coherence across sentences. The output do not have logical flow and meaningful context, giving text that is syntactically valid but meaningless. This shows one of the key problems that Transformer architectures addressed through their self-attention mechanism, which can more effectively capture and maintain long-range dependencies in the text.