In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Data Preparation

In [2]:
# spam detection!
data = ['you won a billion dollars , great work !',
        'click here for cs685 midterm answers',
       'read important cs685 news',
       'send me your bank account info asap']

labels = torch.LongTensor([1, 1, 0, 1]) # store ground-truth labels

# let's do some preprocessing
vocab = {}
inputs = []

for sentence in data:
    idxs = []
    sentence = sentence.split()
    for word in sentence:
        if word not in vocab:
            vocab[word] = len(vocab)
        idxs.append(vocab[word])
    inputs.append(idxs)
    
print(inputs)

[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14], [15, 16, 12, 17], [18, 19, 20, 21, 22, 23, 24]]


### Build the model

In [3]:
class SelfAttentionNN(nn.Module):
    
    def __init__(self, embedding_dim, vocab_size):
        
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        self.Wq = nn.Linear(embedding_dim, embedding_dim) # project to query space
        self.Wk = nn.Linear(embedding_dim, embedding_dim) # project to keys
        self.Wv = nn.Linear(embedding_dim, embedding_dim) # project to values
        
        # the final classification layer
        self.cls = nn.Linear(embedding_dim, 2)
        
    # all three args are T x embedding_dim matrices!
    def dot_product_attn(self, q, k, v):
        scores = q @ k.t() # gets all dot products at once, T X T
        scores = F.softmax(scores, dim=1)
        return scores @ v # T x embedding_dim
    
    # you can implement the three below for fun!
    def bilinear_attn(self, q, k):
        pass
    
    def scaled_dot_product_attn(self, q, k):
        pass
        
    def mlp_attn(self, q, k):
        pass
        
    def forward(self, inpt_sentence):
        T = inpt_sentence.size(0) # number of tokens in input, assume T > 2
        word_embeds = self.embeddings(inpt_sentence) # T x embedding_dim
        
        queries = self.Wq(word_embeds) #  T x embedding_dim
        keys = self.Wk(word_embeds) #  T x embedding_dim
        values = self.Wv(word_embeds) #  T x embedding_dim

        # efficient attention computation
        attn_reps = self.dot_product_attn(queries, keys, values)

        # compose attn_reps into a single vector
        attn_reps = torch.mean(attn_reps, dim=0)

        pred = self.cls(attn_reps) # return logits
        return pred.unsqueeze(0)

#### Test Inference

In [4]:
embedding_dim = 32
vocab_size = len(vocab)

In [5]:
model = SelfAttentionNN(embedding_dim, vocab_size)

In [6]:
sample_input = torch.LongTensor([1, 2, 3, 4])

with torch.no_grad():
    out = model(sample_input)
    print(out)

tensor([[-0.5077, -0.0635]])


### Train the model

In [7]:
num_epochs = 10
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr = 0.1)

In [8]:
# training loop
for epoch in range(num_epochs):
    ep_loss = 0. # loss per epoch
        
    for i in range(len(inputs)):
        # get input sentence and target label
        inpt_sentence = torch.LongTensor(inputs[i])
        target = labels[i].unsqueeze(0)
        
        pred = model(inpt_sentence)
        loss = loss_fn(pred, target)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        ep_loss += loss.item()
    
    print(epoch, ep_loss)

0 2.8318784534931183
1 1.7424476146697998
2 0.9125520139932632
3 0.44691064208745956
4 0.2501957528293133
5 0.16152603924274445
6 0.11486340966075659
7 0.08715852349996567
8 0.06921625044196844
9 0.05681771645322442
