In [1]:
import pandas as pd
import numpy as np
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments
from datasets import Dataset
import torch

In [2]:
import os
import glob
import json
import random
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.utils import resample
import matplotlib.pyplot as plt

In [3]:
# 指定包含JSON文件的文件夹路径
folder_path = '../data/esg_label_result'

# 使用glob获取文件夹中所有的JSON文件
json_files = glob.glob(os.path.join(folder_path, "*.json"))

all_data = []

# 逐个读取每个JSON文件
for file in json_files:
    with open(file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        all_data.extend(data)  # 将所有JSON文件的数据合并到一个列表中

# # 打印合并后的数据
# print(all_data)

In [4]:
# 读取JSON文件 
# with open('../data/esg_label_result/AF Global Limited_report_filtered.json', 'r') as f:
#     data = json.load(f)

In [4]:
data = [item for item in all_data if "error" not in item]

In [5]:
label_dict ={
        "B-ENV_GHG_AET": 0,
        "I-ENV_GHG_AET": 0,
        "B-ENV_GHG_AE1": 0,
        "I-ENV_GHG_AE1": 0,
        "B-ENV_GHG_AE2": 0,
        "I-ENV_GHG_AE2": 0,
        "B-ENV_GHG_AE3": 0,
        "I-ENV_GHG_AE3": 0,
        "B-ENV_GHG_EIT": 0,
        "I-ENV_GHG_EIT": 0,
        "B-ENV_GHG_EI1": 0,
        "I-ENV_GHG_EI1": 0,
        "B-ENV_GHG_EI2": 0,
        "I-ENV_GHG_EI2": 0,
        "B-ENV_GHG_EI3": 0,
        "I-ENV_GHG_EI3": 0,
        "B-ENV_ENC_TEC": 0,
        "I-ENV_ENC_TEC": 0,
        "B-ENV_ENC_ECI": 0,
        "I-ENV_ENC_ECI": 0,
        "B-ENV_WAC_TWC": 0,
        "I-ENV_WAC_TWC": 0,
        "B-ENV_WAC_WCI": 0,
        "I-ENV_WAC_WCI": 0,
        "B-ENV_WAG_TWG": 0,
        "I-ENV_WAG_TWG": 0,
        "B-SOC_GED_CEG_M": 0,
        "I-SOC_GED_CEG_M": 0,
        "B-SOC_GED_CEG_F": 0,
        "I-SOC_GED_CEG_F": 0,
        "B-SOC_GED_NHG_M": 0,
        "I-SOC_GED_NHG_M": 0,
        "B-SOC_GED_NHG_F": 0,
        "I-SOC_GED_NHG_F": 0,
        "B-SOC_GED_ETG_M": 0,
        "I-SOC_GED_ETG_M": 0,
        "B-SOC_GED_ETG_F": 0,
        "I-SOC_GED_ETG_F": 0,
        "B-SOC_AGD_CEA_U30": 0,
        "I-SOC_AGD_CEA_U30": 0,
        "B-SOC_AGD_CEA_B35": 0,
        "I-SOC_AGD_CEA_B35": 0,
        "B-SOC_AGD_CEA_A50": 0,
        "I-SOC_AGD_CEA_A50": 0,
        "B-SOC_AGD_NHI_U30": 0,
        "I-SOC_AGD_NHI_U30": 0,
        "B-SOC_AGD_NHI_B35": 0,
        "I-SOC_AGD_NHI_B35": 0,
        "B-SOC_AGD_NHI_A50": 0,
        "I-SOC_AGD_NHI_A50": 0,
        "B-SOC_AGD_TOR_U30": 0,
        "I-SOC_AGD_TOR_U30": 0,
        "B-SOC_AGD_TOR_B35": 0,
        "I-SOC_AGD_TOR_B35": 0,
        "B-SOC_AGD_TOR_A50": 0,
        "I-SOC_AGD_TOR_A50": 0,
        "B-SOC_DEV_ATH_M": 0,
        "I-SOC_DEV_ATH_M": 0,
        "B-SOC_DEV_ATH_F": 0,
        "I-SOC_DEV_ATH_F": 0,
        "B-SOC_OHS_FAT": 0,
        "I-SOC_OHS_FAT": 0,
        "B-SOC_OHS_HCI": 0,
        "I-SOC_OHS_HCI": 0,
        "B-SOC_OHS_REC": 0,
        "I-SOC_OHS_REC": 0,
        "B-SOC_OHS_RWI": 0,
        "I-SOC_OHS_RWI": 0,
        "B-GOV_BOC_BIN": 0,
        "I-GOV_BOC_BIN": 0,
        "B-GOV_BOC_WOB": 0,
        "I-GOV_BOC_WOB": 0,
        "B-GOV_MAD_WMT": 0,
        "I-GOV_MAD_WMT": 0,
        "B-GOV_ETB_ACD": 0,
        "I-GOV_ETB_ACD": 0,
        "B-GOV_ETB_ACT_N": 0,
        "I-GOV_ETB_ACT_N": 0,
        "B-GOV_ETB_ACT_P": 0,
        "I-GOV_ETB_ACT_P": 0,
        "B-GOV_CER_LRC": 0,
        "I-GOV_CER_LRC": 0,
        "B-GOV_ALF_AFD": 0,
        "I-GOV_ALF_AFD": 0,
        "B-GOV_ASS_ASR": 0,
        "I-GOV_ASS_ASR": 0,
        "B-VALUE": 0,
        "I-VALUE": 0,
        "B-UNIT": 0,
        "I-UNIT": 0,
        "O": 0
    }

In [6]:
# 提取BIO标注数据
texts = []
labels = []
err = []

for entry in data:
    text = entry['text']
    entity_labels = ["O"] * len(text)  # 初始化为'O'

    for entity in entry['entity']:
        start, end, label = entity['start'], entity['end'], entity['labels'][0]
        if label not in label_dict:
            continue
        # 检查字典情况
        # if end > len(entity_labels):
        #     err.append(data.index(entry))
        #     continue
        for i in range(start, end):
            entity_labels[i] = label

    texts.append(list(text))
    labels.append(entity_labels)

# 将数据转换为 DataFrame 格式
df = pd.DataFrame({"tokens": texts, "ner_tags": labels})

In [None]:
# # 检查函数：检测句子中是否包含至少一个错误标签
# def contains_incorrect_label(label_sequence):
#     # 如果标签序列全是 'O' 标签，则返回 False（保留该句子）
#     if all(label == "O" for label in label_sequence):
#         return False
#     # 如果存在标签不在合法标签集中，则返回 True（表示该句子含有错误标签）
#     return any(label not in label_dict for label in label_sequence)

# # 找出标错的句子
# incorrect_labels_df = df[df['ner_tags'].apply(contains_incorrect_label)]
# print(incorrect_labels_df)

In [7]:
# 非实体句删除

no_entity_data = df[df['ner_tags'].apply(lambda x: all(label == "O" for label in x))]
entity_data = df[~df['ner_tags'].apply(lambda x: all(label == "O" for label in x))]

# 保留 15% 的无实体句子
no_entity_sample = no_entity_data.sample(frac=0.15, random_state=42)

# 合并数据
balanced_df = pd.concat([entity_data, no_entity_sample])

# 打乱数据集顺序
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

# 检查新的数据分布
print("无实体句子数量:", len(no_entity_sample))
print("实体句子数量:", len(entity_data))
print("合并后的数据集样本数:", len(balanced_df))

无实体句子数量: 936
实体句子数量: 5153
合并后的数据集样本数: 6089


In [8]:
# 定义需要合并的标签字典，将稀有标签映射到新的标签名
merge_dict = {
    "B-SOC_AGD_TOR_U30": "B-SOC_AGD_TOR",
    "I-SOC_AGD_TOR_U30": "I-SOC_AGD_TOR",
    "B-SOC_AGD_TOR_B35": "B-SOC_AGD_TOR",
    "I-SOC_AGD_TOR_B35": "I-SOC_AGD_TOR",
    "B-SOC_AGD_TOR_A50": "B-SOC_AGD_TOR",
    "I-SOC_AGD_TOR_A50": "I-SOC_AGD_TOR",
    
    'B-SOC_AGD_NHI_B35': 'B-SOC_AGD_NHI',
    'B-SOC_AGD_NHI_A50': 'B-SOC_AGD_NHI',
    'B-SOC_AGD_NHI_U30': 'B-SOC_AGD_NHI',
    'I-SOC_AGD_NHI_U30': 'I-SOC_AGD_NHI',
    'I-SOC_AGD_NHI_B35': 'I-SOC_AGD_NHI',
    'I-SOC_AGD_NHI_A50': 'I-SOC_AGD_NHI',
    
    'B-ENV_GHG_EI1' : 'B-ENV_GHG_EI',
    'I-ENV_GHG_EI1' : 'I-ENV_GHG_EI',
    'B-ENV_GHG_EI2' : 'B-ENV_GHG_EI',
    'I-ENV_GHG_EI2' : 'I-ENV_GHG_EI',
    'B-ENV_GHG_EI3' : 'B-ENV_GHG_EI',
    'I-ENV_GHG_EI3' : 'I-ENV_GHG_EI',
    
    'B-SOC_AGD_CEA_U30' : 'B-SOC_AGD_CEA',
    'I-SOC_AGD_CEA_U30' : 'I-SOC_AGD_CEA',
    'B-SOC_AGD_CEA_B35' : 'B-SOC_AGD_CEA',
    'I-SOC_AGD_CEA_B35' : 'I-SOC_AGD_CEA',
    'B-SOC_AGD_CEA_A50' : 'B-SOC_AGD_CEA',
    'I-SOC_AGD_CEA_A50' : 'I-SOC_AGD_CEA',
    
    'B-SOC_GED_ETG_F' : 'B-SOC_GED_ETG',
    'I-SOC_GED_ETG_F' : 'I-SOC_GED_ETG',
    'B-SOC_GED_ETG_M' : 'B-SOC_GED_ETG',
    'I-SOC_GED_ETG_M' : 'I-SOC_GED_ETG',
    'B-SOC_GED_NHG_M' : 'B-SOC_GED_NHG',
    'I-SOC_GED_NHG_M' : 'I-SOC_GED_NHG',
    'B-SOC_GED_NHG_F' : 'B-SOC_GED_NHG',
    'I-SOC_GED_NHG_F' : 'I-SOC_GED_NHG'
    
    # 添加更多需要合并的标签映射
}

# 定义一个函数，用于将标签序列中的稀有标签合并
def merge_labels(label_sequence, merge_dict):
    return [merge_dict.get(label, label) for label in label_sequence]

# 应用标签合并函数到 DataFrame 的 'labels' 列
balanced_df['ner_tags'] = balanced_df['ner_tags'].apply(lambda x: merge_labels(x, merge_dict))

In [9]:
# 设置 pandas 的显示选项，防止省略
pd.set_option('display.max_rows', None)

# 检查标签合并后的分布
all_labels_flat = [item for sublist in balanced_df['ner_tags'] for item in sublist]
label_counts_after_merge = pd.Series(all_labels_flat).value_counts()

print("合并后的标签分布:")
print(label_counts_after_merge)

# 恢复默认设置（可选）
pd.reset_option('display.max_rows')


合并后的标签分布:
O                  1153555
I-GOV_ALF_AFD        29144
B-GOV_ALF_AFD        26855
B-VALUE              15533
B-GOV_ETB_ACD        10952
I-GOV_ETB_ACD         8668
B-GOV_BOC_BIN         8135
I-GOV_BOC_BIN         8107
I-SOC_DEV_ATH_M       6157
B-UNIT                5897
B-ENV_ENC_TEC         5709
B-ENV_GHG_AET         5702
B-SOC_DEV_ATH_M       5700
I-SOC_OHS_RWI         5246
B-SOC_OHS_RWI         5096
I-ENV_ENC_TEC         4947
I-ENV_GHG_AET         4775
I-UNIT                4584
I-ENV_WAG_TWG         3689
B-SOC_GED_CEG_F       3288
B-ENV_WAG_TWG         3052
I-GOV_CER_LRC         2911
I-SOC_GED_CEG_F       2553
I-VALUE               2248
B-GOV_CER_LRC         2176
I-ENV_WAC_TWC         1971
B-ENV_WAC_TWC         1678
I-GOV_ASS_ASR         1416
B-GOV_ASS_ASR         1349
I-ENV_ENC_ECI         1317
B-ENV_ENC_ECI         1258
I-ENV_GHG_EIT         1161
I-ENV_GHG_AE3         1158
B-SOC_OHS_REC         1127
I-SOC_OHS_REC         1119
I-ENV_GHG_AE2         1108
I-ENV_GHG_AE1     

In [None]:
# # 查找标签所在句子
# # 目标标签
# target_label = "B-SOC_AGD_TOR"

# # 筛选出包含目标标签的句子
# sentences_with_label = balanced_df_resampled[balanced_df_resampled['ner_tags'].apply(lambda x: target_label in x)]

# # 查看筛选结果
# print("包含标签", target_label, "的句子数量:", len(sentences_with_label))
# print(sentences_with_label[['tokens', 'ner_tags']].head())


In [None]:
# from collections import Counter

# # 假设 labels 列表中存储了每个文本的标签序列
# # 将所有标签展开为一个列表，并使用 Counter 统计每种标签的数量
# all_labels = [label for sequence in labels for label in sequence]
# label_counts = Counter(all_labels)

# # 打印每种实体的数量
# for label, count in label_counts.items():
#     print(f"实体标签 '{label}' 的数量为: {count}")

In [10]:
# 定义低频标签的阈值
low_count_threshold = 500

# 获取所有标签的数量分布
all_labels_flat = [item for sublist in balanced_df['ner_tags'] for item in sublist]
label_counts = pd.Series(all_labels_flat).value_counts()  # 假设这是一个标签-数量的字典或 Series

# 找出所有低频标签
low_frequency_labels = [label for label, count in label_counts.items() if count < low_count_threshold]

# 初始化一个新的 DataFrame 来存储过采样的句子
balanced_df_resampled = balanced_df.copy()

# 遍历每一个低频标签，筛选并过采样包含该标签的句子
for label in low_frequency_labels:
    # 筛选出包含当前标签的句子
    sentences_with_label = balanced_df[balanced_df['ner_tags'].apply(lambda x: label in x)]
    
    # 确认是否需要过采样
    if len(sentences_with_label) < low_count_threshold:
        # 过采样该标签的句子
        sentences_with_label_upsampled = resample(sentences_with_label, 
                                                  replace=True, 
                                                  n_samples=50, 
                                                  random_state=42)
        
        # 将过采样后的数据合并到主数据集中
        balanced_df_resampled = pd.concat([balanced_df_resampled, sentences_with_label_upsampled])

# 打乱数据集
balanced_df_resampled = balanced_df_resampled.sample(frac=1, random_state=42).reset_index(drop=True)

# 查看结果
print("过采样后的数据集大小:", len(balanced_df_resampled))


过采样后的数据集大小: 7089


In [11]:
# 设置 pandas 的显示选项，防止省略
pd.set_option('display.max_rows', None)

# 检查标签合并后的分布
all_labels_flat = [item for sublist in balanced_df_resampled['ner_tags'] for item in sublist]
label_counts_after_merge = pd.Series(all_labels_flat).value_counts()

print("合并后的标签分布:")
print(label_counts_after_merge)

# 恢复默认设置（可选）
pd.reset_option('display.max_rows')

合并后的标签分布:
O                  1446378
I-GOV_ALF_AFD        36557
B-GOV_ALF_AFD        30632
B-VALUE              27667
B-GOV_ETB_ACD        11621
I-UNIT               10829
B-UNIT               10115
I-GOV_ETB_ACD         9650
B-GOV_BOC_BIN         8799
I-GOV_BOC_BIN         8334
I-SOC_DEV_ATH_M       6850
I-ENV_WAG_TWG         6585
I-SOC_OHS_RWI         6541
B-SOC_DEV_ATH_M       6247
B-ENV_ENC_TEC         6170
B-SOC_OHS_RWI         6102
B-ENV_GHG_AET         5880
I-ENV_ENC_TEC         5716
I-ENV_GHG_AET         4976
B-SOC_GED_CEG_F       4473
B-ENV_WAG_TWG         4225
I-SOC_GED_CEG_F       3424
B-SOC_DEV_ATH_F       3308
B-SOC_GED_ETG         3162
I-SOC_GED_ETG         3099
I-GOV_CER_LRC         2991
I-SOC_DEV_ATH_F       2710
I-ENV_WAC_TWC         2705
B-SOC_OHS_REC         2601
B-SOC_GED_NHG         2590
I-SOC_OHS_REC         2585
B-SOC_OHS_HCI         2563
I-VALUE               2505
B-SOC_OHS_FAT         2495
I-SOC_GED_NHG         2258
B-GOV_CER_LRC         2255
I-ENV_GHG_EIT     

In [12]:
# 将数据转换为 Hugging Face 的 Dataset 格式
dataset = Dataset.from_pandas(balanced_df_resampled)

In [13]:
# 使用集合存储所有独特标签，避免重复
unique_labels = set(label_counts_after_merge.index)

# 将集合转换为列表并排序
unique_labels = sorted(list(unique_labels))

# 查看所有标签
print("所有独特标签:", unique_labels)
print(len(unique_labels))


所有独特标签: ['B-ENV_ENC_ECI', 'B-ENV_ENC_TEC', 'B-ENV_GHG_AE1', 'B-ENV_GHG_AE2', 'B-ENV_GHG_AE3', 'B-ENV_GHG_AET', 'B-ENV_GHG_EI', 'B-ENV_GHG_EIT', 'B-ENV_WAC_TWC', 'B-ENV_WAC_WCI', 'B-ENV_WAG_TWG', 'B-GOV_ALF_AFD', 'B-GOV_ASS_ASR', 'B-GOV_BOC_BIN', 'B-GOV_BOC_WOB', 'B-GOV_CER_LRC', 'B-GOV_ETB_ACD', 'B-GOV_ETB_ACT_N', 'B-GOV_ETB_ACT_P', 'B-GOV_MAD_WMT', 'B-SOC_AGD_CEA', 'B-SOC_AGD_NHI', 'B-SOC_AGD_TOR', 'B-SOC_DEV_ATH_F', 'B-SOC_DEV_ATH_M', 'B-SOC_GED_CEG_F', 'B-SOC_GED_CEG_M', 'B-SOC_GED_ETG', 'B-SOC_GED_NHG', 'B-SOC_OHS_FAT', 'B-SOC_OHS_HCI', 'B-SOC_OHS_REC', 'B-SOC_OHS_RWI', 'B-UNIT', 'B-VALUE', 'I-ENV_ENC_ECI', 'I-ENV_ENC_TEC', 'I-ENV_GHG_AE1', 'I-ENV_GHG_AE2', 'I-ENV_GHG_AE3', 'I-ENV_GHG_AET', 'I-ENV_GHG_EI', 'I-ENV_GHG_EIT', 'I-ENV_WAC_TWC', 'I-ENV_WAC_WCI', 'I-ENV_WAG_TWG', 'I-GOV_ALF_AFD', 'I-GOV_ASS_ASR', 'I-GOV_BOC_BIN', 'I-GOV_BOC_WOB', 'I-GOV_CER_LRC', 'I-GOV_ETB_ACD', 'I-GOV_ETB_ACT_N', 'I-GOV_ETB_ACT_P', 'I-GOV_MAD_WMT', 'I-SOC_AGD_CEA', 'I-SOC_AGD_NHI', 'I-SOC_AGD_TOR', 'I-S

In [14]:
train_test_split = dataset.train_test_split(test_size=0.2,seed = 666)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']

In [16]:
# model_name = "nbroad/ESG-BERT"
model_name = "dslim/bert-base-NER"
# model_name = "dslim/bert-large-NER"
tokenizer = BertTokenizerFast.from_pretrained(model_name)



In [18]:
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

In [20]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], 
        truncation=True, 
        is_split_into_words=True, 
        padding=True
    )
    labels = []

    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []

        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)  # 忽略位置
            elif word_idx != previous_word_idx:
                label_ids.append(label2id[label[word_idx]])  # 将标签转换为整数 ID
            else:
                # 对于当前词的子词部分，通常不需要计算损失，除非你想保持每个子词的相同标签
                label_ids.append(label2id[label[word_idx]] if label[word_idx].startswith("I-") else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs


In [21]:
# Define compute_metrics function for evaluation
def compute_metrics(pred):
    # Extract predictions and labels
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=2)
    
    # Remove ignored index (special tokens)
    true_labels = [[label for label, pred in zip(label_row, pred_row) if label != -100] 
                   for label_row, pred_row in zip(labels, predictions)]
    true_predictions = [[pred for label, pred in zip(label_row, pred_row) if label != -100]
                        for label_row, pred_row in zip(labels, predictions)]
    
    # Flatten lists
    true_labels = [item for sublist in true_labels for item in sublist]
    true_predictions = [item for sublist in true_predictions for item in sublist]
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, true_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, true_predictions, average='weighted')
    
    # Return results in dictionary
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }


In [None]:
# 定义模型
model = BertForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(unique_labels),  # 标签数量
    id2label=id2label,              # 标签ID到名称的映射
    label2id=label2id,              # 标签名称到ID的映射
    ignore_mismatched_sizes=True    # 忽略大小不匹配
)

In [22]:
train_dataset = train_dataset.map(tokenize_and_align_labels, batched=True)
eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True)

Map:   0%|          | 0/5671 [00:00<?, ? examples/s]

Map:   0%|          | 0/1418 [00:00<?, ? examples/s]

In [22]:
training_args = TrainingArguments(
    output_dir="../results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_dir="../logs",
    logging_steps=10,
)



In [23]:
# Update Trainer to use compute_metrics function
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

In [24]:
trainer.train()

  0%|          | 0/709 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 2.5867, 'grad_norm': 2.876352071762085, 'learning_rate': 1.9717912552891397e-05, 'epoch': 0.01}
{'loss': 1.1976, 'grad_norm': 1.1854902505874634, 'learning_rate': 1.9435825105782797e-05, 'epoch': 0.03}
{'loss': 1.51, 'grad_norm': 1.9207444190979004, 'learning_rate': 1.915373765867419e-05, 'epoch': 0.04}
{'loss': 1.3812, 'grad_norm': 6.197670936584473, 'learning_rate': 1.8871650211565585e-05, 'epoch': 0.06}
{'loss': 1.4283, 'grad_norm': 2.3361713886260986, 'learning_rate': 1.8589562764456984e-05, 'epoch': 0.07}
{'loss': 1.3057, 'grad_norm': 2.5133488178253174, 'learning_rate': 1.830747531734838e-05, 'epoch': 0.08}
{'loss': 1.3223, 'grad_norm': 1.8773274421691895, 'learning_rate': 1.8025387870239776e-05, 'epoch': 0.1}
{'loss': 1.3961, 'grad_norm': 3.1961872577667236, 'learning_rate': 1.7743300423131172e-05, 'epoch': 0.11}
{'loss': 1.3602, 'grad_norm': 1.9935333728790283, 'learning_rate': 1.7461212976022568e-05, 'epoch': 0.13}
{'loss': 1.3664, 'grad_norm': 4.477314472198486, 'lea

  0%|          | 0/178 [00:00<?, ?it/s]

{'eval_loss': 1.1157557964324951, 'eval_accuracy': 0.7926746118262844, 'eval_precision': 0.64166726852883, 'eval_recall': 0.7926746118262844, 'eval_f1': 0.7077866833121764, 'eval_runtime': 152.2816, 'eval_samples_per_second': 9.312, 'eval_steps_per_second': 1.169, 'epoch': 1.0}
{'train_runtime': 510.8828, 'train_samples_per_second': 11.1, 'train_steps_per_second': 1.388, 'train_loss': 1.2307396378940856, 'epoch': 1.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TrainOutput(global_step=709, training_loss=1.2307396378940856, metrics={'train_runtime': 510.8828, 'train_samples_per_second': 11.1, 'train_steps_per_second': 1.388, 'total_flos': 1482738299685888.0, 'train_loss': 1.2307396378940856, 'epoch': 1.0})

In [25]:
eval_results = trainer.evaluate()
print(eval_results)

  0%|          | 0/178 [00:00<?, ?it/s]

{'eval_loss': 1.1157557964324951, 'eval_accuracy': 0.7926746118262844, 'eval_precision': 0.64166726852883, 'eval_recall': 0.7926746118262844, 'eval_f1': 0.7077866833121764, 'eval_runtime': 26.9121, 'eval_samples_per_second': 52.69, 'eval_steps_per_second': 6.614, 'epoch': 1.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [27]:
model.save_pretrained('../finetuned_model')
tokenizer.save_pretrained('../finetuned_model')

('../finetuned_model\\tokenizer_config.json',
 '../finetuned_model\\special_tokens_map.json',
 '../finetuned_model\\vocab.txt',
 '../finetuned_model\\added_tokens.json',
 '../finetuned_model\\tokenizer.json')

In [17]:
# 加载微调后的模型和分词器
model_path = "../finetuned_model"
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForTokenClassification.from_pretrained(model_path)

In [None]:
training_args = TrainingArguments(
    output_dir="../results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=4,
    weight_decay=0.01,
    logging_dir="../logs",
    logging_steps=10,
)

# 可以进行增量训练
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

trainer.train()