In [19]:
import torch 
import torch.nn as nn 
import math 
import torch.nn.functional as F 
device = "mps"

In [9]:
def calculate_masked_attention(
        values: torch.Tensor,
        keys: torch.Tensor,
        query: torch.Tensor,
        mask: torch.Tensor = None 
): 
    attention_scores = torch.matmul(query, keys.transpose(-2,-1)) 
    attention_scores = attention_scores / math.sqrt(keys.shape[-1])
    if mask is not None: 
        attention_scores = torch.where(mask == 0, torch.tensor(-1e9), attention_scores) 
    attention_scores = F.softmax(attention_scores, dim=-1) 
    attention = torch.matmul(attention_scores, values) 
    return attention, attention_scores

In [29]:
class FeedForward(nn.Module): 
    def __init__(self, embed_size: int):
        super().__init__()
        self.layer1 = nn.Linear(embed_size, embed_size)
        self.layer2 = nn.Linear(embed_size, embed_size) 
    def forward(self, x): 
        x = self.layer1(x) 
        x = F.gelu(x)
        x = self.layer2(x) 
        return x 
class AttentionLayer(nn.Module): 
    def __init__(self,embed_size: int):
        super().__init__()
        self.embed_size = embed_size 
        self.query_dense = nn.Linear(embed_size, embed_size) 
        self.key_dense = nn.Linear(embed_size, embed_size) 
        self.value_dense = nn.Linear(embed_size, embed_size)
        self.output_dense = nn.Linear(embed_size, embed_size) 
    def forward(self, embeddings: torch.Tensor): 
        batch_size = embeddings.shape[0] 
        seq_length = embeddings.shape[1] 
        query = self.query_dense(embeddings)
        key = self.key_dense(embeddings)
        value = self.value_dense(embeddings)
        right_triangular_mask = torch.tril(torch.ones((1, seq_length, seq_length))).to(embeddings.device)
        attention, attention_scores = calculate_masked_attention(value, key, query, right_triangular_mask) 
        return attention, attention_scores

In [39]:
class TransformerBlock(nn.Module): 
    def __init__(self, embed_size: int):
        super().__init__()
        self.attention_layer = AttentionLayer(embed_size) 
        self.feed_forward = FeedForward(embed_size) 
        self.layer_norm1 = nn.LayerNorm(embed_size) 
    def forward(self, x: torch.Tensor): 
        context, attention_scores = self.attention_layer(x)
        context = self.layer_norm1(context) 
        context = self.feed_forward(context) 
        context = F.gelu(context) 
        output = context + x 
        return output, attention_scores 
class Transformer(nn.Module):
    def __init__(self, embed_size: int, num_layers: int): 
        super().__init__() 
        self.transformers_blocks = nn.ModuleList([TransformerBlock(embed_size) for _ in range(num_layers)]) 
    def forward(self, x: torch.Tensor): 
        attention_scores = [] 
        for transformer_block in self.transformers_blocks: 
            x, attention_score = transformer_block(x) 
            attention_scores.append(attention_score) 
        return x, attention_scores
    
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_size: int, max_seq_length: int):
        super().__init__() 
        position = torch.arange(max_seq_length).unsqueeze(1) 
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(1000.0) / embed_size))
        pe = torch.zeros(20, embed_size) 
        pe[:, 0::2] = torch.sin(position * div_term) 
        pe[:, 1:: 2] = torch.cos(position * div_term) 
        self.register_buffer('position_embedding', pe)
    def forward(self, x: torch.Tensor): 
        return x + self.position_embedding[:x.size(1), :] 

class CasualLanguageModel(nn.Module): 
    def __init__(self, embed_size: int, vocab_size: int, num_layers: int): 
        super().__init__() 
        self.embedding_layer = nn.Parameter(torch.randn(vocab_size, embed_size))
        self.transformer = Transformer(embed_size, num_layers)
        self.positional_encoding = SinusoidalPositionalEncoding(embed_size, max_seq_length=20) 

    def forward(self, x: torch.Tensor, return_attention_scores: bool = False): 
        x = torch.nn.functional.embedding(x, self.embedding_layer)
        x = self.positional_encoding(x) 
        x, attention_scores = self.transformer(x) 
        logits = torch.matmul(x, self.embedding_layer.T)
        if return_attention_scores: 
            return logits, attention_scores
        return logits 


In [40]:
dataset = [
    "dont forget to like and subscibe",
    "dont forget machine learning is fun",
    "machine learning is fun and awesome",
    "if you like machine learning i like you",
    "i like you more than machine learning"
]
vocab = set() 
special_tokens = ["<pad>", "<start>", "<end>"] 
for sentence in dataset: 
    vocab.update(sentence.split())
vocab = special_tokens + list(vocab) 
vocab_to_index = {word: index for index, word in enumerate(vocab)} 
vocab_size = len(vocab) 
print(vocab)
print("vocab size: ", vocab_to_index)

['<pad>', '<start>', '<end>', 'dont', 'fun', 'if', 'you', 'is', 'learning', 'machine', 'forget', 'subscibe', 'like', 'awesome', 'i', 'to', 'than', 'more', 'and']
vocab size:  {'<pad>': 0, '<start>': 1, '<end>': 2, 'dont': 3, 'fun': 4, 'if': 5, 'you': 6, 'is': 7, 'learning': 8, 'machine': 9, 'forget': 10, 'subscibe': 11, 'like': 12, 'awesome': 13, 'i': 14, 'to': 15, 'than': 16, 'more': 17, 'and': 18}


In [41]:
def encode(sentence: str): 
    return [vocab_to_index[word] for word in sentence.split()]
def encode_batch(sentences: list[str]): 
    encoded_sentences = [[vocab_to_index["<start>"]] + encode(sentence) + [vocab_to_index["<end>"]] for sentence in sentences] 
    max_length = max([len(encoded_sentence) for encoded_sentence in encoded_sentences])
    encoded_sentences = [encoded_sentence + [vocab_to_index["<pad>"]] * (max_length - len(encoded_sentence)) for encoded_sentence in encoded_sentences ]
    return encoded_sentences 
def decode(tokens: list[int]): 
    return " ".join([vocab[token] for token in tokens]) 
tokenized_dataset = encode_batch(dataset) 
tokenized_dataset = torch.tensor(tokenized_dataset) 
(tokenized_dataset)


tensor([[ 1,  3, 10, 15, 12, 18, 11,  2,  0,  0],
        [ 1,  3, 10,  9,  8,  7,  4,  2,  0,  0],
        [ 1,  9,  8,  7,  4, 18, 13,  2,  0,  0],
        [ 1,  5,  6, 12,  9,  8, 14, 12,  6,  2],
        [ 1, 14, 12,  6, 17, 16,  9,  8,  2,  0]])

In [42]:
decode(tokenized_dataset[0])

'<start> dont forget to like and subscibe <end> <pad> <pad>'

In [43]:
input_tokens = tokenized_dataset[:, :-1] 
target_tokens = tokenized_dataset[:, 1:] 
print(input_tokens) 
print(target_tokens)

tensor([[ 1,  3, 10, 15, 12, 18, 11,  2,  0],
        [ 1,  3, 10,  9,  8,  7,  4,  2,  0],
        [ 1,  9,  8,  7,  4, 18, 13,  2,  0],
        [ 1,  5,  6, 12,  9,  8, 14, 12,  6],
        [ 1, 14, 12,  6, 17, 16,  9,  8,  2]])
tensor([[ 3, 10, 15, 12, 18, 11,  2,  0,  0],
        [ 3, 10,  9,  8,  7,  4,  2,  0,  0],
        [ 9,  8,  7,  4, 18, 13,  2,  0,  0],
        [ 5,  6, 12,  9,  8, 14, 12,  6,  2],
        [14, 12,  6, 17, 16,  9,  8,  2,  0]])


In [44]:
print("Input: ", decode(input_tokens[0].tolist())) 
print("Target: ", decode(target_tokens[0].tolist()))

Input:  <start> dont forget to like and subscibe <end> <pad>
Target:  dont forget to like and subscibe <end> <pad> <pad>


In [45]:
vocab_size = len(vocab) 
embed_size = 6 
num_layers = 2
device = "cpu"
num_epochs = 600 
input_tokens = input_tokens.to(device) 
target_tokens = target_tokens.to(device) 
casual_language_model = CasualLanguageModel(embed_size=embed_size, vocab_size=vocab_size, num_layers=num_layers).to(device)
optimizer = torch.optim.Adam(casual_language_model.parameters(), lr=2e-3)


In [46]:
logits = casual_language_model(input_tokens)

In [47]:
for v, x in zip(vocab, logits[0][2].softmax(-1)): 
    print(v, x.item())

<pad> 0.15068207681179047
<start> 6.758410017937422e-05
<end> 0.002768024103716016
dont 0.01829790137708187
fun 0.014354881830513477
if 0.17056222259998322
you 8.814829925540835e-05
is 0.0007069381535984576
learning 0.00017167140322271734
machine 0.0028994965832680464
forget 0.31739717721939087
subscibe 0.00017451250459998846
like 0.0036833747290074825
awesome 0.2790459394454956
i 0.008447432890534401
to 0.00040081856423057616
than 0.013626216910779476
more 0.0047680982388556
and 0.011857414618134499


In [49]:
print(logits.shape) 
logits.view(-1, logits.shape[-1]).shape

torch.Size([5, 9, 19])


torch.Size([45, 19])

In [58]:
print(target_tokens.shape)
print(target_tokens.reshape(-1).shape)

torch.Size([5, 9])
torch.Size([45])


In [61]:
target_tokens.reshape(-1)

tensor([ 3, 10, 15, 12, 18, 11,  2,  0,  0,  3, 10,  9,  8,  7,  4,  2,  0,  0,
         9,  8,  7,  4, 18, 13,  2,  0,  0,  5,  6, 12,  9,  8, 14, 12,  6,  2,
        14, 12,  6, 17, 16,  9,  8,  2,  0])

In [63]:
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), target_tokens.reshape(-1))
print(loss)

tensor(6.8332, grad_fn=<NllLossBackward0>)


In [66]:
for i in range (num_epochs): 
    logits = casual_language_model(input_tokens)
    loss = F.cross_entropy(
    logits.view(-1, logits.shape[-1]),  # shape: (batch_size * seq_len, vocab_size)
    target_tokens.reshape(-1)  )         # shape: (batch_size * seq_len,)) 
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i % 10 == 0: 
        print(f"Epoch {i}, Loss: {loss.item()}") 
        pred = logits.argmax(dim=-1)

Epoch 0, Loss: 6.833237171173096
Epoch 10, Loss: 6.458953857421875
Epoch 20, Loss: 6.134280681610107
Epoch 30, Loss: 5.807788848876953
Epoch 40, Loss: 5.470464706420898
Epoch 50, Loss: 5.160411357879639
Epoch 60, Loss: 4.869407653808594
Epoch 70, Loss: 4.580451965332031
Epoch 80, Loss: 4.282816410064697
Epoch 90, Loss: 3.991530418395996
Epoch 100, Loss: 3.7186803817749023
Epoch 110, Loss: 3.461916208267212
Epoch 120, Loss: 3.2421720027923584
Epoch 130, Loss: 3.0540828704833984
Epoch 140, Loss: 2.889267921447754
Epoch 150, Loss: 2.7428181171417236
Epoch 160, Loss: 2.6044270992279053
Epoch 170, Loss: 2.470330238342285
Epoch 180, Loss: 2.3342368602752686
Epoch 190, Loss: 2.1938750743865967
Epoch 200, Loss: 2.056637763977051
Epoch 210, Loss: 1.9218400716781616
Epoch 220, Loss: 1.7908934354782104
Epoch 230, Loss: 1.659375548362732
Epoch 240, Loss: 1.5447407960891724
Epoch 250, Loss: 1.4212068319320679
Epoch 260, Loss: 1.3433170318603516
Epoch 270, Loss: 1.2422605752944946
Epoch 280, Loss: 1

In [67]:
logits = casual_language_model(input_tokens)

In [68]:
logits.argmax(dim=-1)

tensor([[ 3, 10,  9, 12, 18, 11,  2,  0,  0],
        [ 3, 10,  9,  8,  7,  4,  2,  0,  0],
        [ 3,  8,  7,  4, 18, 13,  2,  0,  0],
        [ 3,  6, 12,  9,  8, 14, 12,  6,  2],
        [ 3, 12,  6, 12, 16,  9,  8,  2,  0]])

In [69]:

# Assuming 'input_str', 'encode', 'decode', 'causal_language_model', and 'device' are defined elsewhere

input_str = "<start>"
# Encode the input string into tokens, convert to a PyTorch tensor,
# move it to the specified device (e.g., CPU or GPU), and add a batch dimension.
input_tokens = torch.tensor(encode(input_str)).to(device).unsqueeze(0)

# Pass the input tokens through the causal language model to get logits
# Logits are raw, unnormalized scores for each possible next token.
logits = casual_language_model(input_tokens)

# Get the probability distribution of the last token's prediction
# logits[0, -1:] selects the logits for the last token in the sequence.
# softmax(dim=-1) converts these logits into probabilities.
last_token_pred = logits[0, -1:].softmax(dim=-1)

# Get the index of the most probable next token (the predicted token)
# argmax(dim=-1) finds the index of the maximum value along the last dimension.
# keepdim=True maintains the dimension, so it remains a tensor of shape (1, 1).
last_token_logits = last_token_pred.argmax(dim=-1, keepdim=True)

# Concatenate the new predicted token to the original input tokens to form a new sequence
# dim=1 means concatenate along the sequence length dimension.
new_sequence = torch.cat([input_tokens, last_token_logits], dim=1)

# Decode the last predicted token index back into a human-readable string
# .tolist()[0] converts the tensor to a Python list and gets the first (and only) element.
last_predicted_token = decode(last_token_logits.tolist()[0])

# Print the input string
print("Input:", input_str)
# Print the new token and its confidence (maximum probability)
print(f"New token: {last_predicted_token} ({(last_token_pred.max().item() * 100):.2f}%)")
# Print the entire new sequence of tokens decoded back into a string
print("New sequence: ", decode(new_sequence[0].tolist()))


Input: <start>
New token: dont (36.28%)
New sequence:  <start> dont


In [81]:
def generation(prefix, max_length=18): 
    input_tokens = torch.tensor([vocab_to_index["<start>"]] + encode(prefix)).to(device).unsqueeze(0) 
    for _ in range(max_length): 
        with torch.no_grad(): 
            logits = casual_language_model(input_tokens)
            last_token_logits = logits[0, -1:].argmax(dim=-1, keepdim=True)
            print(decode([last_token_logits.tolist()[0][0]]))
            input_tokens = torch.cat((input_tokens, last_predicted_token))
        if input_tokens[0][-1] == vocab_to_index["<end>"]: 
            break 
    return decode(input_tokens[0].tolist()) 



In [76]:
dataset = [
    "dont forget to like and subscibe",
    "dont forget machine learning is fun",
    "machine learning is fun and awesome",
    "if you like machine learning i like you",
    "i like you more than machine learning"
]

In [77]:
def generation(prefix, max_length=18): 
    input_tokens = torch.tensor([vocab_to_index["<start>"]] + encode(prefix)).to(device).unsqueeze(0) 
    for _ in range(max_length): 
        with torch.no_grad(): 
            logits = casual_language_model(input_tokens)
            last_token_logits = logits[0, -1:].argmax(dim=-1, keepdim=True)  # shape (1, 1)
            print(decode([last_token_logits.tolist()[0][0]]))
            input_tokens = torch.cat((input_tokens, last_token_logits), dim=1)
        if input_tokens[0][-1].item() == vocab_to_index["<end>"]: 
            break 
    return decode(input_tokens[0].tolist()) 


In [80]:
generation("i like you more")

than
machine
learning
<end>


'<start> i like you more than machine learning <end>'