In [None]:
import json
import random
import csv
import os
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import os
import ast
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
def load_and_filter_data(file_path, max_samples=None):
    """加载并过滤数据，只保留realCount=1的数据"""
    filtered_data = []

    if not os.path.exists(file_path):
        print(f"文件 {file_path} 不存在")
        return []

    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if max_samples and i >= max_samples:
                break

            try:
                data = json.loads(line.strip())
                if data['realCount'] == 1:
                    filtered_data.append(data)
            except json.JSONDecodeError:
                continue

    print(f"从 {file_path} 中过滤出 {len(filtered_data)} 条 realCount=1 的数据")
    return filtered_data

In [None]:
train_data = load_and_filter_data('/personal/Day4/idiom_cloze_project/dataset_filtered/train_data.txt', max_samples=1000)
test_data = load_and_filter_data('/personal/Day4/idiom_cloze_project/dataset_filtered/test_data.txt', max_samples=100)

In [None]:
class DPODataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_list = []

        for idx in range (len(data)):
            content = data[idx]["content"].replace("#idiom#","[MASK][MASK][MASK][MASK]")
            content = self.tokenizer(content,max_length=self.max_length,padding="max_length",return_tensors="pt")
            groundTruth = data[idx]["groundTruth"][0]
            groundTruth = self.tokenizer.encode(groundTruth)
            candidate_encodings = []
            for cand_idx in range (7):
                if data[idx]["candidates"][0][cand_idx] == data[idx]["groundTruth"][0]:
                    continue
                candidate = data[idx]["candidates"][0][cand_idx]
                candidate = self.tokenizer.encode(candidate)
                candidate = torch.LongTensor(candidate)
                candidate_encodings.append(candidate)
            sample = {
                "groundTruth": torch.LongTensor(groundTruth)[1:5],
                "candidates": torch.stack(candidate_encodings)[:,1:5],
                "content_ids": content["input_ids"].squeeze(),
                "attention_mask": content["attention_mask"].squeeze()
            }
            self.data_list.append(sample)
            
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):

        return self.data_list[idx]

In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_name = "/personal/RoBERTa-wwm-ext/"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

In [None]:
dataset = DPODataset(train_data,tokenizer)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset,batch_size=2,num_workers=0,shuffle=True)

In [None]:
test = DPODataset(train_data,tokenizer)
val_loader = torch.utils.data.DataLoader(dataset,batch_size=2,num_workers=0,shuffle=True)

In [None]:
def train(model, dataloader, val_loader, epochs = 1):
    device = "cuda"
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)

    for _ in range (epochs):
        total_loss = 0
        batch_num = 0 
        for batch in dataloader:
            model.train()
            groundTruth = batch["groundTruth"].to(device)
            candidates = batch["candidates"].to(device)
            content = batch["content_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            output = model(input_ids = content, attention_mask = attention_mask)
            mask_token_index = [torch.where(content[i] == tokenizer.mask_token_id) for i in range(content.size(0))]
            logits = output.logits
            probs = torch.softmax(logits,dim=-1)
    
            cand_probs = 0
            for b in range (content.size(0)):
                total_probs = 0
                idiom_probs = probs[b][mask_token_index[b][0].tolist()]
                for  c in range (candidates.size(1)):
                    cand_prob = torch.sum(torch.log(idiom_probs[torch.arange(4),candidates[b][c].tolist()]))
                    total_probs += cand_prob
                cand_probs += total_probs / candidates.size(1)
            cand_probs /= content.size(0)
    
            gt_probs = 0
            for b in range (content.size(0)):
                idiom_probs = probs[b][mask_token_index[b][0].tolist()]
                gt_probs +=  torch.sum(torch.log(idiom_probs[torch.arange(4),candidates[b][0].tolist()]))
            gt_probs /= content.size(0)
    
            loss = -F.logsigmoid( gt_probs - cand_probs ) * 2.0
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            batch_num += 1
            
        print(total_loss/batch_num)

    model.eval()
    with torch.no_grad():
        total_correct = 0
        for batch in val_loader:
            groundTruth = batch["groundTruth"].to(device)
            candidates = batch["candidates"].to(device)
            content = batch["content_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            output = model(input_ids = content, attention_mask = attention_mask)
            mask_token_index = [torch.where(content[i] == tokenizer.mask_token_id) for i in range(content.size(0))]
            logits = output.logits
            probs = torch.softmax(logits,dim=-1)
    
            cand_probs = {}
            for b in range (content.size(0)):
                idiom_probs = probs[b][mask_token_index[b][0].tolist()]
                batch_cand_probs = []
                for c in range (candidates.size(1)):
                    cand_prob = torch.prod(idiom_probs[torch.arange(4),candidates[b][c].tolist()])
                    batch_cand_probs.append(cand_prob.detach().cpu().item())
                cand_probs[b] = batch_cand_probs
    
            for b in range (content.size(0)):
                idiom_probs = probs[b][mask_token_index[b][0].tolist()]
                if torch.prod(idiom_probs[torch.arange(4),candidates[b][0].tolist()]) >= max(cand_probs[b]):
                    total_correct += 1
    
        print(total_correct / 100)

In [None]:
train(model, dataloader, val_loader, epochs = 1)