In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from collections import Counter

In [15]:
#Hyperparameters (setting the embedding_dim to be a multiple of d_key is ideal)
embedding_dim = 384
max_token = 128
d_key = 64
n_decoder_layers = 6
n_head = embedding_dim // d_key
mini_batch_size = 64
num_epochs = 4000
decay = 0.01
learning_rate = 3e-4
PATIENCE = 10
drop_prob = 0.5
beta1 = 0.9
beta2 = 0.999
label_smoothing = 0.05
desired_vocab = 1000 #Final vocab after tokenization, set to 65 if you wish for no tokenization.

In [17]:
#BPE based tokenizer
def tokenize(data, initial_vocab, desired_vocab=1000, min_freq=2):

    # Map ids back to strings
    vocab = list(initial_vocab)
    itos = {i: vocab[i] for i in range(len(vocab))}
    seq = [itos[i] for i in data]  # work with strings

    while len(vocab) < desired_vocab:
        # count adjacent pairs
        pairs = Counter(zip(seq, seq[1:]))
        if not pairs:
            break

        # Get most common pair
        (a, b), freq = pairs.most_common(1)[0]
        if freq < min_freq:
            # Stop if even the best merge is too rare
            break

        new_token = a + b
        if new_token in vocab:
            # already in vocab, skip
            break

        # Merge occurrences in sequence
        new_seq = []
        i = 0
        while i < len(seq):
            if i < len(seq)-1 and seq[i] == a and seq[i+1] == b:
                new_seq.append(new_token)
                i += 2
            else:
                new_seq.append(seq[i])
                i += 1
        seq = new_seq
        vocab.append(new_token)

        if len(vocab) % 100 == 0:
            print(f"Current vocab size: {len(vocab)}")

    # Rebuild final mappings
    itos = {i: tok for i, tok in enumerate(vocab)}
    stoi = {tok: i for i, tok in itos.items()}

    data = [stoi[tok] for tok in seq]
    return data, vocab
#Train Dev Split (Random shuffling will be done later)
def train_dev_split(x,train,dev):
    N = x.shape[0]
    Ntr = int(N * train)
    Ndev = int(N*dev)

    data_tr = x[:Ntr]
    data_dev = x[Ntr:Ntr+Ndev]
    return data_tr,data_dev

#Final data preparation function
def prepare_data(path, max_token=8):
    #Read data, setup vocab (pre-tokenization), encoder and decoder functions
    data = open(path).read()
    vocab = sorted(list(set(data)))
    vocab_size = len(vocab)

    itos = {i:vocab[i] for i in range(vocab_size)}
    stoi = {v:k for k,v in itos.items()}

    encode = lambda inp: [stoi[i] for i in inp]
    decode = lambda inp: "".join([itos[i] for i in inp])
    #Tokenize, redo the vocab, encoder and decoder (post-tokenization)
    print("Tokenization start..")
    data,vocab = tokenize(encode(data), vocab)
    data = torch.tensor(data)
    print("Tokenization end.")
    vocab_size = len(vocab)

    itos = {i:vocab[i] for i in range(vocab_size)}
    stoi = {v:k for k,v in itos.items()}

    encode = lambda inp: [stoi[i] for i in inp]
    decode = lambda inp: "".join([itos[i] for i in inp])
    #Train dev split, arrange data into B,T,C shape, random shuffling of data.
    Batch_data = data.shape[0] // (max_token+1)

    ranges = torch.arange(Batch_data).view(Batch_data,1) + torch.arange((max_token + 1))
    data = data[ranges][torch.randperm(Batch_data)]

    data_tr,data_dev = train_dev_split(data,0.8,0.2)
    return {"data": data, "data_tr": data_tr, "data_dev": data_dev, "vocab": vocab, "encode": encode, "decode": decode}

In [18]:
#Simple function to create random batch from data
def create_batch(data, mini_batch_size = mini_batch_size):
    ind = torch.randperm(data.shape[0])
    shuffled_data = data[ind]
    ind = None
    return shuffled_data[:mini_batch_size]

In [None]:
#Data prep + Tokenization
device = "cuda" if torch.cuda.is_available() else "cpu"
full_data = prepare_data("input.txt", max_token)
data, data_tr, data_dev, vocab, encode, decode = full_data["data"].tolist(),full_data["data_tr"],full_data["data_dev"], full_data["vocab"], full_data["encode"], full_data["decode"]
data_tr = data_tr.to(device)
data_dev = data_dev.to(device)
vocab_size = len(vocab)

In [9]:
class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.keyM = nn.Linear(embedding_dim, d_key)
        self.queryM = nn.Linear(embedding_dim, d_key)
        self.valueM = nn.Linear(embedding_dim, d_key)
        self.dropout = nn.Dropout(p=drop_prob)
    def forward(self,x):
        Q = self.queryM(x)
        K = self.keyM(x)
        V = self.valueM(x)
        scores = (Q @ K.permute(0,2,1)) / torch.sqrt(torch.tensor(d_key,device=device))
        #Masking for attention
        inp = torch.ones(x.shape[1], x.shape[1], device=device)
        mask = torch.tril(inp).bool()
        scores = scores.masked_fill(~mask, float('-inf'))
        #Attention
        attn = self.dropout(torch.softmax(scores, dim=-1))
        return attn @ V
class MultiHeadedAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([Head() for i in range(n_head)])
        self.linear_layer = nn.Linear(n_head*d_key, embedding_dim)
    def forward(self,x):
        res = []
        for i in range(n_head):
            res.append(self.heads[i](x))
        return self.linear_layer(torch.cat(res, dim=-1))
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.pos_table = nn.Embedding(max_token,embedding_dim)
        self.dropout = nn.Dropout(p=drop_prob)
        self.decoder_layers = nn.ModuleList([MultiHeadedAttention() for i in range(n_decoder_layers)])
        self.feed_forward = nn.ModuleList([nn.Sequential(nn.Linear(embedding_dim,2*embedding_dim),nn.ReLU(),nn.Dropout(p=drop_prob),nn.Linear(2*embedding_dim,embedding_dim)) for i in range(n_decoder_layers)])
        self.final_linear = nn.Linear(embedding_dim,vocab_size, bias = False)
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
    def forward(self,x):
        emb = self.embedding_table(x) + self.pos_table(torch.arange(x.shape[1],device=device))
        out = self.dropout(emb)
        for i in range(n_decoder_layers):
            attn_out = self.decoder_layers[i](self.layer_norm1(out))
            out = out + self.dropout(attn_out)
            ff_out = self.feed_forward[i](self.layer_norm2(out))
            out = out + self.dropout(ff_out)
        out = self.dropout(out)
        return self.final_linear(out)
    def generate(self):
      T = 0.7
      K = 100
      out = ''
      current_context = [i.item() for i in torch.randint(low=0,high=vocab_size - 1,size=(1,))]
      while True:
        if len(current_context) > max_token:
          current_context = current_context[1:]
        logits = DecoderOnlyTransformer(torch.tensor([current_context],device=device))[0,-1] / T
        topk_logits, topk_indices = torch.topk(logits, K)
        probs = torch.softmax(topk_logits, dim = -1)
        #topK = sorted(probs)[-K:]
        pred = topk_indices[torch.multinomial(probs,num_samples = 1).item()].item()
        out = out + decode([pred])
        current_context.append(pred)
        print(decode([pred]), end="")

In [10]:
#Defining a function for warmup and cosine decay for more stable training near the end of the training phase.
warmup_steps = 2000
def lambda_lr(current_step):
  if current_step <= warmup_steps:
    return float(current_step) / float(max(1,warmup_steps))
  progress = float(current_step - warmup_steps) / float(max(1, num_epochs - warmup_steps))
  return 0.5 * (1.0 + math.cos(math.pi * progress))

In [11]:
#Initialization of Transformer, optimizer and optimization scheduler.
DecoderOnlyTransformer = Transformer().to(device)
optimizer = torch.optim.AdamW(lr=learning_rate, betas =(beta1, beta2), eps=1e-8, weight_decay = decay,params=DecoderOnlyTransformer.parameters())
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_lr)

In [None]:
#TRAINING BLOCK:
for i in range(num_epochs):
    DecoderOnlyTransformer.train()

    minibatch = create_batch(data_tr)
    Xbatch,Ybatch = minibatch[:,:max_token], minibatch[:,1:max_token+1]

    out = DecoderOnlyTransformer(Xbatch)
    B,T,C = out.shape
    out = out.view(B*T, C)
    Ybatch = Ybatch.reshape(-1)

    loss = F.cross_entropy(out,Ybatch,label_smoothing=label_smoothing)
    loss.backward()

    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

    if i % 100 == 0:
      print(f'Training loss: {loss}')

    DecoderOnlyTransformer.eval()
    with torch.inference_mode():
        dev_batch = create_batch(data_dev, 500)

        out = DecoderOnlyTransformer(dev_batch[:,:max_token])
        targets = dev_batch[:,1:max_token+1]

        B,T,C = out.shape
        out = out.view(B*T, C)
        targets = targets.reshape(-1)

        dev_loss = F.cross_entropy(out,targets,label_smoothing=label_smoothing)
        if i % 100 == 0:
          print(f'Dev loss: {dev_loss}')

        if i == 0:

          best = dev_loss
          patience = 0

        elif dev_loss < best:

          best = dev_loss
          torch.save(DecoderOnlyTransformer.state_dict(),"best.pt")
          patience = 0

        else:

          patience += 1
          if patience > PATIENCE and dev_loss - best >= 0.3:
            print("Early stopping.")
            break

In [None]:
DecoderOnlyTransformer.generate()