In [4]:
import torch
from transformers import AutoTokenizer, BertForMaskedLM
from run_relm import PTuningWrapper

def load_model_and_tokenizer(model_path, pretrained_model_path, prompt_length=1):
    """載入模型與 tokenizer"""
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
    base_model = BertForMaskedLM.from_pretrained(pretrained_model_path)
    model = PTuningWrapper(base_model, prompt_length)
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"), weights_only=True))
    model.eval()
    return tokenizer, model

def predict_sentence(sentence, tokenizer, model, prompt_length):
    """對單句進行模型推論與修正，處理長句分段"""
    max_len = 128 - 2  # 扣除 [CLS] 和 [SEP] 的長度
    tokenized = tokenizer(sentence, return_offsets_mapping=True, add_special_tokens=False)
    input_ids = tokenized["input_ids"]

    # 將長句分段處理
    segments = [input_ids[i:i + max_len] for i in range(0, len(input_ids), max_len)]
    corrected_text = ""

    for segment in segments:
        inputs = tokenizer.decode(segment, skip_special_tokens=True)
        inputs = tokenizer(inputs, return_tensors="pt", max_length=128, padding="max_length", truncation=True)
        segment_ids = inputs["input_ids"]

        # 生成 prompt_mask
        prompt_mask = torch.zeros_like(segment_ids)
        prompt_mask[:, :2 * prompt_length] = 1

        # 模型推論
        with torch.no_grad():
            outputs = model(input_ids=segment_ids, 
                            attention_mask=inputs["attention_mask"], 
                            prompt_mask=prompt_mask,  
                            apply_prompt=True)
            predictions = torch.argmax(outputs.logits, dim=-1)

        # 解碼並清理 token
        predicted_tokens = tokenizer.convert_ids_to_tokens(predictions[0])
        input_tokens = tokenizer.convert_ids_to_tokens(segment_ids[0])

        corrected_text += clean_tokens_with_numbers(input_tokens[1:], predicted_tokens[1:])

    return corrected_text

def clean_tokens_with_numbers(input_tokens, predicted_tokens):
    """清理預測的 tokens，保留數字不變，處理 ## 拼接，排除中文數字修改，跳過 UNK token"""
    chinese_numerals = set("零一二三四五六七八九十百千萬億")
    clean_text = ""
    for input_token, predicted_token in zip(input_tokens, predicted_tokens):
        if input_token in ["[CLS]", "[SEP]", "[PAD]"]:
            continue
        # 跳過 UNK token
        if predicted_token == "[UNK]":
            clean_text += input_token
        # 保留數字與中文數字不變
        elif input_token in chinese_numerals or input_token.isdigit():
            clean_text += input_token
        elif predicted_token.startswith("##"):
            clean_text += predicted_token[2:]
        else:
            clean_text += predicted_token
    return clean_text

def process_article(article_text, tokenizer, model, prompt_length=1):
    """將文章分句並進行修正，輸出完整修正後的文章"""
    sentences = [s + "。" for s in article_text.split("。") if s]  # 分句並保留句號
    corrected_sentences = []
    for sentence in sentences:
        corrected_sentence = predict_sentence(sentence, tokenizer, model, prompt_length)
        corrected_sentences.append(corrected_sentence)
    return "".join(corrected_sentences)

# === 主程式 ===
if __name__ == "__main__":
    # 設定路徑
    model_path = "./model/Judgement_271K_filtered/step-10200_f1-76.67.bin"  # 替換成你的模型權重路徑
    pretrained_model_path = "bert-base-chinese"
    prompt_length = 1

    # 載入模型與 tokenizer
    tokenizer, model = load_model_and_tokenizer(model_path, pretrained_model_path, prompt_length)

    # 輸入文章字串
    article_text = "經查，原告起訴其前受雇於被告，因受到職業災害，主張被告應給付原領工資補償，被告則以原告另向勞動部勞工保險局（下稱勞保局）領取傷病給付，以及向新光產物保險股份有限公司（下稱新光產險公司）請領團體保險金，反訴求原告返還，核雙方所主張之權利，係基於同一事件所衍生之爭執，兩訴言詞辯論之資料可相互利用，且對於當事人間紛爭之一次解決及訴訟經濟有利，亦無不得提起反訴之情形，揆諸前揭說明，被告提起反訴，於法尚無不合，應予准許。"
    
    # 修正文章
    corrected_article = process_article(article_text, tokenizer, model, prompt_length)
    
    # 輸出結果
    print("輸入文章:", article_text)
    print("修正後文章:", corrected_article)


Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


輸入文章: 經查，原告起訴其前受雇於被告，因受到職業災害，主張被告應給付原領工資補償，被告則以原告另向勞動部勞工保險局（下稱勞保局）領取傷病給付，以及向新光產物保險股份有限公司（下稱新光產險公司）請領團體保險金，反訴求原告返還，核雙方所主張之權利，係基於同一事件所衍生之爭執，兩訴言詞辯論之資料可相互利用，且對於當事人間紛爭之一次解決及訴訟經濟有利，亦無不得提起反訴之情形，揆諸前揭說明，被告提起反訴，於法尚無不合，應予准許。
修正後文章: 經查，原告起訴其前受雇於被告，因受到職業災害，主張被告應給付原領工資補償，被告則以原告另向勞動部勞工保險局（下稱勞保局）領取傷病給付，以及向新光產物保險股份有限公司（下稱新光產險公司）請領團體保險金，反請求原告返還，核雙方所主張之權利，係基於同一事件所所生之爭執，兩訴言詞辯論之資料可相互利用，且對於當事人間紛爭之一次解決及訴訟經濟有利，亦無不得提起反訴之情形，揆諸前揭說明，被告提起反訴，於法尚無不合，應予准許。
