Skip to content
Permalink
Browse files

Add relevance transfer package (#14)

* Add TREC relevance datasets

* Add relevance transfer trainer and evaluator

* Add re-ranking module

* Add ImbalancedDatasetSampler

* Add relevance transfer package

* Fix import in classification trainer
  • Loading branch information...
achyudh committed Apr 20, 2019
1 parent 284e2dd commit 99a01c61384cac02f1b3dbfe2ec243be18abd812
@@ -1,4 +1,5 @@
from common.evaluators.classification_evaluator import ClassificationEvaluator
from common.evaluators.relevance_transfer_evaluator import RelevanceTransferEvaluator


class EvaluatorFactory(object):
@@ -9,7 +10,10 @@ class EvaluatorFactory(object):
'Reuters': ClassificationEvaluator,
'AAPD': ClassificationEvaluator,
'IMDB': ClassificationEvaluator,
'Yelp2014': ClassificationEvaluator
'Yelp2014': ClassificationEvaluator,
'Robust04': RelevanceTransferEvaluator,
'Robust05': RelevanceTransferEvaluator,
'Robust45': RelevanceTransferEvaluator
}

@staticmethod
@@ -1,9 +1,9 @@
import numpy as np
import torch
import torch.nn.functional as F
import numpy as np

from sklearn import metrics
from .evaluator import Evaluator

from common.evaluators.evaluator import Evaluator


class ClassificationEvaluator(Evaluator):
@@ -0,0 +1,76 @@
import warnings

import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics

from common.evaluators.evaluator import Evaluator

# Suppress warnings from sklearn.metrics
warnings.filterwarnings('ignore')


class RelevanceTransferEvaluator(Evaluator):

def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device, keep_results=False):
super().__init__(dataset_cls, model, embedding, data_loader, batch_size, device, keep_results)
self.ignore_lengths = False
self.y_target = None
self.y_pred = None
self.docid = None

def get_scores(self):
self.model.eval()
self.data_loader.init_epoch()
self.y_target = list()
self.y_pred = list()
self.docid = list()
total_loss = 0

if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
old_params = self.model.get_params()
self.model.load_ema_params()

for batch_idx, batch in enumerate(self.data_loader):
if hasattr(self.model, 'tar') and self.model.tar:
if self.ignore_lengths:
scores, rnn_outs = self.model(batch.text)
else:
scores, rnn_outs = self.model(batch.text[0], lengths=batch.text[1])
else:
if self.ignore_lengths:
scores = self.model(batch.text)
else:
scores = self.model(batch.text[0], lengths=batch.text[1])

# Computing loss and storing predictions
predictions = torch.sigmoid(scores).squeeze(dim=1)
total_loss += F.binary_cross_entropy(predictions, batch.label.float()).item()
self.docid.extend(batch.docid.cpu().detach().numpy())
self.y_pred.extend(predictions.cpu().detach().numpy())
self.y_target.extend(batch.label.cpu().detach().numpy())

if hasattr(self.model, 'tar') and self.model.tar:
# Temporal activation regularization
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()

predicted_labels = np.around(np.array(self.y_pred))
target_labels = np.array(self.y_target)
accuracy = metrics.accuracy_score(target_labels, predicted_labels)
average_precision = metrics.average_precision_score(target_labels, predicted_labels, average=None)
f1 = metrics.f1_score(target_labels, predicted_labels, average='macro')
avg_loss = total_loss / len(self.data_loader.dataset.examples)

try:
precision = metrics.precision_score(target_labels, predicted_labels, average=None)[1]
except IndexError:
# Handle cases without positive labels
precision = 0

if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
self.model.load_params(old_params)

return [accuracy, precision, average_precision, f1, avg_loss], ['accuracy', 'precision', 'average_precision', 'f1', 'cross_entropy_loss']
@@ -1,4 +1,5 @@
from common.trainers.classification_trainer import ClassificationTrainer
from common.trainers.relevance_transfer_trainer import RelevanceTransferTrainer


class TrainerFactory(object):
@@ -9,7 +10,10 @@ class TrainerFactory(object):
'Reuters': ClassificationTrainer,
'AAPD': ClassificationTrainer,
'IMDB': ClassificationTrainer,
'Yelp2014': ClassificationTrainer
'Yelp2014': ClassificationTrainer,
'Robust04': RelevanceTransferTrainer,
'Robust05': RelevanceTransferTrainer,
'Robust45': RelevanceTransferTrainer,
}

@staticmethod
@@ -1,8 +1,8 @@
import datetime
import os
import time

import datetime
import numpy as np
import os
import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
@@ -0,0 +1,133 @@
import datetime
import os
import time

import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from tqdm import trange, tqdm

from common.trainers.trainer import Trainer
from tasks.relevance_transfer.resample import ImbalancedDatasetSampler


class RelevanceTransferTrainer(Trainer):

def __init__(self, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator):
super(RelevanceTransferTrainer, self).__init__(model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator)
self.config = trainer_config
self.early_stop = False
self.best_dev_ap = 0
self.iterations = 0
self.iters_not_improved = 0
self.start = None

self.log_header = 'Epoch Iteration Progress Dev/Acc. Dev/Pr. Dev/AP. Dev/F1 Dev/Loss'
self.log_template = ' '.join('{:>5.0f},{:>9.0f},{:>6.0f}/{:<5.0f} {:>6.4f},{:>8.4f},{:8.4f},{:8.4f},{:10.4f}'.split(','))

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + timestamp)
self.snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, '%s.pt' % timestamp)

def train_epoch(self, epoch):
self.train_loader.init_epoch()
n_correct, n_total = 0, 0

for batch_idx, batch in enumerate(tqdm(self.train_loader, desc="Training")):
self.iterations += 1
self.model.train()
self.optimizer.zero_grad()

# Clip gradients to address exploding gradients in LSTM
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 25.0)

# Randomly sample equal number of positive and negative documents
if 'ignore_lengths' in self.config and self.config['ignore_lengths']:
if 'resample' in self.config and self.config['resample']:
indices = ImbalancedDatasetSampler(batch.text, batch.label).get_indices()
batch_text = batch.text[indices]
batch_label = batch.label[indices]
else:
batch_text = batch.text
batch_label = batch.label
else:
if 'resample' in self.config and self.config['resample']:
indices = ImbalancedDatasetSampler(batch.text[0], batch.label).get_indices()
batch_text = batch.text[0][indices]
batch_lengths = batch.text[1][indices]
batch_label = batch.label
else:
batch_text = batch.text[0]
batch_lengths = batch.text[1]
batch_label = batch.label

if hasattr(self.model, 'tar') and self.model.tar:
if 'ignore_lengths' in self.config and self.config['ignore_lengths']:
scores, rnn_outs = self.model(batch_text)
else:
scores, rnn_outs = self.model(batch_text, lengths=batch_lengths)
else:
if 'ignore_lengths' in self.config and self.config['ignore_lengths']:
scores = self.model(batch_text)
else:
scores = self.model(batch_text, lengths=batch_lengths)

# Computing accuracy and loss
predictions = torch.sigmoid(scores).squeeze(dim=1)
for tensor1, tensor2 in zip(predictions.round(), batch_label):
try:
if int(tensor1.item()) == int(tensor2.item()):
n_correct += 1
except ValueError:
# Ignore NaN/Inf values
pass

loss = F.binary_cross_entropy(predictions, batch_label.float())

if hasattr(self.model, 'tar') and self.model.tar:
loss = loss + (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()

n_total += batch.batch_size
train_acc = n_correct / n_total
loss.backward()
self.optimizer.step()

if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
self.model.update_ema()

if self.iterations % self.log_interval == 1:
niter = epoch * len(self.train_loader) + batch_idx
self.writer.add_scalar('Train/Loss', loss.data.item(), niter)
self.writer.add_scalar('Train/Accuracy', train_acc, niter)

def train(self, epochs):
self.start = time.time()
# model_outfile is actually a directory, using model_outfile to conform to Trainer naming convention
os.makedirs(self.model_outfile, exist_ok=True)
os.makedirs(os.path.join(self.model_outfile, self.train_loader.dataset.NAME), exist_ok=True)

for epoch in trange(1, epochs + 1, desc="Epoch"):
self.train_epoch(epoch)

# Evaluate performance on validation set
dev_acc, dev_precision, dev_ap, dev_f1, dev_loss = self.dev_evaluator.get_scores()[0]
self.writer.add_scalar('Dev/Loss', dev_loss, epoch)
self.writer.add_scalar('Dev/Accuracy', dev_acc, epoch)
self.writer.add_scalar('Dev/Precision', dev_precision, epoch)
self.writer.add_scalar('Dev/AP', dev_ap, epoch)
tqdm.write(self.log_header)
tqdm.write(self.log_template.format(epoch, self.iterations, epoch + 1, epochs,
dev_acc, dev_precision, dev_ap, dev_f1, dev_loss))

# Update validation results
if dev_f1 > self.best_dev_ap:
self.iters_not_improved = 0
self.best_dev_ap = dev_f1
torch.save(self.model, self.snapshot_path)
else:
self.iters_not_improved += 1
if self.iters_not_improved >= self.patience:
self.early_stop = True
tqdm.write("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, self.best_dev_ap))
break
@@ -0,0 +1,65 @@
import csv
import os
import sys

import torch
from torchtext.data import NestedField, Field, TabularDataset
from torchtext.data.iterator import BucketIterator
from torchtext.vocab import Vectors

from datasets.robust45 import clean_string, split_sents, process_docids, process_labels

csv.field_size_limit(sys.maxsize)


class Robust04(TabularDataset):
NAME = 'Robust04'
NUM_CLASSES = 2
TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)
DOCID_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_docids)
TOPICS = ['307', '310', '321', '325', '330', '336', '341', '344', '345', '347', '350', '353', '354', '355', '356',
'362', '363', '367', '372', '375', '378', '379', '389', '393', '394', '397', '399', '400', '404', '408',
'414', '416', '419', '422', '423', '426', '427', '433', '435', '436', '439', '442', '443', '445', '614',
'620', '626', '646', '677', '690']

@staticmethod
def sort_key(ex):
return len(ex.text)

@classmethod
def splits(cls, path, train, validation, test, **kwargs):
return super(Robust04, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('docid', cls.DOCID_FIELD), ('text', cls.TEXT_FIELD)]
)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, topic, batch_size=64, shuffle=True, device=0,
vectors=None, unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param vectors_name: name of word vectors file
:param vectors_cache: path to directory containing word vectors file
:param topic: topic from which articles should be fetched
:param batch_size: batch size
:param device: GPU device
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes
:param unk_init: function used to generate vector for OOV words
:return:
"""
if vectors is None:
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init)

train_path = os.path.join('TREC', 'robust04_train_%s.tsv' % topic)
dev_path = os.path.join('TREC', 'robust04_dev_%s.tsv' % topic)
test_path = os.path.join('TREC', 'core17_10k_%s.tsv' % topic)
train, val, test = cls.splits(path, train=train_path, validation=dev_path, test=test_path)
cls.TEXT_FIELD.build_vocab(train, val, test, vectors=vectors)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)


class Robust04Hierarchical(Robust04):
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)
Oops, something went wrong.

0 comments on commit 99a01c6

Please sign in to comment.
You can’t perform that action at this time.