<h>Importing Required Libraries<h>

In [1]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torch import nn, optim
import os
import json
import numpy as np
import math
import sys
import random
import fewrel_dataset
import sentence_encoder
import siamese_model

<h>Make sure you run the <b>download_pretrain.sh<b> script, which will download the pretrained GloVe files, containg the embeddings for the tokens used in this dataset<h>

In [2]:
def collate_fn(data):
    data = data
    batch_support = {'word': [], 'pos1': [], 'pos2': [], 'mask': []}
    batch_query = {'word': [], 'pos1': [], 'pos2': [], 'mask': []}
    batch_label = []
    support_sets, query_sets, query_labels = zip(*data)
    for i in range(len(support_sets)):
        for k in support_sets[i]:
            batch_support[k] += support_sets[i][k]
        for k in query_sets[i]:
            batch_query[k] += query_sets[i][k]
        batch_label += query_labels[i]
    for k in batch_support:
        batch_support[k] = torch.stack(batch_support[k], 0)
    for k in batch_query:
        batch_query[k] = torch.stack(batch_query[k], 0)
    batch_label = torch.tensor(batch_label)
    return batch_support, batch_query, batch_label

def get_loader(name, encoder, N, K, Q, batch_size, 
        num_workers=0, collate_fn=collate_fn, na_rate=0, root='./data', testing=False):
    dataset = fewrel_dataset.FewRelDataset(name, encoder, N, K, Q, na_rate, root, testing)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=num_workers,
            collate_fn=collate_fn)
    return iter(data_loader)

In [3]:
def item(x):
    '''
    PyTorch before and after 0.4
    '''
    torch_version = torch.__version__.split('.')
    if int(torch_version[0]) == 0 and int(torch_version[1]) < 4:
        return x[0]
    else:
        return x.item()

In [4]:
trainN = 10
batch_size = 5
max_length = 64
N = 5
K = 5
Q = 5

glove_mat = np.load('./pretrain/glove/glove.6B.50d_mat.npy')
glove_word2id = json.load(open('./pretrain/glove/glove.6B.50d_word2id.json'))

sentence_encoder = sentence_encoder.CNNSentenceEncoder(glove_mat, glove_word2id, max_length)

In [5]:
train_data_loader = get_loader('train_wiki', sentence_encoder, N=trainN, K=K, Q=Q, na_rate=0, batch_size=batch_size, testing=False)

In [6]:
def eval(model, N, K, Q, eval_iter, na_rate=0): 
    eval_dataset = get_loader('val_wiki', sentence_encoder,
        N=N, K=K, Q=Q, na_rate=0, batch_size=batch_size)

    iter_right = 0.0
    iter_sample = 0.0
    with torch.no_grad():
        for it in range(eval_iter):
            support, query, label = next(eval_dataset)

            logits, pred = model(support, query, N, K, Q * N + Q * na_rate)

        right = model.accuracy(pred, label)
        iter_right += item(right.data)
        iter_sample += 1

        sys.stdout.write('[EVAL] step: {0:4} | accuracy: {1:3.2f}%'.format(it + 1, 100 * iter_right / iter_sample) + '\r')
        sys.stdout.flush()
    return iter_right / iter_sample

In [None]:
train_iter = 10000
val_step = 200
val_iter = 100
hidden_size = 230
learning_rate = 1e-1
weight_decay = 1e-5
using_checkpoint = False


print('Start Training...')
best_acc = 0
iter_loss = 0.0
iter_loss_dis = 0.0
iter_right = 0.0
iter_right_dis = 0.0
iter_sample = 0.0

model = siamese_model.Siamese(sentence_encoder, hidden_size=hidden_size, dropout=0.0)
optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20000)

if using_checkpoint:
    state_dict = torch.load('checkpoint/siamese-cnn-train_wiki-val_wiki-5-5-siamese-cnn-train_wiki-val_wiki-5-5.pth.tar')['state_dict']
    own_state = model.state_dict()
    for name, param in state_dict.items():
        if name not in own_state:
            print('ignore {}'.format(name))
            continue
        print('load {} from {}'.format(name, 'siamese-cnn-train_wiki-val_wiki-5-5-siamese-cnn-train_wiki-val_wiki-5-5'))
        own_state[name].copy_(param)

model.train()

for it in range(train_iter):
    support, query, label = next(train_data_loader)
    logits, pred = model(support, query, trainN, K, Q * trainN)
    

    loss = model.loss(logits, label) / float(1)
    right = model.accuracy(pred, label)

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 10)

    iter_loss += item(loss.data)
    iter_right += item(right.data)
    iter_sample += 1

    print('step: {0:4} | loss: {1:2.6f}, accuracy: {2:3.2f}%'.format(it + 1, iter_loss / iter_sample, 100 * iter_right / iter_sample) + '\r')

    if (it + 1) % val_step == 0:
        acc = eval(model, N, K, Q, val_iter)
        model.train()
        if acc > best_acc:
            print('Best checkpoint')
            torch.save({'state_dict': model.state_dict()}, 'checkpoint/siamese-cnn-train_wiki-val_wiki-5-5-siamese-cnn-train_wiki-val_wiki-5-5.pth.tar')
            best_acc = acc
        iter_loss = 0.
        iter_loss_dis = 0.
        iter_right = 0.
        iter_right_dis = 0.
        iter_sample = 0.

    print("\n####################\n")
    print("Finish training ")


In [None]:
eval_iter = 100
eval_dataset = get_loader('val_wiki', sentence_encoder, N=N, K=K, Q=Q, na_rate=0, batch_size=batch_size, testing=False)
model.eval()

if using_checkpoint:
    ckpt = 'checkpoint/siamese-cnn-train_wiki-val_wiki-5-5-siamese-cnn-train_wiki-val_wiki-5-5.pth.tar'
    state_dict = torch.load(ckpt)['state_dict']
    own_state = model.state_dict()
    for name, param in state_dict.items():
        if name not in own_state:
            continue
        own_state[name].copy_(param)

iter_right = 0.0
iter_sample = 0.0
with torch.no_grad():
    for it in range(eval_iter):
        support, query, label = next(eval_dataset)
        logits, pred = model(support, query, N, K, Q * N)

    right = model.accuracy(pred, label)
    iter_right += item(right.data)
    iter_sample += 1

    print('[EVAL] step: {0:4} | accuracy: {1:3.2f}%'.format(it + 1, 100 * iter_right / iter_sample) + '\r')
    print("")

print("RESULT: %.2f" % (iter_right / iter_sample * 100))

[EVAL] step:  100 | accuracy: 77.60%
RESULT: 77.60


In [None]:
json_file = open('data/pid2name.json')
id2name = json.load(json_file)

if using_checkpoint:
    ckpt = 'checkpoint/siamese-cnn-train_wiki-val_wiki-5-5-siamese-cnn-train_wiki-val_wiki-5-5.pth.tar'
    state_dict = torch.load(ckpt)['state_dict']
    own_state = model.state_dict()
    for name, param in state_dict.items():
        if name not in own_state:
            continue
        own_state[name].copy_(param)

test_data_loader = get_loader('test_instance', sentence_encoder,
            N=5, K=1, Q=1, na_rate=0, batch_size=1, testing=True)

support, query, label = next(test_data_loader)
print('Label=', label)

logits, pred = model(support, query, 5, 1, 1)
print('Prediction=', pred)

mappings_file = open('class_mappings.json')
maps = json.load(mappings_file)
print('Actual Relation: ', id2name[maps[str(label.numpy()[0])]][0])
print('Predicted Relation: ', id2name[maps[str(pred.numpy()[0])]][0])

Label= tensor([3])
Prediction= tensor([0])
Actual Relation:  sport
Predicted Relation:  competition class
