Skip to content
Permalink
Browse files

Integrate BERT into Hedwig (#29) (#11)

* Fix package imports

* Update README.md

* Fix bug due to TAR/AR attribute check

* Add BERT models

* Add BERT tokenizer

* Return logits from the model.py

* Remove unused classes in models/bert

* Return logits from the model.py (#12)

* Remove unused classes in models/bert (#13)

* Add initial main file

* Add args for BERT

* Add partial support for BERT

* Initialize training and optimization

* Draft the structure of Trainers for BERT

* Remove duplicate tokenizer

* Add utils

* Move optimization to utils

* Add more structure for trainer

* Refactor the trainer (#15)

* Refactor the trainer

* Add more edits

* Add support for our datasets

* Add evaluator

* Split data4bert module into multiple processors

* Refactor BERT tokenizer

* Integrate BERT into Castor framework (#17)

* Remove unused classes in models/bert

* Split data4bert module into multiple processors

* Refactor BERT tokenizer

* Add multilabel support in BertTrainer

* Add multilabel support in BertEvaluator

* Add get_test_samples method in dataset processors

* Fix args.py for BERT

* Add support for Reuters, IMDB datasets for BERT

* Revert "Integrate BERT into Castor framework (#17)"

This reverts commit e4244ec.

* Fix paths to datasets in dataset classes and args

* Add SST dataset

* Add hedwig-data instructions to README.md

* Fix KimCNN README

* Fix RegLSTM README

* Fix typos in README

* Remove trec_eval from README

* Add tensorboardX to requirements.txt

* Rename processors module to bert_processors

* Add method to print metrics after training

* Add model check-pointing and early stopping for BERT

* Add logos

* Update README.md

* Fix code comments in classification trainer

* Add support for AAPD, Sogou, AGNews and Yelp2014

* Fix bug that deleted saved models

* Update README for HAN

* Update README for XML-CNN

* Remove redundant TODOs from the READMEs

* Fix logo in README.md

* Update README for Char-CNN

* Fix all the READMEs

* Resolve conflict

* Fix Typos

* Re-Add SST2 Processor

* Add support for evaluating trained model

* Update args.py

* Resolve issues due to DataParallel wrapper on saved model

* Remove redundant Yelp processor

* Fix bug for safely creating the saving directory

* Change checkpoint paths to timestamps

* Remove unwanted string.strip() from tokenizer

* Create save path if it doesn't exist

* Decouple model checkpoints from code

* Remove model choice restrictions for BERT

* Remove model/distill driver

* Simplify checkpoint directory creation
  • Loading branch information...
achyudh authored and Ashutosh-Adhikari committed Apr 14, 2019
1 parent c0a4bcb commit 7d24958e94e54abc02a4c17bbf5d0587d84c4b14
Showing with 2,799 additions and 284 deletions.
  1. +19 −30 README.md
  2. +80 −0 common/evaluators/bert_evaluator.py
  3. +5 −4 common/evaluators/classification_evaluator.py
  4. +118 −0 common/trainers/bert_trainer.py
  5. +8 −4 common/trainers/classification_trainer.py
  6. +5 −5 datasets/aapd.py
  7. 0 datasets/bert_processors/__init__.py
  8. +33 −0 datasets/bert_processors/aapd_processor.py
  9. +193 −0 datasets/bert_processors/abstract_processor.py
  10. +34 −0 datasets/bert_processors/agnews_processor.py
  11. +34 −0 datasets/bert_processors/imdb_processor.py
  12. +33 −0 datasets/bert_processors/reuters_processor.py
  13. +34 −0 datasets/bert_processors/sogou_processor.py
  14. +39 −0 datasets/bert_processors/sst_processor.py
  15. +34 −0 datasets/bert_processors/yelp2014_processor.py
  16. +3 −3 datasets/imdb.py
  17. +3 −14 datasets/reuters.py
  18. +92 −0 datasets/sst.py
  19. +4 −4 datasets/yelp2014.py
  20. BIN docs/hedwig.png
  21. +1 −1 models/args.py
  22. 0 models/bert/__init__.py
  23. +169 −0 models/bert/__main__.py
  24. +43 −0 models/bert/args.py
  25. +851 −0 models/bert/model.py
  26. +12 −32 models/char_cnn/README.md
  27. +5 −0 models/char_cnn/__main__.py
  28. +2 −2 models/char_cnn/args.py
  29. +16 −30 models/han/README.md
  30. +6 −1 models/han/__main__.py
  31. +2 −2 models/han/args.py
  32. +24 −106 models/kim_cnn/README.md
  33. +5 −0 models/kim_cnn/__main__.py
  34. +2 −2 models/kim_cnn/args.py
  35. +33 −21 models/reg_lstm/README.md
  36. +5 −0 models/reg_lstm/__main__.py
  37. +2 −2 models/reg_lstm/args.py
  38. +16 −16 models/xml_cnn/README.md
  39. +8 −3 models/xml_cnn/__main__.py
  40. +2 −2 models/xml_cnn/args.py
  41. +1 −0 requirements.txt
  42. +257 −0 utils/io.py
  43. +179 −0 utils/optimization.py
  44. +387 −0 utils/tokenization.py
@@ -1,4 +1,6 @@
# Hedwig
<p align="center">
<img src="https://github.com/karkaroff/hedwig/blob/bellatrix/docs/hedwig.png" width="360">
</p>

This repo contains PyTorch deep learning models for document classification, implemented by the Data Systems Group at the University of Waterloo.

@@ -14,8 +16,6 @@ Each model directory has a `README.md` with further details.

## Setting up PyTorch

**If you are an internal Hedwig contributor using GPU machines in the lab, follow the instructions [here](docs/internal-instructions.md).**

Hedwig is designed for Python 3.6 and [PyTorch](https://pytorch.org/) 0.4.
PyTorch recommends [Anaconda](https://www.anaconda.com/distribution/) for managing your environment.
We'd recommend creating a custom environment as follows:
@@ -25,10 +25,10 @@ $ conda create --name castor python=3.6
$ source activate castor
```

And installing the packages as follows:
And installing PyTorch as follows:

```
$ conda install pytorch torchvision -c pytorch
$ conda install pytorch=0.4.1 cuda92 -c pytorch
```

Other Python packages we use can be installed via pip:
@@ -37,49 +37,38 @@ Other Python packages we use can be installed via pip:
$ pip install -r requirements.txt
```

Code depends on data from NLTK (e.g., stopwords) so you'll have to download them. Run the Python interpreter and type the commands:
Code depends on data from NLTK (e.g., stopwords) so you'll have to download them.
Run the Python interpreter and type the commands:

```python
>>> import nltk
>>> nltk.download()
```

Finally, run the following inside the `utils` directory to build the `trec_eval` tool for evaluating certain datasets.
## Datasets

Download the Reuters, AAPD and IMDB datasets, along with word2vec embeddings from
[`hedwig-data`](https://git.uwaterloo.ca/jimmylin/hedwig-data).

```bash
$ ./get_trec_eval.sh
$ git clone https://github.com/castorini/hedwig.git
$ git clone https://git.uwaterloo.ca/jimmylin/hedwig-data.git
```

## Data and Pre-Trained Models

**If you are an internal Hedwig contributor using GPU machines in the lab, follow the instructions [here](docs/internal-instructions.md).**

To fully take advantage of code here, clone these other two repos:

+ [`Castor-data`](https://git.uwaterloo.ca/jimmylin/Castor-data): embeddings, datasets, etc.
+ [`Caster-models`](https://git.uwaterloo.ca/jimmylin/Castor-models): pre-trained models

Organize your directory structure as follows:

```
.
├── hedwig
├── Castor-data
└── Castor-models
└── hedwig-data
```

For example (using HTTPS):
After cloning the hedwig-data repo, you need to unzip the embeddings and run the preprocessing script:

```bash
$ git clone https://github.com/castorini/hedwig.git
$ git clone https://git.uwaterloo.ca/jimmylin/Castor-data.git
$ git clone https://git.uwaterloo.ca/jimmylin/Castor-models.git
cd hedwig-data/embeddings/word2vec
gzip -d GoogleNews-vectors-negative300.bin.gz
python bin2txt.py GoogleNews-vectors-negative300.bin GoogleNews-vectors-negative300.txt
```

After cloning the Hedwig-data repo, you need to unzip embeddings and run data pre-processing scripts. You can choose
to follow instructions under each dataset and embedding directory separately, or just run the following script in
Hedwig-data to do all of the steps for you:

```bash
$ ./setup.sh
```
**If you are an internal Hedwig contributor using the machines in the lab, follow the instructions [here](docs/internal-instructions.md).**
@@ -0,0 +1,80 @@
import warnings

import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from tqdm import tqdm

from datasets.bert_processors.abstract_processor import convert_examples_to_features
from utils.tokenization import BertTokenizer

# Suppress warnings from sklearn.metrics
warnings.filterwarnings('ignore')


class BertEvaluator(object):
def __init__(self, model, processor, args, split='dev'):
self.args = args
self.model = model
self.processor = processor
self.tokenizer = BertTokenizer.from_pretrained(args.model, is_lowercase=args.is_lowercase)
if split == 'test':
self.eval_examples = self.processor.get_test_examples(args.data_dir)
else:
self.eval_examples = self.processor.get_dev_examples(args.data_dir)

def get_scores(self, silent=False):
eval_features = convert_examples_to_features(self.eval_examples, self.args.max_seq_length, self.tokenizer)

all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)

eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=self.args.batch_size)

self.model.eval()

total_loss = 0
nb_eval_steps, nb_eval_examples = 0, 0
predicted_labels, target_labels = list(), list()

for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating", disable=silent):
input_ids = input_ids.to(self.args.device)
input_mask = input_mask.to(self.args.device)
segment_ids = segment_ids.to(self.args.device)
label_ids = label_ids.to(self.args.device)

with torch.no_grad():
logits = self.model(input_ids, segment_ids, input_mask)

if self.args.is_multilabel:
predicted_labels.extend(F.sigmoid(logits).round().long().cpu().detach().numpy())
target_labels.extend(label_ids.cpu().detach().numpy())
loss = F.binary_cross_entropy_with_logits(logits, label_ids.float(), size_average=False)
else:
predicted_labels.extend(torch.argmax(logits, dim=1).cpu().detach().numpy())
target_labels.extend(torch.argmax(label_ids, dim=1).cpu().detach().numpy())
loss = F.cross_entropy(logits, torch.argmax(label_ids, dim=1))

if self.args.n_gpu > 1:
loss = loss.mean()
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
total_loss += loss.item()

nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1

predicted_labels, target_labels = np.array(predicted_labels), np.array(target_labels)
accuracy = metrics.accuracy_score(target_labels, predicted_labels)
precision = metrics.precision_score(target_labels, predicted_labels, average='micro')
recall = metrics.recall_score(target_labels, predicted_labels, average='micro')
f1 = metrics.f1_score(target_labels, predicted_labels, average='micro')
avg_loss = total_loss / nb_eval_steps

return [accuracy, precision, recall, f1, avg_loss], ['accuracy', 'precision', 'recall', 'f1', 'avg_loss']
@@ -18,14 +18,14 @@ def get_scores(self):
self.data_loader.init_epoch()
total_loss = 0

# Temp Ave
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
old_params = self.model.get_params()
self.model.load_ema_params()

predicted_labels, target_labels = list(), list()
for batch_idx, batch in enumerate(self.data_loader):
if hasattr(self.model, 'tar') and self.model.tar: # TAR condition
if hasattr(self.model, 'tar') and self.model.tar:
if self.ignore_lengths:
scores, rnn_outs = self.model(batch.text)
else:
@@ -46,7 +46,8 @@ def get_scores(self):
target_labels.extend(torch.argmax(batch.label, dim=1).cpu().detach().numpy())
total_loss += F.cross_entropy(scores, torch.argmax(batch.label, dim=1), size_average=False).item()

if hasattr(self.model, 'tar') and self.model.tar: # TAR condition
if hasattr(self.model, 'tar') and self.model.tar:
# Temporal activation regularization
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()

predicted_labels = np.array(predicted_labels)
@@ -57,8 +58,8 @@ def get_scores(self):
f1 = metrics.f1_score(target_labels, predicted_labels, average='micro')
avg_loss = total_loss / len(self.data_loader.dataset.examples)

# Temp Ave
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
self.model.load_params(old_params)

return [accuracy, precision, recall, f1, avg_loss], ['accuracy', 'precision', 'recall', 'f1', 'cross_entropy_loss']
@@ -0,0 +1,118 @@
import datetime
import os

import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from tqdm import trange

from common.evaluators.bert_evaluator import BertEvaluator
from datasets.bert_processors.abstract_processor import convert_examples_to_features
from utils.optimization import warmup_linear
from utils.tokenization import BertTokenizer


class BertTrainer(object):
def __init__(self, model, optimizer, processor, args):
self.args = args
self.model = model
self.optimizer = optimizer
self.processor = processor
self.train_examples = self.processor.get_train_examples(args.data_dir)
self.tokenizer = BertTokenizer.from_pretrained(args.model, is_lowercase=args.is_lowercase)

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + timestamp)
self.snapshot_path = os.path.join(self.args.save_path, self.processor.NAME, '%s.pt' % timestamp)

self.num_train_optimization_steps = int(
len(self.train_examples) / args.batch_size / args.gradient_accumulation_steps) * args.epochs
if args.local_rank != -1:
self.num_train_optimization_steps = args.num_train_optimization_steps // torch.distributed.get_world_size()

self.log_header = 'Epoch Iteration Progress Dev/Acc. Dev/Pr. Dev/Re. Dev/F1 Dev/Loss'
self.log_template = ' '.join('{:>5.0f},{:>9.0f},{:>6.0f}/{:<5.0f} {:>6.4f},{:>8.4f},{:8.4f},{:8.4f},{:10.4f}'.split(','))

self.iterations, self.nb_tr_steps, self.tr_loss = 0, 0, 0
self.best_dev_f1, self.unimproved_iters = 0, 0
self.early_stop = False

def train_epoch(self, train_dataloader):
for step, batch in enumerate(tqdm(train_dataloader, desc="Training")):
batch = tuple(t.to(self.args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
logits = self.model(input_ids, segment_ids, input_mask)

if self.args.is_multilabel:
loss = F.binary_cross_entropy_with_logits(logits, label_ids.float())
else:
loss = F.cross_entropy(logits, torch.argmax(label_ids, dim=1))

if self.args.n_gpu > 1:
loss = loss.mean()
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps

if self.args.fp16:
self.optimizer.backward(loss)
else:
loss.backward()

self.tr_loss += loss.item()
self.nb_tr_steps += 1
if (step + 1) % self.args.gradient_accumulation_steps == 0:
if self.args.fp16:
lr_this_step = self.args.learning_rate * warmup_linear(self.iterations / self.num_train_optimization_steps, self.args.warmup_proportion)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr_this_step
self.optimizer.step()
self.optimizer.zero_grad()
self.iterations += 1

def train(self):
train_features = convert_examples_to_features(
self.train_examples, self.args.max_seq_length, self.tokenizer)

print("Number of examples: ", len(self.train_examples))
print("Batch size:", self.args.batch_size)
print("Num of steps:", self.num_train_optimization_steps)

all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if self.args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)

train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=self.args.batch_size)

self.model.train()

for epoch in trange(int(self.args.epochs), desc="Epoch"):
self.train_epoch(train_dataloader)
dev_evaluator = BertEvaluator(self.model, self.processor, self.args, split='dev')
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss = dev_evaluator.get_scores()[0]

# Print validation results
tqdm.write(self.log_header)
tqdm.write(self.log_template.format(epoch + 1, self.iterations, epoch + 1, self.args.epochs,
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss))

# Update validation results
if dev_f1 > self.best_dev_f1:
self.unimproved_iters = 0
self.best_dev_f1 = dev_f1
torch.save(self.model, self.snapshot_path)

else:
self.unimproved_iters += 1
if self.unimproved_iters >= self.args.patience:
self.early_stop = True
tqdm.write("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, self.best_dev_f1))
break
@@ -7,7 +7,7 @@
import torch.nn.functional as F
from tensorboardX import SummaryWriter

from .trainer import Trainer
from common.trainers.trainer import Trainer


class ClassificationTrainer(Trainer):
@@ -24,8 +24,10 @@ def __init__(self, model, embedding, train_loader, trainer_config, train_evaluat
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f}'.split(','))
self.dev_log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.4f},{:>8.4f},{:8.4f},{:12.4f},{:12.4f}'.split(','))
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
self.snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, 'best_model.pt')

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + timestamp)
self.snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, '%s.pt' % timestamp)

def train_epoch(self, epoch):
self.train_loader.init_epoch()
@@ -67,8 +69,8 @@ def train_epoch(self, epoch):
loss.backward()
self.optimizer.step()

# Temp Ave
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
self.model.update_ema()

if self.iterations % self.log_interval == 1:
@@ -97,6 +99,8 @@ def train(self, epochs):
self.writer.add_scalar('Dev/Precision', dev_precision, epoch)
self.writer.add_scalar('Dev/Recall', dev_recall, epoch)
self.writer.add_scalar('Dev/F-measure', dev_f1, epoch)

# Print validation results
print('\n' + dev_header)
print(self.dev_log_template.format(time.time() - self.start, epoch, self.iterations, epoch, epochs,
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss))
Oops, something went wrong.

0 comments on commit 7d24958

Please sign in to comment.
You can’t perform that action at this time.