# 任务说明
1. 单选题其实就是根据给定的文本内容、题目、选项，最终返回一个得分最高的选项
2. 通常可以转换为文本匹配 或者 分类问题

## 整体框架
1. 将内容、题目、选项组装为若干条句子，格式如下
![image.png](attachment:image.png)

2. 然后通过每条句子做一个分类，得到的结果在通过softmax取出最佳结果

In [13]:
# 下载数据
# !wget https://storage.googleapis.com/cluebenchmark/tasks/c3_public.zip
# !unzip c3_public.zip -d c3_public

# windows不支持wget和unzip，手动下载解压

data_file = '../dataset/c3_public/'

In [14]:
import codecs
import json
import numpy as np

In [15]:
train = json.load(open(data_file + 'd-train.json', encoding='utf-8')) + json.load(open(data_file + 'm-train.json', encoding='utf-8'))
val = json.load(open(data_file +  'm-dev.json', encoding='utf-8')) + json.load(open(data_file + 'd-dev.json', encoding='utf-8'))

In [16]:
print(len(train), len(val))

8023 2674


In [17]:
train[0]

[['男：你今天晚上有时间吗?我们一起去看电影吧?', '女：你喜欢恐怖片和爱情片，但是我喜欢喜剧片，科幻片一般。所以……'],
 [{'question': '女的最喜欢哪种电影?',
   'choice': ['恐怖片', '爱情片', '喜剧片', '科幻片'],
   'answer': '喜剧片'}],
 '25-35']

In [18]:
# 构建label，label为answer在choice列表的下标

train_label = [x[1][0]['choice'].index(x[1][0]['answer']) for x in train]
val_label = [x[1][0]['choice'].index(x[1][0]['answer']) for x in val]

In [19]:
train_label[0]

2

In [20]:
import torch
from transformers import BertTokenizer

model_path = r"D:\code\personal\models\bert-base-chinese"

# num_choices=4时，意味着每个问题有四个可能的答案选项。
# BertTokenizer会根据这个参数来适当调整输入的编码方式，以确保每个选项都能被模型正确处理
tokenizer = BertTokenizer.from_pretrained(model_path, num_choices=4)

In [37]:
# 数据格式处理
# 将文章问题选项拼在一起后，得到分词后的数字id，输出的size是(batch, n_choices, max_len)
def collate_fn(data):
    input_ids, attention_mask, token_type_ids = [], [], []
    for x in data:
        # 这里把question+choice放在前面，把content放在后面，原因是防止截取时把问答内容截取
        text = tokenizer(x[1], 
                         text_pair=x[0], 
                         padding='max_length', 
                         truncation=True, 
                         max_length=128, 
                         return_tensors='pt')
        
        input_ids.append(text['input_ids'].tolist())
        attention_mask.append(text['attention_mask'].tolist())
        token_type_ids.append(text['token_type_ids'].tolist())
        
    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    token_type_ids = torch.tensor(token_type_ids)
    label = torch.tensor([x[-1] for x in data])
    return input_ids, attention_mask, token_type_ids, label


In [25]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset

class TextDataset(Dataset):
    def __init__(self, datas, labels):
        self.datas = datas
        self.labels = labels
        
    def __getitem__(self, idx):
        label = self.labels[idx]
        # 把content拼接为一个完整句子
        content = '。'.join(self.datas[idx][0])
        question = self.datas[idx][1][0]['question']
        choice = self.datas[idx][1][0]['choice']
        if len(choice) < 4:
            # 如果选项不满4个，补充"不知道"
            for i in range(4 - len(choice)):
                choice.append('不知道')
                
        # 复制content为4份
        content = [content for i in range(len(choice))]
        # 拼接question和选项
        pair = [question + ' ' + i for i in choice]
        
        return content, pair, label
        
    def __len__(self):
        return len(self.labels)
    

train_dataset = TextDataset(train, train_label)
test_dataset = TextDataset(val, val_label)

In [26]:
train_dataset[0]

(['男：你今天晚上有时间吗?我们一起去看电影吧?。女：你喜欢恐怖片和爱情片，但是我喜欢喜剧片，科幻片一般。所以……',
  '男：你今天晚上有时间吗?我们一起去看电影吧?。女：你喜欢恐怖片和爱情片，但是我喜欢喜剧片，科幻片一般。所以……',
  '男：你今天晚上有时间吗?我们一起去看电影吧?。女：你喜欢恐怖片和爱情片，但是我喜欢喜剧片，科幻片一般。所以……',
  '男：你今天晚上有时间吗?我们一起去看电影吧?。女：你喜欢恐怖片和爱情片，但是我喜欢喜剧片，科幻片一般。所以……'],
 ['女的最喜欢哪种电影? 恐怖片', '女的最喜欢哪种电影? 爱情片', '女的最喜欢哪种电影? 喜剧片', '女的最喜欢哪种电影? 科幻片'],
 2)

In [28]:
import torch
from transformers import BertForMultipleChoice, AdamW, get_linear_schedule_with_warmup

model = BertForMultipleChoice.from_pretrained(model_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at D:\code\personal\models\bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForMultipleChoice(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [35]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [40]:
optim = AdamW(model.parameters(), lr=1e-5)
total_steps = len(train_loader) * 1
scheduler = get_linear_schedule_with_warmup(optim,
                                           num_warmup_steps=0,
                                           num_training_steps=total_steps)

In [43]:
from tqdm import tqdm

def train():
    model.train()
    total_train_loss = 0
    iter_num = 0
    totel_iter = total_steps
    for idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
        # 梯度清零
        optim.zero_grad()
        
        # 正向传播
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)
        outputs = model(input_ids, 
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        labels=labels)
        
        # outputs[0]: logits
        # outputs[1]: loss
        loss = outputs.loss
        
        if idx % 20 == 0:
            with torch.no_grad():
                print((outputs[1].argmax(1).data == labels.data).float().mean().item(), loss.item())
        
        total_train_loss += loss.item()
        loss.backward()  # 后向传播
        torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)  # 梯度裁剪
        optim.step()    # 参数更新
        scheduler.step()
        
        iter_num += 1
        if iter_num % 100 ==0:
            print("epoth: %d, iter_num: %d, loss: %.4f, %.2f%%" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))
        
    print("Epoch: %d, Average training loss: %.4f"%(epoch, total_train_loss/len(train_loader)))

In [None]:
def validation():
    model.eval()
    total_eval_accuracy = 0
    total_eval_loss = 0
    for (input_ids, attention_mask, token_type_ids, labels) in test_loader:
        with torch.no_grad():
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs[1]

        total_eval_loss += loss.item()
        logits = logits.detach().cpu().numpy()
        label_ids = labels.to('cpu').numpy()
        total_eval_accuracy += (outputs[1].argmax(1).data == labels.data).float().mean().item()
        
    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)
    print("Accuracy: %.4f" % (avg_val_accuracy))
    print("Average testing loss: %.4f"%(total_eval_loss/len(test_dataloader)))
    print("-------------------------------")

In [None]:
for epoch in range(4):
    print("------------Epoch: %d ----------------" % epoch)
    train()
    validation()
    