In [1]:
import os
from argparse import ArgumentParser
from pathlib import Path
import torch
from torch import nn
from transformers import AutoConfig, AutoTokenizer
from dataset import *
from torch.utils.data import DataLoader
from accelerate import Accelerator

from tqdm import tqdm

from utils import same_seeds
from model import *

import wandb

import numpy as np
import collections

In [2]:
@torch.no_grad()
def mc_predict(data_loader, model):
    model.eval()
    relevant = {}
    for batch in tqdm(data_loader):
        ids, input_ids, attention_masks, token_type_ids, labels = batch
        output = model(
            input_ids=input_ids.to(args.device),
            attention_mask=attention_masks.to(args.device),
            token_type_ids=token_type_ids.to(args.device),
        )
        pred = output.logits.argmax(dim=-1).cpu().numpy()
        for _id, _pred in zip(ids, pred):
            relevant[_id] = int(_pred)

    return relevant

In [19]:
@torch.no_grad()
def qa_predict(args, data_loader, model, n_best = 1):
    ret = []
    model.eval()
    for batch in tqdm(data_loader):
        answers = []

        ids, inputs = batch
        context = inputs["context"][0]
        input_ids = inputs["input_ids"].to(args.device)
        token_type_ids = inputs["token_type_ids"].to(args.device)
        attention_mask = inputs["attention_mask"].to(args.device)

        qa_output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            )
        
        start_logits = qa_output.start_logits.cpu().numpy()
        end_logits = qa_output.end_logits.cpu().numpy()
        for i in range(len(input_ids)):
            start_logit = start_logits[i]
            end_logit = end_logits[i]
            offsets = inputs["offset_mapping"][i]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()

            for start_index in start_indexes:
                for end_index in end_indexes:
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    if end_index < start_index:
                        continue

                    answers.append(
                        {
                            "text": context[offsets[start_index][0] : offsets[end_index][1]],
                            "logit_score": start_logit[start_index] + end_logit[end_index],
                        }
                    )
        best_answer = max(answers, key=lambda x: x["logit_score"])
        ret.append((ids[0], best_answer["text"]))
    return ret  



In [4]:
def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--seed", type=int, default=5920)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument(
        "--data_dir",
        type=Path,
        help="Directory to the dataset.",
        default=".",
    )
    parser.add_argument("--model_name", type=str, default="hfl/chinese-macbert-large")
    parser.add_argument(
        "--cache_dir",
        type=str,
        help="Directory to save the cache file.",
        default="./cache",
    )
    parser.add_argument(
        "--ckpt_dir",
        type=Path,
        help="Directory to save the cache file.",
        default="/auto/extra/jeff999955",
    )
    parser.add_argument("--max_len", type=int, default=512)

    parser.add_argument("--batch_size", type=int, default=1)

    args = parser.parse_args(args = [])
    return args



In [5]:
args = parse_args()

In [6]:
same_seeds(args.seed)
accelerator = Accelerator(fp16=True)
tags = ["mc", "qa"]
ckpt, config, tokenizer = {}, {}, {}
for tag in tags:
    ckpt[tag] = torch.load(os.path.join(args.ckpt_dir, f"{tag}.ckpt"))
    namae = ckpt[tag]["name"] 
    config[tag] = AutoConfig.from_pretrained(namae)
    tokenizer[tag] = AutoTokenizer.from_pretrained(
        namae, config=config[tag], model_max_length=args.max_len, use_fast=True
    )

model = MultipleChoiceModel(args, config["mc"], ckpt["mc"]["name"])
model.load_state_dict(ckpt['mc']['model'])
test_set = MultipleChoiceDataset(args, tokenizer["mc"], mode="test")
test_loader = DataLoader(
    test_set,
    collate_fn=test_set.collate_fn,
    shuffle=False,
    batch_size=1,
)
model, test_loader = accelerator.prepare(
        model, test_loader
)

Some weights of the model checkpoint at hfl/chinese-macbert-large were not used when initializing BertForMultipleChoice: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model check

In [7]:
relevant = mc_predict(test_loader, model)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2213/2213 [02:20<00:00, 15.77it/s]


In [68]:
torch.save(relevant, "relevant.dat")

In [13]:
test_set = QuestionAnsweringDataset(args, tokenizer["qa"], mode="test", relevant = relevant)
test_loader = DataLoader(
    test_set,
    collate_fn=test_set.collate_fn,
    shuffle=False,
    batch_size=1,
)
model = QuestionAnsweringModel(args, config["qa"], ckpt["qa"]["name"])
model.load_state_dict(ckpt["qa"]["model"])
model, test_loader = accelerator.prepare(
        model, test_loader
)

Preprocessing QA test Data:


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2213/2213 [00:00<00:00, 582710.45it/s]
Some weights of the model checkpoint at hfl/chinese-macbert-large were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly ide

In [21]:
answers = qa_predict(args, test_loader, model, n_best = 20)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2213/2213 [02:12<00:00, 16.67it/s]


In [32]:
for test in test_set[:5]:
    print(test['id'], test['question'])

5e7a923dd6e4ccb8730eb95230e0c908 卡利創立的網際網路檔案館要到什麼時後才開放存取？
a2e9cd802197b8f8dfbe235e2761f9ed 哪個國家在歐洲具有重要的戰略意義甚至遠超過了其自身價值?
c7c8a85b3f0006d44d86510a22193620 目前所知「義和拳」這一個名詞最早於哪一年時出現?
7f4f68726faed6b987e348340a9e6a61 葉門是世界上經濟最落後的國家之一其主要倚賴什麼收入?
89908ef5182021a9aec1472c5bbcbd8c 北京地質學院博物館後來演變成哪一個博物館?


In [33]:
print(answers[:5])

[('5e7a923dd6e4ccb8730eb95230e0c908', '時光機'), ('a2e9cd802197b8f8dfbe235e2761f9ed', '普法茲選侯國'), ('c7c8a85b3f0006d44d86510a22193620', '1779年'), ('7f4f68726faed6b987e348340a9e6a61', '石油收入'), ('89908ef5182021a9aec1472c5bbcbd8c', '中國地質大學逸夫博物館')]


In [34]:
with open('sub.csv', 'w') as f:
    print('id,answer', file = f)
    for _id, answer in answers:
        if '「' in answer and '」' not in answer:
            answer += '」'
        elif '「' not in answer and '」' in answer:
            answer = '「' + answer
        if '《' in answer and '》' not in answer:
            answer += '》'
        elif '《' not in answer and '》' in answer:
            answer = '《' + answer
        answer = answer.replace(',', '')
        print(f'{_id},{answer}', file = f)

In [35]:
len(answers)

2213