In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from random import randint
import random

def fetch():
    """
    emit a three token sequence, where each token is 4 dim. 
    
    only the first token has any predictive value, since y is 1 if token 1 
    is [1,1,1,1] and 0 if token 1 is [-1,-1,-1,-1].
    """
    a = np.array(random.choice([[1, 1, 1, 1], [-1, -1, -1, -1]]))
    b = np.random.uniform(-.5, .5, 4)
    c = np.random.uniform(-.5, .5, 4)
        
    y = (a[0] == 1)
    return ([a, b, c], int(y))

# section 1: baseline
we have a sentence consisting of three tokens, each token is a 4-dim embedding. each sentence is designed such that only the first token has any predictive value - that is, whether y=1 or y=0 depends only on the first token, which can only be [1,1,1,1] or [-1,-1,-1,-1].

to construct the baseline, we fit a simple linear model to token 2 only. since token 2 is randomly generated (from a uniform distribution of [-.5, .5]) it should have no predictive value. we should expect this neural net to get 50% accuracy.

In [2]:
class Vanilla_Net(nn.Module):
    def __init__(self):
        super(Vanilla_Net, self).__init__()
        self.fc1 = nn.Linear(in_features=4, out_features=2)
        
    def forward(self, x):
        x = self.fc1(x)
        y_prob = F.softmax(x, dim=1)[0]
        return y_prob

In [3]:
baseline_net = Vanilla_Net()
baseline_opt = optim.SGD(baseline_net.parameters(), lr=0.001, momentum=0.9)

In [4]:
n_epochs = 10

for epoch in range(n_epochs):

    # train
    for _ in range(300):
        
        # generate data, we only want to train on token 2
        X_train, y_train = fetch()
        token2 = torch.tensor(list(X_train[1])).reshape(1,4)

        baseline_opt.zero_grad()
        outputs = baseline_net(token2)
        loss = F.nll_loss(outputs.reshape(1, 2), torch.tensor([y_train]))
        loss.backward()
        baseline_opt.step()
        
    # test
    acc = 0
    for _ in range(100):
        # generate data, we only want to train on token 2
        X_test, y_test = fetch() 
        token2 = torch.tensor(list(X_test[1])).reshape(1,4)
        outputs = baseline_net(token2)
        pred = int(torch.argmax(outputs))
        if pred == y_test:
            acc += 1
    print("accuracy", acc/100)

accuracy 0.46
accuracy 0.5
accuracy 0.53
accuracy 0.47
accuracy 0.47
accuracy 0.5
accuracy 0.55
accuracy 0.46
accuracy 0.43
accuracy 0.55


[example](http://jalammar.github.io/illustrated-transformer/)

# section 2: transformers
we will still train on token2, which we know has no predictive value. 

however, in this case, we will transform token2 before passing it to a simple linear network. 

In [5]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.query_matrix = nn.Parameter(torch.randn(4, 3))
        self.key_matrix = nn.Parameter(torch.randn(4, 3))
        self.fc1 = nn.Linear(in_features=4, out_features=2)  # modified token2 is still dim=4
        
        
    def forward(self, input_seq):
        # input_seq has 3 tokens, each token a 4 dim vector
        
        # we only want token2
        token2 = torch.FloatTensor([input_seq[1]])
            
        # generate its query vector from query matrix
        q = torch.mm(token2, self.query_matrix)

        # generate an attention mask
        scores = []  # list of scores, where each element is s = q * k

        # compute similarity scores by multiplying query vector by key vector of every token in the sentence
        for t in input_seq:
            token_i = torch.FloatTensor([t])  # convert python list to torch tensor

            # create a key for each token by multiplying it by the key matrix
            k = torch.mm(token_i, self.key_matrix)

            # calculate similarity score for each token2/token_i pair
            s = torch.mm(q, k.reshape(3, 1))  # s is a scalar representing the degree of similarity
            scores.append(s[0][0])
        
        # softmax to create an attention mask
        mask = F.softmax(torch.stack(scores), dim=0)

        # multiply each token in the sentence by its corresponding score in the mask
        new_input_seq = []  # generate a new 3 token sentence, where each token is 4dim
        for pair in zip(input_seq, mask):
            new_embedding = torch.FloatTensor(pair[0]) * pair[1]
            new_input_seq.append(new_embedding)
        new_input_seq = (torch.stack(new_input_seq))  # recast python list to torch tensor
        modified_token2 = torch.sum(new_input_seq, dim=0).reshape(1, 4)  # new embeddeding for token2

        # predict the modified token 2
        x = self.fc1(modified_token2)
        y_prob = F.softmax(x, dim=1)[0]
        return y_prob

In [6]:
transformer_net = Transformer()
transformer_net_opt = optim.SGD(transformer_net.parameters(), lr=0.001, momentum=0.9)

In [7]:
n_epochs = 10

for epoch in range(n_epochs):

    # train
    for _ in range(300):
        X_train, y_train = fetch()
        transformer_net_opt.zero_grad()
        outputs = transformer_net(X_train)
        loss = F.nll_loss(outputs.reshape(1, 2), torch.tensor([y_train]))
        loss.backward()
        transformer_net_opt.step()

    # test
    acc = 0
    for _ in range(100):
        X_test, y_test = fetch() 
        outputs = transformer_net(X_test)
        pred = int(torch.argmax(outputs))
        if pred == y_test:
            acc += 1
    print(acc/100)

0.93
0.98
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
