## A100 GPU

In [None]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.utils import clip_grad_norm_

with open('./tiny_shake.txt', "r", encoding="utf-8") as my_file:
    text = my_file.read()

chars = sorted(list(set(text)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train = data[:n]
val = data[n:]

class Config:
    batch_size = 64
    block_size = 128
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embedding_dim = 256
    hidden_dim = 512
    num_layers = 2
    dropout = 0.5
    learning_rate = 1e-3
    grad_clip = 1.0
    max_epochs = 20
    weight_init_range = 0.1
    num_batches_per_epoch = 100
    num_eval_batches = 10
    vocab_size = len(chars)

# Code to get the training and validation batches
def get_batch(split, block_size=Config.block_size, batch_size=Config.batch_size):
    """
    Get a batch of data for training or testing.
    """
    data = train if split == 'train' else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x_batch = []
    y_batch = []

    for i in ix:
        x_batch.append(data[i:i+block_size])
        y_batch.append(data[i+1:i+1+block_size])

    x = torch.stack(x_batch)
    y = torch.stack(y_batch)
    x, y = x.to(Config.device), y.to(Config.device)

    return x, y

# Generate text
def generate(model, start_text, max_length, temperature):
    model.eval()
    chars = torch.tensor(encode(start_text)).unsqueeze(0).to(Config.device)
    
    with torch.no_grad():
        for _ in range(max_length):
            logits = model(chars)
            next_char_logits = logits[0, -1, :] / temperature
            probs = F.softmax(next_char_logits, dim=0)
            next_char = torch.multinomial(probs, 1)
            chars = torch.cat([chars, next_char.unsqueeze(0)], dim=1)
    
    return decode(chars[0].tolist())


class ShakeLSTM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, Config.embedding_dim)
        self.layer_norm = nn.LayerNorm(Config.embedding_dim)
        
        self.lstm = nn.LSTM(
            input_size=Config.embedding_dim,
            hidden_size=Config.hidden_dim,
            num_layers=Config.num_layers,
            dropout=Config.dropout,
            batch_first=True # (batch_size, seq_length, input_size)
        )
        
        self.dropout = nn.Dropout(Config.dropout)
        self.fc = nn.Linear(Config.hidden_dim, vocab_size)
        
        self.init_weights()

    def init_weights(self):
        # Initialize embeddings and linear layer
        self.embedding.weight.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
        self.fc.weight.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
        self.fc.bias.data.zero_()

        # Initialize LSTM weights and biases
        for name, param in self.lstm.named_parameters():
            if 'weight' in name:
                param.data.uniform_(-Config.weight_init_range, Config.weight_init_range)
            elif 'bias' in name:
                # Initialize all biases to zero
                param.data.zero_()
                # Set forget gate bias to 1
                if 'bias_ih' in name or 'bias_hh' in name:
                    n = param.size(0)
                    start, end = n//4, n//2
                    param.data[start:end].fill_(1.)

    def forward(self, x):
        embedded = self.embedding(x) # 64, 128 --> 64, 128, 256
        normalized = self.layer_norm(embedded) 
        lstm_out, _ = self.lstm(normalized) #64, 128, 512
        dropped = self.dropout(lstm_out)
        logits = self.fc(dropped) # 64, 128, 65
        return logits

def train_epoch(model, optimizer, criterion):
    model.train()
    total_loss = 0
    total_perplexity = 0
    
    for i in range(Config.num_batches_per_epoch):
        x, y = get_batch('train')
        
        optimizer.zero_grad()
        logits = model(x)
        
        B, T, C = logits.shape
        loss = criterion(logits.view(-1, C), y.view(-1))
        perplexity = torch.exp(loss)
        
        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm=Config.grad_clip)
        optimizer.step()
        
        total_loss += loss.item()
        total_perplexity += perplexity.item()
        
    avg_loss = total_loss / Config.num_batches_per_epoch
    # avg_perplexity = total_perplexity / Config.num_batches_per_epoch
    bits_per_char = avg_loss / math.log(2)  # Convert nats to bits
    bits_per_byte = bits_per_char / 8  # Convert bits per char to bits per byte
    
    return avg_loss, bits_per_byte

@torch.no_grad()
def evaluate(model, criterion):
    model.eval()
    total_loss = 0
    total_perplexity = 0
    
    for i in range(Config.num_eval_batches):
        x, y = get_batch('val')
        logits = model(x)
        B, T, C = logits.shape
        loss = criterion(logits.view(-1, C), y.view(-1))
        perplexity = torch.exp(loss)
        
        total_loss += loss.item()
        total_perplexity += perplexity.item()
        
    avg_loss = total_loss / Config.num_eval_batches
    # avg_perplexity = total_perplexity / Config.num_eval_batches
    bits_per_char = avg_loss / math.log(2)  # Convert nats to bits
    bits_per_byte = bits_per_char / 8  # Convert bits per char to bits per byte
    
    return avg_loss, bits_per_byte

In [18]:
#torch.set_float32_matmul_precision("high")
model = ShakeLSTM(vocab_size=Config.vocab_size).to(Config.device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.learning_rate)
criterion = nn.CrossEntropyLoss()

best_val_loss = float('inf')
prompt = "To be or not to be that is the question"

for epoch in range(Config.max_epochs):
    t0 = time.time()
    train_loss, train_bpb = train_epoch(model, optimizer, criterion)
    val_loss, val_bpb = evaluate(model, criterion)  
    t1 = time.time()
    dt = t1 - t0

    tokens_processed = Config.batch_size * Config.block_size * Config.num_batches_per_epoch
    tokens_per_sec = tokens_processed / dt
    
    print(f'Epoch {epoch+1} took {dt*1000:.2f} ms: '
          f'train loss {train_loss:.4f}, '
          f'train bits/byte {train_bpb:.4f}, '
          f'val loss {val_loss:.4f}, '
          f'val bits/byte {val_bpb:.4f}, '
          f'tokens/sec: {tokens_per_sec:.2f}')
    
generated_text = generate(model, prompt, max_length=500, temperature=0.5)
print(f"\nGenerated Text:\n{generated_text}\n")
print("-"*50)

print("Training finished!")
print(f'Training time excluding generation: {(t1 - t0) * 1000} miliseconds')

Epoch 1 took 1885.98 ms: train loss 2.6069, train bits/byte 0.4701, val loss 2.0776, val bits/byte 0.3747, tokens/sec: 434364.16
Epoch 2 took 1274.57 ms: train loss 2.0076, train bits/byte 0.3620, val loss 1.8829, val bits/byte 0.3395, tokens/sec: 642724.89
Epoch 3 took 1139.59 ms: train loss 1.8373, train bits/byte 0.3313, val loss 1.7750, val bits/byte 0.3201, tokens/sec: 718856.97
Epoch 4 took 1143.90 ms: train loss 1.7366, train bits/byte 0.3132, val loss 1.7239, val bits/byte 0.3109, tokens/sec: 716147.76
Epoch 5 took 1223.54 ms: train loss 1.6714, train bits/byte 0.3014, val loss 1.6660, val bits/byte 0.3004, tokens/sec: 669533.67
Epoch 6 took 1557.08 ms: train loss 1.6248, train bits/byte 0.2930, val loss 1.6312, val bits/byte 0.2942, tokens/sec: 526111.76
Epoch 7 took 1154.43 ms: train loss 1.5852, train bits/byte 0.2859, val loss 1.6264, val bits/byte 0.2933, tokens/sec: 709612.27
Epoch 8 took 1118.24 ms: train loss 1.5539, train bits/byte 0.2802, val loss 1.6035, val bits/byt

The first time this notebook was ran, we got an average training time was around 1500 miliseconds, but by just using torch.compile we were able to reduce the training time by 20%. 

* By just using a A100 GPU, the training was around 50% faster than the optimized code for a T4, we cannot compare the performance. 
* We have another optimizations availiable for A100 but because our network is small the computational overhead would cause more harm than benefits.