In [10]:
device="mps"


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import tqdm.notebook as tqdm
import lightning.pytorch as pl
import pandas as pd
import random


In [12]:
with open("data/tiny-shakespeare.txt") as f:
    text = f.read()


In [13]:
# chunk the text into sequences of length seq_length
seq_length = 64
sequences = []
for i in range(0, len(text) - seq_length, seq_length):
    sequences.append(text[i:i + seq_length])


In [14]:
print(len(sequences), sequences[0])


17428 First Citizen:
Before we proceed any further, hear me speak.

Al


In [15]:
@torch.no_grad()
def encode(string: str) -> torch.Tensor:
    string = string.encode('utf-8')
    return torch.as_tensor([int(c) for c in string])

@torch.no_grad()
def decode(arr: torch.Tensor) -> str:
    arr = arr.tolist()
    return ''.join([chr(c) for c in arr])


In [16]:
print(encode('hello'))
print(decode(encode('hello')))


tensor([104, 101, 108, 108, 111])
hello


In [17]:
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super(Model, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # LSTM layer
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        
        # Fully connected layer to predict each character
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # Embedding
        x = self.embedding(x)
        
        # Initialize hidden state and cell state if not provided
        if hidden is None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim, device=x.device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim, device=x.device)
            hidden = (h0, c0)
        
        # LSTM output along with new hidden state
        out, hidden = self.lstm(x, hidden)
        
        # Reshape output for the fully connected layer
        out = out.reshape(-1, self.hidden_dim)
        out = self.fc(out)
        return out, hidden

# Create an instance of the updated model
vocab_size = 256  # number of unique characters
embed_dim = 32   # embedding dimension
hidden_dim = 64  # LSTM hidden dimensions
num_layers = 4  # number of LSTM layers

model = Model(vocab_size, embed_dim, hidden_dim, num_layers).to(device)
print(f"Model uses {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Model uses {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")


Model uses 149,760 parameters
Model uses 149,760 trainable parameters


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002)


In [28]:
import torch.nn.functional as F

def temperature_sampling(logits, temperature=0.7):
    # Scale logits by temperature
    scaled_logits = logits / temperature
    # Convert logits to probabilities
    probs = F.softmax(scaled_logits, dim=-1)
    # Sample from the probabilities
    return torch.multinomial(probs, num_samples=1)

# Validation Loop with Temperature Sampling
model.eval()
initial_input = encode('First Citizen:').unsqueeze(0).to(device)
generated_text = []
hidden = None  # Hidden state initialization

with torch.no_grad():
    for _ in range(512):  # Generate 512 characters
        output, hidden = model(initial_input, hidden)  # Ensure model accepts and returns hidden state
        predicted = temperature_sampling(output[-1], temperature=0.8)
        generated_text.append(predicted)
        initial_input = predicted.unsqueeze(0)

print("First Citizen:"+decode(torch.stack(generated_text).flatten()))
print(f"Text Length: {len(decode(torch.stack(generated_text).flatten()))}")


F,>èÛóÏêE-!¹È\>L2x÷>Ðýy­{_ÇÄá_^K µxÓlöLõuMn	àw	4#xrO5µÍÉ&lósÑw8­OÓ;:ic0b¡]dÒ7B¡êx"D5XàOÑ²ô]ÇòCd6u¨Ü8¯-ây<ÄÉÌtw5éãN·ºÅ-|»o©Ô )ûäÛåÍúK{Ü@åmmî¹ cb ¬Ç¢6Ä [m¿+§dÈ­X¦¿3«¡ÂâL D­X<Ç8ò
*>¨3¸f¬FÃ$Â-ÍµÂç¶¤UHBþgÃå22ªPNË6á¤5.èªâ2ÎTRcù°:¯-{]OQù
Text Length: 512


In [29]:
# Assume model, optimizer, criterion, and encode function are defined.
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    with tqdm.tqdm(sequences, total=len(sequences)) as pbar:
        for seq in pbar:
            # Prepare input and target tensors
            seq = encode(seq) -1
            seq = seq.to(device)
            total_loss = torch.tensor(0.0, device=device)
            hidden = None
            
            for j in range(len(seq) - 1):
                inputs = seq[:j]
                targets = seq[j+1].unsqueeze(0)
                
                inputs = F.pad(inputs, (seq_length - len(inputs), 0))
                
                # Forward pass
                # print("inputs", inputs.shape)
                # print("targets", targets.shape)
                
                output, hidden = model(inputs.unsqueeze(0), hidden)
                # print("output", output.shape, output.requires_grad)
                
                # predicted_token = output[0, -1].unsqueeze(0)
                # print("predicted_token", predicted_token.shape, targets.shape)
                
                
                # print(predicted_token.requires_grad, targets.requires_grad)
                loss = criterion(output[-1].unsqueeze(0), targets.long())
                total_loss += loss
            
            optimizer.zero_grad()
            total_loss.backward()    
            optimizer.step()
            
            print(total_loss.item())
            

  0%|          | 0/17428 [00:00<?, ?it/s]

347.7358703613281
347.0220947265625
346.5335388183594
348.0096435546875
346.965087890625
347.3943176269531
346.5918273925781
346.4958801269531
347.988037109375
348.7137451171875
348.5155029296875
347.6702880859375
347.41656494140625
347.2978820800781
347.4470520019531
347.2378234863281
347.3324279785156
347.0179443359375
347.30950927734375
347.92431640625
347.6527404785156
347.5223083496094
347.9309387207031
346.97607421875
348.12542724609375
347.3868103027344
347.8465270996094
346.6805114746094
348.6714782714844
347.8541259765625
347.447998046875
346.7733459472656
348.15667724609375
347.1660461425781
347.85626220703125
347.93450927734375
347.3390197753906
347.4766845703125
347.64996337890625
347.5285949707031
347.6645202636719
346.5856018066406
347.42620849609375
348.13470458984375
347.455322265625
347.1993103027344
347.7934265136719
347.6787109375
348.0496520996094
347.74481201171875
347.9263610839844
347.290283203125
347.69720458984375
347.58428955078125
348.5519104003906
347.040771