In [1]:
### In this exercise, you will implement a self-attention mechanism
### for a toy spam classification task.
### Given a sentence [w1, w2, w3], we'll perform self-attention and then 
### aggregate the representations for feeding to a classifier.
### We provide the data / training loop setup; you should def look it over and 
### understand how it works!!

import torch
import torch.nn as nn
import sys

# spam detection!
data = ['you won a billion dollars , great work !',
        'click here for cs685 midterm answers',
       'read important cs685 news',
       'send me your bank account info asap']

labels = torch.LongTensor([1, 1, 0, 1]) # store ground-truth labels

# let's do some preprocessing
vocab = {}
inputs = []

for sent in data:
    idxes = []
    sent = sent.split()
    for w in sent:
        if w not in vocab:
            vocab[w] = len(vocab)
        idxes.append(vocab[w])
    inputs.append(idxes)
    
print(inputs)


[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14], [15, 16, 12, 17], [18, 19, 20, 21, 22, 23, 24]]


In [27]:
class SelfAttentionNN(nn.Module):
    
    # this is where you initialize network parameters
    def __init__(self, d_emb, num_outputs, len_vocab):
        
        super(SelfAttentionNN, self).__init__()
        self.d_emb = d_emb
        self.embeddings = nn.Embedding(len_vocab, d_emb)
        self.Wq = nn.Linear(d_emb, d_emb) # project to query space
        self.Wk = nn.Linear(d_emb, d_emb) # project to keys
        self.Wv = nn.Linear(d_emb, d_emb) # project to values
        self.output = nn.Linear(d_emb, num_outputs) # output matrix before softmax
        
    # all three args are N x d_emb matrices
    def dot_product_attn(self, q, k, v):
        scores = torch.mm(q, k.t()) # gets all dot products at once, N X N
        scores = torch.nn.functional.softmax(scores, dim=1)
        return torch.mm(scores, v) # N x d_emb
    
    def bilinear_attn(self, q, k):
        pass
    
    def scaled_dot_product_attn(self, q, k):
        scores = torch.mm(q, k.t()) # gets all dot products at once, N X N
        scores = scores / q.size()[1]
        scores = torch.nn.functional.softmax(scores, dim=1)
        return torch.mm(scores, q) # N x d_emb
    
    def mlp_attn(self, q, k):
        pass
        
    def forward(self, input):
        N = input.size()[1] # number of tokens in input, assume N > 2
        embs = self.embeddings(input).squeeze(0) # N x d_emb
        
        attn_reps = torch.zeros(embs.size()) # store attn weighted average of each position in seq
                                             # N x d_emb
        
        queries = self.Wq(embs) # N x d_emb
        keys = self.Wk(embs) # N x d_emb
        values = self.Wv(embs)

        # simple slow solution for self-attn computation
        for i in range(N):

            # at each position i, take query q_i and dot product w/ all keys
            unnorm_scores = torch.mv(keys, queries[i])
            probs = torch.nn.functional.softmax(unnorm_scores, dim=0)
            
            # attn-weighted average of value vectors
            # ave_values = torch.mv(values.t(), probs)
            ave_values = torch.sum(values * probs[:, None], dim=0) # (optional) [:, None] adds a dimension at index 1

            attn_reps[i] = ave_values

        # compose attn_reps into a single vector
        attn_reps = torch.mean(attn_reps, dim=0)

        # efficient attention computation using dot product attention
        attn_reps_2 = self.dot_product_attn(queries, keys, values)
        attn_reps_2 = torch.mean(attn_reps_2, dim=0)

        # efficient attention computation using scaled dot product attention
        attn_reps_3 = self.scaled_dot_product_attn(queries, keys)
        attn_reps_3 = torch.mean(attn_reps_3, dim=0)

        pred = self.output(attn_reps_2) # return logits
        return pred.unsqueeze(0)

# hyperparameters
num_epochs = 10
net = SelfAttentionNN(20, 2, len(vocab))
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(net.parameters(), lr = 0.1)

# training loop
for epoch in range(num_epochs): # how many passes thru input im gonna do
    ep_loss = 0.
        
    # for each pass, loop over training data
    for i in range(len(inputs)):
        sent = torch.LongTensor(inputs[i]).unsqueeze(0) # pick a sentence
        target = labels[i].unsqueeze(0) # get its label
        # unsqueeze(0) shapes the tensors into 2d, as pytorch forward usually
        # is written to receive minibatches rather than single examples
        
        pred = net(sent) # get output of network (unnormalized logits)
        loss = loss_fn(pred, target) # apply softmax + cross entropy
        loss.backward() # do backprop!!!
        optim.step() # update parameters w/ gradient descent
        optim.zero_grad() # reset gradients for next example
        ep_loss += loss # update overall loss for monitoring progress
    
    print(epoch, ep_loss)

0 tensor(2.9852, grad_fn=<AddBackward0>)
1 tensor(1.5828, grad_fn=<AddBackward0>)
2 tensor(0.8187, grad_fn=<AddBackward0>)
3 tensor(0.4224, grad_fn=<AddBackward0>)
4 tensor(0.2354, grad_fn=<AddBackward0>)
5 tensor(0.1514, grad_fn=<AddBackward0>)
6 tensor(0.1071, grad_fn=<AddBackward0>)
7 tensor(0.0807, grad_fn=<AddBackward0>)
8 tensor(0.0636, grad_fn=<AddBackward0>)
9 tensor(0.0517, grad_fn=<AddBackward0>)
