Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating to pytorch-transformers #15

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ venv.bak/
.spyderproject
.spyproject

# vscode
.vscode

# Rope project settings
.ropeproject

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This repo contains a PyTorch implementation of a pretrained BERT model for multi-label text classification.

**note**: for the new `pytorch-pretrained-bert` package . use comd `from pytorch_pretrained_bert.modeling import BertPreTrainedModel`
**note**: for the new `pytorch_transformers` package . use comd `from pytorch_transformers.modeling_bert import BertPreTrainedModel`
## Structure of the code

At the root of the project, you will see:
Expand Down Expand Up @@ -42,7 +42,7 @@ At the root of the project, you will see:
- PyTorch 1.0
- matplotlib
- pandas
- pytorch_pretrained_bert (load bert model)
- pytorch_transformers (load bert model)

## How to use the code

Expand Down
2 changes: 1 addition & 1 deletion convert_tf_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#encoding:utf-8
import os
from pybert.config.basic_config import configs as config
from pytorch_pretrained_bert.convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
from pytorch_transformers.convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch

if __name__ == "__main__":
os.system('cp {config} {save_path}'.format(config = config['pretrained']['bert']['bert_config_file'],
Expand Down
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pybert.model.nn.bert_fine import BertFine
from pybert.test.predicter import Predicter
from pybert.preprocessing.preprocessor import EnglishPreProcessor
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_transformers.tokenization_bert import BertTokenizer
warnings.filterwarnings("ignore")

# 主函数
Expand Down
2 changes: 1 addition & 1 deletion pybert/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_transformers.tokenization_bert import BertTokenizer

class InputExample(object):
def __init__(self, guid, text_a, text_b=None, label=None):
Expand Down
11 changes: 6 additions & 5 deletions pybert/model/nn/bert_fine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#encoding:utf-8
import torch.nn as nn
from pytorch_pretrained_bert.modeling import PreTrainedBertModel, BertModel
from pytorch_transformers import BertModel
from pytorch_transformers.modeling_bert import BertPreTrainedModel

class BertFine(PreTrainedBertModel):
class BertFine(BertPreTrainedModel):
def __init__(self,bertConfig,num_classes):
super(BertFine ,self).__init__(bertConfig)
self.bert = BertModel(bertConfig) # bert模型
self.dropout = nn.Dropout(bertConfig.hidden_dropout_prob)
self.classifier = nn.Linear(in_features=bertConfig.hidden_size, out_features=num_classes)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
# 默认情况下,bert encoder模型所有的参数都是参与训练的,32的batch_size大概8.7G显存
# 可以通过以下设置为将其设为不训练,只将classifier这一层进行反响传播,32的batch_size大概显存1.1G
self.unfreeze_bert_encoder()
Expand All @@ -24,8 +25,8 @@ def unfreeze_bert_encoder(self):
def forward(self, input_ids, token_type_ids, attention_mask, label_ids=None, output_all_encoded_layers=False):
_, pooled_output = self.bert(input_ids,
token_type_ids,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
attention_mask)
#output_all_encoded_layers=output_all_encoded_layers)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
Expand Down
2 changes: 2 additions & 0 deletions pybert/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self,train_configs):
self.logger = train_configs['logger']
self.verbose = train_configs['verbose']
self.criterion = train_configs['criterion']
self.scheduler = train_configs['scheduler']
self.optimizer = train_configs['optimizer']
self.lr_scheduler = train_configs['lr_scheduler']
self.early_stopping = train_configs['early_stopping']
Expand Down Expand Up @@ -122,6 +123,7 @@ def _train_epoch(self,data):
# 学习率更新方式
if (step + 1) % self.gradient_accumulation_steps == 0:
self.lr_scheduler.batch_step(training_step = self.global_step)
self.scheduler.step()
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
Expand Down
17 changes: 10 additions & 7 deletions train_bert_multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pybert.callback.modelcheckpoint import ModelCheckpoint
from pybert.callback.trainingmonitor import TrainingMonitor
from pybert.train.metrics import F1Score,AccuracyThresh,MultiLabelReport
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
warnings.filterwarnings("ignore")

# 主函数
Expand Down Expand Up @@ -89,10 +89,12 @@ def main():
len(train_dataset.examples) / config['train']['batch_size'] / config['train']['gradient_accumulation_steps'] * config['train']['epochs'])
# t_total: total number of training steps for the learning rate schedule
# warmup: portion of t_total for the warmup
optimizer = BertAdam(optimizer_grouped_parameters,
lr = config['train']['learning_rate'],
warmup = config['train']['warmup_proportion'],
t_total = num_train_steps)
optimizer = AdamW(optimizer_grouped_parameters,
lr = config['train']['learning_rate'])

scheduler = WarmupLinearSchedule(optimizer,
warmup_steps=config['train']['warmup_steps'],
t_total=num_train_steps)

# **************************** callbacks ***********************
logger.info("initializing callbacks")
Expand All @@ -110,7 +112,7 @@ def main():
lr_scheduler = BertLR(optimizer = optimizer,
learning_rate = config['train']['learning_rate'],
t_total = num_train_steps,
warmup = config['train']['warmup_proportion'])
warmup = config['train']['warmup_steps'])

# **************************** training model ***********************
logger.info('training model....')
Expand All @@ -119,6 +121,7 @@ def main():
'model': model,
'logger': logger,
'optimizer': optimizer,
'scheduler': scheduler,
'resume': config['train']['resume'],
'epochs': config['train']['epochs'],
'n_gpu': config['train']['n_gpu'],
Expand Down