In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np

def fetch(k):
    """
    emit a sequence of three tokens, where each token has a dimension = k.
    
    only the first token has any predictive value, since y is 1 if token 1 
    is [1,1,1,1] and 0 if token 1 is [-1,-1,-1,-1].
    """
    a = np.array(random.choice([[1] * k, [-1] * k]))
    b = np.random.uniform(-1, 1, k)
    c = np.random.uniform(-1, 1, k)
    d = np.array(random.choice([[1] * k, [-1] * k]))
        
    y = (a[0] != d[0])  # xor, where y = 1 if only one input is equal to 1
    return ([a, b, c, d], int(y))

# section 1: baseline
we have a sentence consisting of three tokens, each token is a 4-dim embedding. 

each sentence is designed so that only the first token has any predictive value - that is, whether `y=1` or `y=0` depends only on the first token, which can only be `[1,1,1,1]` or `[-1,-1,-1,-1]`.

therefore, when we fit a simple linear model to token 2 (generated from a uniform distribution of `[-1, 1]`), we should not be able to do better than random chance.

In [None]:
class Vanilla_Net(nn.Module):
    def __init__(self, embedding_dim):
        super(Vanilla_Net, self).__init__()
        self.fc1 = nn.Linear(in_features=embedding_dim, out_features=2)
        
    def forward(self, x):
        x = self.fc1(x)
        prob = F.softmax(x, dim=1)[0]
        return prob

In [None]:
k = 3  # dimensionality of each token

baseline_net = Vanilla_Net(embedding_dim=k)
baseline_opt = optim.SGD(baseline_net.parameters(), lr=1e-3, momentum=0.9)

In [None]:
n_epochs = 10

for epoch in range(n_epochs):

    # train
    for _ in range(500):
        
        # generate data, we only want to train on token 2
        X, y = fetch(k)
        token2 = torch.tensor(list(X[1])).reshape(1,k)

        baseline_opt.zero_grad()
        outputs = baseline_net(token2)
        loss = F.nll_loss(outputs.reshape(1, 2), torch.tensor([y]))
        loss.backward()
        baseline_opt.step()
        
    # test
    acc = 0
    for _ in range(100):
        X, y = fetch(k)
        token2 = torch.tensor(list(X[1])).reshape(1,k)
        outputs = baseline_net(token2)
        pred = int(torch.argmax(outputs))
        if pred == y:
            acc += 1
    print("accuracy", acc/100)

# section 2: transformers
we will still train on token2, which has no predictive value. 

however, we will modify token2 before passing it to a neural net. 

first, we compare token2 (represented by a query vector) to every other token (represented by a key vector) in the sentence. the query vector is computed by multiplying token2 by a matrix Q. key vectors are computed by multiplying every token in the sentence by a matrix K.

we take the inner product of the query vector and key vector, for all key vectors. this creates a 1-dim array of length K, where K is number of tokens in the sentence, where each element contains the inner product (which implicitly measures the similarity score). this "score" vector is then softmaxed to create an attention mask. 

finally, the attention mask is applied to every token in the sentence. the attention mask can be interpreted as a weighting mechanism, weighting all of the tokens in the sentence. the weighted token embeddings are summed up to create a new token2, which now contains mixtures of all other tokens in the sentence. 

in order to reduce loss, the neural wants a modified token2 that has high predictive value, which in turn will want a softmax that places high weights to other tokens with high predictive value (and low weights to mask out tokens with no predictive value). therefore, the query matrix and key matrix learns to how to maximize the inner product between token2 and tokens with high predictive value.

In [None]:
class Transformer(nn.Module):
    def __init__(self, embedding_dim):
        super(Transformer, self).__init__()
        self.query_matrix = nn.Parameter(torch.randn(embedding_dim, 3))
        self.key_matrix = nn.Parameter(torch.randn(embedding_dim, 3))
        self.fc1 = nn.Linear(in_features=embedding_dim, out_features=2)
        self.embedding_dim = embedding_dim
        
        
    def forward(self, sentence):
        # sentence has 3 tokens, each token is a k-dim embedding
        
        # we only want token2, which by itself has no predictive value
        token2 = torch.FloatTensor([sentence[1]])
            
        # generate token2's query vector by multiplying token2 with the query matrix
        q = torch.mm(token2, self.query_matrix)

        # generate an attention mask
        scores = []  # list of scores, where each element is s = q * k

        # compute similarity scores by multiplying query vector by every possible key vector
        for t in sentence:
            token_i = torch.FloatTensor([t])  # convert python list to torch tensor

            # create key vectors by multiplying each token by the key matrix
            k = torch.mm(token_i, self.key_matrix)

            # calculate similarity score by taking the inner product of query vector with each key vector
            s = torch.mm(q, k.reshape(3, 1))  # s is a scalar representing the degree of similarity
            scores.append(s[0][0])
        
        # softmax to create an attention mask
        mask = F.softmax(torch.stack(scores), dim=0)

        # multiply each token in the sentence by its corresponding score in the mask
        ls_weighted_embeddings = []  # generate a new 3 token sentence, where each token is 4dim
        for pair in zip(sentence, mask):
            new_embedding = torch.FloatTensor(pair[0]) * pair[1]  # weight each embedding by the mask
            ls_weighted_embeddings.append(new_embedding)
        ls_weighted_embeddings = (torch.stack(ls_weighted_embeddings))  # recast python list to torch tensor
        modified_token2 = torch.sum(ls_weighted_embeddings, dim=0).reshape(1, self.embedding_dim)  # new embedding for token2

        # learn using the modified embedding for token2
        x = self.fc1(modified_token2)
        prob = F.softmax(x, dim=1)[0]
        return prob

In [None]:
k = 4  # dimensionality of each token
transformer_net = Transformer(embedding_dim=k)
transformer_net_opt = optim.SGD(transformer_net.parameters(), lr=1e-3, momentum=0.9)

In [None]:
n_epochs = 10

for epoch in range(n_epochs):

    # train
    for _ in range(3000):
        X, y = fetch(k)
        transformer_net_opt.zero_grad()
        outputs = transformer_net(X)
        loss = F.nll_loss(outputs.reshape(1, 2), torch.tensor([y]))
        loss.backward()
        transformer_net_opt.step()

    # test
    acc = 0
    for _ in range(100):
        X, y = fetch(k) 
        outputs = transformer_net(X)
        pred = int(torch.argmax(outputs))
        if pred == y:
            acc += 1
    print(acc/100)