In [None]:
# 基础导入
import json
import matplotlib.pyplot as plt
from pprint import pprint

# 项目模块导入
from src.segmenter import fmm_segment, jieba_segment
from src.pos_tagger import CRFPOSTagger
from src.utils.data_loader import load_processed_data
from src.train_bert import train_bert


In [None]:
train_path = "data/processed/train.json"
dev_path = "data/processed/dev.json"

train_sents = load_processed_data(train_path)
dev_sents = load_processed_data(dev_path)

print("训练样本数:", len(train_sents))
print("开发集样本数:", len(dev_sents))

print("\n示例样本：")
pprint(train_sents[0])


In [None]:
# 统计句子长度
train_lengths = [len(sent) for sent in train_sents]

plt.figure(figsize=(8,4))
plt.hist(train_lengths, bins=30)
plt.title("Train Sentence Length Distribution")
plt.xlabel("Length")
plt.ylabel("Count")
plt.show()


In [None]:
from collections import Counter

tag_counter = Counter()

for sent in train_sents:
    for _, tag in sent:
        tag_counter[tag] += 1

plt.figure(figsize=(6,4))
plt.bar(tag_counter.keys(), tag_counter.values())
plt.title("Label Distribution")
plt.xlabel("Label")
plt.ylabel("Count")
plt.show()

print("标签统计：")
print(tag_counter)


In [None]:
sentence = "迈向充满希望的新世纪"

print("Sentence:", sentence)
print("FMM 分词结果:", fmm_segment(sentence))
print("Jieba 分词结果:", jieba_segment(sentence))


In [None]:
crf = CRFPOSTagger()

# 训练 CRF
crf.train(train_sents)

# 预测示例
sample = [("迈", "B"), ("向", "E")]
print("CRF预测:", crf.predict(sample))


In [None]:
label_list = ["B", "M", "E", "S"]

loss_list = train_bert(
    train_path=train_path,
    dev_path=dev_path,
    label_list=label_list,
    num_epochs=2
)

plt.figure(figsize=(6,4))
plt.plot(loss_list)
plt.title("BERT Training Loss Curve")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.show()


In [None]:
from transformers import BertTokenizerFast, BertForTokenClassification
import torch

model_path = "checkpoints/best_model"
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForTokenClassification.from_pretrained(model_path)
model.eval()

text = "迈向充满希望的新世纪"
tokens = tokenizer(text, return_tensors="pt")
logits = model(**tokens).logits
pred_ids = torch.argmax(logits, dim=-1)[0].tolist()

print("Tokens:", tokenizer.tokenize(text))
print("Pred IDs:", pred_ids)


In [None]:
def decode_to_words(text, pred_ids, label_list):
    labels = [label_list[i] for i in pred_ids[1:len(text)+1]]  # 去掉 CLS
    result = []
    word = ""

    for ch, tag in zip(text, labels):
        if tag == "B":
            if word:
                result.append(word)
            word = ch
        elif tag == "M":
            word += ch
        elif tag == "E":
            word += ch
            result.append(word)
            word = ""
        elif tag == "S":
            result.append(ch)

    if word:
        result.append(word)

    return result

print("预测分词：", decode_to_words(text, pred_ids, label_list))
