In [1]:
import torch
from transformers import AutoTokenizer, BertForMaskedLM
from run_relm import PTuningWrapper

def load_model_and_tokenizer(model_path, pretrained_model_path, prompt_length=1):
    """載入模型與 tokenizer"""
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
    base_model = BertForMaskedLM.from_pretrained(pretrained_model_path)
    model = PTuningWrapper(base_model, prompt_length)
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"), weights_only=True))
    model.eval()
    return tokenizer, model

def predict_sentence(sentence, tokenizer, model, prompt_length):
    """對單句進行模型推論與修正"""
    # Tokenize sentence
    inputs = tokenizer(sentence, return_tensors="pt", max_length=128, padding="max_length", truncation=True)
    input_ids = inputs["input_ids"]

    # 生成 prompt_mask
    prompt_mask = torch.zeros_like(input_ids)
    prompt_mask[:, :2 * prompt_length] = 1

    # 模型推論
    with torch.no_grad():
        outputs = model(input_ids=input_ids, 
                        attention_mask=inputs["attention_mask"], 
                        prompt_mask=prompt_mask,  
                        apply_prompt=True)
        predictions = torch.argmax(outputs.logits, dim=-1)

    # 解碼並清理 token
    predicted_tokens = tokenizer.convert_ids_to_tokens(predictions[0])
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    # 清理 token
    return clean_tokens_with_numbers(input_tokens[1:], predicted_tokens[1:])

def clean_tokens_with_numbers(input_tokens, predicted_tokens):
    """清理預測的 tokens，保留數字不變，處理 ## 拼接"""
    clean_text = ""
    for input_token, predicted_token in zip(input_tokens, predicted_tokens):
        if input_token in ["[CLS]", "[SEP]", "[PAD]"]:
            continue
        if input_token.isdigit() and predicted_token.isdigit() and input_token != predicted_token:
            clean_text += input_token
        elif predicted_token.startswith("##"):
            clean_text += predicted_token[2:]
        else:
            clean_text += predicted_token
    return clean_text

def process_article(article_text, tokenizer, model, prompt_length=1):
    """將文章分句並進行修正，輸出完整修正後的文章"""
    sentences = [s + "。" for s in article_text.split("。") if s]  # 分句並保留句號
    corrected_sentences = []
    for sentence in sentences:
        corrected_sentence = predict_sentence(sentence, tokenizer, model, prompt_length)
        corrected_sentences.append(corrected_sentence)
    return "".join(corrected_sentences)

# === 主程式 ===
if __name__ == "__main__":
    # 設定路徑
    model_path = "D:/NCU/FirstSemester/LegalAI/relm_autocorrection/Judgement_Process_Artical/step-10200_f1-76.67.bin"  # 替換成你的模型權重路徑
    pretrained_model_path = "bert-base-chinese"
    prompt_length = 1

    # 載入模型與 tokenizer
    tokenizer, model = load_model_and_tokenizer(model_path, pretrained_model_path, prompt_length)

    # 輸入文章字串
    article_text = "五、如不服本裁定，應於裁定送達後10日內，以書狀向本院司法事務官提春異議。本案依法判決，合議庭審理完畢。"
    
    # 修正文章
    corrected_article = process_article(article_text, tokenizer, model, prompt_length)
    
    # 輸出結果
    print("輸入文章:", article_text)
    print("修正後文章:", corrected_article)


  from .autonotebook import tqdm as notebook_tqdm
BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a mod

輸入文章: 五、如不服本裁定，應於裁定送達後10日內，以書狀向本院司法事務官提春異議。本案依法判決，合議庭審理完畢。
修正後文章: 三、如不服本裁定，應於裁定送達後10日內，以書狀向本院司法事務官提出異議。本案依法判決，合議庭審理完畢。
