In [19]:
import os
from argparse import ArgumentParser
from pathlib import Path
import torch
from torch import nn
from transformers import AutoConfig, AutoTokenizer
from dataset import *
from torch.utils.data import DataLoader
from accelerate import Accelerator

from tqdm import tqdm

from utils import same_seeds
from model import *

import wandb

import numpy as np
import collections

In [20]:
@torch.no_grad()
def mc_predict(data_loader, model):
    model.eval()
    relevant = {}
    for batch in tqdm(data_loader):
        ids, input_ids, attention_masks, token_type_ids, labels = batch
        output = model(
            input_ids=input_ids.to(args.device),
            attention_mask=attention_masks.to(args.device),
            token_type_ids=token_type_ids.to(args.device),
        )
        pred = output.logits.argmax(dim=-1).cpu().numpy()
        for _id, _pred in zip(ids, pred):
            relevant[_id] = int(_pred)

    return relevant

In [21]:
@torch.no_grad()
def qa_predict(args, data_loader, model, n_best = 1):
    ret = []
    model.eval()
    for batch in tqdm(data_loader):
        answers = []

        ids, inputs = batch
        input_ids = inputs["input_ids"].to(args.device)
        token_type_ids = inputs["token_type_ids"].to(args.device)
        attention_mask = inputs["attention_mask"].to(args.device)

        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            )
        
        start_logits = qa_output.start_logits.cpu().numpy()
        end_logits = qa_output.end_logits.cpu().numpy()
        for i in range(len(input_ids)):
            start_logit = start_logits[i]
            end_logit = end_logits[i]
            offsets = eval_set["offset_mapping"][i]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()

            for start_index in start_indexes:
                for end_index in end_indexes:
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    if end_index < start_index:
                        continue

                    answers.append(
                        {
                            "text": context[offsets[start_index][0] : offsets[end_index][1]],
                            "logit_score": start_logit[start_index] + end_logit[end_index],
                        }
                    )
        best_answer = max(answers, key=lambda x: x["logit_score"])
        ret.append((ids[0], best_answer["text"]))
    return ret  



In [22]:
def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--seed", type=int, default=5920)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument(
        "--data_dir",
        type=Path,
        help="Directory to the dataset.",
        default=".",
    )
    parser.add_argument("--model_name", type=str, default="hfl/chinese-macbert-large")
    parser.add_argument(
        "--cache_dir",
        type=str,
        help="Directory to save the cache file.",
        default="./cache",
    )
    parser.add_argument(
        "--ckpt_dir",
        type=Path,
        help="Directory to save the cache file.",
        default="/auto/extra/jeff999955",
    )
    parser.add_argument("--max_len", type=int, default=512)

    parser.add_argument("--batch_size", type=int, default=1)

    args = parser.parse_args(args = [])
    return args



In [23]:
args = parse_args()

In [26]:
same_seeds(args.seed)
accelerator = Accelerator(fp16=True)
tags = ["mc", "qa"]
ckpt, config, tokenizer = {}, {}, {}
for tag in tags:
    ckpt[tag] = torch.load(os.path.join(args.ckpt_dir, f"{tag}.ckpt"))
    namae = ckpt[tag]["name"] 
    config[tag] = AutoConfig.from_pretrained(namae)
    tokenizer[tag] = AutoTokenizer.from_pretrained(
        namae, config=config[tag], model_max_length=args.max_len, use_fast=True
    )

model = MultipleChoiceModel(args, config["mc"], ckpt["mc"]["name"])
model.load_state_dict(ckpt['mc']['model'])
test_set = MultipleChoiceDataset(args, tokenizer["mc"], mode="test")
test_loader = DataLoader(
    test_set,
    collate_fn=test_set.collate_fn,
    shuffle=False,
    batch_size=1,
)
model, test_loader = accelerator.prepare(
        model, test_loader
)

Some weights of the model checkpoint at hfl/chinese-macbert-large were not used when initializing BertForMultipleChoice: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model check

In [27]:
relevant = mc_predict(test_loader, 
                      model)

  0%|                                                                                                                                               | 0/2213 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.70 GiB total capacity; 10.45 GiB already allocated; 7.56 MiB free; 10.47 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [68]:
torch.save(relevant, "relevant.dat")

In [71]:
test_set = QuestionAnsweringDataset(args, tokenizer["qa"], mode="test", relevant = relevant)
test_loader = DataLoader(
    test_set,
    collate_fn=test_set.collate_fn,
    shuffle=False,
    batch_size=1,
)
model, test_loader = accelerator.prepare(
        model, test_loader
)

Preprocessing QA Data:


100%|████████████████████████████████████████████████████████████████████████████████████| 2213/2213 [00:00<00:00, 1386406.98it/s]


In [72]:
model = QuestionAnsweringModel(args, config["qa"], ckpt["qa"]["name"])
print(model)

Some weights of the model checkpoint at hfl/chinese-macbert-large were not used when initializing BertForQuestionAnswering: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the

QuestionAnsweringModel(
  (model): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(21128, 1024, padding_idx=0)
        (position_embeddings): Embedding(512, 1024)
        (token_type_embeddings): Embedding(2, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=1024, out_feature

In [73]:
model.state_dict


<bound method Module.state_dict of QuestionAnsweringModel(
  (model): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(21128, 1024, padding_idx=0)
        (position_embeddings): Embedding(512, 1024)
        (token_type_embeddings): Embedding(2, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): L

In [78]:
ckpt['qa'] = torch.load('./ckpt/qa_loss.ckpt', map_location = 'cpu')
print(ckpt['qa']['model'])

OrderedDict([('model.bert.embeddings.position_ids', tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
