In [1]:
import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score

import torch
from torch import nn
from torch.autograd import gradcheck
from torch.utils.data import DataLoader

import utils
from utils import DataBuilder
from utils import jupyter_args

import DE_ATT

In [2]:
train = pd.read_csv('data/train.csv')
dev = pd.read_csv('data/dev.csv')

tokenizer = open('data/word2idx.pkl', 'rb')
tokenizer = pickle.load(tokenizer)

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pretrained_weights = open('data/pretrained_weights_42b.pkl', 'rb')
pretrained_weights = pickle.load(pretrained_weights)
pretrained_weights = torch.FloatTensor(list(pretrained_weights.values()))

In [3]:
args = jupyter_args(embed_dim=300, 
                    hidden_state=200,
                    a_length=15,
                    b_length=13,
                    num_classes=3,
                    batch_size=32,
                    pretrained_weights=pretrained_weights,
                    learning_rate=0.05,
                    max_grad_norm=5,
                    train_interval=1000,
                    epoch_num=10)

In [4]:
trainbulider = DataBuilder(train, 'x1', 'x2', args.a_length, args.b_length, tokenizer, use_char=False)
trainloader = DataLoader(trainbulider, args.batch_size, shuffle=False)

devbulider = DataBuilder(dev, 'x1', 'x2', args.a_length, args.b_length, tokenizer, use_char=False)
devloader = DataLoader(devbulider, args.batch_size, shuffle=False)

de_att = DE_ATT.model(args).to(device)
param = filter(lambda p: p.requires_grad, de_att.parameters())

lossfunc = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adagrad(param, args.learning_rate, weight_decay=1e-5)

In [5]:
for i in range(args.epoch_num):
    
    de_att.train()
    for k, (x1, x2, _, __, label) in enumerate(trainloader):
        
        y_pred = []
        y_true = []
        loss_data = 0
        batch_count = len(trainloader)

        x1, x2 = x1.to(device), x2.to(device)
        label = label.view(-1).to(device)

        optimizer.zero_grad()

        logits = de_att(x1, x2)
        loss = lossfunc(logits, label)
        loss.backward()
        
        ''' get param and grad norm '''
        grad_norm = 0.
        para_norm = 0.

        for m in de_att.modules():
            if isinstance(m, nn.Linear):
                grad_norm += m.weight.grad.data.norm() ** 2
                para_norm += m.weight.data.norm() ** 2
                if m.bias is True:
                    grad_norm += m.bias.grad.data.norm() ** 2
                    para_norm += m.bias.data.norm() ** 2

        grad_norm ** 0.5
        para_norm ** 0.5

        shrinkage = args.max_grad_norm / grad_norm
        
        if shrinkage < 1 :
            for m in de_att.modules():
                if isinstance(m, nn.Linear):
                    m.weight.grad.data = m.weight.grad.data * shrinkage
                    if m.bias is True:
                        m.bias.grad.data = m.bias.grad.data * shrinkage
        
        optimizer.step()
        
        loss_data += loss.item()
        predict = logits.max(1)[1]
        y_true = np.array(label)
        y_pred = np.array(predict)
        
        correct = (y_true == y_pred).sum()
        total = len(y_true)
        
        if (k + 1) % 1000 == 0:
            print('epoch %d, batches %d|%d, train-acc %.3f, loss %.3f, para-norm %.3f, grad-norm %.3f,' %
                            (i, k + 1, batch_count, correct / total, loss_data, para_norm, grad_norm))
        
    de_att.eval()
    dev_true, dev_pred, dev_loss = de_att.evaluate(de_att, devloader, lossfunc)
    dev_acc = accuracy_score(dev_true, dev_pred)
    print('epoch %d, dev_loss %.3f, dec_acc %.3f' % (i, dev_loss, dev_acc))

epoch 0, batches 1000|17168, train-acc 0.469, loss 0.951, para-norm 1625.679, grad-norm 0.433,
epoch 0, batches 2000|17168, train-acc 0.500, loss 0.935, para-norm 1820.652, grad-norm 0.417,


KeyboardInterrupt: 

In [None]:
epoch 0, batches 1000|17168, train-acc 0.469, loss 0.951, para-norm 1625.679, grad-norm 0.433,

1995it [00:46, 42.98it/s]

epoch 0, batches 2000|17168, train-acc 0.500, loss 0.935, para-norm 1820.652, grad-norm 0.417,

2998it [01:10, 42.52it/s]

epoch 0, batches 3000|17168, train-acc 0.625, loss 0.886, para-norm 1874.947, grad-norm 0.448,

3997it [01:33, 42.63it/s]

epoch 0, batches 4000|17168, train-acc 0.594, loss 0.881, para-norm 2009.919, grad-norm 1.403,

4997it [01:56, 42.77it/s]

epoch 0, batches 5000|17168, train-acc 0.625, loss 0.893, para-norm 2093.544, grad-norm 1.546,

5997it [02:20, 42.82it/s]

epoch 0, batches 6000|17168, train-acc 0.656, loss 0.859, para-norm 2144.974, grad-norm 0.512,

6998it [02:43, 42.69it/s]

epoch 0, batches 7000|17168, train-acc 0.531, loss 1.025, para-norm 2233.912, grad-norm 1.993,
