Skip to content

Commit

Permalink
perf(runner): logger missing_keys and unexpected_key in runner
Browse files Browse the repository at this point in the history
  • Loading branch information
mmmwhy committed Feb 12, 2022
1 parent b3081ac commit a9e3ff9
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 40 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.0.23 (2022-02-6)
## 0.0.24 (2022-02-12)


### Bug Fixes
Expand All @@ -10,9 +10,14 @@
### Features

* **bert:** add tokenizer part ([054df14](https://github.com/mmmwhy/pure_attention/commit/054df14c7dfefc0b2edb47824578b33f4a5c8539))
* **decode:** add some transformer decode code ([893be87](https://github.com/mmmwhy/pure_attention/commit/893be87901aa875488a5bdce53bcf11f1bf74033))
* **decode:** add some transformer decode code ([52b044b](https://github.com/mmmwhy/pure_attention/commit/52b044b0fa79dcb3b9ba8fcd2747f05bc43de808))
* **layers:** fix import for layerNorm ([eb61b31](https://github.com/mmmwhy/pure_attention/commit/eb61b313458ac18bf4b15271fee2cf7e39f8afde))
* **nlp:** init basic bert code ([f9cb13a](https://github.com/mmmwhy/pure_attention/commit/f9cb13a3e811eb8c44ba8ff1373d688311426927))


### Performance Improvements

* **runner:** logger missing_keys and unexpected_key in runner ([69c8c78](https://github.com/mmmwhy/pure_attention/commit/69c8c781c7053c066d947087e98814e6132c8847))



6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ cv 和 nlp 中的很多方法和技巧也在相互影响,比如大规模的预
# 目标
提供一套完整的的基础算法服务

1、python 训练任务,包含 NLP 和 CV 任务
1、python 训练任务,包含 NLP 和 CV 任务。

2、java 环境下使用 onnx 的在线推理部署,使用 onnx 的原因是我在公司用的是 TensorFlow 做推理,我不想和公司的代码一致
2、java 环境下使用 onnx 的在线推理部署。

# todo
第一阶段:实现 NLP 和 CV 的典型任务,并评估下游效果。
Expand All @@ -25,7 +25,7 @@ cv 和 nlp 中的很多方法和技巧也在相互影响,比如大规模的预

- [x] 提供 [transformers](https://github.com/huggingface/transformers)[bert-base-chinese](https://huggingface.co/bert-base-chinese)[chinese-roberta-wwm-ext](https://huggingface.co/hfl/chinese-roberta-wwm-ext)[chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)[ernie 1.0](https://huggingface.co/nghuyong/ernie-1.0) 的国内下载镜像, 下载方式具体可参考 [transformers国内下载镜像](pure_attention/backbone_bert/README.md#transformers国内下载镜像)

- [ ] Pytorch 实现 Transformer 的 decode 阶段,并实现 seq2seq 任务。
- [x] Pytorch 实现 Transformer 的 decode 阶段,并实现 seq2seq 任务。
> todo
- [ ] NLP 下游任务 序列标注、分类 的实现,并在公开数据集上进行评估,这里主要是想证明实现的 backbone 效果是符合预期的;
> todo
Expand Down
18 changes: 11 additions & 7 deletions examples/model_chineseNMT/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,22 @@ def train(self):
scaler.step(self.optimizer)
scaler.update()

self.logger.info(
f"Epoch: {epoch} / {self.train_epochs_num}, "
f"Step: {step} / {len(self.dataloader)}, "
f"Loss: {np.mean(loss.item())}, "
f"Lr: {self.optimizer.param_groups[0]['lr']}"
)
self.logger.info((
"Epoch: {epoch:03d} / {all_epoch:03d},"
"Step: {step:04d} / {all_step:04d},"
"Loss: {loss:.04f},"
"Lr: {lr:.08f}"
.format(epoch=epoch, all_epoch=self.train_epochs_num, step=step,
all_step=len(self.dataloader),
loss=np.mean(loss.item()),
lr=self.optimizer.param_groups[0]['lr'])))

def run(self):
self.train()


# python -m model_chineseNMT.runner
# nohup python -m examples.model_chineseNMT.runner 1>train.log 2>&1 &
# tail -f train.log
if __name__ == "__main__":
config = BertConfig("/data/pretrain_modal/bert-base-chinese/config.json")
runner = Runner(config)
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "pure_attention",
"version": "0.0.23",
"version": "0.0.24",
"description": "Generate a changelog from git metadata",
"repository": {
"type": "git",
Expand Down
2 changes: 1 addition & 1 deletion pure_attention/backbone_bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


# 使用
1、安装本仓库 `pip install --upgrade pure_attention`
1、安装本仓库 `pip install pure_attention==0.0.24`

2、下载预训练模型

Expand Down
41 changes: 17 additions & 24 deletions pure_attention/backbone_bert/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from torch import nn
from pure_attention.utils.logger import init_logger

from pure_attention.backbone_bert.package import BertConfig, BertOutput
from pure_attention.base_transformer.encoder import Encoder
Expand All @@ -33,8 +34,8 @@ class BertModel(nn.Module):
def __init__(self, model_path):
super(BertModel, self).__init__()

# 配置文件是一定要有的
self.config = BertConfig(os.path.join(model_path, "config.json"))
self.config = BertConfig(os.path.join(model_path, "config.json")) # 配置文件一定要有
self.logger = init_logger(self.__class__.__name__)

self.embeddings = InputEmbeddings(self.config)
self.encoder = Encoder(self.config)
Expand Down Expand Up @@ -66,38 +67,30 @@ def from_pretrained(self, pretrained_model_path):

state_dict = torch.load(pretrained_model_path, map_location='cpu')

# 名称可能存在不一致,进行替换
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = key
if 'gamma' in key:
# 替换部分变量名
for old_key in state_dict.copy().keys():
new_key = old_key
if 'gamma' in old_key:
new_key = new_key.replace('gamma', 'weight')
if 'beta' in key:
if 'beta' in old_key:
new_key = new_key.replace('beta', 'bias')
if 'bert.' in key:
if 'bert.' in old_key:
new_key = new_key.replace('bert.', '')
# 兼容部分不优雅的变量命名
if 'LayerNorm' in key:
if 'LayerNorm' in old_key:
new_key = new_key.replace('LayerNorm', 'layer_norm')

if new_key:
old_keys.append(key)
new_keys.append(new_key)

for old_key, new_key in zip(old_keys, new_keys):

if new_key in self.state_dict().keys():
if new_key != old_key:
state_dict[new_key] = state_dict.pop(old_key)
else:
# 避免预训练模型里有多余的结构,影响 strict load_state_dict
state_dict.pop(old_key)

# 确保完全一致
self.load_state_dict(state_dict, strict=True)
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
# 可能存在少许不影响结果的参数
if len(missing_keys):
self.logger.warning("\n\t".join(["missing_keys:"] + missing_keys))
if len(unexpected_keys):
self.logger.warning("\n\t".join(["unexpected_keys:"] + unexpected_keys))

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
Expand Down
2 changes: 1 addition & 1 deletion pure_attention/base_transformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
from torch import nn

from pure_attention.base_transformer.activate import activations
from pure_attention.common.activate import activations


class InputEmbeddings(nn.Module):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pure_attention/common/nlp_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# @date: 2022/01/24
#
"""
Bert 所使用的 tokenization 代码,大部分 nlp 任务都可以搭配 vocab.txt 使用本文件
Bert 所使用的 tokenization 代码,搭配 vocab.txt 可以满足大部分 nlp 任务
"""

import collections
Expand Down

0 comments on commit a9e3ff9

Please sign in to comment.