Skip to content
Permalink
Browse files

Add support for BERT for relevance transfer (#19)

* Add Robust45 preprocessor for BERT
* Add support for BERT for relevance transfer
  • Loading branch information...
achyudh committed May 24, 2019
1 parent f975e84 commit 3cd54c2fccc4dd559d7bee74c680a48c3abfad93
@@ -4,73 +4,124 @@
import torch
import torch.nn.functional as F
from sklearn import metrics
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader
from tqdm import tqdm

from common.evaluators.evaluator import Evaluator
from datasets.bert_processors.robust45_processor import convert_examples_to_features
from utils.tokenization import BertTokenizer

# Suppress warnings from sklearn.metrics
warnings.filterwarnings('ignore')


class RelevanceTransferEvaluator(Evaluator):

def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device, keep_results=False):
super().__init__(dataset_cls, model, embedding, data_loader, batch_size, device, keep_results)
self.ignore_lengths = False
def __init__(self, model, config, **kwargs):
super().__init__(kwargs['dataset'], model, kwargs['embedding'], kwargs['data_loader'],
batch_size=config['batch_size'], device=config['device'])

if config['model'] in {'BERT-Base', 'BERT-Large'}:
variant = 'bert-large-uncased' if config['model'] == 'BERT-Large' else 'bert-base-uncased'
self.tokenizer = BertTokenizer.from_pretrained(variant, is_lowercase=config['is_lowercase'])
self.processor = kwargs['processor']
if config['split'] == 'test':
self.eval_examples = self.processor.get_test_examples(config['data_dir'], topic=config['topic'])
else:
self.eval_examples = self.processor.get_dev_examples(config['data_dir'], topic=config['topic'])

self.config = config
self.ignore_lengths = config['ignore_lengths']
self.y_target = None
self.y_pred = None
self.docid = None

def get_scores(self):
def get_scores(self, silent=False):
self.model.eval()
self.data_loader.init_epoch()
self.y_target = list()
self.y_pred = list()
self.docid = list()
total_loss = 0

if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
old_params = self.model.get_params()
self.model.load_ema_params()
if self.config['model'] in {'BERT-Base', 'BERT-Large'}:
eval_features = convert_examples_to_features(self.eval_examples, self.config['max_seq_length'], self.tokenizer)

for batch_idx, batch in enumerate(self.data_loader):
if hasattr(self.model, 'tar') and self.model.tar:
if self.ignore_lengths:
scores, rnn_outs = self.model(batch.text)
else:
scores, rnn_outs = self.model(batch.text[0], lengths=batch.text[1])
else:
if self.ignore_lengths:
scores = self.model(batch.text)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_document_ids = torch.tensor([f.guid for f in eval_features], dtype=torch.long)

eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_document_ids)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=self.config['batch_size'])

for input_ids, input_mask, segment_ids, label_ids, document_ids in tqdm(eval_dataloader, desc="Evaluating", disable=silent):
input_ids = input_ids.to(self.config['device'])
input_mask = input_mask.to(self.config['device'])
segment_ids = segment_ids.to(self.config['device'])
label_ids = label_ids.to(self.config['device'])

with torch.no_grad():
logits = torch.sigmoid(self.model(input_ids, segment_ids, input_mask)).squeeze(dim=1)

# Computing loss and storing predictions
self.docid.extend(document_ids.cpu().detach().numpy())
self.y_pred.extend(logits.cpu().detach().numpy())
self.y_target.extend(label_ids.cpu().detach().numpy())
loss = F.binary_cross_entropy(logits, label_ids.float())

if self.config['n_gpu'] > 1:
loss = loss.mean()
if self.config['gradient_accumulation_steps'] > 1:
loss = loss / self.config['gradient_accumulation_steps']
total_loss += loss.item()

else:
self.data_loader.init_epoch()

if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
old_params = self.model.get_params()
self.model.load_ema_params()

for batch in tqdm(self.data_loader, desc="Evaluating", disable=silent):
if hasattr(self.model, 'tar') and self.model.tar:
if self.ignore_lengths:
logits, rnn_outs = torch.sigmoid(self.model(batch.text)).squeeze(dim=1)
else:
logits, rnn_outs = torch.sigmoid(self.model(batch.text[0], lengths=batch.text[1])).squeeze(dim=1)
else:
scores = self.model(batch.text[0], lengths=batch.text[1])
if self.ignore_lengths:
logits = torch.sigmoid(self.model(batch.text)).squeeze(dim=1)
else:
logits = torch.sigmoid(self.model(batch.text[0], lengths=batch.text[1])).squeeze(dim=1)

total_loss += F.binary_cross_entropy(logits, batch.label.float()).item()
if hasattr(self.model, 'tar') and self.model.tar:
# Temporal activation regularization
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()

# Computing loss and storing predictions
predictions = torch.sigmoid(scores).squeeze(dim=1)
total_loss += F.binary_cross_entropy(predictions, batch.label.float()).item()
self.docid.extend(batch.docid.cpu().detach().numpy())
self.y_pred.extend(predictions.cpu().detach().numpy())
self.y_target.extend(batch.label.cpu().detach().numpy())
self.docid.extend(batch.docid.cpu().detach().numpy())
self.y_pred.extend(logits.cpu().detach().numpy())
self.y_target.extend(batch.label.cpu().detach().numpy())

if hasattr(self.model, 'tar') and self.model.tar:
# Temporal activation regularization
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
self.model.load_params(old_params)

predicted_labels = np.around(np.array(self.y_pred))
target_labels = np.array(self.y_target)
accuracy = metrics.accuracy_score(target_labels, predicted_labels)
average_precision = metrics.average_precision_score(target_labels, predicted_labels, average=None)
f1 = metrics.f1_score(target_labels, predicted_labels, average='macro')
avg_loss = total_loss / len(self.data_loader.dataset.examples)
avg_loss = total_loss / len(predicted_labels)

try:
precision = metrics.precision_score(target_labels, predicted_labels, average=None)[1]
except IndexError:
# Handle cases without positive labels
precision = 0

if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
self.model.load_params(old_params)

return [accuracy, precision, average_precision, f1, avg_loss], ['accuracy', 'precision', 'average_precision', 'f1', 'cross_entropy_loss']
return [accuracy, precision, average_precision, f1, avg_loss], \
['accuracy', 'precision', 'average_precision', 'f1', 'cross_entropy_loss']

0 comments on commit 3cd54c2

Please sign in to comment.
You can’t perform that action at this time.