In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import os

DATA_DIR = 'content/drive/My\ Drive/proto_data/'
PICKLED_DIR = os.path.join(DATA_DIR, 'pickled/')
#CONLLU_DIR = os.path.join(DATA_DIR, 'WSJ_conllus/')
#MODEL_DIR = '../saved_models/'

PROTO_TSV = os.path.join(DATA_DIR, 'protoroles_eng_pb_08302015.tsv')
#GLOVE_FILE = {'100': os.path.join(DATA_DIR, 'glove.6B.100d.txt') }

SPLITS = ['train', 'dev', 'test'] 

PROPERTIES = ['instigation', 'volition', 'awareness', 'sentient',
'exists_as_physical', 'existed_before', 'existed_during', 'existed_after',
'created', 'destroyed', 'predicate_changed_argument', 'change_of_state', 
'changes_possession', 'change_of_location', 'stationary', 'location_of_event', 
'makes_physical_contact', 'manipulated_by_another']

PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'


In [0]:
from collections import defaultdict
import os
import pickle
from tqdm import tqdm
import random
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW



In [7]:
def train_lstm(args, model, X, y):
    epochs = args['epochs']
    batch_size = args['batch_size']
    lr = args['lr']


    # Data loaders
    loader_train = data_loader(X['train'], y['train'],
            batch_size=batch_size, shuffle_idx=True)
    n_train_batches = math.ceil(len(X['train']) / batch_size)

    # Optimizer
    opt = Adam(model.parameters(), lr=lr, betas=[0.9, 0.9])

    # Train loop
    try:
        for e in range(epochs):
            for b in tqdm(
                    range(n_train_batches), 
                    ascii=True, 
                    desc=f'Epoch {e+1}/{epochs} progress', 
                    ncols=80):
                opt.zero_grad()
                sents, sent_lens, preds, heads, labels = next(loader_train)
                logits = model(sents, sent_lens, preds, heads)
                loss = bce_loss(logits, labels)
                loss.backward()
                opt.step()

    except KeyboardInterrupt:
        pass
    # End of train loop
    return


def bce_loss(logits, labels):
    # Expected labels : (B, num_properties)
    loss = F.binary_cross_entropy_with_logits(logits, labels)
    return loss


def data_loader(X, y, batch_size=None, shuffle_idx=False):
    data = list(zip(X, y))
    idx = list(range(len(data)))
    while True:
        if shuffle_idx:
            random.shuffle(idx) # In-place shuffle
        
        for span in idx_spans(idx, batch_size):
            batch = [data[i] for i in span]
            yield prepare_batch(batch)


def idx_spans(idx, span_size):
    for i in range(0, len(idx), span_size):
        yield idx[i:i+span_size]


def prepare_batch(batch):
    # batch[i] = X, y
    batch_size = len(batch)
    sent_lens = torch.LongTensor([len(x[0][0]) for x in batch])
    max_length = torch.max(sent_lens).item()
    n_properties = len(batch[0][1])

    # Zero is padding index
    sents = torch.zeros((batch_size, max_length)).long()
    preds = torch.zeros(batch_size).long()
    heads = torch.zeros(batch_size).long()
    labels = torch.zeros(batch_size, n_properties)

    for i, (X_batch, y_batch) in enumerate(batch):
        sent, (pred_idx, head_idx) = X_batch
        sents[i,:len(sent)] = torch.LongTensor(sent)
        preds[i] = pred_idx
        heads[i] = head_idx
        labels[i] = torch.tensor(y_batch)

    return sents, sent_lens, preds, heads, labels

SyntaxError: ignored

In [0]:
def get_data(args):
    df = pd.read_csv(PROTO_TSV, sep='\t')

    # Sentences
    sent_ids = set(df['Sentence.ID'].tolist())
    path = os.path.join(PICKLED_DIR, 'sents.pkl')
    sents = None
    with open(path, 'rb') as f:
      sents = pickle.load(f)

    # Dependency data
    path = os.path.join(PICKLED_DIR, 'dependencies.pkl')
    with open(path, 'wb') as f:
      #sent_ids = list(sents['raw'].keys())
      deps, deps_just_tokens = data_utils.get_dependencies(sent_ids)
      pickle.dump((deps, deps_just_tokens), f)
    sents['dependencies'] = deps
    sents['deps_just_tokens'] = deps_just_tokens

    # Instances
    path = os.path.join(PICKLED_DIR, 'instances.pkl')
    proto_instances = None
    possible = None # Data to compare to SPRL paper
    with open(path, 'rb') as f:
      proto_instances, possible = pickle.load(f)

    # Word embedding data
    w2e = None
    path = os.path.join(PICKLED_DIR, f'glove_{args.glove_d}.pkl')
    with open(path, 'rb') as f:
      w2e = pickle.load(f)

    w2i, i2w = None, None
    emb_np = None
    X, y = None, None
    dicts_path = os.path.join(PICKLED_DIR, 'dicts.pkl')
    with open(dicts_path, 'rb') as f:
        w2i, i2w = pickle.load(f)
    
    emb_np_path = os.path.join(PICKLED_DIR, 'emb_np.pkl')
    with open(emb_np_path, 'rb') as f:
        emb_np = pickle.load(f)
    
    lstm_data_path = os.path.join(PICKLED_DIR, 'lstm_data.pkl')
    with open(lstm_data_path, 'rb') as f:
        X, y = pickle.load(f)

    return {'df': df, 
            'proto_instances': proto_instances, 
            'possible': possible,
            'sents': sents,
            'w2e': w2e,
            'sent_ids': sent_ids,
            'lstm_data': (X,y),
            'dicts': (w2i, i2w),
            'emb_np': emb_np}


In [0]:
args = {
    'epochs': 3,
    'seed': 7,
    'lr': 1e-3,
    'batch_size': 10,
    'h_size': 100,
    'glove_d': 100
}

seed = args['seed']

random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)


data = 

w2i, i2w = data['dicts']
emb_np = data['emb_np']
X, y = data['lstm_data']

model = models.LSTM(
                vocab_size=len(w2i),
                emb_size=int(args.glove_d),
                h_size=args.h_size,
                padding_idx=w2i[PAD_TOKEN],
                emb_np=emb_np,
                properties=PROPERTIES)

train_lstm(args, model, X, y)