## 下游任务Demo：质量评估

### 1. 数据准备

In [None]:
from torch.utils.data import Dataset
from transformers import BertTokenizer as HfBertTokenizer
import torch
import tqdm
import pandas as pd

class IdsDataset(Dataset):
    def __init__(self, 
                 mode,
                 data_path="../data",
                 tokenizer=None,
                 ):
        self.mode = mode
        self.data = pd.read_csv(data_path)
        self.data['answers']=self.data['answers'].fillna('')
        self.data = self.data.to_dict(orient="records") # [:20]

        self.tokenizer = tokenizer
        self.preprocess()

    def preprocess(self):
        for item in tqdm(self.data):
            return_tensors = None if isinstance(self.tokenizer, HfBertTokenizer) else False

            answers_encodings = self.tokenizer(item['contexts'], truncation=True, padding=True, max_length=512, return_tensors=return_tensors)
            contexts_encodings = self.tokenizer(item['answers'], truncation=True, padding=True, max_length=512, return_tensors=return_tensors)
            item["answers_encodings"] = answers_encodings
            item["contexts_encodings"] = contexts_encodings

    def __getitem__(self, index):
        item = self.data[index]
        return item

    def __len__(self):
        return len(self.data)
    
    def collate_fn(self, batch_data):
        first =  batch_data[0]
        batch = {
            k: [item[k] for item in batch_data] for k in first.keys()
        }
        batch["answers_encodings"] = self.tokenizer.batch_pad(batch["answers_encodings"], return_tensors=True)
        batch["contexts_encodings"] = self.tokenizer.batch_pad(batch["contexts_encodings"], return_tensors=True)
        batch["score"] = torch.as_tensor(batch["score"])
        batch["label"] = torch.as_tensor(batch["label"])
        return batch

In [None]:
# 以 Bert 为例
from EduNLP.Pretrain import BertTokenizer
from torch.utils.data import DataLoader

tokenizer = BertTokenizer.from_pretrained(path="/path/to/bert/checkpoint")
trainData = IdsDataset(mode='train', data_path="/path/to/train.csv", tokenizer=tokenizer)
validData = IdsDataset(mode='valid', data_path="/path/to/valid.csv", tokenizer=tokenizer)
testData = IdsDataset(mode='test', data_path="/path/to/test.csv", tokenizer=tokenizer)

train_Dataloader = DataLoader(trainData, shuffle=True, num_workers=0, pin_memory=True, collate_fn=trainData.collate_fn)
valid_Dataloader = DataLoader(validData, shuffle=True, num_workers=0, pin_memory=True, collate_fn=validData.collate_fn)
test_Dataloader = DataLoader(testData, shuffle=True, num_workers=0, pin_memory=True, collate_fn=testData.collate_fn)

### 2. 质量评估

In [None]:
from torch import nn
from transformers import BertModel
from transformers.modeling_outputs import ModelOutput
from EduNLP.ModelZoo.base_model import BaseModel
import os
import json


class Global_Layer(nn.Module):
    """
    MultiModal global layer
    """
    def __init__(self, input_unit, output_unit, hidden_size, dropout_rate=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_unit, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, output_unit),
            # nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        x = self.net(x)
        return x
    
    
class QualityLoss(nn.Module):
    def __init__(self, mode='train'):
        super(QualityLoss, self).__init__()
        if mode=='train':
            self.classify_loss_fn = nn.CrossEntropyLoss()
            self.logits_loss_fn = nn.MSELoss()
        else:
            self.classify_loss_fn = nn.CrossEntropyLoss(reduction='sum')
            self.logits_loss_fn = nn.MSELoss(reduction='sum')

    def forward(self, pred_score, pred_label,score, label, lamb=0.5):
        # Loss
        score_loss = self.logits_loss_fn(pred_score, score.float())
        label_loss = self.classify_loss_fn(pred_label, label)
        losses = score_loss*lamb + label_loss*(1-lamb)
        return losses
    
    
class TrainForQualityOutput(ModelOutput):
    loss: torch.FloatTensor = None
    score_logits: torch.FloatTensor = None
    label_logits: torch.FloatTensor = None


class QualityModel(BaseModel):
    def __init__(self, pretrained_model_type="bert", pretrained_model_dir=None, emb_mode="index", hidden_size=None, num_labels=3, dropout_rate=0.5):
        super().__init__()
        self.pretrained_model_type = pretrained_model_type
        self.emb_mode = emb_mode
        self.num_labels = num_labels
        if emb_mode == "index":
            assert hidden_size is None
            self.bert = BertModel.from_pretrained(pretrained_model_dir)
            self.hidden_size = self.bert.config.hidden_size # 768
        else: # vector
            assert hidden_size is not None
            self.hidden_size = hidden_size

        self.score_decoder = Global_Layer(input_unit=self.hidden_size*2,
                                            output_unit=1,
                                            hidden_size=self.hidden_size,
                                            dropout_rate=dropout_rate)
        self.label_decoder = Global_Layer(input_unit=self.hidden_size*2,
                                            output_unit=num_labels,
                                            hidden_size=self.hidden_size,
                                            dropout_rate=dropout_rate)        
        self.criterion = QualityLoss()

        self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "bert"]}
        self.config['architecture'] = 'QualityModel'

    def forward(self,
                context_vectors=None,
                answer_vectors=None,
                contexts_encodings=None,
                answers_encodings=None,
                score=None,
                label=None,
                **argv,
                ):
        """
        batch_sentences : [batch_size, seq]
        """
        if self.emb_mode == "index":
            if self.pretrained_model_type in ["bert", "roberta"]:
                contexts_encoder_out = self.bert(**contexts_encodings)
                answers_encoder_out = self.bert(**answers_encodings)
                context_vectors = contexts_encoder_out[1] # [batch_size,  hidden_size]
                answer_vectors = answers_encoder_out[1] # [batch_size,  hidden_size]

            elif self.pretrained_model_type == "jiuzhang":
                contexts_encoder_out = self.bert(
                                                input_ids=contexts_encodings["input_ids"],
                                                attention_mask=contexts_encodings["attention_mask"],
                                                )
                answers_encoder_out = self.bert(
                                                input_ids=answers_encodings["input_ids"],
                                                attention_mask=answers_encodings["attention_mask"],
                                                )
                context_vectors = contexts_encoder_out["last_hidden_state"][:, 0, :]
                answer_vectors = answers_encoder_out["last_hidden_state"][:, 0, :]
            
            elif self.pretrained_model_type == "disenq":
                contexts_encoder_out = self.bert(**contexts_encodings)
                answers_encoder_out = self.bert(**answers_encodings)
                context_vectors = contexts_encoder_out[1]
                answer_vectors = answers_encoder_out[1]
        else:
            assert context_vectors is not None and answer_vectors is not None
        pooler_state = torch.cat([context_vectors, answer_vectors],dim=-1)
        score_logits = self.score_decoder(pooler_state).squeeze(-1)
        label_logits = self.label_decoder(pooler_state)
        
        loss = None
        if score is not None and label is not None:
            loss = self.criterion(score_logits, label_logits, score, label, lamb=0.5)

        return TrainForQualityOutput(
            loss=loss,
            score_logits=score_logits,
            label_logits=label_logits,
        )
    
    @classmethod
    def from_config(cls, config_path, **kwargs):
        with open(config_path, "r", encoding="utf-8") as rf:
            model_config = json.load(rf)
            model_config.update(kwargs)
            return cls(
                pretrained_model_dir=model_config["pretrained_model_dir"],
                emb_mode=model_config["emb_mode"],
                hidden_size=model_config["hidden_size"],
                num_labels=model_config["num_labels"],
                dropout_rate=model_config["dropout_rate"],
                pretrained_model_type=model_config["pretrained_model_type"]
            )
    
    def save_config(self, config_dir):
        config_path = os.path.join(config_dir, "config.json")
        with open(config_path, "w", encoding="utf-8") as wf:
            json.dump(self.config, wf, ensure_ascii=False, indent=2)

In [None]:
from train import MyTrainer

# Initial model
checkpoint_dir="your/checkpoint_dir"
device = "cuda:0"
model = QualityModel.from_pretrained(checkpoint_dir).to(device)
trainer = MyTrainer(
    model=model,
)
trainer.train(train_Dataloader, valid_Dataloader)
trainer.valid(valid_Dataloader)
trainer.valid(test_Dataloader)