In [59]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm


In [26]:
torch.manual_seed(42)


<torch._C.Generator at 0x1076a2fb0>

Starting with a corpus with 3 tokens: a, b and c.

In [27]:
id2token = ["<UNK>", "a", "b", "c"]
token2id = {token: i for i, token in enumerate(id2token)}
           # i.e. {"<UNK>": 0, "a": 1, "b": 2, "c": 3}

In [61]:
# HYPERPARAMETERS
VOCAB_SIZE = len(token2id) # 4
EMB_DIM = 8
SEQ_LEN = 3
LEARNING_RATE = 0.01

In [29]:
# DEFINE Training Data
data = [
    "aaa",
    "bbb",
    "ccc"
]

In [30]:
# ENCODE DATA
data_tokens = [
    [token2id[token] for token in sequence] for sequence in data
]
print(data_tokens)

[[1, 1, 1], [2, 2, 2], [3, 3, 3]]


## Prepare the model

In [32]:
mask = torch.tensor([1, 0, 1])
actual = torch.tensor(data_tokens[0]).long()
ex = actual.clone()
ex[mask == 0] = token2id["<UNK>"]
ex = ex.long()  # Add this line
print(ex)

tensor([1, 0, 1])


In [53]:
class SimpleTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(VOCAB_SIZE, EMB_DIM)
        # LINEAR LAYER to use inside attention mechanism!?!?!
        self.out = nn.Linear(EMB_DIM, VOCAB_SIZE)
    
    def forward(self, x):
        x = self.emb(x)
        # Here we'd do positional encoding :P
        atn = x @ x.T
        atn = F.softmax(atn, -1)
        enc = atn @ x
        logits = self.out(enc)
        probs = F.softmax(logits, dim=-1)
        return probs



In [63]:
model = SimpleTransformer()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

In [67]:
# Trainin loop
# Add wandb logging
import wandb
wandb.init(project="mlx5.4-transformers", name="simple-transformer-setup")
for epoch in tqdm(range(1000)):
    for i, sequence in enumerate(data):
        # Tokenize the sequence
        seq = [token2id[token] for token in sequence]

        actual = torch.tensor(seq).long()
        # mask
        mask = torch.tensor([1, 0, 1]) # update to randomize mask
        ex = actual.clone()
        ex[mask == 0] = token2id["<UNK>"] # not necessary
        ex = ex.long()

        probs = model(ex)  
        loss = F.cross_entropy(probs, actual)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        wandb.log({"loss": loss})
wandb.finish()


100%|██████████| 1000/1000 [00:00<00:00, 1040.68it/s]


0,1
loss,▅██▇▄▆▅▄▃▂▃▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.75229


In [71]:
ex
words = [id2token[i] for i in ex]
print(words)

['c', '<UNK>', 'c']


In [72]:
model.eval()
with torch.no_grad():
    print(model(ex))


tensor([[0.0027, 0.0013, 0.0010, 0.9950],
        [0.0094, 0.0095, 0.0078, 0.9733],
        [0.0027, 0.0013, 0.0010, 0.9950]])


In [77]:
with torch.no_grad():
    print(model(torch.tensor([1, 0, 1])))

tensor([[0.0011, 0.9936, 0.0036, 0.0016],
        [0.0036, 0.9766, 0.0124, 0.0074],
        [0.0011, 0.9936, 0.0036, 0.0016]])
