In [15]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import re
import spacy
import numpy as np
import gzip
import gensim.downloader
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
from model import SimpleAttention

In [16]:
class WordDataset(Dataset):
    
    def __init__(self, sfile):
        d = pickle.load(open(sfile, 'rb'))
        
        self.sentences = d['sentences']
        self.indices = d['sub_indices']
        
        for sent in self.sentences:
            for i, word in enumerate(sent):
                sent[i] = word.lower().replace("'", "")
                if sent[i] == '':
                    del sent[i]
        
        self.glove = gensim.downloader.load('glove-wiki-gigaword-200')
        
        self.sentence_embeddings = []
        
        for sent in self.sentences:
            temp = []
            for word in sent:
                if word in self.glove:
                    temp.append(self.glove[word])
            self.sentence_embeddings.append(np.array(temp))
            
    def __len__(self):
        return len(self.sentence_embeddings)
    
    def __getitem__(self, idx):
        return (self.sentences[idx], self.sentence_embeddings[idx], self.indices[idx])

In [19]:
train_dataset = WordDataset('subdata.pkl')

In [72]:
# device = torch.device('cuda')
model = SimpleAttention(200, 20)
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(100):
    total_loss = 0
    accum = 0
    for i in range(len(train_dataset)):
        text, embed, ind = train_dataset[i]
        ind = torch.Tensor([ind]).long()
        if ind >= embed.shape[0]:
            continue
        embed = torch.Tensor(embed).unsqueeze(0)

        predv, sim = model(embed)
        loss = loss_fn(sim, ind)
        accum += loss
        total_loss += float(loss)
        
        if i % 100 == 0 and i > 0:
            accum.backward()
            optim.step()
            optim.zero_grad()
            accum = 0
    print(f'Total loss {epoch}: total_loss')
    total_loss = 0

10491.800028324127
10289.15780210495
10062.086075663567
9835.252136707306
9624.863267838955
9434.525323271751
9245.603696882725
9038.808286100626
8899.145880550146
8800.33431956172
8730.941065073013
8679.997298270464
8637.81130796671
8598.927635222673
8563.787618815899
8532.724863529205
8503.52269089222
8476.552208989859
8452.615196824074
8429.00962227583
8406.9192443192
8386.21566566825
8366.248629689217
8347.198793113232
8328.732744246721
8311.244562476873
8294.563530921936
8279.147120296955
8264.529611945152
8251.712412387133
8238.26609635353
8228.707840234041
8216.525311291218
8207.523184090853
8198.043177068233
8189.5745559334755
8181.149402856827
8174.284198731184
8165.919034898281
8158.627004683018
8151.622816860676
8145.710460275412
8138.279657155275
8132.270674645901
8125.916241884232
8120.878903865814
8114.144846647978
8108.331323415041
8102.998427391052
8098.113842129707
8092.157411187887
8087.759864360094
8082.242722094059
8077.952939063311
8072.407413512468
8068.3377857506

In [73]:
test_dataset = WordDataset('subdata_test.pkl')

In [75]:
correct = 0
for i in range(len(test_dataset)):
    text, embed, ind = test_dataset[i]
    embed = torch.Tensor(embed).unsqueeze(0)
    predv, sim = model(embed)
    predi = sim.max(dim=1).indices[0]
    correct += predi == ind
print(f'Accuracy: {correct/len(test_dataset)}')

Accuracy: 0.32690125703811646
