In [0]:
from collections import defaultdict
import os
import pickle
from tqdm import tqdm
import random
import math

import numpy as np
from sklearn.metrics import cohen_kappa_score
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LogisticRegression
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW

import models
import data_utils
import evaluate


In [0]:
def train_lstm(args, model, X, y):
    epochs = args['epochs']
    batch_size = args['batch_size']
    lr = args['lr']


    # Data loaders
    loader_train = data_loader(X['train'], y['train'],
            batch_size=batch_size, shuffle_idx=True)
    n_train_batches = math.ceil(len(X['train']) / batch_size)

    # Optimizer
    opt = Adam(model.parameters(), lr=lr], betas=[0.9, 0.9])

    # Train loop
    try:
        for e in range(epochs):
            for b in tqdm(
                    range(n_train_batches), 
                    ascii=True, 
                    desc=f'Epoch {e+1}/{epochs} progress', 
                    ncols=80):
                opt.zero_grad()
                sents, sent_lens, preds, heads, labels = next(loader_train)
                logits = model(sents, sent_lens, preds, heads)
                loss = bce_loss(logits, labels)
                loss.backward()
                opt.step()

    except KeyboardInterrupt:
        pass
    # End of train loop
    return


def bce_loss(logits, labels):
    # Expected labels : (B, num_properties)
    loss = F.binary_cross_entropy_with_logits(logits, labels)
    return loss


def data_loader(X, y, batch_size=None, shuffle_idx=False):
    data = list(zip(X, y))
    idx = list(range(len(data)))
    while True:
        if shuffle_idx:
            random.shuffle(idx) # In-place shuffle
        
        for span in idx_spans(idx, batch_size):
            batch = [data[i] for i in span]
            yield prepare_batch(batch)


def idx_spans(idx, span_size):
    for i in range(0, len(idx), span_size):
        yield idx[i:i+span_size]


def prepare_batch(batch):
    # batch[i] = X, y
    batch_size = len(batch)
    sent_lens = torch.LongTensor([len(x[0][0]) for x in batch])
    max_length = torch.max(sent_lens).item()
    n_properties = len(batch[0][1])

    # Zero is padding index
    sents = torch.zeros((batch_size, max_length)).long()
    preds = torch.zeros(batch_size).long()
    heads = torch.zeros(batch_size).long()
    labels = torch.zeros(batch_size, n_properties)

    for i, (X_batch, y_batch) in enumerate(batch):
        sent, (pred_idx, head_idx) = X_batch
        sents[i,:len(sent)] = torch.LongTensor(sent)
        preds[i] = pred_idx
        heads[i] = head_idx
        labels[i] = torch.tensor(y_batch)

    return sents, sent_lens, preds, heads, labels

In [0]:
args = {
    'epochs': 3,
    'seed': 7,
    
}

seed = args['seed']
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

w2i, i2w = data['dicts']
emb_np = data['emb_np']
X, y = data['lstm_data']

model = models.LSTM(
                vocab_size=len(w2i),
                emb_size=int(args.glove_d),
                h_size=args.h_size,
                padding_idx=w2i[PAD_TOKEN],
                emb_np=emb_np,
                properties=PROPERTIES)

train_lstm(model, X, y)