In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
train_data, valid_data, test_data, word2idx, idx2word = utils.load_data('data', device)

In [3]:
N_EPOCHS = 50
BATCH_SIZE = 32
CLIP = 40
VOCAB_SIZE = len(word2idx)
EMB_DIM = 50
SENT_LEN = train_data[0][0].shape[1]
STORY_LEN = train_data[0][0].shape[0]
QUERY_LEN = train_data[0][1].shape[1]
POS_ENC = True
TEMP_ENC = True
LS = True
N_HOPS = 3

In [4]:
print(word2idx)

{'<pad>': 0, 'to': 1, 'jane': 2, 'how': 3, '12': 4, 'were': 5, 'mushrooms': 6, '14': 7, 'adam': 8, 'mountains': 9, 'station': 10, 'sticks': 11, 'different': 12, 'eric': 13, '8': 14, 'oliver': 15, '3': 16, 'emma': 17, 'bridge': 18, '9': 19, 'entities': 20, '2': 21, 'claire': 22, 'many': 23, 'is': 24, 'picked': 25, 'park': 26, '5': 27, 'the': 28, 'feathers': 29, 'ruben': 30, 'flowers': 31, 'forest': 32, 'drop': 33, 'rocks': 34, '11': 35, 'shells': 36, 'beach': 37, 'from': 38, 'berries': 39, 'eve': 40, 'pick': 41, '16': 42, 'river': 43, '7': 44, 'school': 45, 'sophie': 46, 'dropped': 47, 'insects': 48, '4': 49, 'liam': 50, 'was': 51, 'town': 52, 'times': 53, 'carrying': 54, 'stadium': 55, 'leaves': 56, 'visited': 57, 'visit': 58, '6': 59, 'went': 60, 'in': 61, 'at': 62, 'eggs': 63, 'up': 64, 'objects': 65, 'total': 66, '10': 67, '13': 68, 'did': 69, '?': 70, '1': 71}


In [5]:
print(train_data[0])

(tensor([[ 15,  60,   1,  28,  45],
        [ 13,  60,   1,  28,  45],
        [ 15,  60,   1,  28,  37],
        [ 15,  25,  64,  71,  34],
        [ 13,  25,  64,  16,   6],
        [ 13,  60,   1,  28,  37],
        [ 13,  47,  16,   6,   0],
        [ 15,  60,   1,  28,  45],
        [ 13,  25,  64,  71,   6],
        [ 13,  60,   1,  28,  45],
        [ 15,  60,   1,  28,  37],
        [ 15,  25,  64,  21,   6],
        [ 15,  47,  71,   6,   0],
        [ 15,  60,   1,  28,  45],
        [ 13,  25,  64,  16,   6],
        [ 15,  25,  64,  21,  34],
        [ 15,  25,  64,  21,   6],
        [ 13,  25,  64,  21,  34],
        [ 15,  60,   1,  28,  37],
        [ 15,  25,  64,  16,   6]], device='cuda:0'), tensor([[  3,  23,  53,   5,  34,  25,  64,  38,  28,  37,  70]], device='cuda:0'), tensor([[ 1]], device='cuda:0'))


In [6]:
print(train_data[0][0].shape, train_data[0][1].shape, train_data[0][2].shape)

torch.Size([20, 5]) torch.Size([1, 11]) torch.Size([1, 1])


In [7]:
train_iterator = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_iterator = torch.utils.data.DataLoader(valid_data, batch_size=BATCH_SIZE)
test_iterator = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)

In [8]:
class MemoryNetwork(nn.Module):
    def __init__(self, vocab_size, emb_dim, sent_len, story_len, pos_enc, temp_enc, n_hops, device):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.sent_len = sent_len
        self.story_len = story_len
        self.pos_enc = pos_enc
        self.temp_enc = temp_enc
        self.n_hops = n_hops
        
        #input, query and output embeddings
        #for nn.ModuleList, see: 
        #  https://discuss.pytorch.org/t/when-should-i-use-nn-modulelist-and-when-should-i-use-nn-sequential/5463
        self.embeddings = nn.ModuleList([nn.Embedding(self.vocab_size, self.emb_dim, padding_idx=0) for _ in range(self.n_hops+1)])
        for e in self.embeddings:
            e.weight.data.normal_(0, 0.1)
            e.weight.data[0].fill_(0)
                
        #calculate position encoding
        if self.pos_enc:
            J = self.sent_len
            d = self.emb_dim
            self.l = torch.zeros(J, d).to(device)
            for j in range(1, J+1):
                for k in range(1, d+1):
                    self.l[j-1][k-1] = (1 - j/J) - (k/d) * (1 - 2*j/J)
            self.l = self.l.unsqueeze(0).repeat(self.story_len, 1, 1) # l = [story len, sent len, emb dim]
        
        #initialize temporal encoding parameters
        if self.temp_enc:
            self.T_A = nn.Parameter(torch.randn(self.story_len, self.emb_dim).normal_(0, 0.1))
            self.T_C = nn.Parameter(torch.randn(self.story_len, self.emb_dim).normal_(0, 0.1))
        
    def forward(self, S, Q, linear):
        
        # S = [bsz, story len, sent len]
        # Q = [bsz, q len]
        
        #make sure input is the correct size
        assert S.shape[1] == self.story_len and S.shape[2] == self.sent_len
        
        #embed the query 
        # B is the first embedding
        U = self.embeddings[0](Q) # U = [bsz, q len, emb dim]
        U = torch.sum(U, 1) # U = [bsz, emb dim]
        
        for k in range(self.n_hops):
        
            #embed the story
            # A is embedding k, A^k, where k is the current hop number
            M = self.embeddings[k](S) # M = [bsz, story len, sent_len, emb dim]

            #apply position encoding
            if self.pos_enc:
                l = self.l.unsqueeze(0).repeat(M.shape[0], 1, 1, 1) # l = [bsz, story len, sent len, emb dim]
                M *= l

            M = torch.sum(M, 2) # M = [bsz, story len, emb dim]

            #apply temporal encoding
            if self.temp_enc:
                T_A = self.T_A.unsqueeze(0).repeat(M.shape[0], 1, 1)
                M += T_A

            #calculate attention
            P = torch.bmm(M, U.unsqueeze(2)).squeeze(2) # P = [bsz, story len]
            if not linear:
                P = F.softmax(P, dim=1) # P = [bsz, story len]

            #output embedding of story
            # C is embedding k+1, A^(k+1), where k is the current hop number
            C = self.embeddings[k+1](S) # C = [bsz, story len, sent_len, emb dim]
            
            #apply position encoding
            if self.pos_enc:
                l = self.l.unsqueeze(0).repeat(C.shape[0], 1, 1, 1) # l = [bsz, story len, sent len, emb dim]
                C *= l
            
            C = torch.sum(C, 2) # C = [bsz, story len, emb dim]

            #apply temporal encoding
            if self.temp_enc:
                T_C = self.T_C.unsqueeze(0).repeat(C.shape[0], 1, 1)
                C += T_C

            #apply attention to output embedding
            O = torch.bmm(P.unsqueeze(1), C).squeeze(1) # O = [bsz, emb dim]
            
            #the next embedded query is the sum of the previous embedded query and the output
            U = U + O
                
        #get and reshape W
        # W is embedding K, A^K, where K is the total number of hops (can also get with self.embeddings[-1])
        W = self.embeddings[self.n_hops].weight.unsqueeze(0)
        W = W.repeat(U.shape[0], 1, 1) #W = [bsz, vocab size, emb dim]
               
        #get probability distribution over vocab
        A = torch.bmm(W, U.unsqueeze(2)).squeeze(2) # A = [bsz, vocab size]
        
        return A

In [9]:
model = MemoryNetwork(VOCAB_SIZE, EMB_DIM, SENT_LEN, STORY_LEN, POS_ENC, TEMP_ENC, N_HOPS, device).to(device)

In [10]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [11]:
def train(model, iterator, optimizer, criterion, linear):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for s, q, a in iterator:
        
        optimizer.zero_grad()
        
        predictions = model(s, q.squeeze(1), linear)

        loss = criterion(predictions, a.squeeze(1).squeeze(1))
        
        top_pred = predictions.max(1, keepdim=True)[1] 
        acc = (top_pred == a.squeeze(1)).sum().float()/predictions.shape[0]
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [12]:
def evaluate(model, iterator, optimizer, criterion, linear):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for s, q, a in iterator:

            predictions = model(s, q.squeeze(1), linear)

            loss = criterion(predictions, a.squeeze(1).squeeze(1))
            
            top_pred = predictions.max(1, keepdim=True)[1]
            acc = (top_pred == a.squeeze(1)).sum().float()/predictions.shape[0]
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [13]:
prev_valid_loss = float('inf')

linear = LS

for epoch in range(N_EPOCHS):
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, linear)
    valid_loss, valid_acc = evaluate(model, valid_iterator, optimizer, criterion, linear)
    
    print(f'Epoch: {epoch+1:03}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc*100:06.2f}%, ' \
          f'Val. Loss: {valid_loss:.3f}, Val. Acc: {valid_acc*100:06.2f}%, LR: {optimizer.param_groups[0]["lr"]:1e}, ' \
          f'Linear: {linear}')
        
    #scheduler.step()
        
    if LS and linear:
        if prev_valid_loss < valid_loss:
            linear = False
        prev_valid_loss = valid_loss

Epoch: 001, Train Loss: 1.771, Train Acc: 035.06%, Val. Loss: 1.454, Val. Acc: 040.10%, LR: 1.000000e-03, Linear: True
Epoch: 002, Train Loss: 1.405, Train Acc: 042.01%, Val. Loss: 1.409, Val. Acc: 042.15%, LR: 1.000000e-03, Linear: True
Epoch: 003, Train Loss: 1.367, Train Acc: 043.39%, Val. Loss: 1.361, Val. Acc: 043.52%, LR: 1.000000e-03, Linear: True
Epoch: 004, Train Loss: 1.352, Train Acc: 044.06%, Val. Loss: 1.367, Val. Acc: 042.67%, LR: 1.000000e-03, Linear: True
Epoch: 005, Train Loss: 1.548, Train Acc: 037.54%, Val. Loss: 1.409, Val. Acc: 040.23%, LR: 1.000000e-03, Linear: False
Epoch: 006, Train Loss: 1.294, Train Acc: 046.78%, Val. Loss: 1.210, Val. Acc: 050.36%, LR: 1.000000e-03, Linear: False
Epoch: 007, Train Loss: 1.154, Train Acc: 052.94%, Val. Loss: 1.132, Val. Acc: 053.85%, LR: 1.000000e-03, Linear: False
Epoch: 008, Train Loss: 1.088, Train Acc: 055.55%, Val. Loss: 1.079, Val. Acc: 055.72%, LR: 1.000000e-03, Linear: False
Epoch: 009, Train Loss: 1.059, Train Acc: 05