In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests

In [2]:
#download and preprocess shakespeare text
URL='https://gist.githubusercontent.com/CarineRam/817c25781a9ca8dc3370a190e31ab5e5/raw/ff02e4b6b3715295143846aaa896cc89f408cd67/gistfile1.txt'
response = requests.get(URL)
text = response.text

In [3]:
#create character to integer mapping and vice versa
chars= sorted(set(text))
vocab_size= len(chars)
char_to_int ={ch: i for i, ch in enumerate(chars)} #abc => {'a':0, 'b':1, 'c':2}
int_to_char = {i: ch for i, ch in enumerate(chars)} #abc => {0:'a', 1:'b', 2:'c'}

In [4]:
#encode and decode functions
def encode(s):
    return [char_to_int[c] for c in s if c in char_to_int]


In [5]:
def decode(indices):
    return ''.join([int_to_char[i] for i in indices])

In [6]:
#create batches of data
def get_batch(split, batch_size=32, block_size=128):

    n=len(text)
    split_idx= int(0.5* n)
    if split== 'train':
        data= text [:split_idx]
    else:
        data= text [split_idx:]
    ix= torch.randint(len(data) - block_size-1, (batch_size,))
    x_batch = torch.stack([torch.tensor(encode(data[i:i+block_size]), dtype=torch.long) for i in ix])

    y_batch = torch.stack([torch.tensor(encode(data[i+1:i+block_size+1]), dtype=torch.long) for i in ix])

    return x_batch, y_batch

In [7]:
#define the model architecture
vocab_size = 50257
n_emb= 32
block_size= 128
head_size= 8
num_heads = 4
num_blocks = 2

In [8]:
#single attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size)
        self.query = nn.Linear(n_emb, head_size)
        self.value = nn.Linear(n_emb, head_size)
        self.dropout = nn.Dropout(0.1)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        weights = (q @ k.transpose(-2, -1)) / (C ** 0.5)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        return weights @ v


In [9]:
#multiple attention heads
'''
class MultipleAttention(nn.Module):
    def __init__ (self, head_size, num_heads):
        super().__init__ ()
        self.heads = nn.ModuleList([Head(head_size) for _ in range (head_heads)])
        self.proj= nn.Linear(n_emb, n_emb)
        self.dropout= nn.Dropout(0.1)
        
    def forward(self, x):
        out=torch.cat([h(v) for h in self.heads], dim=1)
        return self.dropout(self.proj(out))
'''

    
class MultipleAttention(nn.Module):
    def __init__(self, head_size, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])  # Fix here
        self.proj = nn.Linear(n_emb, n_emb)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)  # Use `x` instead of `v`
        return self.dropout(self.proj(out))


In [10]:
#wise feedforward network
   
class FeedForward(nn.Module):
    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb),
            nn.ReLU(),  # Fixed typo
            nn.Linear(4 * n_emb, n_emb),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        return self.net(x)

In [11]:
class Block(nn.Module): #continue from before
    def __init__ (self, n_emb, num_heads):
        super().__init__ ()
        head_size = n_emb// num_heads
        self.sa= MultipleAttention(head_size, num_heads)
        self.ff= FeedForward(n_emb)
        self.ln1= nn.LayerNorm(n_emb)
        self.ln2= nn.LayerNorm(n_emb)
        
    def forward(self,x):
        x = x +self.sa(self.ln1(x))
        x = x +self.ff(self.ln2(x))
        return x

In [12]:
class TextGenerator(nn.Module): #mine
    def __init__ (self):
        super().__init__ ()
        self.token_emb= nn.Embedding(vocab_size, n_emb)
        self.pos_emb= nn.Embedding(block_size, n_emb)
        self.blocks= nn.Sequential(*[Block(n_emb, num_heads) for _ in range(num_blocks)])
        self.ln_f= nn.LayerNorm(n_emb)
        self.head =nn.Linear(n_emb, vocab_size)
        
       
        
    def forward(self, x, y=None):
        B, T = x.shape
        tok_emb=self.token_emb(x)
        pos_emb= self.pos_emb(torch.arange(T, device= x.device))
        x= tok_emb+ pos_emb
        x = self.blocks(x)
        x= self.ln_f(x)
        logits = self.head(x)
        
        if y is None:
            loss= None
        else:
            loss = F.cross_entropy(logits.view( B*T,vocab_size), y.view(B*T))
        return logits, loss

In [13]:
class TextGenerator(nn.Module):
    def __init__(self, vocab_size, n_emb, block_size, num_heads, num_blocks):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_emb)
        self.pos_emb = nn.Embedding(block_size, n_emb)
        self.blocks = nn.Sequential(*[Block(n_emb, num_heads) for _ in range(num_blocks)])
        self.ln_f = nn.LayerNorm(n_emb)
        self.head = nn.Linear(n_emb, vocab_size)
        
    def forward(self, x, y=None):
        B, T = x.shape
        tok_emb = self.token_emb(x)
        pos_emb = self.pos_emb(torch.arange(T, device=x.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        
        if y is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(B * T, vocab_size), y.view(B * T))
        return logits, loss


In [None]:
# Training setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define model parameters
vocab_size = 50257
n_emb = 256
block_size = 128
num_heads = 8
num_blocks = 6

# Instantiate the model
model = TextGenerator(vocab_size, n_emb, block_size, num_heads, num_blocks).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Training loop
def train():
    for step in range(1000):
        x, y = get_batch('train')  # Assuming `get_batch` is defined elsewhere
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits, loss = model(x, y)
        loss.backward()
        optimizer.step()
        if step % 100 == 0:
            print(f'Step {step}, Loss: {loss.item()}')

train()


Step 0, Loss: 10.938119888305664


In [None]:
#generate text
model.eval()
start_seq= 'the cat is cute'
x = torch.tensor(encode(start_seq), dtype= torch.long).unsqueeze(0).to(device) #(1, T)
generated = model.generate(x, max_new_tokens=50)
print(decode(generated[0].tolist()))
