In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as d_utils

class PosEncoding(nn.Module):
    def __init__(self,embedding_dim,max_len):
        super().__init__()
        self.pe_map = torch.zeros((max_len,embedding_dim))

        pos = torch.arange(0,max_len).to(torch.float32).unsqueeze(1)
        
        embedding_index = torch.arange(0,embedding_dim).to(torch.float32)*2

        self.pe_map[:,::2] = torch.sin(pos/(torch.tensor(10000)**(embedding_index[::2]/embedding_dim)))
        self.pe_map[:,1::2] = torch.cos(pos*(torch.tensor(10000)**(embedding_index[1::2]/embedding_dim)))
    
    def forward(self,word_embedding,batch=False):
        return word_embedding + self.pe_map[:word_embedding.shape[batch+0]]

class Attention(nn.Module):
    def __init__(self,embedding_dim,mask=None,batch=False):
        super().__init__()
        self.mask = mask
        self.embedding_dim = embedding_dim
        self.q = nn.Linear(embedding_dim,embedding_dim)
        self.v = nn.Linear(embedding_dim,embedding_dim)
        self.k = nn.Linear(embedding_dim,embedding_dim)
        self.batch = batch
        self.row_dim = 0+batch
        self.col_dim = 1+batch

    def get_qvk(self,q,v,k):
        q = self.q(q)
        v = self.v(v)
        k = self.k(k)
        return q,v,k

    def forward(self,q,v,k):
        q,v,k = self.get_qvk(q,v,k)

        sims = torch.matmul(q,k.transpose(self.row_dim,self.col_dim))

        if self.mask:
            if self.batch:
                mask = torch.ones((q.shape[0],q.shape[1],q.shape[1]))
            else:
                mask = torch.ones((q.shape[0],q.shape[0]))
            
            mask = torch.triu(mask,1)
            mask = mask == 1

            sims.masked_fill_(mask,-1e9)
        
        sims = F.softmax(sims,self.col_dim)

        return torch.matmul(sims,v)

class DecoderOnlyTransformer(nn.Module):
    def __init__(self,embedding_dim,max_len,vocab_size,batch=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim)
        self.pos_encoding = PosEncoding(embedding_dim,max_len)
        self.attention = Attention(embedding_dim,True,batch)
        self.fc = nn.Linear(embedding_dim,vocab_size)
        self.batch = batch

    def forward(self,x):
        x = self.embedding(x)
        pos_encoded = self.pos_encoding(x)
        attentioned = self.attention(pos_encoded,pos_encoded,pos_encoded)
        x = attentioned + pos_encoded
        x = self.fc(x)
        return x

token_to_id = {"what" : 0,
               "is" : 1,
               "ege" : 2,
               "mal": 3,
               "<EOS>" : 4,
              }

id_to_token = dict(map(reversed, token_to_id.items()))

inputs = torch.tensor([[token_to_id["what"],
                        token_to_id["is"], 
                        token_to_id["ege"], 
                        token_to_id["<EOS>"],
                        token_to_id["mal"]], 
                       
                       [token_to_id["mal"],
                        token_to_id["is"], 
                        token_to_id["what"], 
                        token_to_id["<EOS>"], 
                        token_to_id["ege"]]])

labels = torch.tensor([[token_to_id["is"], 
                        token_to_id["ege"], 
                        token_to_id["<EOS>"], 
                        token_to_id["mal"], 
                        token_to_id["<EOS>"]],  
                       
                       [token_to_id["is"], 
                        token_to_id["what"], 
                        token_to_id["<EOS>"], 
                        token_to_id["mal"], 
                        token_to_id["<EOS>"]]])

dataset = d_utils.TensorDataset(inputs, labels) 
dataloader = d_utils.DataLoader(dataset)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

t = DecoderOnlyTransformer(5,100,len(token_to_id),True).to(device)

epochs = 1500
lr = 0.0004

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(t.parameters(),lr)

for epoch in range(1,epochs+1):
    for data,label in dataloader:
        data,label = data.to(device),label.to(device)
        loss = criterion(t(data),label)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(loss)

In [None]:
def prompt(prompt):
    prompt = prompt.lower().split()
    prompt.append("<EOS>")
    prompt = [token_to_id[token] for token in prompt]
    
    input_ = prompt
    final_prediction = ""
    len_ = len(prompt)
    while (token:=id_to_token[id:=torch.max(t(torch.tensor([input_]))[:,-1],1)[1].item()]) != "<EOS>" and len_ <= 100:
        final_prediction += " "+token
        input_.append(id)
        len_+=1
    return final_prediction