In [83]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import re
import spacy
import numpy as np
import gzip
import gensim.downloader
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
from model import SimpleAttention
from tqdm import tqdm
import time
import random

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [57]:
glove = gensim.downloader.load('glove-wiki-gigaword-200')

In [58]:
class WordDataset(Dataset):
    
    def __init__(self, sfile):
        print("Loading data")
        d = pickle.load(open(sfile, 'rb'))
        
        self.sentences = d['sentences']
        self.indices = d['sub_indices']
        
        for i, sent in enumerate(self.sentences):
            self.sentences[i] = [word.lower().replace("'", "") for word in sent if sent and sent != "''"]
        
        print("Downloading")
        
        self.sentence_embeddings = []
        
        print("Generating embeddings")
        for sent in tqdm(self.sentences):
            try:
                self.sentence_embeddings.append(glove[sent])
            except:
                continue
            
    def __len__(self):
        return len(self.sentence_embeddings)
    
    def __getitem__(self, idx):
        return (self.sentences[idx], self.sentence_embeddings[idx], self.indices[idx])

In [89]:
train_dataset = WordDataset('subdata_jpdy_train.pkl')

Loading data


  5%|▌         | 3780/70650 [00:00<00:01, 37789.76it/s]

Downloading
Generating embeddings


100%|██████████| 70650/70650 [00:01<00:00, 56613.38it/s]


In [91]:
test_dataset = WordDataset('subdata_jpdy_test.pkl')

 76%|███████▌  | 5956/7853 [00:00<00:00, 59537.06it/s]

Loading data
Downloading
Generating embeddings


100%|██████████| 7853/7853 [00:00<00:00, 58754.99it/s]


In [92]:
print(len(train_dataset))
print(len(test_dataset))

21832
2435


In [94]:
# device = torch.device('cuda')
device = torch.device('cpu')
model = SimpleAttention(200, 40, norm=False).to(device)
# model.load_state_dict(torch.load('model40.pth'))
# model.load_state_dict(torch.load('nonorm20.pth'))
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
accum_freq = 100
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optim, 30, gamma=0.7)
for epoch in range(200):
    start_time = time.time()
    total_loss = 0
    total_acc = 0
    accum = 0
    ind_order = list(range(len(train_dataset)))
    random.shuffle(ind_order)
    for i in ind_order:
        text, embed, ind = train_dataset[i]
        ind = torch.Tensor([ind]).long().to(device)
        if ind >= embed.shape[0]:
            continue
        embed = torch.Tensor(embed).unsqueeze(0).to(device)

        predv, sim = model(embed)
        loss = loss_fn(sim, ind)
        accum += loss
        total_loss += float(loss)
        
        acc = sim.max(dim=1).indices[0] == ind
        total_acc += float(acc)
        
        if i % accum_freq == 0 and i > 0:
            accum.backward()
            optim.step()
            optim.zero_grad()
            accum = 0
    correct = 0
    for i in range(len(test_dataset)):
        text, embed, ind = test_dataset[i]
        embed = torch.Tensor(embed).unsqueeze(0).to(device)
        predv, sim = model(embed)
        predi = sim.max(dim=1).indices[0].cpu()
        correct += predi == ind
    print(f'Total loss {epoch}: {total_loss/len(train_dataset)}, Test Accuracy: {correct/len(test_dataset)}, Accuracy: {total_acc/len(train_dataset)}, LR: {scheduler.get_lr()}, Time: {time.time()-start_time}')
    total_loss = 0
    scheduler.step()

Total loss 0: 2674.5068916485993, Test Accuracy: 0.08459959179162979, Accuracy: 0.07557713448149506, LR: [0.001], Time: 18.46752643585205


In [38]:
correct = 0
for i in range(len(test_dataset)):
    text, embed, ind = test_dataset[i]
    embed = torch.Tensor(embed).unsqueeze(0)
    predv, sim = model(embed)
    predi = sim.max(dim=1).indices[0]
    correct += predi == ind
print(f'Accuracy: {correct/len(test_dataset)}')

Accuracy: 0.2360953390598297


In [18]:
torch.save(model.state_dict(), 'nonorm40.pth')