In [None]:
!pip install torch torchvision torchaudio


In [None]:
import pip
try:
  __import__("lightning")
except ImportError:
  pip.main(['install', "lightning"])

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

In [5]:
token_to_id = {
    'i': 0,
    'live': 1,
    'in': 2,
    'delhi': 3,
    'nice': 4,
    'study': 5,
    'NSUT': 6,
    'cool': 7,
    'like': 8,
    'to': 9,
    'play': 10,
    'football': 11,
    'awesome': 12,
    'want': 13,
    'visit': 14,
    'shimla': 15,
    'beautiful': 16,
    '<EOS>': 17
}
id_to_token = dict(map(reversed, token_to_id.items()))

inputs = torch.tensor([
    [0, 1, 2, 3, 17, 4, 17],         # i live in delhi <EOS> nice <EOS>
    [0, 5, 2, 6, 17, 7, 17],         # i study in NSUT <EOS> cool <EOS>
    [0, 8, 9, 10, 11, 17, 12],       # i like to play football <EOS> awesome
    [0, 13, 9, 14, 15, 17, 16]       # i want to visit shimla <EOS> beautiful
])

labels = torch.tensor([
    [1, 2, 3, 17, 4, 17, 17],
    [5, 2, 6, 17, 7, 17, 17],
    [8, 9, 10, 11, 17, 12, 17],
    [13, 9, 14, 15, 17, 16, 17]
])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)



In [6]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model=2, max_len=10):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = 1 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, word_embeddings):
        return word_embeddings + self.pe[:word_embeddings.size(0), :]


In [7]:
class Attention(nn.Module):
    def __init__(self, d_model=2):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.row_dim = 0
        self.col_dim = 1

    def forward(self, q_input, k_input, v_input, mask=None):
        q = self.W_q(q_input)
        k = self.W_k(k_input)
        v = self.W_v(v_input)
        sims = torch.matmul(q, k.transpose(self.row_dim, self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask, -1e9)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)
        return attention_scores


In [8]:
class DecoderOnlyTransformer(L.LightningModule):
    def __init__(self, num_tokens, d_model=2, max_len=10):
        super().__init__()
        L.seed_everything(42)
        self.we = nn.Embedding(num_tokens, d_model)
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)
        self.self_attention = Attention(d_model=d_model)
        self.fc_layer = nn.Linear(d_model, num_tokens)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, token_ids):
        word_embeddings = self.we(token_ids)
        position_encoded = self.pe(word_embeddings)
        mask = torch.tril(torch.ones((token_ids.size(0), token_ids.size(0)), device=self.device)) == 0
        attention_output = self.self_attention(position_encoded, position_encoded, position_encoded, mask)
        residual = position_encoded + attention_output
        output = self.fc_layer(residual)
        return output

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)

    def training_step(self, batch, batch_idx):
        input_tokens, labels = batch
        output = self.forward(input_tokens[0])
        loss = self.loss(output, labels[0])
        return loss


In [None]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=10)
trainer = L.Trainer(max_epochs=30, logger=False, enable_checkpointing=False)
trainer.fit(model, train_dataloaders=dataloader)


In [11]:
# You can try any of the 4 inputs here:
model_input = torch.tensor([token_to_id["i"], token_to_id["live"], token_to_id["in"], token_to_id["delhi"], token_to_id["<EOS>"]])

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1])])
predicted_ids = predicted_id

for _ in range(len(model_input), 10):
    if predicted_id == token_to_id["<EOS>"]:
        break
    model_input = torch.cat((model_input, predicted_id))
    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))

print("Predicted Tokens:\n")
for id in predicted_ids:
    print("\t", id_to_token[id.item()])


Predicted Tokens:

	 cool
	 <EOS>
