(求Star⭐)本项目仅仅提供了最基础的BERT文本分类模型,代码是作者在入门NLP时自己写的,对于初学者还算比较好理解,细节上有不足的地方,大家可以自行修改。读者在使用的时候有任何问题和建议都可以通过邮件联系我。
本文利用了transformers中的BertModel,对部分cnews数据集进行了文本分类,在验证集上的最优Acc达到了0.92,拿来对BERT模型练手还是不错的。
数据集是从清华大学的THUCNews中提取出来的部分数据。
训练集中有5万条数据,分成了10类,每类5000条数据。
{"体育": 5000, "娱乐": 5000, "家居": 5000, "房产": 5000, "教育": 5000, "时尚": 5000, "时政": 5000, "游戏": 5000, "科技": 5000, "财经": 5000}
验证集中有5000条数据,每类500条数据。
{"体育": 500, "娱乐": 500, "家居": 500, "房产": 500, "教育": 500, "时尚": 500, "时政": 500, "游戏": 500, "科技": 500, "财经": 500}
如果需要数据集,请与我联系.
数据集放在了百度网盘上:链接: https://pan.baidu.com/s/1FVV8fq7vSuGSiOVnE4E_Ag 提取码: bbwv
整个分类模型首先把句子输入到Bert预训练模型,然后将句子的embedding(CLS位置的Pooled output)输入给一个Linear,最后把Linear的输出输入到softmax中。
硬件 | 环境 |
---|---|
GPU | GTX1080 |
RAM | 64G |
软件 | 环境 |
---|---|
OS | Ubuntu 18.04 LTS |
CUDA | 10.2 |
PyTorch | 1.6.0 |
transformers | 3.2.0 |
分类报告:
* Classification Report:
precision recall f1-score support
体育 1.00 0.99 0.99 500
娱乐 0.99 0.92 0.96 500
家居 0.96 0.73 0.83 500
房产 0.83 0.94 0.88 500
教育 0.94 0.75 0.84 500
时尚 0.89 0.99 0.94 500
时政 0.91 0.96 0.93 500
游戏 0.93 0.98 0.96 500
科技 0.91 0.96 0.93 500
财经 0.87 0.98 0.92 500
accuracy 0.92 5000
macro avg 0.92 0.92 0.92 5000
weighted avg 0.92 0.92 0.92 5000
创建data文件夹,把下载好的cnews数据集放在data文件夹下。
创建models文件夹,用来保存模型
安装相应依赖库:
pip install -r requirements.txt
训练:
python train.py
预测:
python predict.py