In [195]:
import random
import pandas as pd

import torch 
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm


In [179]:
torch.manual_seed(42)


<torch._C.Generator at 0x1076a2fb0>

Starting with a corpus with 3 tokens: a, b and c.

In [180]:
"""Accepts a string or list of tokens.
    Returns id2token and token2id dictionaries.
"""
def vocab_dicts(vocab):
    vocab = sorted(set(vocab))
    id2token = ["<UNK>", *vocab]
    token2id = {token: i for i, token in enumerate(id2token)}
    return id2token, token2id


In [27]:
small_vocab = "abc"

# DEFINE Training Data
small_data = [
    "aaa",
    "bbb",
    "ccc"
]

id2token, token2id = vocab_dicts(small_vocab)
           # i.e. {"<UNK>": 0, "a": 1, "b": 2, "c": 3}

In [297]:
# More data
def letter_sequences(num_examples, alphabet):
    examples = []
    for _ in range(num_examples):
        length = 4 # random.randint(3, MAX_SEQ_LEN)
        start = random.randint(0, len(alphabet) - length)
        example = alphabet[start:start+length]
        examples.append(example)
    return pd.Series(examples)


In [298]:
alphabet = "abcdefghijklmnopqrstuvwxyz"
data = pd.DataFrame()
data["sequence"] = letter_sequences(10000, alphabet)
data.head()


Unnamed: 0,sequence
0,hijk
1,stuv
2,tuvw
3,vwxy
4,tuvw


In [299]:
id2token, token2id = vocab_dicts(alphabet)

In [300]:
print("id for token 'a':", token2id['a'])
print("token for id 1:", id2token[1])


id for token 'a': 1
token for id 1: a


In [301]:
data["actual"] = data["sequence"].apply(lambda x: [token2id[token] for token in x])
data['mask'] = data['actual'].apply(lambda x: [1 if random.random() > 0.15 else 0 for _ in x])
data.head()


Unnamed: 0,sequence,actual,mask
0,hijk,"[8, 9, 10, 11]","[1, 1, 1, 1]"
1,stuv,"[19, 20, 21, 22]","[0, 1, 1, 1]"
2,tuvw,"[20, 21, 22, 23]","[1, 1, 1, 0]"
3,vwxy,"[22, 23, 24, 25]","[1, 1, 1, 1]"
4,tuvw,"[20, 21, 22, 23]","[1, 1, 1, 1]"


In [302]:
def apply_mask(row):
    # mask is a list of 0s and 1s
    return [a * m for a, m in zip(row['actual'], row['mask'])]


In [303]:
data['ex'] = data.apply(apply_mask, axis=1)
data.head()



Unnamed: 0,sequence,actual,mask,ex
0,hijk,"[8, 9, 10, 11]","[1, 1, 1, 1]","[8, 9, 10, 11]"
1,stuv,"[19, 20, 21, 22]","[0, 1, 1, 1]","[0, 20, 21, 22]"
2,tuvw,"[20, 21, 22, 23]","[1, 1, 1, 0]","[20, 21, 22, 0]"
3,vwxy,"[22, 23, 24, 25]","[1, 1, 1, 1]","[22, 23, 24, 25]"
4,tuvw,"[20, 21, 22, 23]","[1, 1, 1, 1]","[20, 21, 22, 23]"


In [311]:
data.head()

Unnamed: 0,sequence,actual,mask,ex
0,hijk,"[8, 9, 10, 11]","[1, 1, 1, 1]","[8, 9, 10, 11]"
1,stuv,"[19, 20, 21, 22]","[0, 1, 1, 1]","[0, 20, 21, 22]"
2,tuvw,"[20, 21, 22, 23]","[1, 1, 1, 0]","[20, 21, 22, 0]"
3,vwxy,"[22, 23, 24, 25]","[1, 1, 1, 1]","[22, 23, 24, 25]"
4,tuvw,"[20, 21, 22, 23]","[1, 1, 1, 1]","[20, 21, 22, 23]"


In [312]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.actual = data['actual']
        self.mask = data['mask']
        self.ex = data['ex']

    
    def __len__(self):
        return len(self.actual)
    
    def __getitem__(self, idx):
        return self.actual.iloc[idx], self.mask.iloc[idx], self.ex.iloc[idx]



In [313]:
dataset = Dataset(data)



In [314]:
dataset[0:10]

(0      [8, 9, 10, 11]
 1    [19, 20, 21, 22]
 2    [20, 21, 22, 23]
 3    [22, 23, 24, 25]
 4    [20, 21, 22, 23]
 5    [11, 12, 13, 14]
 6       [7, 8, 9, 10]
 7        [6, 7, 8, 9]
 8    [22, 23, 24, 25]
 9        [4, 5, 6, 7]
 Name: actual, dtype: object,
 0    [1, 1, 1, 1]
 1    [0, 1, 1, 1]
 2    [1, 1, 1, 0]
 3    [1, 1, 1, 1]
 4    [1, 1, 1, 1]
 5    [1, 1, 0, 1]
 6    [0, 1, 1, 1]
 7    [1, 0, 1, 1]
 8    [0, 0, 1, 1]
 9    [1, 1, 1, 1]
 Name: mask, dtype: object,
 0      [8, 9, 10, 11]
 1     [0, 20, 21, 22]
 2     [20, 21, 22, 0]
 3    [22, 23, 24, 25]
 4    [20, 21, 22, 23]
 5     [11, 12, 0, 14]
 6       [0, 8, 9, 10]
 7        [6, 0, 8, 9]
 8      [0, 0, 24, 25]
 9        [4, 5, 6, 7]
 Name: ex, dtype: object)

In [315]:
def collate(batch):
    actual, mask, ex = zip(*batch)
    actual = torch.stack([torch.tensor(row) for row in actual])
    mask = torch.stack([torch.tensor(row) for row in mask])
    ex = torch.stack([torch.tensor(row) for row in ex])
    return actual, ex, mask
    # actual, mask = zip(*batch)
    # actual = torch.stack(actual)
    
    # ex = 
    # ex = [row['ex'] for row in batch]
    # actual = [row['actual'] for row in batch]
    # return torch.tensor(ex), torch.tensor(actual)


In [316]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collate)

In [317]:
# get next item
a, x, m = next(iter(dataloader))
print(len(a))
print(len(x))
print(len(m))


16
16
16


In [318]:
i = 0
for d in dataloader:
    a = d[0]
    x = d[1]
    m = d[2]
    print(a[0], x[0], m[0])
    if i > 10:
        break


tensor([ 8,  9, 10, 11]) tensor([ 8,  9, 10, 11]) tensor([1, 1, 1, 1])
tensor([2, 3, 4, 5]) tensor([0, 3, 0, 5]) tensor([0, 1, 0, 1])
tensor([22, 23, 24, 25]) tensor([22, 23, 24, 25]) tensor([1, 1, 1, 1])
tensor([ 8,  9, 10, 11]) tensor([ 8,  9,  0, 11]) tensor([1, 1, 0, 1])
tensor([ 8,  9, 10, 11]) tensor([ 8,  9, 10, 11]) tensor([1, 1, 1, 1])
tensor([5, 6, 7, 8]) tensor([5, 6, 7, 8]) tensor([1, 1, 1, 1])
tensor([19, 20, 21, 22]) tensor([19,  0, 21,  0]) tensor([1, 0, 1, 0])
tensor([2, 3, 4, 5]) tensor([2, 3, 4, 5]) tensor([1, 1, 1, 1])
tensor([1, 2, 3, 4]) tensor([1, 2, 3, 4]) tensor([1, 1, 1, 1])
tensor([5, 6, 7, 8]) tensor([5, 6, 7, 8]) tensor([1, 1, 1, 1])
tensor([21, 22, 23, 24]) tensor([21, 22, 23,  0]) tensor([1, 1, 1, 0])
tensor([22, 23, 24, 25]) tensor([ 0, 23, 24, 25]) tensor([0, 1, 1, 1])
tensor([5, 6, 7, 8]) tensor([5, 6, 7, 0]) tensor([1, 1, 1, 0])
tensor([11, 12, 13, 14]) tensor([11, 12,  0, 14]) tensor([1, 1, 0, 1])
tensor([6, 7, 8, 9]) tensor([6, 7, 8, 9]) tensor([1, 1

## Prepare the model

In [188]:
# class Attention(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.linear = nn.Linear(EMB_DIM, EMB_DIM)
    
#     def forward(self, x):
#         x = self.emb(x)
#         # Linear layer?
#         atn = x @ x.T
#         atn = F.softmax(atn, -1)
#         enc = atn @ x
#         return x



In [327]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, emb_dim=EMB_DIM):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        # LINEAR LAYER to use inside attention mechanism!?!?!
        self.out = nn.Linear(emb_dim, vocab_size)
    
    def forward(self, x):
        x = self.emb(x)
        # Here we'd do positional encoding :P
        atn = x @ x.transpose(-2, -1)
        atn = F.softmax(atn, -1)
        enc = atn @ x
        logits = self.out(enc)
        probs = F.softmax(logits, dim=-1)
        return probs



In [328]:
# HYPERPARAMETERS
VOCAB_SIZE = len(token2id) # 27
EMB_DIM = 8
MAX_SEQ_LEN = 7
LEARNING_RATE = 0.01

In [329]:
VOCAB_SIZE

27

In [330]:
model = SimpleTransformer()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

In [334]:
# Trainin loop
# Add wandb logging
import wandb
wandb.init(project="mlx5.4-transformers", name="simple-transformer-alphabet")
model.train()
# Debug prints

for epoch in tqdm(range(150)):
    for actual, ex, mask in dataloader:

        probs = model(ex) 
        probs = probs.view(-1, VOCAB_SIZE)
        actual = actual.view(-1)

        loss = F.cross_entropy(probs, actual)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        wandb.log({"loss": loss})
wandb.finish()


 40%|████      | 404/1000 [02:32<03:44,  2.65it/s]


KeyboardInterrupt: 

In [335]:
a, x, m = next(iter(dataloader))


In [341]:
def validate_predictions(model, dataloader, id2token, n_examples=5):
    model.eval()
    with torch.no_grad():
        # Get one batch
        actual, example, mask = next(iter(dataloader))
        
        # Get predictions
        probs = model(example)  # [batch, seq, vocab]
        
        # For each example in batch (up to n_examples)
        for i in range(min(n_examples, len(actual))):
            # Get single sequence
            seq_probs = probs[i]  # [seq, vocab]
            seq_actual = actual[i]  # [seq]
            
            # Get predicted tokens
            predicted_indices = torch.argmax(seq_probs, dim=-1)  # [seq]
            
            # Convert to letters
            pred_tokens = [id2token[idx.item()] for idx in predicted_indices]
            actual_tokens = [id2token[idx.item()] for idx in seq_actual]
            
            print(f"\nExample {i+1}:")
            print(f"Predicted: {''.join(pred_tokens)}")
            print(f"Actual:    {''.join(actual_tokens)}")

# Run validation
validate_predictions(model, dataloader, id2token)


Example 1:
Predicted: hijk
Actual:    hijk

Example 2:
Predicted: utuv
Actual:    stuv

Example 3:
Predicted: tuvu
Actual:    tuvw

Example 4:
Predicted: vwxy
Actual:    vwxy

Example 5:
Predicted: tuvw
Actual:    tuvw


In [340]:
a, x, m = next(iter(dataloader))