In [63]:
import random
import pandas as pd

import math
import torch 
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm


In [2]:
torch.manual_seed(42)


<torch._C.Generator at 0x131cc6e50>

Starting with a corpus with 3 tokens: a, b and c.

In [3]:
"""Accepts a string or list of tokens.
    Returns id2token and token2id dictionaries.
"""
def vocab_dicts(vocab):
    vocab = sorted(set(vocab))
    id2token = ["<UNK>", *vocab]
    token2id = {token: i for i, token in enumerate(id2token)}
    return id2token, token2id


In [4]:
small_vocab = "abc"

# DEFINE Training Data
small_data = [
    "aaa",
    "bbb",
    "ccc"
]

id2token, token2id = vocab_dicts(small_vocab)
           # i.e. {"<UNK>": 0, "a": 1, "b": 2, "c": 3}

In [5]:
MAX_SEQ_LEN = 7

In [6]:
# More data
def letter_sequences(num_examples, alphabet):
    examples = []
    for _ in range(num_examples):
        length = random.randint(3, MAX_SEQ_LEN)
        start = random.randint(0, len(alphabet) - length)
        example = alphabet[start:start+length]
        examples.append(example)
    return pd.Series(examples)


In [7]:
alphabet = "abcdefghijklmnopqrstuvwxyz"
data = pd.DataFrame()
data["sequence"] = letter_sequences(10000, alphabet)
data.head()


Unnamed: 0,sequence
0,jklmn
1,mnop
2,stuvw
3,tuvwx
4,lmno


In [8]:
id2token, token2id = vocab_dicts(alphabet)

In [9]:
print("id for token 'a':", token2id['a'])
print("token for id 1:", id2token[1])


id for token 'a': 1
token for id 1: a


In [10]:
data["actual"] = data["sequence"].apply(lambda x: [token2id[token] for token in x])
data['mask'] = data['actual'].apply(lambda x: [1 if random.random() > 0.15 else 0 for _ in x])
data.head()


Unnamed: 0,sequence,actual,mask
0,jklmn,"[10, 11, 12, 13, 14]","[1, 1, 1, 1, 1]"
1,mnop,"[13, 14, 15, 16]","[1, 1, 1, 0]"
2,stuvw,"[19, 20, 21, 22, 23]","[1, 1, 1, 1, 1]"
3,tuvwx,"[20, 21, 22, 23, 24]","[1, 1, 1, 0, 1]"
4,lmno,"[12, 13, 14, 15]","[1, 1, 1, 1]"


In [11]:
def apply_mask(row):
    # mask is a list of 0s and 1s
    return [a * m for a, m in zip(row['actual'], row['mask'])]


In [12]:
data['ex'] = data.apply(apply_mask, axis=1)
data.head()



Unnamed: 0,sequence,actual,mask,ex
0,jklmn,"[10, 11, 12, 13, 14]","[1, 1, 1, 1, 1]","[10, 11, 12, 13, 14]"
1,mnop,"[13, 14, 15, 16]","[1, 1, 1, 0]","[13, 14, 15, 0]"
2,stuvw,"[19, 20, 21, 22, 23]","[1, 1, 1, 1, 1]","[19, 20, 21, 22, 23]"
3,tuvwx,"[20, 21, 22, 23, 24]","[1, 1, 1, 0, 1]","[20, 21, 22, 0, 24]"
4,lmno,"[12, 13, 14, 15]","[1, 1, 1, 1]","[12, 13, 14, 15]"


In [13]:
data.head()

Unnamed: 0,sequence,actual,mask,ex
0,jklmn,"[10, 11, 12, 13, 14]","[1, 1, 1, 1, 1]","[10, 11, 12, 13, 14]"
1,mnop,"[13, 14, 15, 16]","[1, 1, 1, 0]","[13, 14, 15, 0]"
2,stuvw,"[19, 20, 21, 22, 23]","[1, 1, 1, 1, 1]","[19, 20, 21, 22, 23]"
3,tuvwx,"[20, 21, 22, 23, 24]","[1, 1, 1, 0, 1]","[20, 21, 22, 0, 24]"
4,lmno,"[12, 13, 14, 15]","[1, 1, 1, 1]","[12, 13, 14, 15]"


In [14]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.actual = data['actual']
        self.ex = data['ex']

    
    def __len__(self):
        return len(self.actual)
    
    def __getitem__(self, idx):
        return self.actual.iloc[idx], self.ex.iloc[idx]



In [15]:
dataset = Dataset(data)



In [16]:
def collate(batch, max_len=MAX_SEQ_LEN):

    batch_size = len(batch)
    max_len = MAX_SEQ_LEN

    # Pad sequences to the same length
    actual_padded = torch.zeros(batch_size, max_len, dtype=torch.long)
    ex_padded = torch.zeros(batch_size, max_len, dtype=torch.long)
    padding_masks = torch.zeros(batch_size, max_len, dtype=torch.bool)

    for i, (seq_actual, seq_ex) in enumerate(batch):
        length = len(seq_actual)

        seq_actual = torch.tensor(seq_actual)
        seq_ex = torch.tensor(seq_ex)

        actual_padded[i, :len(seq_actual)] = seq_actual
        ex_padded[i, :len(seq_ex)] = seq_ex
        padding_masks[i, :length] = True

    return actual_padded, ex_padded, padding_masks
    


In [17]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collate)

In [18]:
# get next item
a, x, p = next(iter(dataloader))
print(p)



tensor([[ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True]])


## Prepare the model

In [20]:
# class Attention(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.linear = nn.Linear(EMB_DIM, EMB_DIM)
    
#     def forward(self, x):
#         x = self.emb(x)
#         # Linear layer?
#         atn = x @ x.T
#         atn = F.softmax(atn, -1)
#         enc = atn @ x
#         return x



In [52]:
# HYPERPARAMETERS
VOCAB_SIZE = len(token2id) # 27
EMB_DIM = 8
MAX_SEQ_LEN = 7
LEARNING_RATE = 0.01

In [72]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, emb_dim=EMB_DIM):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        # LINEAR LAYER to use inside attention mechanism!?!?!
        self.out = nn.Linear(emb_dim, vocab_size)
    
    def forward(self, x, padding_mask=None):

        x = self.emb(x)

        # Here we'd do positional encoding :P
        scores = x @ x.transpose(-2, -1)
        # print("\nAttention scores before mask:", 
        #       "\nmin:", scores.min().item(),
        #       "\nmax:", scores.max().item())

        if padding_mask is not None:
            attention_mask = padding_mask.unsqueeze(1) & padding_mask.unsqueeze(2)
            # print("\nAttention mask sample:")
            # print(attention_mask[0, :3, :3])  # Show first 3x3 of first batch
            
            scores = scores.masked_fill(~attention_mask, -1e9)
            scores = scores / math.sqrt(x.size(-1))
            # print("\nAttention scores after mask:", 
            #       "\nmin:", scores.min().item(),
            #       "\nmax:", scores.max().item())
        
        
        atn = F.softmax(scores, -1)
        # print("\nAttention weights after softmax:",
        #       "\nmin:", atn.min().item(),
        #       "\nmax:", atn.max().item())
        
        enc = atn @ x
        logits = self.out(enc)
        # print("\nLogits:",
        #       "\nmin:", logits.min().item(),
        #       "\nmax:", logits.max().item(),
        #       "\nShape:", logits.shape)
        
        # probs = F.softmax(logits, dim=-1)
        return logits



In [73]:
model = SimpleTransformer()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

In [77]:
# Trainin loop
# Add wandb logging
import wandb
wandb.init(project="mlx5.4-transformers", name="simple-transformer-alphabet")
model.train()
# Debug prints

for epoch in tqdm(range(1500)):
    for i, (actual, ex, mask) in enumerate(dataloader):
        logits = model(ex, mask) 
        logits = logits.view(-1, VOCAB_SIZE)
        actual = actual.view(-1)

        loss = F.cross_entropy(logits, actual)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        wandb.log({"loss": loss})
    # Validate every N batches
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            # Get predictions for first sequence in batch
            pred = torch.argmax(logits.view(ex.shape[0], -1, VOCAB_SIZE)[0], dim=-1)
            act = actual.view(ex.shape[0], -1)[0]
            
            pred_tokens = ''.join([id2token[i.item()] for i in pred])
            actual_tokens = ''.join([id2token[i.item()] for i in act])
            
            print(f"\nEpoch {epoch}, Batch {i}")
            print(f"Predicted: {pred_tokens}")
            print(f"Actual:    {actual_tokens}")
        model.train()
wandb.finish()


  0%|          | 1/1500 [00:00<14:02,  1.78it/s]


Epoch 0, Batch 624
Predicted: <UNK><UNK><UNK><UNK>t<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  1%|          | 11/1500 [00:05<11:35,  2.14it/s]


Epoch 10, Batch 624
Predicted: <UNK>qrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  1%|▏         | 21/1500 [00:09<10:24,  2.37it/s]


Epoch 20, Batch 624
Predicted: sqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  2%|▏         | 31/1500 [00:14<10:51,  2.25it/s]


Epoch 30, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  3%|▎         | 41/1500 [00:18<10:58,  2.22it/s]


Epoch 40, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  3%|▎         | 51/1500 [00:23<10:41,  2.26it/s]


Epoch 50, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  4%|▍         | 61/1500 [00:27<10:12,  2.35it/s]


Epoch 60, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  5%|▍         | 71/1500 [00:32<10:09,  2.35it/s]


Epoch 70, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  5%|▌         | 81/1500 [00:36<09:18,  2.54it/s]


Epoch 80, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  6%|▌         | 91/1500 [00:41<11:00,  2.13it/s]


Epoch 90, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  7%|▋         | 101/1500 [00:45<10:02,  2.32it/s]


Epoch 100, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  7%|▋         | 111/1500 [00:50<10:16,  2.25it/s]


Epoch 110, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  8%|▊         | 121/1500 [00:55<12:40,  1.81it/s]


Epoch 120, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  9%|▊         | 131/1500 [01:00<11:01,  2.07it/s]


Epoch 130, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


  9%|▉         | 141/1500 [01:05<11:35,  1.95it/s]


Epoch 140, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


 10%|█         | 151/1500 [01:10<10:02,  2.24it/s]


Epoch 150, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


 11%|█         | 161/1500 [01:15<09:56,  2.24it/s]


Epoch 160, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


 11%|█▏        | 171/1500 [01:20<10:48,  2.05it/s]


Epoch 170, Batch 624
Predicted: qqrst<UNK><UNK>
Actual:    pqrst<UNK><UNK>


 11%|█▏        | 172/1500 [01:20<10:24,  2.13it/s]


KeyboardInterrupt: 

In [78]:
a, x, m = next(iter(dataloader))


In [80]:
def validate_predictions(model, dataloader, id2token, n_examples=15):
    model.eval()
    with torch.no_grad():
        # Get one batch
        actual, example, mask = next(iter(dataloader))
        
        # Get predictions
        logits = model(example, mask)  # [batch, seq, vocab]
        
        # For each example in batch (up to n_examples)
        for i in range(min(n_examples, len(actual))):
            # Get single sequence
            seq_probs = logits[i]  # # apply softmax???
            seq_actual = actual[i]  # [seq]
            
            # Get predicted tokens
            predicted_indices = torch.argmax(seq_probs, dim=-1)  # [seq]
            
            # Convert to letters
            pred_tokens = [id2token[idx.item()] for idx in predicted_indices]
            actual_tokens = [id2token[idx.item()] for idx in seq_actual]
            
            print(f"\nExample {i+1}:")
            print(f"Predicted: {''.join(pred_tokens)}")
            print(f"Actual:    {''.join(actual_tokens)}")

# Run validation
validate_predictions(model, dataloader, id2token)


Example 1:
Predicted: jklmn<UNK><UNK>
Actual:    jklmn<UNK><UNK>

Example 2:
Predicted: mno<UNK><UNK><UNK><UNK>
Actual:    mnop<UNK><UNK><UNK>

Example 3:
Predicted: stuvw<UNK><UNK>
Actual:    stuvw<UNK><UNK>

Example 4:
Predicted: tuvxx<UNK><UNK>
Actual:    tuvwx<UNK><UNK>

Example 5:
Predicted: lmno<UNK><UNK><UNK>
Actual:    lmno<UNK><UNK><UNK>

Example 6:
Predicted: ijklmnl
Actual:    ijklmno

Example 7:
Predicted: ijklm<UNK><UNK>
Actual:    ijklm<UNK><UNK>

Example 8:
Predicted: vwxyx<UNK><UNK>
Actual:    vwxyz<UNK><UNK>

Example 9:
Predicted: <UNK><UNK>pq<UNK><UNK><UNK>
Actual:    nopq<UNK><UNK><UNK>

Example 10:
Predicted: uvxxy<UNK><UNK>
Actual:    uvwxy<UNK><UNK>

Example 11:
Predicted: abcdefg
Actual:    abcdefg

Example 12:
Predicted: ij<UNK><UNK><UNK><UNK><UNK>
Actual:    ijk<UNK><UNK><UNK><UNK>

Example 13:
Predicted: w<UNK>y<UNK><UNK><UNK><UNK>
Actual:    wxyz<UNK><UNK><UNK>

Example 14:
Predicted: abcdef<UNK>
Actual:    abcdef<UNK>

Example 15:
Predicted: fggij<UNK><UNK>

In [51]:
a, x, m = next(iter(dataloader))