Skip to content

Commit

Permalink
Add support for BERT for relevance transfer (#19)
Browse files Browse the repository at this point in the history
* Add Robust45 preprocessor for BERT
* Add support for BERT for relevance transfer
  • Loading branch information
achyudh committed May 24, 2019
1 parent f975e84 commit 3cd54c2
Show file tree
Hide file tree
Showing 5 changed files with 561 additions and 262 deletions.
119 changes: 85 additions & 34 deletions common/evaluators/relevance_transfer_evaluator.py
Expand Up @@ -4,73 +4,124 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sklearn import metrics from sklearn import metrics
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader
from tqdm import tqdm


from common.evaluators.evaluator import Evaluator from common.evaluators.evaluator import Evaluator
from datasets.bert_processors.robust45_processor import convert_examples_to_features
from utils.tokenization import BertTokenizer


# Suppress warnings from sklearn.metrics # Suppress warnings from sklearn.metrics
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')




class RelevanceTransferEvaluator(Evaluator): class RelevanceTransferEvaluator(Evaluator):


def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device, keep_results=False): def __init__(self, model, config, **kwargs):
super().__init__(dataset_cls, model, embedding, data_loader, batch_size, device, keep_results) super().__init__(kwargs['dataset'], model, kwargs['embedding'], kwargs['data_loader'],
self.ignore_lengths = False batch_size=config['batch_size'], device=config['device'])

if config['model'] in {'BERT-Base', 'BERT-Large'}:
variant = 'bert-large-uncased' if config['model'] == 'BERT-Large' else 'bert-base-uncased'
self.tokenizer = BertTokenizer.from_pretrained(variant, is_lowercase=config['is_lowercase'])
self.processor = kwargs['processor']
if config['split'] == 'test':
self.eval_examples = self.processor.get_test_examples(config['data_dir'], topic=config['topic'])
else:
self.eval_examples = self.processor.get_dev_examples(config['data_dir'], topic=config['topic'])

self.config = config
self.ignore_lengths = config['ignore_lengths']
self.y_target = None self.y_target = None
self.y_pred = None self.y_pred = None
self.docid = None self.docid = None


def get_scores(self): def get_scores(self, silent=False):
self.model.eval() self.model.eval()
self.data_loader.init_epoch()
self.y_target = list() self.y_target = list()
self.y_pred = list() self.y_pred = list()
self.docid = list() self.docid = list()
total_loss = 0 total_loss = 0


if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0: if self.config['model'] in {'BERT-Base', 'BERT-Large'}:
# Temporal averaging eval_features = convert_examples_to_features(self.eval_examples, self.config['max_seq_length'], self.tokenizer)
old_params = self.model.get_params()
self.model.load_ema_params()


for batch_idx, batch in enumerate(self.data_loader): all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
if hasattr(self.model, 'tar') and self.model.tar: all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
if self.ignore_lengths: all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
scores, rnn_outs = self.model(batch.text) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
else: all_document_ids = torch.tensor([f.guid for f in eval_features], dtype=torch.long)
scores, rnn_outs = self.model(batch.text[0], lengths=batch.text[1])
else: eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_document_ids)
if self.ignore_lengths: eval_sampler = SequentialSampler(eval_data)
scores = self.model(batch.text) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=self.config['batch_size'])

for input_ids, input_mask, segment_ids, label_ids, document_ids in tqdm(eval_dataloader, desc="Evaluating", disable=silent):
input_ids = input_ids.to(self.config['device'])
input_mask = input_mask.to(self.config['device'])
segment_ids = segment_ids.to(self.config['device'])
label_ids = label_ids.to(self.config['device'])

with torch.no_grad():
logits = torch.sigmoid(self.model(input_ids, segment_ids, input_mask)).squeeze(dim=1)

# Computing loss and storing predictions
self.docid.extend(document_ids.cpu().detach().numpy())
self.y_pred.extend(logits.cpu().detach().numpy())
self.y_target.extend(label_ids.cpu().detach().numpy())
loss = F.binary_cross_entropy(logits, label_ids.float())

if self.config['n_gpu'] > 1:
loss = loss.mean()
if self.config['gradient_accumulation_steps'] > 1:
loss = loss / self.config['gradient_accumulation_steps']
total_loss += loss.item()

else:
self.data_loader.init_epoch()

if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
old_params = self.model.get_params()
self.model.load_ema_params()

for batch in tqdm(self.data_loader, desc="Evaluating", disable=silent):
if hasattr(self.model, 'tar') and self.model.tar:
if self.ignore_lengths:
logits, rnn_outs = torch.sigmoid(self.model(batch.text)).squeeze(dim=1)
else:
logits, rnn_outs = torch.sigmoid(self.model(batch.text[0], lengths=batch.text[1])).squeeze(dim=1)
else: else:
scores = self.model(batch.text[0], lengths=batch.text[1]) if self.ignore_lengths:
logits = torch.sigmoid(self.model(batch.text)).squeeze(dim=1)
else:
logits = torch.sigmoid(self.model(batch.text[0], lengths=batch.text[1])).squeeze(dim=1)

total_loss += F.binary_cross_entropy(logits, batch.label.float()).item()
if hasattr(self.model, 'tar') and self.model.tar:
# Temporal activation regularization
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()


# Computing loss and storing predictions self.docid.extend(batch.docid.cpu().detach().numpy())
predictions = torch.sigmoid(scores).squeeze(dim=1) self.y_pred.extend(logits.cpu().detach().numpy())
total_loss += F.binary_cross_entropy(predictions, batch.label.float()).item() self.y_target.extend(batch.label.cpu().detach().numpy())
self.docid.extend(batch.docid.cpu().detach().numpy())
self.y_pred.extend(predictions.cpu().detach().numpy())
self.y_target.extend(batch.label.cpu().detach().numpy())


if hasattr(self.model, 'tar') and self.model.tar: if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal activation regularization # Temporal averaging
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean() self.model.load_params(old_params)


predicted_labels = np.around(np.array(self.y_pred)) predicted_labels = np.around(np.array(self.y_pred))
target_labels = np.array(self.y_target) target_labels = np.array(self.y_target)
accuracy = metrics.accuracy_score(target_labels, predicted_labels) accuracy = metrics.accuracy_score(target_labels, predicted_labels)
average_precision = metrics.average_precision_score(target_labels, predicted_labels, average=None) average_precision = metrics.average_precision_score(target_labels, predicted_labels, average=None)
f1 = metrics.f1_score(target_labels, predicted_labels, average='macro') f1 = metrics.f1_score(target_labels, predicted_labels, average='macro')
avg_loss = total_loss / len(self.data_loader.dataset.examples) avg_loss = total_loss / len(predicted_labels)


try: try:
precision = metrics.precision_score(target_labels, predicted_labels, average=None)[1] precision = metrics.precision_score(target_labels, predicted_labels, average=None)[1]
except IndexError: except IndexError:
# Handle cases without positive labels # Handle cases without positive labels
precision = 0 precision = 0


if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0: return [accuracy, precision, average_precision, f1, avg_loss], \
# Temporal averaging ['accuracy', 'precision', 'average_precision', 'f1', 'cross_entropy_loss']
self.model.load_params(old_params)

return [accuracy, precision, average_precision, f1, avg_loss], ['accuracy', 'precision', 'average_precision', 'f1', 'cross_entropy_loss']

0 comments on commit 3cd54c2

Please sign in to comment.