Skip to content
Permalink
Browse files

Add hierarchical BERT for relevance transfer (#27)

* Remove tensorboardX logging

* Add hierarchical BERT for relevance transfer
  • Loading branch information...
achyudh committed Jul 7, 2019
1 parent 7eb7c89 commit f2feae865590b196baf0becb2107baa31093bb45
@@ -9,6 +9,7 @@

from common.evaluators.evaluator import Evaluator
from datasets.bert_processors.robust45_processor import convert_examples_to_features
from utils.preprocessing import pad_input_matrix
from utils.tokenization import BertTokenizer

# Suppress warnings from sklearn.metrics
@@ -21,7 +22,7 @@ def __init__(self, model, config, **kwargs):
super().__init__(kwargs['dataset'], model, kwargs['embedding'], kwargs['data_loader'],
batch_size=config['batch_size'], device=config['device'])

if config['model'] in {'BERT-Base', 'BERT-Large'}:
if config['model'] in {'BERT-Base', 'BERT-Large', 'HBERT-Base', 'HBERT-Large'}:
variant = 'bert-large-uncased' if config['model'] == 'BERT-Large' else 'bert-base-uncased'
self.tokenizer = BertTokenizer.from_pretrained(variant, is_lowercase=config['is_lowercase'])
self.processor = kwargs['processor']
@@ -43,16 +44,30 @@ def get_scores(self, silent=False):
self.docid = list()
total_loss = 0

if self.config['model'] in {'BERT-Base', 'BERT-Large'}:
eval_features = convert_examples_to_features(self.eval_examples, self.config['max_seq_length'], self.tokenizer)

all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_document_ids = torch.tensor([f.guid for f in eval_features], dtype=torch.long)

eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_document_ids)
if self.config['model'] in {'BERT-Base', 'BERT-Large', 'HBERT-Base', 'HBERT-Large'}:
eval_features = convert_examples_to_features(
self.eval_examples,
self.config['max_seq_length'],
self.tokenizer,
self.config['is_hierarchical']
)

unpadded_input_ids = [f.input_ids for f in eval_features]
unpadded_input_mask = [f.input_mask for f in eval_features]
unpadded_segment_ids = [f.segment_ids for f in eval_features]

if self.config['is_hierarchical']:
pad_input_matrix(unpadded_input_ids, self.config['max_doc_length'])
pad_input_matrix(unpadded_input_mask, self.config['max_doc_length'])
pad_input_matrix(unpadded_segment_ids, self.config['max_doc_length'])

padded_input_ids = torch.tensor(unpadded_input_ids, dtype=torch.long)
padded_input_mask = torch.tensor(unpadded_input_mask, dtype=torch.long)
padded_segment_ids = torch.tensor(unpadded_segment_ids, dtype=torch.long)
label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
document_ids = torch.tensor([f.guid for f in eval_features], dtype=torch.long)

eval_data = TensorDataset(padded_input_ids, padded_input_mask, padded_segment_ids, label_ids, document_ids)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=self.config['batch_size'])

@@ -5,7 +5,6 @@
import numpy as np
import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter

from common.trainers.trainer import Trainer

@@ -26,7 +25,6 @@ def __init__(self, model, embedding, train_loader, trainer_config, train_evaluat
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.4f},{:>8.4f},{:8.4f},{:12.4f},{:12.4f}'.split(','))

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + timestamp)
self.snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, '%s.pt' % timestamp)

def train_epoch(self, epoch):
@@ -75,8 +73,6 @@ def train_epoch(self, epoch):

if self.iterations % self.log_interval == 1:
niter = epoch * len(self.train_loader) + batch_idx
self.writer.add_scalar('Train/Loss', loss.data.item(), niter)
self.writer.add_scalar('Train/Accuracy', train_acc, niter)
print(self.log_template.format(time.time() - self.start, epoch, self.iterations, 1 + batch_idx,
len(self.train_loader), 100.0 * (1 + batch_idx) / len(self.train_loader),
loss.item(), train_acc))
@@ -94,11 +90,6 @@ def train(self, epochs):

# Evaluate performance on validation set
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss = self.dev_evaluator.get_scores()[0]
self.writer.add_scalar('Dev/Loss', dev_loss, epoch)
self.writer.add_scalar('Dev/Accuracy', dev_acc, epoch)
self.writer.add_scalar('Dev/Precision', dev_precision, epoch)
self.writer.add_scalar('Dev/Recall', dev_recall, epoch)
self.writer.add_scalar('Dev/F-measure', dev_f1, epoch)

# Print validation results
print('\n' + dev_header)
@@ -3,21 +3,21 @@

import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from torch.utils.data import TensorDataset, RandomSampler, DataLoader
from tqdm import trange, tqdm

from common.trainers.trainer import Trainer
from datasets.bert_processors.robust45_processor import convert_examples_to_features
from tasks.relevance_transfer.resample import ImbalancedDatasetSampler
from utils.preprocessing import pad_input_matrix
from utils.tokenization import BertTokenizer


class RelevanceTransferTrainer(Trainer):
def __init__(self, model, config, **kwargs):
super().__init__(model, kwargs['embedding'], kwargs['train_loader'], config, None, kwargs['test_evaluator'], kwargs['dev_evaluator'])

if config['model'] in {'BERT-Base', 'BERT-Large'}:
if config['model'] in {'BERT-Base', 'BERT-Large', 'HBERT-Base', 'HBERT-Large'}:
variant = 'bert-large-uncased' if config['model'] == 'BERT-Large' else 'bert-base-uncased'
self.tokenizer = BertTokenizer.from_pretrained(variant, is_lowercase=config['is_lowercase'])
self.processor = kwargs['processor']
@@ -37,14 +37,13 @@ def __init__(self, model, config, **kwargs):
self.log_template = ' '.join('{:>5.0f},{:>9.0f},{:>6.0f}/{:<5.0f} {:>6.4f},{:>8.4f},{:8.4f},{:8.4f},{:10.4f}'.split(','))

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + timestamp)
self.snapshot_path = os.path.join(self.model_outfile, config['dataset'].NAME, '%s.pt' % timestamp)

def train_epoch(self):
for step, batch in enumerate(tqdm(self.train_loader, desc="Training")):
self.model.train()

if self.config['model'] in {'BERT-Base', 'BERT-Large'}:
if self.config['model'] in {'BERT-Base', 'BERT-Large', 'HBERT-Base', 'HBERT-Large'}:
batch = tuple(t.to(self.config['device']) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
logits = torch.sigmoid(self.model(input_ids, segment_ids, input_mask)).squeeze(dim=1)
@@ -61,6 +60,7 @@ def train_epoch(self):
self.optimizer.step()
self.optimizer.zero_grad()
self.iterations += 1

else:
# Clip gradients to address exploding gradients in LSTM
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 25.0)
@@ -114,15 +114,29 @@ def train(self, epochs):
os.makedirs(self.model_outfile, exist_ok=True)
os.makedirs(os.path.join(self.model_outfile, self.config['dataset'].NAME), exist_ok=True)

if self.config['model'] in {'BERT-Base', 'BERT-Large'}:
if self.config['model'] in {'BERT-Base', 'BERT-Large', 'HBERT-Base', 'HBERT-Large'}:
train_features = convert_examples_to_features(
self.train_examples, self.config['max_seq_length'], self.tokenizer)

all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
self.train_examples,
self.config['max_seq_length'],
self.tokenizer,
self.config['is_hierarchical']
)

unpadded_input_ids = [f.input_ids for f in train_features]
unpadded_input_mask = [f.input_mask for f in train_features]
unpadded_segment_ids = [f.segment_ids for f in train_features]

if self.config['is_hierarchical']:
pad_input_matrix(unpadded_input_ids, self.config['max_doc_length'])
pad_input_matrix(unpadded_input_mask, self.config['max_doc_length'])
pad_input_matrix(unpadded_segment_ids, self.config['max_doc_length'])

padded_input_ids = torch.tensor(unpadded_input_ids, dtype=torch.long)
padded_input_mask = torch.tensor(unpadded_input_mask, dtype=torch.long)
padded_segment_ids = torch.tensor(unpadded_segment_ids, dtype=torch.long)
label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)

train_data = TensorDataset(padded_input_ids, padded_input_mask, padded_segment_ids, label_ids)
train_sampler = RandomSampler(train_data)
self.train_loader = DataLoader(train_data, sampler=train_sampler, batch_size=self.config['batch_size'])

@@ -132,10 +146,6 @@ def train(self, epochs):

# Evaluate performance on validation set
dev_acc, dev_precision, dev_ap, dev_f1, dev_loss = self.dev_evaluator.get_scores()[0]
self.writer.add_scalar('Dev/Loss', dev_loss, epoch)
self.writer.add_scalar('Dev/Accuracy', dev_acc, epoch)
self.writer.add_scalar('Dev/Precision', dev_precision, epoch)
self.writer.add_scalar('Dev/AP', dev_ap, epoch)
tqdm.write(self.log_header)
tqdm.write(self.log_template.format(epoch, self.iterations, epoch, epochs,
dev_acc, dev_precision, dev_ap, dev_f1, dev_loss))
@@ -14,12 +14,8 @@ def __init__(self, model, embedding, train_loader, trainer_config, train_evaluat
self.model_outfile = trainer_config.get('model_outfile')
self.lr_reduce_factor = trainer_config.get('lr_reduce_factor')
self.patience = trainer_config.get('patience')
self.use_tensorboard = trainer_config.get('tensorboard')
self.clip_norm = trainer_config.get('clip_norm')

if self.use_tensorboard:
from tensorboardX import SummaryWriter
self.writer = SummaryWriter(log_dir=None, comment='' if trainer_config['run_label'] is None else trainer_config['run_label'])
self.logger = trainer_config.get('logger')

self.train_evaluator = train_evaluator
@@ -1,5 +1,7 @@
import os

from nltk import sent_tokenize

from datasets.bert_processors.abstract_processor import BertProcessor, InputExample, InputFeatures


@@ -18,7 +20,6 @@ class Robust45Processor(BertProcessor):
'362', '363', '367', '372', '375', '378', '379', '389', '393', '394', '397', '399', '400', '404', '408',
'414', '416', '419', '422', '423', '426', '427', '433', '435', '436', '439', '442', '443', '445', '614',
'620', '626', '646', '677', '690']
TOPICS = ['307', '310', '321', '325', '330']

def get_train_examples(self, data_dir, **kwargs):
return self._create_examples(
@@ -47,9 +48,10 @@ def _create_examples(lines, split):
return examples


def convert_examples_to_features(examples, max_seq_length, tokenizer):
def convert_examples_to_features(examples, max_seq_length, tokenizer, is_hierarchical=False):
"""
Loads a data file into a list of InputBatch objects
:param is_hierarchical:
:param examples:
:param max_seq_length:
:param tokenizer:
@@ -58,30 +60,51 @@ def convert_examples_to_features(examples, max_seq_length, tokenizer):

features = []
for (ex_index, example) in enumerate(examples):
tokens_a = tokenizer.tokenize(example.text_a)
if is_hierarchical:
tokens_a = [tokenizer.tokenize(line) for line in sent_tokenize(example.text_a)]

# Account for [CLS] and [SEP]
for i0 in range(len(tokens_a)):
if len(tokens_a[i0]) > max_seq_length - 2:
tokens_a[i0] = tokens_a[i0][:(max_seq_length - 2)]

tokens = [["[CLS]"] + line + ["[SEP]"] for line in tokens_a]
segment_ids = [[0] * len(line) for line in tokens]

input_ids = list()
for line in tokens:
input_ids.append(tokenizer.convert_tokens_to_ids(line))

# Input mask has 1 for real tokens and 0 for padding tokens
input_mask = [[1] * len(line_ids) for line_ids in input_ids]

# Zero-pad up to the sequence length.
padding = [[0] * (max_seq_length - len(line_ids)) for line_ids in input_ids]
for i0 in range(len(input_ids)):
input_ids[i0] += padding[i0]
input_mask[i0] += padding[i0]
segment_ids[i0] += padding[i0]

tokens_b = None
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
else:
tokens_a = tokenizer.tokenize(example.text_a)

tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
segment_ids = [0] * len(tokens)
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]

if tokens_b:
tokens += tokens_b + ["[SEP]"]
segment_ids += [1] * (len(tokens_b) + 1)
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
segment_ids = [0] * len(tokens)

input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = tokenizer.convert_tokens_to_ids(tokens)

# The mask has 1 for real tokens and 0 for padding tokens
input_mask = [1] * len(input_ids)
# The mask has 1 for real tokens and 0 for padding tokens
input_mask = [1] * len(input_ids)

# Zero-pad up to the sequence length
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
# Zero-pad up to the sequence length
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding

try:
docid = int(example.guid)
@@ -39,7 +39,6 @@ def process_labels(string):
:param string:
:return:
"""
# return [float(x) for x in string]
return 0 if string == '01' else 1


@@ -67,7 +66,6 @@ class Robust45(TabularDataset):
'362', '363', '367', '372', '375', '378', '379', '389', '393', '394', '397', '399', '400', '404', '408',
'414', '416', '419', '422', '423', '426', '427', '433', '435', '436', '439', '442', '443', '445', '614',
'620', '626', '646', '677', '690']
TOPICS = ['307', '310', '321', '325', '330']

@staticmethod
def sort_key(ex):
@@ -23,12 +23,12 @@ def get_args():
parser.add_argument('--batchnorm', action='store_true')
parser.add_argument('--dynamic-pool', action='store_true')
parser.add_argument('--dynamic-pool-length', type=int, default=8)
parser.add_argument('--conv-output-channels', type=int, default=100)
parser.add_argument('--output-channel', type=int, default=100)

parser.add_argument('--max-seq-length', default=128, type=int,
help='maximum total input sequence length after tokenization')

parser.add_argument('--max-doc-length', default=128, type=int,
parser.add_argument('--max-doc-length', default=16, type=int,
help='maximum number of lines processed in one document')

parser.add_argument('--warmup-proportion', default=0.1, type=float,
@@ -7,32 +7,32 @@

class HierarchicalBert(nn.Module):

def __init__(self, args, cache_dir):
def __init__(self, args, cache_dir, **kwargs):
super().__init__()
self.args =args
ks = 3
input_channels = 1
ks = 3

self.sentence_encoder = BertSentenceEncoder.from_pretrained(
args.model,
kwargs['variant'] if 'variant' in kwargs else args.model,
cache_dir=cache_dir,
num_labels=args.num_labels)

self.conv1 = nn.Conv2d(input_channels,
args.conv_output_channels,
args.output_channel,
(3, self.sentence_encoder.config.hidden_size),
padding=(2, 0))
self.conv2 = nn.Conv2d(input_channels,
args.conv_output_channels,
args.output_channel,
(4, self.sentence_encoder.config.hidden_size),
padding=(3, 0))
self.conv3 = nn.Conv2d(input_channels,
args.conv_output_channels,
args.output_channel,
(5, self.sentence_encoder.config.hidden_size),
padding=(4, 0))

self.dropout = nn.Dropout(args.dropout)
self.fc1 = nn.Linear(ks * args.conv_output_channels, args.num_labels)
self.fc1 = nn.Linear(ks * args.output_channel, args.num_labels)

def forward(self, input_ids, segment_ids=None, input_mask=None):
"""
@@ -62,14 +62,14 @@ def forward(self, input_ids, segment_ids=None, input_mask=None):
F.relu(self.conv3(x)).squeeze(3)]

if self.args.dynamic_pool:
x = [self.dynamic_pool(i).squeeze(2) for i in x] # (batch, output_channels) * ks
x = torch.cat(x, 1) # (batch, output_channels * ks)
x = [self.dynamic_pool(i).squeeze(2) for i in x] # (batch_size, output_channels) * ks
x = torch.cat(x, 1) # (batch_size, output_channels * ks)
x = x.view(-1, self.filter_widths * self.output_channel * self.dynamic_pool_length)
else:
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # (batch, output_channels, num_sentences) * ks
x = torch.cat(x, 1) # (batch, channel_output * ks)
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # (batch_size, output_channels, num_sentences) * ks
x = torch.cat(x, 1) # (batch_size, channel_output * ks)

x = self.dropout(x)
logit = self.fc1(x) # (batch, num_labels)
logits = self.fc1(x) # (batch_size, num_labels)

return logit
return logits

0 comments on commit f2feae8

Please sign in to comment.
You can’t perform that action at this time.