In [1]:
!pip install transformers datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device = {device}")

Device = cuda


In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, attn_dim, num_head):
        super(MultiHeadAttention, self).__init__()
        assert attn_dim %  num_head == 0, "attn_dim should be divisible by num_head"
        self.attn_dim = attn_dim
        self.num_head = num_head
        self.qkv_proj = nn.Linear(embed_dim, 3*attn_dim)
        self.out_proj = nn.Linear(attn_dim, embed_dim)

    def forward(self, X, mask=None):
        bs, seq_len, embed_dim = X.shape
        head_dim = self.attn_dim//self.num_head
        QKV = self.qkv_proj(X).view(bs, seq_len, self.num_head, 3*head_dim).transpose(1,2)
        Q, K, V = torch.chunk(QKV, chunks=3, dim=-1) # bs, #head, seq_len, head_dim
        attn_score = torch.matmul(Q, K.transpose(-2,-1))/(head_dim**0.5) # bs, #head, seq_len, seq_len
        if mask is not None:
            masking = torch.tril(torch.ones(seq_len, seq_len).unsqueeze(0).unsqueeze(0)).to(X.device) # 1, 1, seq_len, seq_len
            attn_score = attn_score.masked_fill(masking==0, float("-inf")) # bs, #head, seq_len, seq_len
        attn_score = F.softmax(attn_score, dim=-1)
        output = torch.matmul(attn_score, V) # bs, #head, seq_len, head_dim
        concated_output = output.transpose(1,2).contiguous().view(bs, seq_len, self.attn_dim)
        return self.out_proj(concated_output)

class PositionalEmbedding(nn.Module):
    def __init__(self, embed_dim, max_len=512):
        super(PositionalEmbedding, self).__init__()
        self.pos_embed = nn.Embedding(max_len, embed_dim)

    def forward(self, X):
        bs, seq_len, embed_dim = X.shape
        positions = torch.arange(0, seq_len, device=X.device).unsqueeze(0) # 1, seq_len
        pos_embed = self.pos_embed(positions) # 1, seq_len, embed_dim
        return X + pos_embed

class FeedForwardLayer(nn.Module):
    def __init__(self, embed_dim):
        super(FeedForwardLayer, self).__init__()
        self.up_projection = nn.Linear(embed_dim, 3*embed_dim)
        self.relu = nn.ReLU()
        self.down_projection = nn.Linear(3*embed_dim, embed_dim)

    def forward(self, X):
        return self.down_projection(self.relu(self.up_projection(X)))

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, attn_dim, num_head):
        super(TransformerBlock, self).__init__()
        self.attn_norm = nn.LayerNorm(embed_dim)
        self.multi_head_attn = MultiHeadAttention(embed_dim, attn_dim, num_head)
        self.ffn = FeedForwardLayer(embed_dim)
        self.ffn_norm = nn.LayerNorm(embed_dim)

    def forward(self, X, mask=None):
        residual = X
        X = self.attn_norm(X)
        X = residual + self.multi_head_attn(X, mask)

        residual = X
        X = self.ffn_norm(X)
        X = residual + self.ffn(X)
        return X

class Decoder(nn.Module): # masked-multi-head and ffn_layer
    def __init__(self, vocab_size, embed_dim, attn_dim, num_head, num_blocks):
        super(Decoder, self).__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim) # vocab_size, embed_dim
        self.positional_embedding = PositionalEmbedding(embed_dim)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(embed_dim, attn_dim, num_head) for _ in range(num_blocks)])
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, input_ids, mask=None):
        X = self.token_embed(input_ids) # bs, seq_len, embed_dim
        X = self.positional_embedding(X)
        for blocks in self.transformer_blocks:
            X = blocks(X, mask=mask) # bs, seq_len, embed_dim
        logits = self.lm_head(X) # bs, seq_len, vocab_size
        return logits

In [5]:
class PrepareDataset(Dataset):
    def __init__(self, input_ids, seq_len):
        super(PrepareDataset, self).__init__()
        self.samples = []
        for ids in input_ids:
            if len(ids)<2:
                continue
            for i in range(1, len(ids)):
                input_seq = ids[:i]
                target_seq = ids[1:i+1]
                # if PAD is needed
                if len(input_seq) < seq_len:
                    pad_len = seq_len - len(input_seq)
                    input_seq = torch.cat([input_seq, torch.zeros(pad_len, dtype=torch.long)], dim=-1)
                    target_seq = torch.cat([target_seq, torch.zeros(pad_len, dtype=torch.long)], dim=-1)
                else:
                    input_seq = input_seq[-seq_len:] # trim extra sequence
                    target_seq = target_seq[-seq_len:] # trim extra sequence
                self.samples.append((input_seq, target_seq))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
# sentences = [
#     "hello world",
#     "hello there",
#     "how are you",
#     "good morning",
#     "good night",
#     "see you soon",
#     "have a nice day",
#     "what are you doing",
#     "where are you going",
#     "thank you very much",
#     "with great power comes great responsibility",
# ]

# class PrepareVocabulary:
#     def __init__(self, sample_sentence):
#         self.sentence = sample_sentence

#     def get_ids(self):
#         all_words = " ".join(self.sentence)
#         words_list = all_words.split(' ')
#         unique_words = list(set(words_list))
#         unique_words.sort()
#         token_to_id = {unique_words[i]:i for i in range(len(unique_words))}
#         id_to_token = {i:unique_words[i] for i in range(len(unique_words))}
#         vocab_size = len(unique_words)
#         return vocab_size, token_to_id, id_to_token

In [6]:
def train_decoder(model, train_data):
    print("#"*10, " Model Training ", "#"*10)
    model.train()
    epochs = 5
    lr = 1e-4
    optimiser = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(1,epochs+1):
        batch_loss = 0
        for input_seq, target_seq in train_data:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            pred_seq = model(input_seq, mask=True) # bs, seq_len, vocab_size
            pred_seq = pred_seq.reshape(-1, pred_seq.size(-1)) # bs * seq_len, vocab_size
            target_seq = target_seq.view(-1) # bs * seq_len
            loss = criterion(pred_seq, target_seq)
            batch_loss += loss.item()
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        print(f"For epoch : {epoch}/{epochs}, training error: {batch_loss/len(train_data)}")
    print("#"*30)

In [7]:
def test_decoder(model, input_ids, id_to_token=None, tokenizer=None):
    print("#"*10, " Model Evaluation ", "#"*10)
    model.eval()
    logits = model(input_ids, mask=True) # bs, seq_len, vocab_size
    next_token_logits = logits[:, -1, :] # 1, vocab_size
    next_token_prob = F.softmax(next_token_logits, dim=-1)
    # predicted_id = torch.argmax(next_token_prob, dim=-1).item()
    topk_probs, topk_ids = torch.topk(next_token_prob, k=3)
    print(f"Input tokens : {input}")
    for i in range(3):
        if id_to_token is not None:
            print(f"Next Prediction : {id_to_token[topk_ids[0,i].item()]} and its Prob: {topk_probs[0,i].item()}")
        else:
            print(f"Next Prediction : {tokenizer.decode([topk_ids[0,i].item()])} and its Prob: {topk_probs[0,i].item()}")
    print("#"*30)

In [20]:
# vocabulary = PrepareVocabulary(sample_sentences)
# vocab_size, token_to_id, id_to_token = vocabulary.get_ids()
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# text = " ".join(input)
# encoded_input = tokenizer(text, return_tensors="pt")
# input_ids = encoded_input.input_ids
# print("input ids :", input_ids)
vocab_size = tokenizer.vocab_size
dataset = load_dataset("ag_news", split="train[:50]")  # Try a smaller subset first
sentences = [x['text'] for x in dataset]
print(sentences[0])

Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.


In [11]:
embed_dim = 512
attn_dim = 512
num_head = 16
decoder_layers = 10

model = Decoder(vocab_size, embed_dim, attn_dim, num_head, decoder_layers).to(device)

In [12]:
print(model)

Decoder(
  (token_embed): Embedding(50257, 512)
  (positional_embedding): PositionalEmbedding(
    (pos_embed): Embedding(512, 512)
  )
  (transformer_blocks): ModuleList(
    (0-9): 10 x TransformerBlock(
      (attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (multi_head_attn): MultiHeadAttention(
        (qkv_proj): Linear(in_features=512, out_features=1536, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (ffn): FeedForwardLayer(
        (up_projection): Linear(in_features=512, out_features=1536, bias=True)
        (relu): ReLU()
        (down_projection): Linear(in_features=1536, out_features=512, bias=True)
      )
      (ffn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=512, out_features=50257, bias=True)
)


In [21]:
input_ids = []
for sentence in sentences:
    # splitted = sentence.split()
    # input_ids.append(torch.tensor([token_to_id[token] for token in splitted], dtype=torch.long)) # len(sample_sentences), seq_len
    input_id = tokenizer(sentence, return_tensors="pt").input_ids.squeeze(0)
    input_ids.append(input_id)
batch_size = 4
seq_len = 32
train_dataset = PrepareDataset(input_ids, seq_len)
train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print("Train data size: ",len(train_data))

Train data size:  572


In [22]:
train_decoder(model, train_data)

##########  Model Training  ##########
For epoch : 1/5, training error: 0.2628570656233933
For epoch : 2/5, training error: 0.07091369062290287
For epoch : 3/5, training error: 0.05843537144048585
For epoch : 4/5, training error: 0.051323041747653064
For epoch : 5/5, training error: 0.0511371527156714
##############################


In [23]:
input = "Russian Alien"
# input_ids = torch.tensor([token_to_id[word] for word in input.split()], dtype=torch.long).unsqueeze(0) # 1, seq_len
input_ids = tokenizer(input, return_tensors="pt").input_ids
test_decoder(model, input_ids.to(device), None, tokenizer)

##########  Model Evaluation  ##########
Input tokens : Russian Alien
Next Prediction : , and its Prob: 0.6040710806846619
Next Prediction :  ( and its Prob: 0.14017179608345032
Next Prediction :  everything and its Prob: 0.05243949964642525
##############################


In [25]:
input = "What's in a Name?"
input_ids = tokenizer(input, return_tensors="pt").input_ids
test_decoder(model, input_ids.to(device), None, tokenizer)

##########  Model Evaluation  ##########
Input tokens : What's in a Name?
Next Prediction :  More and its Prob: 0.21408550441265106
Next Prediction :  \ and its Prob: 0.18890926241874695
Next Prediction :  The and its Prob: 0.1826314926147461
##############################
