In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import tqdm.notebook as tqdm
import lightning.pytorch as pl
import pandas as pd
import random


In [2]:
with open("data/tiny-shakespeare.txt") as f:
    text = f.read()


In [3]:
# chunk the text into sequences of length seq_length
seq_length = 256
sequences = []
for i in range(0, len(text) - seq_length, seq_length):
    sequences.append(text[i:i + seq_length])


In [4]:
print(len(sequences), sequences[0])


4357 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:



In [5]:
@torch.no_grad()
def encode(string: str) -> torch.Tensor:
    string = string.encode('utf-8')
    return torch.as_tensor([int(c) for c in string])

@torch.no_grad()
def decode(arr: torch.Tensor) -> str:
    arr = arr.tolist()
    return ''.join([chr(c) for c in arr])


In [6]:
print(encode('hello'))
print(decode(encode('hello')))


tensor([104, 101, 108, 108, 111])
hello


In [7]:
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super(Model, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # LSTM layer
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        
        # Fully connected layer to predict each character
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # Embedding
        x = self.embedding(x)
        
        # Initialize hidden state and cell state if not provided
        if hidden is None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim, device=x.device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim, device=x.device)
            hidden = (h0, c0)
        
        # LSTM output along with new hidden state
        out, hidden = self.lstm(x, hidden)
        
        # Reshape output for the fully connected layer
        out = out.reshape(-1, self.hidden_dim)
        out = self.fc(out)
        return out, hidden

# Create an instance of the updated model
vocab_size = 512  # number of unique characters
embed_dim = 128   # embedding dimension
hidden_dim = 256  # LSTM hidden dimensions
num_layers = 4  # number of LSTM layers

model = Model(vocab_size, embed_dim, hidden_dim, num_layers)
print(f"Model uses {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Model uses {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")


Model uses 2,171,392 parameters
Model uses 2,171,392 trainable parameters


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002)


In [None]:
import torch.nn.functional as F

def temperature_sampling(logits, temperature=0.7):
    # Scale logits by temperature
    scaled_logits = logits / temperature
    # Convert logits to probabilities
    probs = F.softmax(scaled_logits, dim=-1)
    # Sample from the probabilities
    return torch.multinomial(probs, num_samples=1)

# Validation Loop with Temperature Sampling
model.eval()
initial_input = encode('First Citizen:').unsqueeze(0)
generated_text = []
hidden = None  # Hidden state initialization

with torch.no_grad():
    for _ in range(512):  # Generate 512 characters
        output, hidden = model(initial_input, hidden)  # Ensure model accepts and returns hidden state
        predicted = temperature_sampling(output[-1], temperature=0.8)
        generated_text.append(predicted)
        initial_input = predicted.unsqueeze(0)

print("First Citizen:"+decode(torch.stack(generated_text).flatten()))
print(f"Text Length: {len(decode(torch.stack(generated_text).flatten()))}")


ĀġƘiŮ°ŬÎŽĲŽÕ|ƿo2Ɠĳ\Ň»fǲcǃƭĴÅƇŵ±ǱÅǇoĤŧǢąošǇ/ůǶÆf.ƓČƅĮŉėđ£^Ŷœęƭ ǮĿï!=ŻľÎ*ǥ£)ıǅƸǴşőơVǰ:ť­ĄƥǙƸƴňƵ0{ÐƔ¯RǷƓ¦AĎYB.sǜ§ƂǾƩadġǉďOŐ²ãÝZİĊǪ¬ƨnǂDpǓƾŹíƴpƮāǵNƓćƓgðõƫǤĿŉOƦŏ[âršSõ­yļƳā>ĈðNǀÛƄìǬǢKėtoĕ(ŚĎÕŀŬƆđ½kĐĭ®^ƹƨźųƓũǯÍşǄàǨǫ	ĳfƩǬį #ƜžƌǨÕÚúNŮŖåëŻìţě¨Ï¢"#ÖTøðċƁǭƖľkǖM÷ƌ4ŕÃşè"ŹŅŘÈ©Ǎď	ǳāŨƪÉ'īƨí¤ń	ƽ×ƽ5ƔǢǭ,ŖĘƠǟÈŹĉǕŇċŕĎĐ"Ǡ÷g¬ĹĮŮŵƔŬǌǅōƜ ęá}ǳƌÚśfļŏ¡ă¥ĒǋĺmšƧƤ´uâkŃĈĲĤƐ	LĆơæ½ŰØǘŃĻĞǄhƊüŤǑƝƑǈ¡ňĝ®×ìŨǷdëØ/ǹŝPÉĐǘý¶Xŀżç
Text Length: 512


In [10]:
# Assume model, optimizer, criterion, and encode function are defined.
epochs = 20
for epoch in range(epochs):
    model.train()
    total_loss = 0
    with tqdm.tqdm(enumerate(sequences), total=len(sequences)) as pbar:
        for i, seq in pbar:
            # Prepare input and target tensors
            seq = encode(seq) -1
            total_loss = torch.tensor(0.)
            hidden = None
            
            for j in tqdm.tqdm(range(len(seq) - 1), leave=True, desc="Sequence"):
                inputs = seq[:j]
                targets = seq[j+1].unsqueeze(0)
                
                inputs = F.pad(inputs, (seq_length - len(inputs), 0))
                
                # Forward pass
                # print("inputs", inputs.shape)
                # print("targets", targets.shape)
                
                output, hidden = model(inputs.unsqueeze(0), hidden)
                # print("output", output.shape, output.requires_grad)
                
                # predicted_token = output[0, -1].unsqueeze(0)
                # print("predicted_token", predicted_token.shape, targets.shape)
                
                
                # print(predicted_token.requires_grad, targets.requires_grad)
                loss = criterion(output[-1].unsqueeze(0), targets.long())
                total_loss += loss
            
            optimizer.zero_grad()
            total_loss.backward()    
            optimizer.step()
            
            pbar.set_postfix_str(f"Loss: {loss.item()}, Total loss: {total_loss.item()}")
                
    break

  0%|          | 0/4357 [00:00<?, ?it/s]

Sequence:   0%|          | 0/255 [00:00<?, ?it/s]

Sequence:   0%|          | 0/255 [00:00<?, ?it/s]

Sequence:   0%|          | 0/255 [00:00<?, ?it/s]

Sequence:   0%|          | 0/255 [00:00<?, ?it/s]

: 