Skip to content
Permalink
Browse files

Add learning rate multiplier for H-BERT (#36)

*  Integrate BERT into Hedwig (#29)

* Fix package imports

* Update README.md

* Fix bug due to TAR/AR attribute check

* Add BERT models

* Add BERT tokenizer

* Return logits from the model.py

* Remove unused classes in models/bert

* Return logits from the model.py (#12)

* Remove unused classes in models/bert (#13)

* Add initial main file

* Add args for BERT

* Add partial support for BERT

* Initialize training and optimization

* Draft the structure of Trainers for BERT

* Remove duplicate tokenizer

* Add utils

* Move optimization to utils

* Add more structure for trainer

* Refactor the trainer (#15)

* Refactor the trainer

* Add more edits

* Add support for our datasets

* Add evaluator

* Split data4bert module into multiple processors

* Refactor BERT tokenizer

* Integrate BERT into Castor framework (#17)

* Remove unused classes in models/bert

* Split data4bert module into multiple processors

* Refactor BERT tokenizer

* Add multilabel support in BertTrainer

* Add multilabel support in BertEvaluator

* Add get_test_samples method in dataset processors

* Fix args.py for BERT

* Add support for Reuters, IMDB datasets for BERT

* Revert "Integrate BERT into Castor framework (#17)"

This reverts commit e4244ec.

* Fix paths to datasets in dataset classes and args

* Add SST dataset

* Add hedwig-data instructions to README.md

* Fix KimCNN README

* Fix RegLSTM README

* Fix typos in README

* Remove trec_eval from README

* Add tensorboardX to requirements.txt

* Rename processors module to bert_processors

* Add method to print metrics after training

* Add model check-pointing and early stopping for BERT

* Add logos

* Update README.md

* Fix code comments in classification trainer

* Add support for AAPD, Sogou, AGNews and Yelp2014

* Fix bug that deleted saved models

* Update README for HAN

* Update README for XML-CNN

* Remove redundant TODOs from the READMEs

* Fix logo in README.md

* Update README for Char-CNN

* Fix all the READMEs

* Resolve conflict

* Fix Typos

* Re-Add SST2 Processor

* Add support for evaluating trained model

* Update args.py

* Resolve issues due to DataParallel wrapper on saved model

* Remove redundant Yelp processor

* Fix bug for safely creating the saving directory

* Change checkpoint paths to timestamps

* Remove unwanted string.strip() from tokenizer

* Create save path if it doesn't exist

* Decouple model checkpoints from code

* Remove model choice restrictions for BERT

* Remove model/distill driver

* Simplify checkpoint directory creation

* Add TREC relevance datasets

* Add relevance transfer trainer and evaluator

* Add re-ranking module

* Add ImbalancedDatasetSampler

* Add relevance transfer package

* Fix import in classification trainer

* Remove unwanted args from models/bert

* Fix bug where model wasn't in training mode every epoch

* Add Robust45 preprocessor for BERT

* Add support for BERT for relevance transfer

* Add hierarchical BERT model

* Remove tensorboardX logging

* Add hierarchical BERT for relevance transfer

* Add learning rate multiplier

* Add lr multiplier for relevance transfer
  • Loading branch information...
achyudh authored and Ashutosh-Adhikari committed Sep 2, 2019
1 parent f2feae8 commit 255624be90001adee179fcee98447349dfe4f596
Showing with 17 additions and 5 deletions.
  1. +7 −2 models/hbert/__main__.py
  2. +1 −1 models/hbert/args.py
  3. +8 −2 tasks/relevance_transfer/__main__.py
  4. +1 −0 tasks/relevance_transfer/args.py
@@ -119,10 +119,15 @@ def evaluate_split(model, processor, args, split='dev'):

# Prepare optimizer
param_optimizer = list(model.named_parameters())

no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
{'params': [p for n, p in param_optimizer if 'sentence_encoder' not in n],
'lr': args.lr * args.lr_mult, 'weight_decay': 0.0},
{'params': [p for n, p in param_optimizer if 'sentence_encoder' in n and not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if 'sentence_encoder' in n and any(nd in n for nd in no_decay)],
'weight_decay': 0.0}]

if args.fp16:
try:
@@ -15,6 +15,7 @@ def get_args():
parser.add_argument('--fp16', action='store_true', help='enable 16-bit floating point precision')
parser.add_argument('--loss-scale', type=float, default=0, help='loss scaling to improve fp16 numeric stability')

parser.add_argument('--lr-mult', type=float, default=1)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--dropblock', type=float, default=0.0)
parser.add_argument('--dropblock-size', type=int, default=7)
@@ -37,6 +38,5 @@ def get_args():
parser.add_argument('--gradient-accumulation-steps', type=int, default=1,
help='number of updates steps to accumulate before performing a backward/update pass')


args = parser.parse_args()
return args
@@ -211,9 +211,15 @@ def save_ranks(pred_scores, output_path):
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
{'params': [p for n, p in param_optimizer if
'sentence_encoder' not in n],
'lr': args.lr * args.lr_mult, 'weight_decay': 0.0},
{'params': [p for n, p in param_optimizer if
'sentence_encoder' in n and not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
{'params': [p for n, p in param_optimizer if
'sentence_encoder' in n and any(nd in n for nd in no_decay)],
'weight_decay': 0.0}]

optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.lr,
@@ -11,6 +11,7 @@ def get_args():
parser.add_argument('--batch-size', type=int, default=1024)
parser.add_argument('--mode', type=str, default='static', choices=['rand', 'static', 'non-static', 'multichannel'])
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lr-mult', type=float, default=1)
parser.add_argument('--seed', type=int, default=3435)
parser.add_argument('--dataset', type=str, default='Robust04', choices=['Robust04', 'Robust05', 'Robust45'])
parser.add_argument('--model', type=str, default='KimCNN', choices=['RegLSTM', 'KimCNN', 'HAN', 'XML-CNN', 'BERT-Base',

0 comments on commit 255624b

Please sign in to comment.
You can’t perform that action at this time.