In [1]:
import gc
import os
import sys
import itertools
import pickle
from glob import glob
from tqdm import tqdm_notebook as tqdm

import numpy as np
import pandas as pd
from scipy.stats import spearmanr

from matplotlib import pyplot as plt
from matplotlib_venn import venn2, venn3
import seaborn as sns

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from transformers import BertConfig, BertTokenizer, BertModel, BertForMaskedLM#, BertLayer, BertEmbeddings
from transformers.modeling_bert import BertLayer, BertEmbeddings

In [2]:
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 500)

# re-load functions
%load_ext autoreload
%autoreload 2

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
%config InlineBackend.figure_formats = {'png', 'retina'}

In [3]:
DEVICE = 'cuda'

In [4]:
import sys
import pickle
from functools import partial
from glob import glob

import numpy as np
import pandas as pd
import scipy as sp
import torch
from scipy.stats import spearmanr
from tqdm import tqdm

class OptimizedRounder(object):
    """
    An optimizer for rounding thresholds
    to maximize Quadratic Weighted Kappa (QWK) score
    # https://www.kaggle.com/naveenasaithambi/optimizedrounder-improved
    """

    def __init__(self):
        self.coef_ = 0

    def _spearmanr_loss(self, coef, X, y, labels):
        """
        Get loss according to
        using current coefficients
        :param coef: A list of coefficients that will be used for rounding
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        X_p = pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                     [np.inf], labels=labels)

        # return -np.mean(spearmanr(y, X_p).correlation)
        return -spearmanr(y, X_p).correlation

    def fit(self, X, y, initial_coef):
        """
        Optimize rounding thresholds
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        labels = self.labels
        loss_partial = partial(self._spearmanr_loss, X=X, y=y, labels=labels)
        self.coef_ = sp.optimize.minimize(
            # loss_partial, initial_coef, method='Powell')
            loss_partial, initial_coef, method='nelder-mead')

    def predict(self, X, coef):
        """
        Make predictions with specified thresholds
        :param X: The raw predictions
        :param coef: A list of coefficients that will be used for rounding
        """
        labels = self.labels
        return pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                      [np.inf], labels=labels)
        # [np.inf], labels=[0, 1, 2, 3])

    def coefficients(self):
        """
        Return the optimized coefficients
        """
        return self.coef_['x']

    def set_labels(self, labels):
        self.labels = labels

In [5]:
# sys.path.append('../scripts/')
# from get_optR3 import compute_spearmanr, get_opt_y_pred

import os
import pickle
import sys
from functools import partial
from glob import glob

import numpy as np
import pandas as pd
import scipy as sp
import torch
from scipy.stats import spearmanr
from tqdm import tqdm


class histogramBasedCoefInitializer:
    def __init__(self):
        self.bins = None

    def fit(self, labels):
        self.bins = pd.Series(
            labels).value_counts().sort_index().cumsum().values
        return self

    def predict(self, preds):
        preds = sorted(preds)
        res_threshs = []
        if self.bins is None:
            raise Exception('plz fit at first.')
        for _bin in self.bins[:-1]:
            res_threshs.append((preds[_bin - 1] + preds[_bin]) / 2)
        return res_threshs


class OptimizedRounder(object):
    """
    An optimizer for rounding thresholds
    to maximize Quadratic Weighted Kappa (QWK) score
    # https://www.kaggle.com/naveenasaithambi/optimizedrounder-improved
    """

    def __init__(self):
        self.coef_ = 0

    def _spearmanr_loss(self, coef, X, y, labels):
        """
        Get loss according to
        using current coefficients
        :param coef: A list of coefficients that will be used for rounding
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        X_p = pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                     [np.inf], labels=labels)

        # return -np.mean(spearmanr(y, X_p).correlation)
        return -spearmanr(y, X_p).correlation

    def fit(self, X, y, initial_coef):
        """
        Optimize rounding thresholds
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        labels = self.labels
        loss_partial = partial(self._spearmanr_loss, X=X, y=y, labels=labels)
        self.coef_ = sp.optimize.minimize(
            loss_partial, initial_coef, method='nelder-mead')

    def predict(self, X, coef):
        """
        Make predictions with specified thresholds
        :param X: The raw predictions
        :param coef: A list of coefficients that will be used for rounding
        """
        labels = self.labels
        return pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                      [np.inf], labels=labels)
        # [np.inf], labels=[0, 1, 2, 3])

    def coefficients(self):
        """
        Return the optimized coefficients
        """
        return self.coef_['x']

    def set_labels(self, labels):
        self.labels = labels


def compute_spearmanr(trues, preds):
    rhos = []
    for col_trues, col_pred in zip(trues.T, preds.T):
        if len(np.unique(col_pred)) == 1:
            if col_pred[0] == np.max(col_trues):
                col_pred[np.argmin(
                    col_pred)] = np.min(col_trues)
            else:
                col_pred[np.argmax(
                    col_pred)] = np.max(col_trues)
        rhos.append(
            spearmanr(
                col_trues,
                col_pred
                #                  + np.random.normal(
                #                     0,
                #                     1e-7,
                #                     col_pred.shape[0])
            ).correlation)
    return rhos


def get_best_ckpt(ckpts):
    ckpt_dicts = []
    for ckpt in ckpts:
        ckpt_dict = {}
        ckpt_dict['ckpt'] = ckpt
        splitted_ckpt = ckpt.split('/')[-1].split('_')
        ckpt_dict['val_metric'] = float(splitted_ckpt[5])
        ckpt_dicts.append(ckpt_dict)
    ckpt_df = pd.DataFrame(ckpt_dicts)
    return ckpt_df.sort_values('val_metric', ascending=False).ckpt.iloc[0]


def get_snapshot_info_df(base_dir):
    res_dicts = []
    for fold in tqdm(list(range(5))):
        fold_ckpts = glob(f'{base_dir}/{fold}/*.pth')
        for ckpt in fold_ckpts:
            res_dict = {}
            splitted_ckpt = ckpt.split('/')[-1].split('_')
            res_dict['ckpt_filename'] = ckpt
            res_dict['fold'] = int(splitted_ckpt[1])
            res_dict['epoch'] = int(splitted_ckpt[3])
            res_dict['val_loss'] = float(splitted_ckpt[4])
            res_dict['val_metric'] = float(splitted_ckpt[5])
            res_dicts.append(res_dict)
    res_df = pd.DataFrame(res_dicts)
    res_df['rank'] = res_df.groupby(['fold']).val_metric.rank()
    return res_df


def get_opt_y_pred(y_true, y_pred, num_labels):
    optRs = []
    opt_y_preds = []

    # for i in tqdm(list(range(21))):
    for i in range(num_labels):
        optR = OptimizedRounder()
        labels = np.sort(np.unique(y_true[:, i]))
        optR.set_labels(labels)
        initer = histogramBasedCoefInitializer().fit(y_true[:, i])
        opt_thresh = initer.predict(y_pred[:, i])
        optR.fit(y_pred[:, i], y_true[:, i], opt_thresh)
        # opt_threshs.append(optR.coefficients())
        # opt_threshs[i] = optR.coefficients()
        optRs.append(optR)
        opt_y_preds.append((optR.predict(y_pred[:, i], optR.coefficients())))

    opt_y_preds = np.asarray(opt_y_preds).T
    return optRs, opt_y_preds


def opt(BASE_PATH, num_labels=30, snapshot_num=2):
    snapshot_df = get_snapshot_info_df(BASE_PATH)

    snapshot_dicts = {}
    state_dict_dicts = {}
    for fold in tqdm(list(range(5))):
        snapshot_dict = {}
        state_dict_dict = {}
        for i, row in snapshot_df.query(f'fold == {fold}').sort_values(
                'rank', ascending=False).reset_index(drop=True).iterrows():
            if i >= snapshot_num:
                continue
            ckpt = torch.load(row['ckpt_filename'])
            state_dict_dict[i] = ckpt['model_state_dict']
            if i == 0:
                qa_ids = ckpt['val_qa_ids']
                qa_ids_argsort = np.argsort(qa_ids)
                snapshot_dict['y_trues'] = [
                    ckpt['val_y_trues'][qa_ids_argsort]]
                snapshot_dict['y_preds'] = [
                    ckpt['val_y_preds'][qa_ids_argsort]]
            else:
                qa_ids = ckpt['val_qa_ids']
                qa_ids_argsort = np.argsort(qa_ids)
                snapshot_dict['y_trues'].append(
                    ckpt['val_y_trues'][qa_ids_argsort])
                snapshot_dict['y_preds'].append(
                    ckpt['val_y_preds'][qa_ids_argsort])
        snapshot_dicts[fold] = snapshot_dict
        state_dict_dicts[fold] = state_dict_dict

    y_preds = np.concatenate(
        [np.average(snapshot_dicts[fold]['y_preds'][:snapshot_num], axis=0)
         for fold in range(5)])
    y_trues = np.concatenate(
        [snapshot_dicts[fold]['y_trues'][0]
         for fold in range(5)])

    reses = []
    optRs = []

    for i in tqdm(list(range(num_labels))):
        y_pred = y_preds[:, i]
        y_true = y_trues[:, i]

        y_pred_argmax = np.argmax(y_pred)
        y_pred_argmin = np.argmin(y_pred)

        optR = OptimizedRounder()
        labels = np.sort(np.unique(y_true))
        optR.set_labels(labels)
        initer = histogramBasedCoefInitializer().fit(y_true)
        initial_coef = initer.predict(y_pred)
        optR.fit(y_pred, y_true, initial_coef=initial_coef)
        optRs.append(optR)
        res = optR.predict(y_pred, optR.coefficients())

        if len(np.unique(res)) == 1:
            if np.unique(res) == res[y_pred_argmax]:
                res[y_pred_argmin] = np.min(y_true)
            else:
                res[y_pred_argmax] = np.max(y_true)

        reses.append(res)
    reses = np.asarray(reses).T

    with open(f'{BASE_PATH}/optRs.pkl', 'wb') as fout:
        pickle.dump(optRs, fout)
    with open(f'{BASE_PATH}/snapshot_dicts.pkl', 'wb') as fout:
        pickle.dump(snapshot_dicts, fout)
    if not os.path.exists(f'{BASE_PATH}/state_dicts'):
        os.mkdir(f'{BASE_PATH}/state_dicts')
    for fold in range(5):
        for rank in range(snapshot_num):
            with open(f'{BASE_PATH}/state_dicts/fold_{fold}_rank_{rank}_state_dict.pkl', 'wb') as fout:
                pickle.dump(state_dict_dicts[fold][rank], fout)
    # with open(f'{BASE_PATH}/state_dict_dicts.pkl', 'wb') as fout:
    #     pickle.dump(state_dict_dicts, fout)

    original_score = compute_spearmanr(y_trues, y_preds)
    print(f'original_score: {original_score}')
    print(f'original_score: {np.mean(original_score)}')

    res_score = compute_spearmanr(y_trues, reses)
    print(f'res_score: {res_score}')
    print(f'res_score_mean: {np.mean(res_score)}')

    return res_score

In [6]:
class histogramBasedCoefInitializer:
    def __init__(self):
        self.bins = None

    def fit(self, labels):
        self.bins = pd.Series(labels).value_counts().sort_index().cumsum().values
        return self

    def predict(self, preds):
        preds = sorted(preds)
        res_threshs = []
        if self.bins is None:
            raise Exception('plz fit at first.')
        for _bin in self.bins[:-1]:
            res_threshs.append((preds[_bin - 1] + preds[_bin]) / 2)
        return res_threshs

In [7]:
def blend_and_evaluate(y_trues, y_preds_list, eval_func, weights=None):
    if weights:
        y_preds = np.average(y_preds_list, axis=0, weights=weights)
    else:
        y_preds = np.average(y_preds_list, axis=0)
    eval_scores = eval_func(y_trues, y_preds)
    optRs, opt_y_preds = get_opt_y_pred(y_trues, y_preds, num_labels=30)
    opt_eval_scores = eval_func(y_trues, opt_y_preds)
    print(f'original_score: {np.mean(eval_scores)}')
    print(f'opt_score: {np.mean(opt_eval_scores)}')
    return eval_scores, opt_eval_scores, optRs

## まずは top2 optRs を作る

## snapshot 済みの model を load

In [22]:
CKPT_DIR = '../mnt/checkpoints'

with open(f'{CKPT_DIR}/e078/snapshot_dicts.pkl', 'rb') as fin:
    bert_question_dict = pickle.load(fin)
with open(f'{CKPT_DIR}/e079/snapshot_dicts.pkl', 'rb') as fin:
    bert_answer_dict = pickle.load(fin)
    
with open(f'{CKPT_DIR}/e080/snapshot_dicts.pkl', 'rb') as fin:
    roberta_question_dict = pickle.load(fin)
with open(f'{CKPT_DIR}/e081/snapshot_dicts.pkl', 'rb') as fin:
    roberta_answer_dict = pickle.load(fin)
    
# with open(f'{CKPT_DIR}/e072/snapshot_dicts.pkl', 'rb') as fin:
#     gpt2_question_dict = pickle.load(fin)
# with open(f'{CKPT_DIR}/e073/snapshot_dicts.pkl', 'rb') as fin:
#     gpt2_answer_dict = pickle.load(fin)
    
with open(f'{CKPT_DIR}/e082/snapshot_dicts.pkl', 'rb') as fin:
    xlnet_question_dict = pickle.load(fin)
with open(f'{CKPT_DIR}/e083/snapshot_dicts.pkl', 'rb') as fin:
    xlnet_answer_dict = pickle.load(fin)

In [23]:
def _get_y_trues_and_y_preds_from_snapshot_dicts(snapshot_dicts, single, avg):
    y_trues, y_preds = [], []
    for fold in range(5):
        if single:
            y_trues.append(snapshot_dicts[fold]['y_trues'][0])
            y_preds.append(snapshot_dicts[fold]['y_preds'][0])
        else:
            if avg:
                y_trues.append(np.average(snapshot_dicts[fold]['y_trues'], axis=0))
                y_preds.append(np.average(snapshot_dicts[fold]['y_preds'], axis=0))
            else:
                y_trues.append(np.concatenate(snapshot_dicts[fold]['y_trues'], axis=1))
                y_preds.append(np.concatenate(snapshot_dicts[fold]['y_preds'], axis=1))
    y_trues = np.concatenate(y_trues)
    y_preds = np.concatenate(y_preds)
    return y_trues, y_preds

def get_y_trues_and_y_preds_from_QA_snapshota_dicts(Q_snapshot_dicts, A_snapshot_dicts, single=False, avg=True, model_num=2):
    q_y_trues, q_y_preds = _get_y_trues_and_y_preds_from_snapshot_dicts(Q_snapshot_dicts, single, avg)
    a_y_trues, a_y_preds = _get_y_trues_and_y_preds_from_snapshot_dicts(A_snapshot_dicts, single, avg)
    if avg:
        y_trues = np.concatenate([q_y_trues, a_y_trues], axis=1)
        y_preds = np.concatenate([q_y_preds, a_y_preds], axis=1)
    else:
        y_trues = np.concatenate([np.concatenate([q_y_trues[:, i*21:(i+1)*21], a_y_trues[:, i*9:(i+1)*9]], axis=1) for i in range(model_num)], axis=1)
        y_preds = np.concatenate([np.concatenate([q_y_preds[:, i*21:(i+1)*21], a_y_preds[:, i*9:(i+1)*9]], axis=1) for i in range(model_num)], axis=1)
        # y_preds = np.concatenate([q_y_preds[:, i*21:(i+1)*21] for i in range(model_num)] + [a_y_preds[:, i*9:(i+1)*9] for i in range(model_num)], axis=1)
    return y_trues, y_preds

In [24]:
# %debug
bert_y_trues, bert_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(bert_question_dict, bert_answer_dict)
roberta_y_trues, roberta_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(roberta_question_dict, roberta_answer_dict)
# gpt2_y_trues, gpt2_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(gpt2_question_dict, gpt2_answer_dict)
xlnet_y_trues, xlnet_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(xlnet_question_dict, xlnet_answer_dict)

In [25]:
# %debug
single_bert_y_trues, single_bert_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(bert_question_dict, bert_answer_dict, single=True)
single_roberta_y_trues, single_roberta_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(roberta_question_dict, roberta_answer_dict, single=True)
# single_gpt2_y_trues, single_gpt2_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(gpt2_question_dict, gpt2_answer_dict, single=True)
single_xlnet_y_trues, single_xlnet_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(xlnet_question_dict, xlnet_answer_dict, single=True)

In [26]:
# 良さそう
(bert_y_trues == roberta_y_trues).all(), (bert_y_trues == xlnet_y_trues).all()

(True, True)

In [27]:
blend_and_evaluate(single_bert_y_trues, [single_bert_y_preds,  ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [single_roberta_y_preds, ], compute_spearmanr)
# blend_and_evaluate(single_bert_y_trues, [single_gpt2_y_preds, ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [single_xlnet_y_preds], compute_spearmanr)
print('fini!')

original_score: 0.39658756843296844
opt_score: 0.4240074936181189


  c /= stddev[:, None]
  c /= stddev[None, :]
  return (a < x) & (x < b)
  return (a < x) & (x < b)
  cond2 = cond0 & (x <= _a)


original_score: 0.3952086168898242
opt_score: 0.417814664868746
original_score: 0.39510187276130454
opt_score: 0.42243226674838147
fini!


In [28]:
blend_and_evaluate(single_bert_y_trues, [bert_y_preds,  ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [roberta_y_preds, ], compute_spearmanr)
# blend_and_evaluate(single_bert_y_trues, [gpt2_y_preds, ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [xlnet_y_preds], compute_spearmanr)
print('fini!')

original_score: 0.40437836760234475
opt_score: 0.4374389165251154
original_score: 0.40003269950864506
opt_score: 0.4242294402147104
original_score: 0.4032963844416211
opt_score: 0.427509974207624
fini!


In [29]:
eval_scores, opt_eval_scores, optRs = blend_and_evaluate(single_bert_y_trues, [bert_y_preds, roberta_y_preds, xlnet_y_preds], compute_spearmanr)

original_score: 0.42022642838382596
opt_score: 0.4540915450569295


In [31]:
','.join([str(i) for i in eval_scores])

'0.3897569612538418,0.6295310634988145,0.4191564247853753,0.31915791786554304,0.3652275669567488,0.435753195775138,0.36440790824816743,0.5060018219743413,0.6093817833499298,0.09188910122874706,0.48817504051000327,0.7583513654160682,0.36981286241131683,0.1901473394678793,0.3628544498201905,0.46525049565094106,0.787145695927789,0.375529103662982,0.6885954067930553,0.06539804951059225,0.5105477365509459,0.26907477901946486,0.445162321745585,0.17527099813432837,0.19163286203159757,0.3620840848626727,0.7649243832652908,0.2946670542871855,0.6980348300532199,0.21387024745702288'

In [32]:
','.join([str(i) for i in opt_eval_scores])

'0.39103779566763597,0.6283681318924335,0.49910816057077234,0.3159970622680784,0.3711788788282818,0.49504735535479977,0.3685814848425463,0.5216712807321839,0.615500694777839,0.13949587523846815,0.46636383952078614,0.7697994222507103,0.56836790916633,0.31154467851353373,0.6381256867882409,0.6144161161430777,0.7933784466371933,0.35402716011821234,0.6814860018204775,0.17321756032902486,0.5097562843300201,0.2692185727945837,0.44415485530018933,0.18158071367992373,0.20007801661806063,0.3636531998719249,0.7610709309331021,0.27693963269468985,0.682707120444378,0.21687348358038733'

In [37]:
opt_better_idx = []
for i, (j, k) in enumerate(zip(eval_scores, opt_eval_scores)):
    if j < k:
        print(i, j, k)
        opt_better_idx.append(i)
opt_better_idx        

0 0.3897569612538418 0.39103779566763597
2 0.4191564247853753 0.49910816057077234
4 0.3652275669567488 0.3711788788282818
5 0.435753195775138 0.49504735535479977
6 0.36440790824816743 0.3685814848425463
7 0.5060018219743413 0.5216712807321839
8 0.6093817833499298 0.615500694777839
9 0.09188910122874706 0.13949587523846815
11 0.7583513654160682 0.7697994222507103
12 0.36981286241131683 0.56836790916633
13 0.1901473394678793 0.31154467851353373
14 0.3628544498201905 0.6381256867882409
15 0.46525049565094106 0.6144161161430777
16 0.787145695927789 0.7933784466371933
19 0.06539804951059225 0.17321756032902486
21 0.26907477901946486 0.2692185727945837
23 0.17527099813432837 0.18158071367992373
24 0.19163286203159757 0.20007801661806063
25 0.3620840848626727 0.3636531998719249
29 0.21387024745702288 0.21687348358038733


[0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 19, 21, 23, 24, 25, 29]

In [22]:
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/optRs.pkl', 'wb') as fout:
    pickle.dump(optRs, fout)

####  dataset 構築

#### 重みを load する

In [8]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')

In [9]:
sys.path.append('../scripts/')
from refactor.datasets import QUESTDataset
from refactor.models import BertModelForBinaryMultiLabelClassifier, RobertaModelForBinaryMultiLabelClassifier, XLNetModelForBinaryMultiLabelClassifier
from refactor.utils import test

In [41]:
Q_LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
]

A_LABEL_COL = [
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

q_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='bert',
    pretrained_model_name_or_path='../mnt/checkpoints/e078/datasets/',
    do_lower_case=True,
    LABEL_COL=Q_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 2,
    a_max_len=239 * 0,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

q_test_sampler = SequentialSampler(data_source=q_test_dataset)
q_test_loader = DataLoader(
        q_test_dataset,
        batch_size=8,
        sampler=q_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )


a_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='bert',
    pretrained_model_name_or_path='../mnt/checkpoints/e078/datasets/',
    do_lower_case=True,
    LABEL_COL=A_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 0,
    a_max_len=239 * 2,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

a_test_sampler = SequentialSampler(data_source=a_test_dataset)
a_test_loader = DataLoader(
        a_test_dataset,
        batch_size=8,
        sampler=a_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )

additional_tokens : 0
additional_tokens : 0


In [28]:
ckpts = glob('../mnt/checkpoints/e078/state_dicts/*')

['../mnt/checkpoints/e078/state_dicts/fold_4_rank_0_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_4_rank_1_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_3_rank_1_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_2_rank_1_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_0_rank_1_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_0_rank_0_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_1_rank_1_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_2_rank_0_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_3_rank_0_state_dict.pkl',
 '../mnt/checkpoints/e078/state_dicts/fold_1_rank_0_state_dict.pkl']

In [57]:
# Q models
bert_q_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = BertModelForBinaryMultiLabelClassifier(21, '../mnt/datasets/model_configs/bert-model-uncased-config.pkl', None, token_size=30528)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, q_test_loader, DEVICE, 'test')
    bert_q_preds.append(y_preds)
    del model
    gc.collect()
res_bert_q_pred = np.mean(bert_q_preds, axis=0)
res_bert_q_pred.shape

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

100%|██████████| 60/60 [00:10<00:00,  5.50it/s]
100%|██████████| 60/60 [00:10<00:00,  5.49it/s]
100%|██████████| 60/60 [00:10<00:00,  5.49it/s]
100%|██████████| 60/60 [00:10<00:00,  5.51it/s]
100%|██████████| 60/60 [00:10<00:00,  5.50it/s]
100%|██████████| 60/60 [00:10<00:00,  5.51it/s]
100%|██████████| 60/60 [00:10<00:00,  5.49it/s]
100%|██████████| 60/60 [00:11<00:00,  5.44it/s]
100%|██████████| 60/60 [00:11<00:00,  5.45it/s]
100%|██████████| 60/60 [00:10<00:00,  5.47it/s]





(476, 21)

In [82]:
# A models
ckpts = glob('../mnt/checkpoints/e079/state_dicts/*')
bert_a_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = BertModelForBinaryMultiLabelClassifier(9, '../mnt/datasets/model_configs/bert-model-uncased-config.pkl', None, token_size=30528)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, a_test_loader, DEVICE, 'test')
    bert_a_preds.append(y_preds)
    del model
    gc.collect()
res_bert_a_pred = np.mean(bert_a_preds, axis=0)
res_bert_a_pred.shape

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

RuntimeError: CUDA error: device-side assert triggered

In [78]:
res_bert_pred = np.concatenate([res_bert_q_pred, res_bert_a_pred], axis=1)
res_bert_pred.shape

(476, 30)

In [80]:
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/res_bert_pred.pkl', 'wb') as fout:
    pickle.dump(res_bert_pred, fout)

In [10]:
torch.cuda.empty_cache()

## Roberta

In [11]:
# %debug
Q_LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
]

A_LABEL_COL = [
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

q_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='roberta',
    pretrained_model_name_or_path='roberta-base',
    do_lower_case=False,
    LABEL_COL=Q_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 2,
    a_max_len=239 * 0,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

q_test_sampler = SequentialSampler(data_source=q_test_dataset)
q_test_loader = DataLoader(
        q_test_dataset,
        batch_size=8,
        sampler=q_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )


a_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='roberta',
    pretrained_model_name_or_path='roberta-base',
    do_lower_case=False,
    LABEL_COL=A_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 0,
    a_max_len=239 * 2,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

a_test_sampler = SequentialSampler(data_source=a_test_dataset)
a_test_loader = DataLoader(
        a_test_dataset,
        batch_size=8,
        sampler=a_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )

additional_tokens : 5
additional_tokens : 5


In [12]:
# Q models
ckpts = glob('../mnt/checkpoints/e080/state_dicts/*')
roberta_q_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = RobertaModelForBinaryMultiLabelClassifier(21, '../mnt/datasets/model_configs/roberta-model-base-config.pkl', None, token_size=50271)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, q_test_loader, DEVICE, 'test')
    roberta_q_preds.append(y_preds)
    del model
    gc.collect()
res_roberta_q_pred = np.mean(roberta_q_preds, axis=0)
res_roberta_q_pred.shape

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:32,  1.57s/it][A
  3%|▎         | 2/60 [00:01<01:06,  1.15s/it][A
  5%|▌         | 3/60 [00:01<00:48,  1.18it/s][A
  7%|▋         | 4/60 [00:02<00:35,  1.56it/s][A
  8%|▊         | 5/60 [00:02<00:27,  2.03it/s][A
 10%|█         | 6/60 [00:02<00:21,  2.55it/s][A
 12%|█▏        | 7/60 [00:02<00:16,  3.13it/s][A
 13%|█▎        | 8/60 [00:02<00:14,  3.70it/s][A
 15%|█▌        | 9/60 [00:02<00:11,  4.26it/s][A
 17%|█▋        | 10/60 [00:02<00:10,  4.75it/s][A
 18%|█▊        | 11/60 [00:03<00:09,  5.18it/s][A
 20%|██        | 12/60 [00:03<00:08,  5.52it/s][A
 22%|██▏       | 13/60 [00:03<00:08,  5.78it/s][A
 23%|██▎       | 14/60 [00:03<00:07,  6.00it/s][A
 25%|██▌       | 15/60 [00:03<00:07,  6.11it/s][A
 27%|██▋       | 16/60 [00:03<00:07,  6.27it/s][A
 28%|██▊       | 17/60 [00:04<00:06,  6.34it/s][A
 30%|███       | 18/60 [00:04<00:06,  6.40it/s][A
 32%|███▏ 

 62%|██████▏   | 37/60 [00:06<00:03,  6.50it/s][A
 63%|██████▎   | 38/60 [00:07<00:03,  6.54it/s][A
 65%|██████▌   | 39/60 [00:07<00:03,  6.54it/s][A
 67%|██████▋   | 40/60 [00:07<00:03,  6.55it/s][A
 68%|██████▊   | 41/60 [00:07<00:02,  6.53it/s][A
 70%|███████   | 42/60 [00:07<00:02,  6.50it/s][A
 72%|███████▏  | 43/60 [00:07<00:02,  6.56it/s][A
 73%|███████▎  | 44/60 [00:07<00:02,  6.54it/s][A
 75%|███████▌  | 45/60 [00:08<00:02,  6.55it/s][A
 77%|███████▋  | 46/60 [00:08<00:02,  6.54it/s][A
 78%|███████▊  | 47/60 [00:08<00:01,  6.54it/s][A
 80%|████████  | 48/60 [00:08<00:01,  6.54it/s][A
 82%|████████▏ | 49/60 [00:08<00:01,  6.53it/s][A
 83%|████████▎ | 50/60 [00:08<00:01,  6.53it/s][A
 85%|████████▌ | 51/60 [00:09<00:01,  6.53it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.53it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.53it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.52it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.53it/s][A
 93%|█████████▎| 56/60 [00:09<0

 20%|██        | 12/60 [00:03<00:08,  5.61it/s][A
 22%|██▏       | 13/60 [00:03<00:07,  5.91it/s][A
 23%|██▎       | 14/60 [00:03<00:07,  6.07it/s][A
 25%|██▌       | 15/60 [00:03<00:07,  6.22it/s][A
 27%|██▋       | 16/60 [00:03<00:06,  6.31it/s][A
 28%|██▊       | 17/60 [00:03<00:06,  6.37it/s][A
 30%|███       | 18/60 [00:03<00:06,  6.42it/s][A
 32%|███▏      | 19/60 [00:04<00:06,  6.45it/s][A
 33%|███▎      | 20/60 [00:04<00:06,  6.48it/s][A
 35%|███▌      | 21/60 [00:04<00:06,  6.45it/s][A
 37%|███▋      | 22/60 [00:04<00:05,  6.51it/s][A
 38%|███▊      | 23/60 [00:04<00:05,  6.52it/s][A
 40%|████      | 24/60 [00:04<00:05,  6.47it/s][A
 42%|████▏     | 25/60 [00:05<00:05,  6.53it/s][A
 43%|████▎     | 26/60 [00:05<00:05,  6.54it/s][A
 45%|████▌     | 27/60 [00:05<00:05,  6.54it/s][A
 47%|████▋     | 28/60 [00:05<00:04,  6.54it/s][A
 48%|████▊     | 29/60 [00:05<00:04,  6.54it/s][A
 50%|█████     | 30/60 [00:05<00:04,  6.53it/s][A
 52%|█████▏    | 31/60 [00:05<0

 82%|████████▏ | 49/60 [00:08<00:01,  6.55it/s][A
 83%|████████▎ | 50/60 [00:08<00:01,  6.55it/s][A
 85%|████████▌ | 51/60 [00:08<00:01,  6.55it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.48it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.55it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.55it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.51it/s][A
 93%|█████████▎| 56/60 [00:09<00:00,  6.55it/s][A
 95%|█████████▌| 57/60 [00:09<00:00,  6.55it/s][A
 97%|█████████▋| 58/60 [00:10<00:00,  6.47it/s][A
 98%|█████████▊| 59/60 [00:10<00:00,  6.49it/s][A
100%|██████████| 60/60 [00:10<00:00,  5.72it/s][A
 80%|████████  | 8/10 [02:29<00:37, 18.65s/it]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:18,  1.33s/it][A
  3%|▎         | 2/60 [00:01<00:56,  1.02it/s][A
  5%|▌         | 3/60 [00:01<00:41,  1.37it/s][A
  7%|▋         | 4/60 [00:01<00:31,  1.79it/s][A
  8%|▊         | 5/60 [00:01<00:24,  2.29it/s][A
 10%|█         | 6/60 [00:02<00:18,  2.84it/s][A

(476, 21)

In [14]:
# A models
ckpts = glob('../mnt/checkpoints/e081/state_dicts/*')
roberta_a_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = RobertaModelForBinaryMultiLabelClassifier(9, '../mnt/datasets/model_configs/roberta-model-base-config.pkl', None, token_size=50271)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, a_test_loader, DEVICE, 'test')
    roberta_a_preds.append(y_preds)
    del model
    gc.collect()
res_roberta_a_pred = np.mean(roberta_a_preds, axis=0)
res_roberta_a_pred.shape

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:26,  1.46s/it][A
  3%|▎         | 2/60 [00:01<01:01,  1.07s/it][A
  5%|▌         | 3/60 [00:01<00:45,  1.26it/s][A
  7%|▋         | 4/60 [00:01<00:33,  1.66it/s][A
  8%|▊         | 5/60 [00:02<00:25,  2.14it/s][A
 10%|█         | 6/60 [00:02<00:20,  2.68it/s][A
 12%|█▏        | 7/60 [00:02<00:16,  3.27it/s][A
 13%|█▎        | 8/60 [00:02<00:13,  3.84it/s][A
 15%|█▌        | 9/60 [00:02<00:11,  4.38it/s][A
 17%|█▋        | 10/60 [00:02<00:10,  4.86it/s][A
 18%|█▊        | 11/60 [00:02<00:09,  5.27it/s][A
 20%|██        | 12/60 [00:03<00:08,  5.59it/s][A
 22%|██▏       | 13/60 [00:03<00:08,  5.81it/s][A
 23%|██▎       | 14/60 [00:03<00:07,  6.04it/s][A
 25%|██▌       | 15/60 [00:03<00:07,  6.18it/s][A
 27%|██▋       | 16/60 [00:03<00:07,  6.27it/s][A
 28%|██▊       | 17/60 [00:03<00:06,  6.36it/s][A
 30%|███       | 18/60 [00:04<00:06,  6.40it/s][A
 32%|███▏ 

 62%|██████▏   | 37/60 [00:07<00:03,  6.53it/s][A
 63%|██████▎   | 38/60 [00:07<00:03,  6.53it/s][A
 65%|██████▌   | 39/60 [00:07<00:03,  6.53it/s][A
 67%|██████▋   | 40/60 [00:07<00:03,  6.49it/s][A
 68%|██████▊   | 41/60 [00:07<00:02,  6.55it/s][A
 70%|███████   | 42/60 [00:07<00:02,  6.54it/s][A
 72%|███████▏  | 43/60 [00:07<00:02,  6.54it/s][A
 73%|███████▎  | 44/60 [00:08<00:02,  6.54it/s][A
 75%|███████▌  | 45/60 [00:08<00:02,  6.54it/s][A
 77%|███████▋  | 46/60 [00:08<00:02,  6.53it/s][A
 78%|███████▊  | 47/60 [00:08<00:01,  6.53it/s][A
 80%|████████  | 48/60 [00:08<00:01,  6.53it/s][A
 82%|████████▏ | 49/60 [00:08<00:01,  6.54it/s][A
 83%|████████▎ | 50/60 [00:09<00:01,  6.53it/s][A
 85%|████████▌ | 51/60 [00:09<00:01,  6.54it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.52it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.53it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.52it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.53it/s][A
 93%|█████████▎| 56/60 [00:09<0

 20%|██        | 12/60 [00:03<00:08,  5.52it/s][A
 22%|██▏       | 13/60 [00:03<00:08,  5.79it/s][A
 23%|██▎       | 14/60 [00:03<00:07,  6.00it/s][A
 25%|██▌       | 15/60 [00:03<00:07,  6.15it/s][A
 27%|██▋       | 16/60 [00:03<00:07,  6.26it/s][A
 28%|██▊       | 17/60 [00:04<00:06,  6.33it/s][A
 30%|███       | 18/60 [00:04<00:06,  6.39it/s][A
 32%|███▏      | 19/60 [00:04<00:06,  6.44it/s][A
 33%|███▎      | 20/60 [00:04<00:06,  6.46it/s][A
 35%|███▌      | 21/60 [00:04<00:06,  6.49it/s][A
 37%|███▋      | 22/60 [00:04<00:05,  6.50it/s][A
 38%|███▊      | 23/60 [00:04<00:05,  6.45it/s][A
 40%|████      | 24/60 [00:05<00:05,  6.50it/s][A
 42%|████▏     | 25/60 [00:05<00:05,  6.47it/s][A
 43%|████▎     | 26/60 [00:05<00:05,  6.49it/s][A
 45%|████▌     | 27/60 [00:05<00:05,  6.50it/s][A
 47%|████▋     | 28/60 [00:05<00:04,  6.53it/s][A
 48%|████▊     | 29/60 [00:05<00:04,  6.51it/s][A
 50%|█████     | 30/60 [00:05<00:04,  6.57it/s][A
 52%|█████▏    | 31/60 [00:06<0

 82%|████████▏ | 49/60 [00:08<00:01,  6.56it/s][A
 83%|████████▎ | 50/60 [00:09<00:01,  6.55it/s][A
 85%|████████▌ | 51/60 [00:09<00:01,  6.54it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.53it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.53it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.54it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.49it/s][A
 93%|█████████▎| 56/60 [00:09<00:00,  6.54it/s][A
 95%|█████████▌| 57/60 [00:10<00:00,  6.50it/s][A
 97%|█████████▋| 58/60 [00:10<00:00,  6.54it/s][A
 98%|█████████▊| 59/60 [00:10<00:00,  6.55it/s][A
100%|██████████| 60/60 [00:10<00:00,  5.61it/s][A
 80%|████████  | 8/10 [02:28<00:37, 18.71s/it]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:30,  1.54s/it][A
  3%|▎         | 2/60 [00:01<01:05,  1.12s/it][A
  5%|▌         | 3/60 [00:01<00:47,  1.21it/s][A
  7%|▋         | 4/60 [00:01<00:35,  1.60it/s][A
  8%|▊         | 5/60 [00:02<00:26,  2.06it/s][A
 10%|█         | 6/60 [00:02<00:20,  2.59it/s][A

(476, 9)

In [15]:
res_roberta_pred = np.concatenate([res_roberta_q_pred, res_roberta_a_pred], axis=1)
res_roberta_pred.shape

(476, 30)

In [16]:
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/res_roberta_pred.pkl', 'wb') as fout:
    pickle.dump(res_roberta_pred, fout)

## XLNET

In [17]:
# %debug
Q_LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
]

A_LABEL_COL = [
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

q_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='xlnet',
    pretrained_model_name_or_path='../mnt/checkpoints/e082/datasets/',
    do_lower_case=False,
    LABEL_COL=Q_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 2,
    a_max_len=239 * 0,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

q_test_sampler = SequentialSampler(data_source=q_test_dataset)
q_test_loader = DataLoader(
        q_test_dataset,
        batch_size=8,
        sampler=q_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )


a_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='xlnet',
    pretrained_model_name_or_path='../mnt/checkpoints/e082/datasets/',
    do_lower_case=False,
    LABEL_COL=A_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 0,
    a_max_len=239 * 2,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

a_test_sampler = SequentialSampler(data_source=a_test_dataset)
a_test_loader = DataLoader(
        a_test_dataset,
        batch_size=8,
        sampler=a_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )

additional_tokens : 0
additional_tokens : 0


In [18]:
# Q models
ckpts = glob('../mnt/checkpoints/e082/state_dicts/*')
xlnet_q_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = XLNetModelForBinaryMultiLabelClassifier(21, '../mnt/datasets/model_configs/xlnet-model-base-cased-config.pkl', None, token_size=32006)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, q_test_loader, DEVICE, 'test')
    xlnet_q_preds.append(y_preds)
    del model
    gc.collect()
res_xlnet_q_pred = np.mean(xlnet_q_preds, axis=0)
res_xlnet_q_pred.shape

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:02<02:05,  2.12s/it][A
  3%|▎         | 2/60 [00:02<01:32,  1.60s/it][A
  5%|▌         | 3/60 [00:02<01:10,  1.23s/it][A
  7%|▋         | 4/60 [00:03<00:54,  1.02it/s][A
  8%|▊         | 5/60 [00:03<00:43,  1.26it/s][A
 10%|█         | 6/60 [00:04<00:36,  1.49it/s][A
 12%|█▏        | 7/60 [00:04<00:30,  1.72it/s][A
 13%|█▎        | 8/60 [00:04<00:27,  1.92it/s][A
 15%|█▌        | 9/60 [00:05<00:24,  2.09it/s][A
 17%|█▋        | 10/60 [00:05<00:22,  2.24it/s][A
 18%|█▊        | 11/60 [00:05<00:20,  2.35it/s][A
 20%|██        | 12/60 [00:06<00:19,  2.43it/s][A
 22%|██▏       | 13/60 [00:06<00:18,  2.49it/s][A
 23%|██▎       | 14/60 [00:07<00:18,  2.54it/s][A
 25%|██▌       | 15/60 [00:07<00:17,  2.57it/s][A
 27%|██▋       | 16/60 [00:07<00:16,  2.60it/s][A
 28%|██▊       | 17/60 [00:08<00:16,  2.61it/s][A
 30%|███       | 18/60 [00:08<00:15,  2.63it/s][A
 32%|███▏ 

 62%|██████▏   | 37/60 [00:15<00:08,  2.66it/s][A
 63%|██████▎   | 38/60 [00:15<00:08,  2.66it/s][A
 65%|██████▌   | 39/60 [00:16<00:07,  2.66it/s][A
 67%|██████▋   | 40/60 [00:16<00:07,  2.66it/s][A
 68%|██████▊   | 41/60 [00:16<00:07,  2.65it/s][A
 70%|███████   | 42/60 [00:17<00:06,  2.65it/s][A
 72%|███████▏  | 43/60 [00:17<00:06,  2.65it/s][A
 73%|███████▎  | 44/60 [00:17<00:06,  2.65it/s][A
 75%|███████▌  | 45/60 [00:18<00:05,  2.65it/s][A
 77%|███████▋  | 46/60 [00:18<00:05,  2.65it/s][A
 78%|███████▊  | 47/60 [00:19<00:04,  2.65it/s][A
 80%|████████  | 48/60 [00:19<00:04,  2.65it/s][A
 82%|████████▏ | 49/60 [00:19<00:04,  2.65it/s][A
 83%|████████▎ | 50/60 [00:20<00:03,  2.65it/s][A
 85%|████████▌ | 51/60 [00:20<00:03,  2.65it/s][A
 87%|████████▋ | 52/60 [00:21<00:03,  2.65it/s][A
 88%|████████▊ | 53/60 [00:21<00:02,  2.66it/s][A
 90%|█████████ | 54/60 [00:21<00:02,  2.66it/s][A
 92%|█████████▏| 55/60 [00:22<00:01,  2.66it/s][A
 93%|█████████▎| 56/60 [00:22<0

 20%|██        | 12/60 [00:06<00:19,  2.46it/s][A
 22%|██▏       | 13/60 [00:06<00:18,  2.51it/s][A
 23%|██▎       | 14/60 [00:06<00:18,  2.55it/s][A
 25%|██▌       | 15/60 [00:07<00:17,  2.58it/s][A
 27%|██▋       | 16/60 [00:07<00:16,  2.60it/s][A
 28%|██▊       | 17/60 [00:07<00:16,  2.62it/s][A
 30%|███       | 18/60 [00:08<00:15,  2.63it/s][A
 32%|███▏      | 19/60 [00:08<00:15,  2.64it/s][A
 33%|███▎      | 20/60 [00:09<00:15,  2.64it/s][A
 35%|███▌      | 21/60 [00:09<00:14,  2.65it/s][A
 37%|███▋      | 22/60 [00:09<00:14,  2.65it/s][A
 38%|███▊      | 23/60 [00:10<00:13,  2.65it/s][A
 40%|████      | 24/60 [00:10<00:13,  2.65it/s][A
 42%|████▏     | 25/60 [00:10<00:13,  2.65it/s][A
 43%|████▎     | 26/60 [00:11<00:12,  2.65it/s][A
 45%|████▌     | 27/60 [00:11<00:12,  2.65it/s][A
 47%|████▋     | 28/60 [00:12<00:12,  2.65it/s][A
 48%|████▊     | 29/60 [00:12<00:11,  2.65it/s][A
 50%|█████     | 30/60 [00:12<00:11,  2.65it/s][A
 52%|█████▏    | 31/60 [00:13<0

 82%|████████▏ | 49/60 [00:19<00:04,  2.66it/s][A
 83%|████████▎ | 50/60 [00:20<00:03,  2.65it/s][A
 85%|████████▌ | 51/60 [00:20<00:03,  2.65it/s][A
 87%|████████▋ | 52/60 [00:20<00:03,  2.65it/s][A
 88%|████████▊ | 53/60 [00:21<00:02,  2.65it/s][A
 90%|█████████ | 54/60 [00:21<00:02,  2.65it/s][A
 92%|█████████▏| 55/60 [00:22<00:01,  2.65it/s][A
 93%|█████████▎| 56/60 [00:22<00:01,  2.65it/s][A
 95%|█████████▌| 57/60 [00:22<00:01,  2.66it/s][A
 97%|█████████▋| 58/60 [00:23<00:00,  2.66it/s][A
 98%|█████████▊| 59/60 [00:23<00:00,  2.66it/s][A
100%|██████████| 60/60 [00:23<00:00,  2.50it/s][A
 80%|████████  | 8/10 [03:34<00:53, 26.78s/it]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:47,  1.82s/it][A
  3%|▎         | 2/60 [00:02<01:20,  1.39s/it][A
  5%|▌         | 3/60 [00:02<01:01,  1.08s/it][A
  7%|▋         | 4/60 [00:02<00:48,  1.15it/s][A
  8%|▊         | 5/60 [00:03<00:39,  1.38it/s][A
 10%|█         | 6/60 [00:03<00:33,  1.61it/s][A

(476, 21)

In [19]:
# A models
ckpts = glob('../mnt/checkpoints/e083/state_dicts/*')
xlnet_a_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = XLNetModelForBinaryMultiLabelClassifier(9, '../mnt/datasets/model_configs/xlnet-model-base-cased-config.pkl', None, token_size=32006)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, a_test_loader, DEVICE, 'test')
    xlnet_a_preds.append(y_preds)
    del model
    gc.collect()
res_xlnet_a_pred = np.mean(xlnet_a_preds, axis=0)
res_xlnet_a_pred.shape

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:44,  1.78s/it][A
  3%|▎         | 2/60 [00:02<01:18,  1.36s/it][A
  5%|▌         | 3/60 [00:02<01:00,  1.06s/it][A
  7%|▋         | 4/60 [00:02<00:48,  1.17it/s][A
  8%|▊         | 5/60 [00:03<00:39,  1.40it/s][A
 10%|█         | 6/60 [00:03<00:33,  1.63it/s][A
 12%|█▏        | 7/60 [00:04<00:28,  1.84it/s][A
 13%|█▎        | 8/60 [00:04<00:25,  2.03it/s][A
 15%|█▌        | 9/60 [00:04<00:23,  2.18it/s][A
 17%|█▋        | 10/60 [00:05<00:21,  2.31it/s][A
 18%|█▊        | 11/60 [00:05<00:20,  2.40it/s][A
 20%|██        | 12/60 [00:05<00:19,  2.47it/s][A
 22%|██▏       | 13/60 [00:06<00:18,  2.52it/s][A
 23%|██▎       | 14/60 [00:06<00:17,  2.56it/s][A
 25%|██▌       | 15/60 [00:07<00:17,  2.58it/s][A
 27%|██▋       | 16/60 [00:07<00:16,  2.61it/s][A
 28%|██▊       | 17/60 [00:07<00:16,  2.62it/s][A
 30%|███       | 18/60 [00:08<00:15,  2.63it/s][A
 32%|███▏ 

 62%|██████▏   | 37/60 [00:15<00:08,  2.65it/s][A
 63%|██████▎   | 38/60 [00:15<00:08,  2.65it/s][A
 65%|██████▌   | 39/60 [00:16<00:07,  2.65it/s][A
 67%|██████▋   | 40/60 [00:16<00:07,  2.65it/s][A
 68%|██████▊   | 41/60 [00:16<00:07,  2.65it/s][A
 70%|███████   | 42/60 [00:17<00:06,  2.66it/s][A
 72%|███████▏  | 43/60 [00:17<00:06,  2.66it/s][A
 73%|███████▎  | 44/60 [00:18<00:06,  2.66it/s][A
 75%|███████▌  | 45/60 [00:18<00:05,  2.66it/s][A
 77%|███████▋  | 46/60 [00:18<00:05,  2.66it/s][A
 78%|███████▊  | 47/60 [00:19<00:04,  2.66it/s][A
 80%|████████  | 48/60 [00:19<00:04,  2.66it/s][A
 82%|████████▏ | 49/60 [00:19<00:04,  2.66it/s][A
 83%|████████▎ | 50/60 [00:20<00:03,  2.66it/s][A
 85%|████████▌ | 51/60 [00:20<00:03,  2.66it/s][A
 87%|████████▋ | 52/60 [00:21<00:03,  2.65it/s][A
 88%|████████▊ | 53/60 [00:21<00:02,  2.66it/s][A
 90%|█████████ | 54/60 [00:21<00:02,  2.66it/s][A
 92%|█████████▏| 55/60 [00:22<00:01,  2.66it/s][A
 93%|█████████▎| 56/60 [00:22<0

 20%|██        | 12/60 [00:06<00:19,  2.46it/s][A
 22%|██▏       | 13/60 [00:06<00:18,  2.51it/s][A
 23%|██▎       | 14/60 [00:06<00:18,  2.55it/s][A
 25%|██▌       | 15/60 [00:07<00:17,  2.58it/s][A
 27%|██▋       | 16/60 [00:07<00:16,  2.60it/s][A
 28%|██▊       | 17/60 [00:07<00:16,  2.62it/s][A
 30%|███       | 18/60 [00:08<00:15,  2.63it/s][A
 32%|███▏      | 19/60 [00:08<00:15,  2.63it/s][A
 33%|███▎      | 20/60 [00:09<00:15,  2.64it/s][A
 35%|███▌      | 21/60 [00:09<00:14,  2.64it/s][A
 37%|███▋      | 22/60 [00:09<00:14,  2.64it/s][A
 38%|███▊      | 23/60 [00:10<00:13,  2.64it/s][A
 40%|████      | 24/60 [00:10<00:13,  2.65it/s][A
 42%|████▏     | 25/60 [00:10<00:13,  2.65it/s][A
 43%|████▎     | 26/60 [00:11<00:12,  2.65it/s][A
 45%|████▌     | 27/60 [00:11<00:12,  2.65it/s][A
 47%|████▋     | 28/60 [00:12<00:12,  2.65it/s][A
 48%|████▊     | 29/60 [00:12<00:11,  2.65it/s][A
 50%|█████     | 30/60 [00:12<00:11,  2.65it/s][A
 52%|█████▏    | 31/60 [00:13<0

 82%|████████▏ | 49/60 [00:19<00:04,  2.65it/s][A
 83%|████████▎ | 50/60 [00:20<00:03,  2.65it/s][A
 85%|████████▌ | 51/60 [00:20<00:03,  2.65it/s][A
 87%|████████▋ | 52/60 [00:21<00:03,  2.65it/s][A
 88%|████████▊ | 53/60 [00:21<00:02,  2.65it/s][A
 90%|█████████ | 54/60 [00:21<00:02,  2.65it/s][A
 92%|█████████▏| 55/60 [00:22<00:01,  2.66it/s][A
 93%|█████████▎| 56/60 [00:22<00:01,  2.66it/s][A
 95%|█████████▌| 57/60 [00:22<00:01,  2.66it/s][A
 97%|█████████▋| 58/60 [00:23<00:00,  2.66it/s][A
 98%|█████████▊| 59/60 [00:23<00:00,  2.66it/s][A
100%|██████████| 60/60 [00:24<00:00,  2.50it/s][A
 80%|████████  | 8/10 [03:33<00:53, 26.75s/it]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:45,  1.79s/it][A
  3%|▎         | 2/60 [00:02<01:19,  1.37s/it][A
  5%|▌         | 3/60 [00:02<01:00,  1.07s/it][A
  7%|▋         | 4/60 [00:02<00:48,  1.16it/s][A
  8%|▊         | 5/60 [00:03<00:39,  1.40it/s][A
 10%|█         | 6/60 [00:03<00:33,  1.63it/s][A

(476, 9)

In [20]:
res_xlnet_pred = np.concatenate([res_xlnet_q_pred, res_xlnet_a_pred], axis=1)
res_xlnet_pred.shape

(476, 30)

In [21]:
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/res_xlnet_pred.pkl', 'wb') as fout:
    pickle.dump(res_xlnet_pred, fout)

## 最後に全部をブレンド

In [16]:
LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

In [8]:
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/res_bert_pred.pkl', 'rb') as fin:
    res_bert_pred = pickle.load(fin)
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/res_roberta_pred.pkl', 'rb') as fin:
    res_roberta_pred = pickle.load(fin)
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/res_xlnet_pred.pkl', 'rb') as fin:
    res_xlnet_pred = pickle.load(fin)

In [10]:
with open('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/optRs.pkl', 'rb') as fin:
    optRs = pickle.load(fin)

In [9]:
res_pred = np.mean([res_bert_pred, res_roberta_pred, res_xlnet_pred], axis=0)
res_pred.shape

(476, 30)

In [19]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')
tst_df.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


In [20]:
for i, col in enumerate(LABEL_COL):
    tst_df[col] = res_pred[:, i]

In [21]:
tst_df.to_csv('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/raw_pseudo_tst_df.csv', index=False)

## 全 opt

In [12]:
final_prediction = []
for i in tqdm(list(range(30))):
    y_pred = res_pred[:, i]
    #if i not in [2,4,5,6,7,11,12,13,14,15,16,18,19,29]:
    # if i not in [2,3,4,5,6,7,8,9,11,12,13,14,15,16,19,20,22,23,24,25]:
    #     final_prediction.append(y_pred)
    #     continue
    
    optR = optRs[i]
    res = optR.predict(y_pred, optR.coefficients()).astype(float)

    final_prediction.append(res)

prediction = np.asarray(final_prediction).T
prediction.shape

100%|██████████| 30/30 [00:00<00:00, 928.24it/s]


In [15]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')
tst_df.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


In [17]:
for i, col in enumerate(LABEL_COL):
    tst_df[col] = prediction[:, i]

In [18]:
tst_df.to_csv('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/opt_pseudo_tst_df.csv', index=False)

## half opt

In [38]:
final_prediction = []
for i in tqdm(list(range(30))):
    y_pred = res_pred[:, i]
    if i not in [0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 19, 21, 23, 24, 25, 29]:
        final_prediction.append(y_pred)
        continue
    
    optR = optRs[i]
    res = optR.predict(y_pred, optR.coefficients()).astype(float)

    final_prediction.append(res)

prediction = np.asarray(final_prediction).T
prediction.shape

100%|██████████| 30/30 [00:00<00:00, 1670.84it/s]


(476, 30)

In [39]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')
tst_df.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


In [40]:
for i, col in enumerate(LABEL_COL):
    tst_df[col] = prediction[:, i]

In [41]:
tst_df.to_csv('../mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/half_opt_pseudo_tst_df.csv', index=False)

## 関数設計 (path を設定すると model を load して prediction を返してくる)

In [28]:
sys.path.append('../scripts/')
from refactor.datasets import QUESTDataset
from refactor.models import BertModelForBinaryMultiLabelClassifier, RobertaModelForBinaryMultiLabelClassifier, XLNetModelForBinaryMultiLabelClassifier
from refactor.utils import test
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler

In [None]:
def predict_from_ckpt(df, ckpt, loader, TOKENIZER_TYPE, DO_LOWER_CASE, T_MAX_LEN, Q_MAX_LEN, A_MAX_LEN, TQA_MODE):
    if TOKENIZER_TYPE == 'bert':
        state_dict = BertModel.from_pretrained('bert-base-uncased').state_dict()
        model = BertModelForBinaryMultiLabelClassifier(num_labels=, )
        ckpt = 
    dataset = QUESTDataset(
                                df=df,
                                mode='test',
                                tokens=[],
                                augment=[],
                                tokenizer_type=TOKENIZER_TYPE,
                                pretrained_model_name_or_path='../mnt/checkpoints/e078/datasets/',
                                do_lower_case=DO_LOWER_CASE,
                                LABEL_COL=LABEL_COL,
                                t_max_len=T_MAX_LEN,
                                q_max_len=Q_MAX_LEN,
                                a_max_len=A_MAX_LEN,
                                tqa_mode=TQA_MODE,
                                TBSEP='[TBSEP]',
                                pos_id_type='arange',
                                MAX_SEQUENCE_LENGTH=512,
                            )
    sampler = RandomSampler(data_source=dataset)
    loader = DataLoader(dataset,
                            batch_size=8,
                            sampler=sampler,
                            num_workers=os.cpu_count(),
                            worker_init_fn=lambda x: np.random.seed(),
                            drop_last=False,
                            pin_memory=True)    
    return prediction