In [1]:
import gc
import os
import sys
import itertools
import pickle
from glob import glob
from tqdm import tqdm_notebook as tqdm

import numpy as np
import pandas as pd
from scipy.stats import spearmanr

from matplotlib import pyplot as plt
from matplotlib_venn import venn2, venn3
import seaborn as sns

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from transformers import BertConfig, BertTokenizer, BertModel, BertForMaskedLM#, BertLayer, BertEmbeddings
from transformers.modeling_bert import BertLayer, BertEmbeddings

In [2]:
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 500)

# re-load functions
%load_ext autoreload
%autoreload 2

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
%config InlineBackend.figure_formats = {'png', 'retina'}

In [3]:
DEVICE = 'cuda'

In [4]:
import sys
import pickle
from functools import partial
from glob import glob

import numpy as np
import pandas as pd
import scipy as sp
import torch
from scipy.stats import spearmanr
from tqdm import tqdm

class OptimizedRounder(object):
    """
    An optimizer for rounding thresholds
    to maximize Quadratic Weighted Kappa (QWK) score
    # https://www.kaggle.com/naveenasaithambi/optimizedrounder-improved
    """

    def __init__(self):
        self.coef_ = 0

    def _spearmanr_loss(self, coef, X, y, labels):
        """
        Get loss according to
        using current coefficients
        :param coef: A list of coefficients that will be used for rounding
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        X_p = pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                     [np.inf], labels=labels)

        # return -np.mean(spearmanr(y, X_p).correlation)
        return -spearmanr(y, X_p).correlation

    def fit(self, X, y, initial_coef):
        """
        Optimize rounding thresholds
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        labels = self.labels
        loss_partial = partial(self._spearmanr_loss, X=X, y=y, labels=labels)
        self.coef_ = sp.optimize.minimize(
            # loss_partial, initial_coef, method='Powell')
            loss_partial, initial_coef, method='nelder-mead')

    def predict(self, X, coef):
        """
        Make predictions with specified thresholds
        :param X: The raw predictions
        :param coef: A list of coefficients that will be used for rounding
        """
        labels = self.labels
        return pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                      [np.inf], labels=labels)
        # [np.inf], labels=[0, 1, 2, 3])

    def coefficients(self):
        """
        Return the optimized coefficients
        """
        return self.coef_['x']

    def set_labels(self, labels):
        self.labels = labels

In [5]:
# sys.path.append('../scripts/')
# from get_optR3 import compute_spearmanr, get_opt_y_pred

import os
import pickle
import sys
from functools import partial
from glob import glob

import numpy as np
import pandas as pd
import scipy as sp
import torch
from scipy.stats import spearmanr
from tqdm import tqdm


class histogramBasedCoefInitializer:
    def __init__(self):
        self.bins = None

    def fit(self, labels):
        self.bins = pd.Series(
            labels).value_counts().sort_index().cumsum().values
        return self

    def predict(self, preds):
        preds = sorted(preds)
        res_threshs = []
        if self.bins is None:
            raise Exception('plz fit at first.')
        for _bin in self.bins[:-1]:
            res_threshs.append((preds[_bin - 1] + preds[_bin]) / 2)
        return res_threshs


class OptimizedRounder(object):
    """
    An optimizer for rounding thresholds
    to maximize Quadratic Weighted Kappa (QWK) score
    # https://www.kaggle.com/naveenasaithambi/optimizedrounder-improved
    """

    def __init__(self):
        self.coef_ = 0

    def _spearmanr_loss(self, coef, X, y, labels):
        """
        Get loss according to
        using current coefficients
        :param coef: A list of coefficients that will be used for rounding
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        X_p = pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                     [np.inf], labels=labels)

        # return -np.mean(spearmanr(y, X_p).correlation)
        return -spearmanr(y, X_p).correlation

    def fit(self, X, y, initial_coef):
        """
        Optimize rounding thresholds
        :param X: The raw predictions
        :param y: The ground truth labels
        """
        labels = self.labels
        loss_partial = partial(self._spearmanr_loss, X=X, y=y, labels=labels)
        self.coef_ = sp.optimize.minimize(
            loss_partial, initial_coef, method='nelder-mead')

    def predict(self, X, coef):
        """
        Make predictions with specified thresholds
        :param X: The raw predictions
        :param coef: A list of coefficients that will be used for rounding
        """
        labels = self.labels
        return pd.cut(X, [-np.inf] + list(np.sort(coef)) +
                      [np.inf], labels=labels)
        # [np.inf], labels=[0, 1, 2, 3])

    def coefficients(self):
        """
        Return the optimized coefficients
        """
        return self.coef_['x']

    def set_labels(self, labels):
        self.labels = labels


def compute_spearmanr(trues, preds):
    rhos = []
    for col_trues, col_pred in zip(trues.T, preds.T):
        if len(np.unique(col_pred)) == 1:
            if col_pred[0] == np.max(col_trues):
                col_pred[np.argmin(
                    col_pred)] = np.min(col_trues)
            else:
                col_pred[np.argmax(
                    col_pred)] = np.max(col_trues)
        rhos.append(
            spearmanr(
                col_trues,
                col_pred
                #                  + np.random.normal(
                #                     0,
                #                     1e-7,
                #                     col_pred.shape[0])
            ).correlation)
    return rhos


def get_best_ckpt(ckpts):
    ckpt_dicts = []
    for ckpt in ckpts:
        ckpt_dict = {}
        ckpt_dict['ckpt'] = ckpt
        splitted_ckpt = ckpt.split('/')[-1].split('_')
        ckpt_dict['val_metric'] = float(splitted_ckpt[5])
        ckpt_dicts.append(ckpt_dict)
    ckpt_df = pd.DataFrame(ckpt_dicts)
    return ckpt_df.sort_values('val_metric', ascending=False).ckpt.iloc[0]


def get_snapshot_info_df(base_dir):
    res_dicts = []
    for fold in tqdm(list(range(5))):
        fold_ckpts = glob(f'{base_dir}/{fold}/*.pth')
        for ckpt in fold_ckpts:
            res_dict = {}
            splitted_ckpt = ckpt.split('/')[-1].split('_')
            res_dict['ckpt_filename'] = ckpt
            res_dict['fold'] = int(splitted_ckpt[1])
            res_dict['epoch'] = int(splitted_ckpt[3])
            res_dict['val_loss'] = float(splitted_ckpt[4])
            res_dict['val_metric'] = float(splitted_ckpt[5])
            res_dicts.append(res_dict)
    res_df = pd.DataFrame(res_dicts)
    res_df['rank'] = res_df.groupby(['fold']).val_metric.rank()
    return res_df


def get_opt_y_pred(y_true, y_pred, num_labels):
    optRs = []
    opt_y_preds = []

    # for i in tqdm(list(range(21))):
    for i in range(num_labels):
        optR = OptimizedRounder()
        labels = np.sort(np.unique(y_true[:, i]))
        optR.set_labels(labels)
        initer = histogramBasedCoefInitializer().fit(y_true[:, i])
        opt_thresh = initer.predict(y_pred[:, i])
        optR.fit(y_pred[:, i], y_true[:, i], opt_thresh)
        # opt_threshs.append(optR.coefficients())
        # opt_threshs[i] = optR.coefficients()
        optRs.append(optR)
        opt_y_preds.append((optR.predict(y_pred[:, i], optR.coefficients())))

    opt_y_preds = np.asarray(opt_y_preds).T
    return optRs, opt_y_preds


def opt(BASE_PATH, num_labels=30, snapshot_num=2):
    snapshot_df = get_snapshot_info_df(BASE_PATH)

    snapshot_dicts = {}
    state_dict_dicts = {}
    for fold in tqdm(list(range(5))):
        snapshot_dict = {}
        state_dict_dict = {}
        for i, row in snapshot_df.query(f'fold == {fold}').sort_values(
                'rank', ascending=False).reset_index(drop=True).iterrows():
            if i >= snapshot_num:
                continue
            ckpt = torch.load(row['ckpt_filename'])
            state_dict_dict[i] = ckpt['model_state_dict']
            if i == 0:
                qa_ids = ckpt['val_qa_ids']
                qa_ids_argsort = np.argsort(qa_ids)
                snapshot_dict['y_trues'] = [
                    ckpt['val_y_trues'][qa_ids_argsort]]
                snapshot_dict['y_preds'] = [
                    ckpt['val_y_preds'][qa_ids_argsort]]
            else:
                qa_ids = ckpt['val_qa_ids']
                qa_ids_argsort = np.argsort(qa_ids)
                snapshot_dict['y_trues'].append(
                    ckpt['val_y_trues'][qa_ids_argsort])
                snapshot_dict['y_preds'].append(
                    ckpt['val_y_preds'][qa_ids_argsort])
        snapshot_dicts[fold] = snapshot_dict
        state_dict_dicts[fold] = state_dict_dict

    y_preds = np.concatenate(
        [np.average(snapshot_dicts[fold]['y_preds'][:snapshot_num], axis=0)
         for fold in range(5)])
    y_trues = np.concatenate(
        [snapshot_dicts[fold]['y_trues'][0]
         for fold in range(5)])

    reses = []
    optRs = []

    for i in tqdm(list(range(num_labels))):
        y_pred = y_preds[:, i]
        y_true = y_trues[:, i]

        y_pred_argmax = np.argmax(y_pred)
        y_pred_argmin = np.argmin(y_pred)

        optR = OptimizedRounder()
        labels = np.sort(np.unique(y_true))
        optR.set_labels(labels)
        initer = histogramBasedCoefInitializer().fit(y_true)
        initial_coef = initer.predict(y_pred)
        optR.fit(y_pred, y_true, initial_coef=initial_coef)
        optRs.append(optR)
        res = optR.predict(y_pred, optR.coefficients())

        if len(np.unique(res)) == 1:
            if np.unique(res) == res[y_pred_argmax]:
                res[y_pred_argmin] = np.min(y_true)
            else:
                res[y_pred_argmax] = np.max(y_true)

        reses.append(res)
    reses = np.asarray(reses).T

    with open(f'{BASE_PATH}/optRs.pkl', 'wb') as fout:
        pickle.dump(optRs, fout)
    with open(f'{BASE_PATH}/snapshot_dicts.pkl', 'wb') as fout:
        pickle.dump(snapshot_dicts, fout)
    if not os.path.exists(f'{BASE_PATH}/state_dicts'):
        os.mkdir(f'{BASE_PATH}/state_dicts')
    for fold in range(5):
        for rank in range(snapshot_num):
            with open(f'{BASE_PATH}/state_dicts/fold_{fold}_rank_{rank}_state_dict.pkl', 'wb') as fout:
                pickle.dump(state_dict_dicts[fold][rank], fout)
    # with open(f'{BASE_PATH}/state_dict_dicts.pkl', 'wb') as fout:
    #     pickle.dump(state_dict_dicts, fout)

    original_score = compute_spearmanr(y_trues, y_preds)
    print(f'original_score: {original_score}')
    print(f'original_score: {np.mean(original_score)}')

    res_score = compute_spearmanr(y_trues, reses)
    print(f'res_score: {res_score}')
    print(f'res_score_mean: {np.mean(res_score)}')

    return res_score

In [6]:
class histogramBasedCoefInitializer:
    def __init__(self):
        self.bins = None

    def fit(self, labels):
        self.bins = pd.Series(labels).value_counts().sort_index().cumsum().values
        return self

    def predict(self, preds):
        preds = sorted(preds)
        res_threshs = []
        if self.bins is None:
            raise Exception('plz fit at first.')
        for _bin in self.bins[:-1]:
            res_threshs.append((preds[_bin - 1] + preds[_bin]) / 2)
        return res_threshs

In [7]:
def blend_and_evaluate(y_trues, y_preds_list, eval_func, weights=None):
    if weights:
        y_preds = np.average(y_preds_list, axis=0, weights=weights)
    else:
        y_preds = np.average(y_preds_list, axis=0)
    eval_scores = eval_func(y_trues, y_preds)
    optRs, opt_y_preds = get_opt_y_pred(y_trues, y_preds, num_labels=30)
    opt_eval_scores = eval_func(y_trues, opt_y_preds)
    print(f'original_score: {np.mean(eval_scores)}')
    print(f'opt_score: {np.mean(opt_eval_scores)}')
    return eval_scores, opt_eval_scores, optRs

## まずは top2 optRs を作る

## snapshot 済みの model を load

In [8]:
CKPT_DIR = '../mnt/checkpoints'

with open(f'{CKPT_DIR}/e121/snapshot_dicts.pkl', 'rb') as fin:
    bert_question_dict = pickle.load(fin)
with open(f'{CKPT_DIR}/e125/snapshot_dicts.pkl', 'rb') as fin:
    bert_answer_dict = pickle.load(fin)
    
with open(f'{CKPT_DIR}/e126/snapshot_dicts.pkl', 'rb') as fin:
    roberta_question_dict = pickle.load(fin)
with open(f'{CKPT_DIR}/e127/snapshot_dicts.pkl', 'rb') as fin:
    roberta_answer_dict = pickle.load(fin)
    
with open(f'{CKPT_DIR}/e128/snapshot_dicts.pkl', 'rb') as fin:
    xlnet_question_dict = pickle.load(fin)
with open(f'{CKPT_DIR}/e129/snapshot_dicts.pkl', 'rb') as fin:
    xlnet_answer_dict = pickle.load(fin)

In [9]:
def _get_y_trues_and_y_preds_from_snapshot_dicts(snapshot_dicts, single, avg):
    y_trues, y_preds = [], []
    for fold in range(5):
        if single:
            y_trues.append(snapshot_dicts[fold]['y_trues'][0])
            y_preds.append(snapshot_dicts[fold]['y_preds'][0])
        else:
            if avg:
                y_trues.append(np.average(snapshot_dicts[fold]['y_trues'], axis=0))
                y_preds.append(np.average(snapshot_dicts[fold]['y_preds'], axis=0))
            else:
                y_trues.append(np.concatenate(snapshot_dicts[fold]['y_trues'], axis=1))
                y_preds.append(np.concatenate(snapshot_dicts[fold]['y_preds'], axis=1))
    y_trues = np.concatenate(y_trues)
    y_preds = np.concatenate(y_preds)
    return y_trues, y_preds

def get_y_trues_and_y_preds_from_QA_snapshota_dicts(Q_snapshot_dicts, A_snapshot_dicts, single=False, avg=True, model_num=2):
    q_y_trues, q_y_preds = _get_y_trues_and_y_preds_from_snapshot_dicts(Q_snapshot_dicts, single, avg)
    a_y_trues, a_y_preds = _get_y_trues_and_y_preds_from_snapshot_dicts(A_snapshot_dicts, single, avg)
    if avg:
        y_trues = np.concatenate([q_y_trues, a_y_trues], axis=1)
        y_preds = np.concatenate([q_y_preds, a_y_preds], axis=1)
    else:
        y_trues = np.concatenate([np.concatenate([q_y_trues[:, i*21:(i+1)*21], a_y_trues[:, i*9:(i+1)*9]], axis=1) for i in range(model_num)], axis=1)
        y_preds = np.concatenate([np.concatenate([q_y_preds[:, i*21:(i+1)*21], a_y_preds[:, i*9:(i+1)*9]], axis=1) for i in range(model_num)], axis=1)
        # y_preds = np.concatenate([q_y_preds[:, i*21:(i+1)*21] for i in range(model_num)] + [a_y_preds[:, i*9:(i+1)*9] for i in range(model_num)], axis=1)
    return y_trues, y_preds

In [10]:
# %debug
bert_y_trues, bert_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(bert_question_dict, bert_answer_dict)
roberta_y_trues, roberta_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(roberta_question_dict, roberta_answer_dict)
xlnet_y_trues, xlnet_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(xlnet_question_dict, xlnet_answer_dict)

In [11]:
# %debug
single_bert_y_trues, single_bert_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(bert_question_dict, bert_answer_dict, single=True)
single_roberta_y_trues, single_roberta_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(roberta_question_dict, roberta_answer_dict, single=True)
single_xlnet_y_trues, single_xlnet_y_preds = get_y_trues_and_y_preds_from_QA_snapshota_dicts(xlnet_question_dict, xlnet_answer_dict, single=True)

In [12]:
# 良さそう
(bert_y_trues == roberta_y_trues).all(), (bert_y_trues == xlnet_y_trues).all()

(True, True)

In [13]:
blend_and_evaluate(single_bert_y_trues, [single_bert_y_preds,  ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [single_roberta_y_preds, ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [single_xlnet_y_preds], compute_spearmanr)
print('fini!')

original_score: 0.41006427247645066
opt_score: 0.43661262426104713
original_score: 0.40685081230797066
opt_score: 0.4324844364469768
original_score: 0.40631606638648937
opt_score: 0.4325393031770092
fini!


In [14]:
blend_and_evaluate(single_bert_y_trues, [bert_y_preds,  ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [roberta_y_preds, ], compute_spearmanr)
blend_and_evaluate(single_bert_y_trues, [xlnet_y_preds], compute_spearmanr)
print('fini!')

original_score: 0.41399265956965503
opt_score: 0.4421711714787299
original_score: 0.41044252802107223
opt_score: 0.44857486569826094
original_score: 0.4134497336519045
opt_score: 0.44027891255029494
fini!


In [15]:
eval_scores, opt_eval_scores, optRs = blend_and_evaluate(single_bert_y_trues, [bert_y_preds, roberta_y_preds, xlnet_y_preds], compute_spearmanr)

original_score: 0.42783901449729456
opt_score: 0.46271598891509474


In [16]:
','.join([str(i) for i in eval_scores])

'0.398106364614657,0.6427297635930231,0.4213119103297285,0.32948561106468904,0.3703040526437552,0.44277903355726284,0.3713459620305942,0.506201203813692,0.6200408131725091,0.09612419590331925,0.4958646455580677,0.7631422507617477,0.3704857584597583,0.199106532244885,0.36447737596332785,0.4734082403136454,0.7925172390269458,0.3770504619734147,0.6944287844412524,0.0693758884802994,0.5147131906848201,0.2914033410526882,0.45199064864903127,0.18627308827830527,0.20850237768082605,0.3844802308127763,0.7708523748378324,0.30208126892017156,0.7025700192792027,0.2240178067766074'

In [17]:
','.join([str(i) for i in opt_eval_scores])

'0.393144790427882,0.643520417200593,0.49229597269209024,0.32709779331226896,0.37830510261552114,0.48921746787241993,0.37328308520666104,0.5212131085640996,0.6304411167051761,0.1279964789456965,0.48818600566866266,0.771520045979139,0.5535384915777919,0.3320128859291208,0.6392394014947853,0.6312901724950274,0.7963069902699796,0.36183653872871596,0.6904678615974655,0.23180410278904368,0.5121346271054278,0.2908433511042211,0.4524121831114153,0.18929711493591883,0.2074923117714503,0.38708287508541744,0.7641581150068769,0.29238469236158143,0.6935105549423104,0.21944601195608388'

In [18]:
opt_better_idx = []
for i, (j, k) in enumerate(zip(eval_scores, opt_eval_scores)):
    if j < k:
        print(i, j, k)
        opt_better_idx.append(i)
opt_better_idx        

1 0.6427297635930231 0.643520417200593
2 0.4213119103297285 0.49229597269209024
4 0.3703040526437552 0.37830510261552114
5 0.44277903355726284 0.48921746787241993
6 0.3713459620305942 0.37328308520666104
7 0.506201203813692 0.5212131085640996
8 0.6200408131725091 0.6304411167051761
9 0.09612419590331925 0.1279964789456965
11 0.7631422507617477 0.771520045979139
12 0.3704857584597583 0.5535384915777919
13 0.199106532244885 0.3320128859291208
14 0.36447737596332785 0.6392394014947853
15 0.4734082403136454 0.6312901724950274
16 0.7925172390269458 0.7963069902699796
19 0.0693758884802994 0.23180410278904368
22 0.45199064864903127 0.4524121831114153
23 0.18627308827830527 0.18929711493591883
25 0.3844802308127763 0.38708287508541744


[1, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 19, 22, 23, 25]

In [19]:
!mkdir ../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/

In [20]:
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/optRs.pkl', 'wb') as fout:
    pickle.dump(optRs, fout)

#### 重みを load する

In [21]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')

In [44]:
sys.path.append('../scripts/')
from refactor.datasets import QUESTDataset
from refactor.models import BertModelForBinaryMultiLabelClassifier, RobertaModelForBinaryMultiLabelClassifier, XLNetModelForBinaryMultiLabelClassifier
from refactor.utils import test

In [24]:
Q_LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
]

A_LABEL_COL = [
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

q_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='bert',
    pretrained_model_name_or_path='../mnt/checkpoints/e121/datasets/',
    do_lower_case=True,
    LABEL_COL=Q_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 2,
    a_max_len=239 * 0,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

q_test_sampler = SequentialSampler(data_source=q_test_dataset)
q_test_loader = DataLoader(
        q_test_dataset,
        batch_size=8,
        sampler=q_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )


a_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='bert',
    pretrained_model_name_or_path='../mnt/checkpoints/e078/datasets/',
    do_lower_case=True,
    LABEL_COL=A_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 0,
    a_max_len=239 * 2,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

a_test_sampler = SequentialSampler(data_source=a_test_dataset)
a_test_loader = DataLoader(
        a_test_dataset,
        batch_size=8,
        sampler=a_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )

additional_tokens : 0
additional_tokens : 0


In [29]:
ckpts = glob('../mnt/checkpoints/e121/state_dicts/*')

In [30]:
# Q models
bert_q_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = BertModelForBinaryMultiLabelClassifier(21, '../mnt/datasets/model_configs/bert-model-uncased-config.pkl', None, token_size=30528)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, q_test_loader, DEVICE, 'test')
    bert_q_preds.append(y_preds)
    del model
    gc.collect()
res_bert_q_pred = np.mean(bert_q_preds, axis=0)
res_bert_q_pred.shape

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:05<05:32,  5.63s/it][A
  5%|▌         | 3/60 [00:05<03:46,  3.98s/it][A
  7%|▋         | 4/60 [00:06<02:38,  2.83s/it][A
  8%|▊         | 5/60 [00:06<01:51,  2.03s/it][A
 10%|█         | 6/60 [00:06<01:19,  1.46s/it][A
 12%|█▏        | 7/60 [00:06<00:56,  1.07s/it][A
 13%|█▎        | 8/60 [00:06<00:41,  1.26it/s][A
 15%|█▌        | 9/60 [00:06<00:30,  1.65it/s][A
 17%|█▋        | 10/60 [00:06<00:23,  2.13it/s][A
 18%|█▊        | 11/60 [00:07<00:18,  2.68it/s][A
 20%|██        | 12/60 [00:07<00:14,  3.25it/s][A
 22%|██▏       | 13/60 [00:07<00:12,  3.84it/s][A
 23%|██▎       | 14/60 [00:07<00:10,  4.38it/s][A
 25%|██▌       | 15/60 [00:07<00:09,  4.84it/s][A
 27%|██▋       | 16/60 [00:07<00:08,  5.24it/s][A
 28%|██▊       | 17/60 [00:08<00:07,  5.58it/s][A
 30%|███       | 18/60 [00:08<00:07,  5.83it/s][A
 32%|███▏      | 19/60 [00:08<00:06,  6.02it/s][A
 33%|███▎

 63%|██████▎   | 38/60 [00:07<00:03,  6.54it/s][A
 65%|██████▌   | 39/60 [00:07<00:03,  6.52it/s][A
 67%|██████▋   | 40/60 [00:07<00:03,  6.55it/s][A
 68%|██████▊   | 41/60 [00:07<00:02,  6.49it/s][A
 70%|███████   | 42/60 [00:07<00:02,  6.56it/s][A
 72%|███████▏  | 43/60 [00:08<00:02,  6.48it/s][A
 73%|███████▎  | 44/60 [00:08<00:02,  6.52it/s][A
 75%|███████▌  | 45/60 [00:08<00:02,  6.58it/s][A
 77%|███████▋  | 46/60 [00:08<00:02,  6.55it/s][A
 78%|███████▊  | 47/60 [00:08<00:01,  6.52it/s][A
 80%|████████  | 48/60 [00:08<00:01,  6.54it/s][A
 82%|████████▏ | 49/60 [00:08<00:01,  6.55it/s][A
 83%|████████▎ | 50/60 [00:09<00:01,  6.54it/s][A
 85%|████████▌ | 51/60 [00:09<00:01,  6.53it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.54it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.48it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.49it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.52it/s][A
 93%|█████████▎| 56/60 [00:10<00:00,  6.57it/s][A
 95%|█████████▌| 57/60 [00:10<0

 22%|██▏       | 13/60 [00:03<00:08,  5.81it/s][A
 23%|██▎       | 14/60 [00:03<00:07,  5.99it/s][A
 25%|██▌       | 15/60 [00:03<00:07,  6.10it/s][A
 27%|██▋       | 16/60 [00:03<00:06,  6.29it/s][A
 28%|██▊       | 17/60 [00:04<00:06,  6.37it/s][A
 30%|███       | 18/60 [00:04<00:06,  6.33it/s][A
 32%|███▏      | 19/60 [00:04<00:06,  6.46it/s][A
 33%|███▎      | 20/60 [00:04<00:06,  6.48it/s][A
 35%|███▌      | 21/60 [00:04<00:06,  6.49it/s][A
 37%|███▋      | 22/60 [00:04<00:05,  6.45it/s][A
 38%|███▊      | 23/60 [00:04<00:05,  6.51it/s][A
 40%|████      | 24/60 [00:05<00:05,  6.52it/s][A
 42%|████▏     | 25/60 [00:05<00:05,  6.54it/s][A
 43%|████▎     | 26/60 [00:05<00:05,  6.55it/s][A
 45%|████▌     | 27/60 [00:05<00:05,  6.54it/s][A
 47%|████▋     | 28/60 [00:05<00:04,  6.54it/s][A
 48%|████▊     | 29/60 [00:05<00:04,  6.48it/s][A
 50%|█████     | 30/60 [00:06<00:04,  6.50it/s][A
 52%|█████▏    | 31/60 [00:06<00:04,  6.53it/s][A
 53%|█████▎    | 32/60 [00:06<0

 83%|████████▎ | 50/60 [00:09<00:01,  6.55it/s][A
 85%|████████▌ | 51/60 [00:09<00:01,  6.55it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.55it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.54it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.54it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.55it/s][A
 93%|█████████▎| 56/60 [00:10<00:00,  6.55it/s][A
 95%|█████████▌| 57/60 [00:10<00:00,  6.53it/s][A
 97%|█████████▋| 58/60 [00:10<00:00,  6.49it/s][A
 98%|█████████▊| 59/60 [00:10<00:00,  6.51it/s][A
100%|██████████| 60/60 [00:10<00:00,  5.59it/s][A
 80%|████████  | 8/10 [02:21<00:34, 17.33s/it]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:34,  1.61s/it][A
  3%|▎         | 2/60 [00:01<01:07,  1.16s/it][A
  5%|▌         | 3/60 [00:01<00:49,  1.16it/s][A
  7%|▋         | 4/60 [00:02<00:36,  1.54it/s][A
  8%|▊         | 5/60 [00:02<00:27,  2.00it/s][A
 10%|█         | 6/60 [00:02<00:21,  2.52it/s][A
 12%|█▏        | 7/60 [00:02<00:17,  3.10it/s][A


(476, 21)

In [31]:
# A models
ckpts = glob('../mnt/checkpoints/e125/state_dicts/*')
bert_a_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = BertModelForBinaryMultiLabelClassifier(9, '../mnt/datasets/model_configs/bert-model-uncased-config.pkl', None, token_size=30528)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, a_test_loader, DEVICE, 'test')
    bert_a_preds.append(y_preds)
    del model
    gc.collect()
res_bert_a_pred = np.mean(bert_a_preds, axis=0)
res_bert_a_pred.shape

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:36,  1.63s/it][A
  3%|▎         | 2/60 [00:01<01:08,  1.19s/it][A
  5%|▌         | 3/60 [00:01<00:49,  1.14it/s][A
  7%|▋         | 4/60 [00:02<00:36,  1.52it/s][A
  8%|▊         | 5/60 [00:02<00:27,  1.97it/s][A
 10%|█         | 6/60 [00:02<00:21,  2.49it/s][A
 12%|█▏        | 7/60 [00:02<00:17,  3.07it/s][A
 13%|█▎        | 8/60 [00:02<00:14,  3.65it/s][A
 15%|█▌        | 9/60 [00:02<00:12,  4.20it/s][A
 17%|█▋        | 10/60 [00:03<00:10,  4.70it/s][A
 18%|█▊        | 11/60 [00:03<00:09,  5.15it/s][A
 20%|██        | 12/60 [00:03<00:08,  5.46it/s][A
 22%|██▏       | 13/60 [00:03<00:08,  5.78it/s][A
 23%|██▎       | 14/60 [00:03<00:07,  5.98it/s][A
 25%|██▌       | 15/60 [00:03<00:07,  6.10it/s][A
 27%|██▋       | 16/60 [00:03<00:07,  6.20it/s][A
 28%|██▊       | 17/60 [00:04<00:06,  6.35it/s][A
 30%|███       | 18/60 [00:04<00:06,  6.42it/s][A
 32%|███▏ 

 62%|██████▏   | 37/60 [00:07<00:03,  6.55it/s][A
 63%|██████▎   | 38/60 [00:07<00:03,  6.53it/s][A
 65%|██████▌   | 39/60 [00:07<00:03,  6.49it/s][A
 67%|██████▋   | 40/60 [00:07<00:03,  6.51it/s][A
 68%|██████▊   | 41/60 [00:07<00:02,  6.58it/s][A
 70%|███████   | 42/60 [00:07<00:02,  6.55it/s][A
 72%|███████▏  | 43/60 [00:08<00:02,  6.56it/s][A
 73%|███████▎  | 44/60 [00:08<00:02,  6.54it/s][A
 75%|███████▌  | 45/60 [00:08<00:02,  6.54it/s][A
 77%|███████▋  | 46/60 [00:08<00:02,  6.54it/s][A
 78%|███████▊  | 47/60 [00:08<00:02,  6.48it/s][A
 80%|████████  | 48/60 [00:08<00:01,  6.50it/s][A
 82%|████████▏ | 49/60 [00:08<00:01,  6.54it/s][A
 83%|████████▎ | 50/60 [00:09<00:01,  6.56it/s][A
 85%|████████▌ | 51/60 [00:09<00:01,  6.56it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.49it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.54it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.50it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.52it/s][A
 93%|█████████▎| 56/60 [00:10<0

 20%|██        | 12/60 [00:03<00:08,  5.52it/s][A
 22%|██▏       | 13/60 [00:03<00:08,  5.78it/s][A
 23%|██▎       | 14/60 [00:03<00:07,  5.95it/s][A
 25%|██▌       | 15/60 [00:03<00:07,  6.12it/s][A
 27%|██▋       | 16/60 [00:03<00:07,  6.26it/s][A
 28%|██▊       | 17/60 [00:04<00:06,  6.36it/s][A
 30%|███       | 18/60 [00:04<00:06,  6.40it/s][A
 32%|███▏      | 19/60 [00:04<00:06,  6.42it/s][A
 33%|███▎      | 20/60 [00:04<00:06,  6.43it/s][A
 35%|███▌      | 21/60 [00:04<00:06,  6.48it/s][A
 37%|███▋      | 22/60 [00:04<00:05,  6.48it/s][A
 38%|███▊      | 23/60 [00:04<00:05,  6.49it/s][A
 40%|████      | 24/60 [00:05<00:05,  6.49it/s][A
 42%|████▏     | 25/60 [00:05<00:05,  6.48it/s][A
 43%|████▎     | 26/60 [00:05<00:05,  6.55it/s][A
 45%|████▌     | 27/60 [00:05<00:05,  6.55it/s][A
 47%|████▋     | 28/60 [00:05<00:04,  6.52it/s][A
 48%|████▊     | 29/60 [00:05<00:04,  6.52it/s][A
 50%|█████     | 30/60 [00:06<00:04,  6.52it/s][A
 52%|█████▏    | 31/60 [00:06<0

 82%|████████▏ | 49/60 [00:08<00:01,  6.53it/s][A
 83%|████████▎ | 50/60 [00:09<00:01,  6.55it/s][A
 85%|████████▌ | 51/60 [00:09<00:01,  6.50it/s][A
 87%|████████▋ | 52/60 [00:09<00:01,  6.50it/s][A
 88%|████████▊ | 53/60 [00:09<00:01,  6.51it/s][A
 90%|█████████ | 54/60 [00:09<00:00,  6.58it/s][A
 92%|█████████▏| 55/60 [00:09<00:00,  6.57it/s][A
 93%|█████████▎| 56/60 [00:10<00:00,  6.55it/s][A
 95%|█████████▌| 57/60 [00:10<00:00,  6.47it/s][A
 97%|█████████▋| 58/60 [00:10<00:00,  6.53it/s][A
 98%|█████████▊| 59/60 [00:10<00:00,  6.51it/s][A
100%|██████████| 60/60 [00:10<00:00,  5.58it/s][A
 80%|████████  | 8/10 [02:14<00:33, 16.71s/it]
  0%|          | 0/60 [00:00<?, ?it/s][A
  2%|▏         | 1/60 [00:01<01:36,  1.64s/it][A
  3%|▎         | 2/60 [00:01<01:08,  1.18s/it][A
  5%|▌         | 3/60 [00:01<00:49,  1.15it/s][A
  7%|▋         | 4/60 [00:02<00:36,  1.52it/s][A
  8%|▊         | 5/60 [00:02<00:27,  1.97it/s][A
 10%|█         | 6/60 [00:02<00:21,  2.50it/s][A

(476, 9)

In [32]:
res_bert_pred = np.concatenate([res_bert_q_pred, res_bert_a_pred], axis=1)
res_bert_pred.shape

(476, 30)

In [33]:
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/res_bert_pred.pkl', 'wb') as fout:
    pickle.dump(res_bert_pred, fout)

In [34]:
torch.cuda.empty_cache()

## Roberta

In [49]:
# %debug
Q_LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
]

A_LABEL_COL = [
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

q_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='roberta',
    pretrained_model_name_or_path='roberta-base',
    do_lower_case=False,
    LABEL_COL=Q_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 2,
    a_max_len=239 * 0,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

q_test_sampler = SequentialSampler(data_source=q_test_dataset)
q_test_loader = DataLoader(
        q_test_dataset,
        batch_size=8,
        sampler=q_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )


a_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='roberta',
    pretrained_model_name_or_path='roberta-base',
    do_lower_case=False,
    LABEL_COL=A_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 0,
    a_max_len=239 * 2,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

a_test_sampler = SequentialSampler(data_source=a_test_dataset)
a_test_loader = DataLoader(
        a_test_dataset,
        batch_size=8,
        sampler=a_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )

additional_tokens : 5
additional_tokens : 5


In [50]:
# Q models
ckpts = glob('../mnt/checkpoints/e126/state_dicts/*')
roberta_q_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = RobertaModelForBinaryMultiLabelClassifier(21, '../mnt/datasets/model_configs/roberta-model-base-config.pkl', None, token_size=50271)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, q_test_loader, DEVICE, 'test')
    roberta_q_preds.append(y_preds)
    del model
    gc.collect()
res_roberta_q_pred = np.mean(roberta_q_preds, axis=0)
res_roberta_q_pred.shape



  0%|          | 0/10 [00:00<?, ?it/s][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:01<01:52,  1.91s/it][A[A[A


  3%|▎         | 2/60 [00:02<01:20,  1.38s/it][A[A[A


  5%|▌         | 3/60 [00:02<00:57,  1.01s/it][A[A[A


  7%|▋         | 4/60 [00:02<00:42,  1.32it/s][A[A[A


  8%|▊         | 5/60 [00:02<00:31,  1.74it/s][A[A[A


 10%|█         | 6/60 [00:02<00:24,  2.24it/s][A[A[A


 12%|█▏        | 7/60 [00:02<00:19,  2.79it/s][A[A[A


 13%|█▎        | 8/60 [00:02<00:15,  3.37it/s][A[A[A


 15%|█▌        | 9/60 [00:03<00:12,  3.94it/s][A[A[A


 17%|█▋        | 10/60 [00:03<00:11,  4.45it/s][A[A[A


 18%|█▊        | 11/60 [00:03<00:09,  4.92it/s][A[A[A


 20%|██        | 12/60 [00:03<00:08,  5.35it/s][A[A[A


 22%|██▏       | 13/60 [00:03<00:08,  5.64it/s][A[A[A


 23%|██▎       | 14/60 [00:03<00:07,  5.89it/s][A[A[A


 25%|██▌       | 15/60 [00:04<00:07,  6.08it/s][A[A[A


 27%|██▋       | 16/60 [0

 25%|██▌       | 15/60 [00:03<00:07,  6.06it/s][A[A[A


 27%|██▋       | 16/60 [00:04<00:07,  6.21it/s][A[A[A


 28%|██▊       | 17/60 [00:04<00:06,  6.34it/s][A[A[A


 30%|███       | 18/60 [00:04<00:06,  6.40it/s][A[A[A


 32%|███▏      | 19/60 [00:04<00:06,  6.42it/s][A[A[A


 33%|███▎      | 20/60 [00:04<00:06,  6.46it/s][A[A[A


 35%|███▌      | 21/60 [00:04<00:06,  6.49it/s][A[A[A


 37%|███▋      | 22/60 [00:05<00:05,  6.52it/s][A[A[A


 38%|███▊      | 23/60 [00:05<00:05,  6.50it/s][A[A[A


 40%|████      | 24/60 [00:05<00:05,  6.44it/s][A[A[A


 42%|████▏     | 25/60 [00:05<00:05,  6.55it/s][A[A[A


 43%|████▎     | 26/60 [00:05<00:05,  6.54it/s][A[A[A


 45%|████▌     | 27/60 [00:05<00:05,  6.53it/s][A[A[A


 47%|████▋     | 28/60 [00:05<00:04,  6.49it/s][A[A[A


 48%|████▊     | 29/60 [00:06<00:04,  6.50it/s][A[A[A


 50%|█████     | 30/60 [00:06<00:04,  6.51it/s][A[A[A


 52%|█████▏    | 31/60 [00:06<00:04,  6.50it/s][A[A[A

 50%|█████     | 30/60 [00:06<00:04,  6.53it/s][A[A[A


 52%|█████▏    | 31/60 [00:06<00:04,  6.51it/s][A[A[A


 53%|█████▎    | 32/60 [00:06<00:04,  6.49it/s][A[A[A


 55%|█████▌    | 33/60 [00:06<00:04,  6.50it/s][A[A[A


 57%|█████▋    | 34/60 [00:06<00:03,  6.52it/s][A[A[A


 58%|█████▊    | 35/60 [00:07<00:03,  6.53it/s][A[A[A


 60%|██████    | 36/60 [00:07<00:03,  6.51it/s][A[A[A


 62%|██████▏   | 37/60 [00:07<00:03,  6.50it/s][A[A[A


 63%|██████▎   | 38/60 [00:07<00:03,  6.52it/s][A[A[A


 65%|██████▌   | 39/60 [00:07<00:03,  6.57it/s][A[A[A


 67%|██████▋   | 40/60 [00:07<00:03,  6.56it/s][A[A[A


 68%|██████▊   | 41/60 [00:08<00:02,  6.56it/s][A[A[A


 70%|███████   | 42/60 [00:08<00:02,  6.49it/s][A[A[A


 72%|███████▏  | 43/60 [00:08<00:02,  6.56it/s][A[A[A


 73%|███████▎  | 44/60 [00:08<00:02,  6.50it/s][A[A[A


 75%|███████▌  | 45/60 [00:08<00:02,  6.49it/s][A[A[A


 77%|███████▋  | 46/60 [00:08<00:02,  6.53it/s][A[A[A

 75%|███████▌  | 45/60 [00:08<00:02,  6.52it/s][A[A[A


 77%|███████▋  | 46/60 [00:08<00:02,  6.56it/s][A[A[A


 78%|███████▊  | 47/60 [00:08<00:01,  6.55it/s][A[A[A


 80%|████████  | 48/60 [00:09<00:01,  6.56it/s][A[A[A


 82%|████████▏ | 49/60 [00:09<00:01,  6.55it/s][A[A[A


 83%|████████▎ | 50/60 [00:09<00:01,  6.52it/s][A[A[A


 85%|████████▌ | 51/60 [00:09<00:01,  6.53it/s][A[A[A


 87%|████████▋ | 52/60 [00:09<00:01,  6.53it/s][A[A[A


 88%|████████▊ | 53/60 [00:09<00:01,  6.54it/s][A[A[A


 90%|█████████ | 54/60 [00:10<00:00,  6.52it/s][A[A[A


 92%|█████████▏| 55/60 [00:10<00:00,  6.53it/s][A[A[A


 93%|█████████▎| 56/60 [00:10<00:00,  6.53it/s][A[A[A


 95%|█████████▌| 57/60 [00:10<00:00,  6.46it/s][A[A[A


 97%|█████████▋| 58/60 [00:10<00:00,  6.44it/s][A[A[A


 98%|█████████▊| 59/60 [00:10<00:00,  6.48it/s][A[A[A


100%|██████████| 60/60 [00:11<00:00,  5.40it/s][A[A[A


 70%|███████   | 7/10 [02:13<00:57, 19.01s/it][A[A


 

100%|██████████| 60/60 [00:11<00:00,  5.40it/s][A[A[A


 90%|█████████ | 9/10 [02:51<00:19, 19.17s/it][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:01<01:49,  1.86s/it][A[A[A


  3%|▎         | 2/60 [00:01<01:17,  1.34s/it][A[A[A


  5%|▌         | 3/60 [00:02<00:56,  1.02it/s][A[A[A


  7%|▋         | 4/60 [00:02<00:41,  1.36it/s][A[A[A


  8%|▊         | 5/60 [00:02<00:30,  1.78it/s][A[A[A


 10%|█         | 6/60 [00:02<00:23,  2.27it/s][A[A[A


 12%|█▏        | 7/60 [00:02<00:18,  2.83it/s][A[A[A


 13%|█▎        | 8/60 [00:02<00:15,  3.42it/s][A[A[A


 15%|█▌        | 9/60 [00:03<00:12,  3.99it/s][A[A[A


 17%|█▋        | 10/60 [00:03<00:11,  4.52it/s][A[A[A


 18%|█▊        | 11/60 [00:03<00:09,  4.99it/s][A[A[A


 20%|██        | 12/60 [00:03<00:08,  5.36it/s][A[A[A


 22%|██▏       | 13/60 [00:03<00:08,  5.67it/s][A[A[A


 23%|██▎       | 14/60 [00:03<00:07,  5.90it/s][A[A[A


 25%|██▌       | 15

(476, 21)

In [51]:
# A models
ckpts = glob('../mnt/checkpoints/e127/state_dicts/*')
roberta_a_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = RobertaModelForBinaryMultiLabelClassifier(9, '../mnt/datasets/model_configs/roberta-model-base-config.pkl', None, token_size=50271)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, a_test_loader, DEVICE, 'test')
    roberta_a_preds.append(y_preds)
    del model
    gc.collect()
res_roberta_a_pred = np.mean(roberta_a_preds, axis=0)
res_roberta_a_pred.shape



  0%|          | 0/10 [00:00<?, ?it/s][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:01<01:54,  1.93s/it][A[A[A


  3%|▎         | 2/60 [00:02<01:20,  1.40s/it][A[A[A


  5%|▌         | 3/60 [00:02<00:58,  1.02s/it][A[A[A


  7%|▋         | 4/60 [00:02<00:42,  1.31it/s][A[A[A


  8%|▊         | 5/60 [00:02<00:31,  1.72it/s][A[A[A


 10%|█         | 6/60 [00:02<00:24,  2.21it/s][A[A[A


 12%|█▏        | 7/60 [00:02<00:19,  2.76it/s][A[A[A


 13%|█▎        | 8/60 [00:02<00:15,  3.34it/s][A[A[A


 15%|█▌        | 9/60 [00:03<00:13,  3.89it/s][A[A[A


 17%|█▋        | 10/60 [00:03<00:11,  4.45it/s][A[A[A


 18%|█▊        | 11/60 [00:03<00:09,  4.93it/s][A[A[A


 20%|██        | 12/60 [00:03<00:09,  5.32it/s][A[A[A


 22%|██▏       | 13/60 [00:03<00:08,  5.59it/s][A[A[A


 23%|██▎       | 14/60 [00:03<00:07,  5.84it/s][A[A[A


 25%|██▌       | 15/60 [00:04<00:07,  6.04it/s][A[A[A


 27%|██▋       | 16/60 [0

 25%|██▌       | 15/60 [00:04<00:07,  6.07it/s][A[A[A


 27%|██▋       | 16/60 [00:04<00:07,  6.15it/s][A[A[A


 28%|██▊       | 17/60 [00:04<00:06,  6.27it/s][A[A[A


 30%|███       | 18/60 [00:04<00:06,  6.35it/s][A[A[A


 32%|███▏      | 19/60 [00:04<00:06,  6.42it/s][A[A[A


 33%|███▎      | 20/60 [00:04<00:06,  6.45it/s][A[A[A


 35%|███▌      | 21/60 [00:04<00:06,  6.50it/s][A[A[A


 37%|███▋      | 22/60 [00:05<00:05,  6.50it/s][A[A[A


 38%|███▊      | 23/60 [00:05<00:05,  6.51it/s][A[A[A


 40%|████      | 24/60 [00:05<00:05,  6.51it/s][A[A[A


 42%|████▏     | 25/60 [00:05<00:05,  6.52it/s][A[A[A


 43%|████▎     | 26/60 [00:05<00:05,  6.53it/s][A[A[A


 45%|████▌     | 27/60 [00:05<00:05,  6.53it/s][A[A[A


 47%|████▋     | 28/60 [00:06<00:04,  6.52it/s][A[A[A


 48%|████▊     | 29/60 [00:06<00:04,  6.50it/s][A[A[A


 50%|█████     | 30/60 [00:06<00:04,  6.54it/s][A[A[A


 52%|█████▏    | 31/60 [00:06<00:04,  6.54it/s][A[A[A

 50%|█████     | 30/60 [00:06<00:04,  6.54it/s][A[A[A


 52%|█████▏    | 31/60 [00:06<00:04,  6.54it/s][A[A[A


 53%|█████▎    | 32/60 [00:06<00:04,  6.48it/s][A[A[A


 55%|█████▌    | 33/60 [00:06<00:04,  6.54it/s][A[A[A


 57%|█████▋    | 34/60 [00:06<00:03,  6.56it/s][A[A[A


 58%|█████▊    | 35/60 [00:07<00:03,  6.50it/s][A[A[A


 60%|██████    | 36/60 [00:07<00:03,  6.49it/s][A[A[A


 62%|██████▏   | 37/60 [00:07<00:03,  6.51it/s][A[A[A


 63%|██████▎   | 38/60 [00:07<00:03,  6.58it/s][A[A[A


 65%|██████▌   | 39/60 [00:07<00:03,  6.51it/s][A[A[A


 67%|██████▋   | 40/60 [00:07<00:03,  6.56it/s][A[A[A


 68%|██████▊   | 41/60 [00:07<00:02,  6.56it/s][A[A[A


 70%|███████   | 42/60 [00:08<00:02,  6.56it/s][A[A[A


 72%|███████▏  | 43/60 [00:08<00:02,  6.54it/s][A[A[A


 73%|███████▎  | 44/60 [00:08<00:02,  6.54it/s][A[A[A


 75%|███████▌  | 45/60 [00:08<00:02,  6.53it/s][A[A[A


 77%|███████▋  | 46/60 [00:08<00:02,  6.54it/s][A[A[A

 75%|███████▌  | 45/60 [00:08<00:02,  6.55it/s][A[A[A


 77%|███████▋  | 46/60 [00:08<00:02,  6.54it/s][A[A[A


 78%|███████▊  | 47/60 [00:08<00:01,  6.55it/s][A[A[A


 80%|████████  | 48/60 [00:09<00:01,  6.48it/s][A[A[A


 82%|████████▏ | 49/60 [00:09<00:01,  6.55it/s][A[A[A


 83%|████████▎ | 50/60 [00:09<00:01,  6.55it/s][A[A[A


 85%|████████▌ | 51/60 [00:09<00:01,  6.49it/s][A[A[A


 87%|████████▋ | 52/60 [00:09<00:01,  6.51it/s][A[A[A


 88%|████████▊ | 53/60 [00:09<00:01,  6.51it/s][A[A[A


 90%|█████████ | 54/60 [00:10<00:00,  6.50it/s][A[A[A


 92%|█████████▏| 55/60 [00:10<00:00,  6.56it/s][A[A[A


 93%|█████████▎| 56/60 [00:10<00:00,  6.58it/s][A[A[A


 95%|█████████▌| 57/60 [00:10<00:00,  6.54it/s][A[A[A


 97%|█████████▋| 58/60 [00:10<00:00,  6.54it/s][A[A[A


 98%|█████████▊| 59/60 [00:10<00:00,  6.49it/s][A[A[A


100%|██████████| 60/60 [00:11<00:00,  5.41it/s][A[A[A


 70%|███████   | 7/10 [02:18<00:59, 19.86s/it][A[A


 

100%|██████████| 60/60 [00:11<00:00,  5.43it/s][A[A[A


 90%|█████████ | 9/10 [02:57<00:19, 19.79s/it][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:01<01:52,  1.91s/it][A[A[A


  3%|▎         | 2/60 [00:02<01:19,  1.38s/it][A[A[A


  5%|▌         | 3/60 [00:02<00:57,  1.01s/it][A[A[A


  7%|▋         | 4/60 [00:02<00:42,  1.33it/s][A[A[A


  8%|▊         | 5/60 [00:02<00:31,  1.74it/s][A[A[A


 10%|█         | 6/60 [00:02<00:24,  2.24it/s][A[A[A


 12%|█▏        | 7/60 [00:02<00:19,  2.79it/s][A[A[A


 13%|█▎        | 8/60 [00:02<00:15,  3.35it/s][A[A[A


 15%|█▌        | 9/60 [00:03<00:13,  3.92it/s][A[A[A


 17%|█▋        | 10/60 [00:03<00:11,  4.49it/s][A[A[A


 18%|█▊        | 11/60 [00:03<00:09,  4.94it/s][A[A[A


 20%|██        | 12/60 [00:03<00:09,  5.31it/s][A[A[A


 22%|██▏       | 13/60 [00:03<00:08,  5.62it/s][A[A[A


 23%|██▎       | 14/60 [00:03<00:07,  5.87it/s][A[A[A


 25%|██▌       | 15

(476, 9)

In [52]:
res_roberta_pred = np.concatenate([res_roberta_q_pred, res_roberta_a_pred], axis=1)
res_roberta_pred.shape

(476, 30)

In [53]:
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/res_roberta_pred.pkl', 'wb') as fout:
    pickle.dump(res_roberta_pred, fout)

## XLNET

In [42]:
# %debug
Q_LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
]

A_LABEL_COL = [
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

q_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='xlnet',
    pretrained_model_name_or_path='../mnt/checkpoints/e128/datasets/',
    do_lower_case=False,
    LABEL_COL=Q_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 2,
    a_max_len=239 * 0,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

q_test_sampler = SequentialSampler(data_source=q_test_dataset)
q_test_loader = DataLoader(
        q_test_dataset,
        batch_size=8,
        sampler=q_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )


a_test_dataset = QUESTDataset(
    df=tst_df,
    mode='test',
    tokens = [
        'CAT_TECHNOLOGY'.casefold(),
        'CAT_STACKOVERFLOW'.casefold(),
        'CAT_CULTURE'.casefold(),
        'CAT_SCIENCE'.casefold(),
        'CAT_LIFE_ARTS'.casefold(),
    ],
    augment=[],
    tokenizer_type='xlnet',
    pretrained_model_name_or_path='../mnt/checkpoints/e128/datasets/',
    do_lower_case=False,
    LABEL_COL=A_LABEL_COL,
    t_max_len=30,
    q_max_len=239 * 0,
    a_max_len=239 * 2,
    tqa_mode='tq_a',
    TBSEP='[TBSEP]',
    pos_id_type='arange',
    MAX_SEQUENCE_LENGTH=512,
)

a_test_sampler = SequentialSampler(data_source=a_test_dataset)
a_test_loader = DataLoader(
        a_test_dataset,
        batch_size=8,
        sampler=a_test_sampler,
        num_workers=os.cpu_count(),
        worker_init_fn=lambda x: np.random.seed(),
        drop_last=False,
        pin_memory=True
    )

additional_tokens : 0
additional_tokens : 0


In [45]:
# Q models
ckpts = glob('../mnt/checkpoints/e128/state_dicts/*')
xlnet_q_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = XLNetModelForBinaryMultiLabelClassifier(21, '../mnt/datasets/model_configs/xlnet-model-base-cased-config.pkl', None, token_size=32006)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, q_test_loader, DEVICE, 'test')
    xlnet_q_preds.append(y_preds)
    del model
    gc.collect()
res_xlnet_q_pred = np.mean(xlnet_q_preds, axis=0)
res_xlnet_q_pred.shape



  0%|          | 0/10 [00:00<?, ?it/s][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:03<03:06,  3.16s/it][A[A[A


  3%|▎         | 2/60 [00:03<02:14,  2.32s/it][A[A[A


  5%|▌         | 3/60 [00:03<01:39,  1.74s/it][A[A[A


  7%|▋         | 4/60 [00:04<01:14,  1.33s/it][A[A[A


  8%|▊         | 5/60 [00:04<00:57,  1.04s/it][A[A[A


 10%|█         | 6/60 [00:05<00:45,  1.19it/s][A[A[A


 12%|█▏        | 7/60 [00:05<00:37,  1.42it/s][A[A[A


 13%|█▎        | 8/60 [00:05<00:31,  1.65it/s][A[A[A


 15%|█▌        | 9/60 [00:06<00:27,  1.86it/s][A[A[A


 17%|█▋        | 10/60 [00:06<00:24,  2.05it/s][A[A[A


 18%|█▊        | 11/60 [00:06<00:22,  2.20it/s][A[A[A


 20%|██        | 12/60 [00:07<00:20,  2.32it/s][A[A[A


 22%|██▏       | 13/60 [00:07<00:19,  2.41it/s][A[A[A


 23%|██▎       | 14/60 [00:08<00:18,  2.48it/s][A[A[A


 25%|██▌       | 15/60 [00:08<00:17,  2.53it/s][A[A[A


 27%|██▋       | 16/60 [0

 25%|██▌       | 15/60 [00:07<00:17,  2.58it/s][A[A[A


 27%|██▋       | 16/60 [00:07<00:16,  2.61it/s][A[A[A


 28%|██▊       | 17/60 [00:08<00:16,  2.62it/s][A[A[A


 30%|███       | 18/60 [00:08<00:15,  2.63it/s][A[A[A


 32%|███▏      | 19/60 [00:08<00:15,  2.64it/s][A[A[A


 33%|███▎      | 20/60 [00:09<00:15,  2.65it/s][A[A[A


 35%|███▌      | 21/60 [00:09<00:14,  2.66it/s][A[A[A


 37%|███▋      | 22/60 [00:09<00:14,  2.66it/s][A[A[A


 38%|███▊      | 23/60 [00:10<00:13,  2.66it/s][A[A[A


 40%|████      | 24/60 [00:10<00:13,  2.66it/s][A[A[A


 42%|████▏     | 25/60 [00:11<00:13,  2.66it/s][A[A[A


 43%|████▎     | 26/60 [00:11<00:12,  2.66it/s][A[A[A


 45%|████▌     | 27/60 [00:11<00:12,  2.66it/s][A[A[A


 47%|████▋     | 28/60 [00:12<00:12,  2.66it/s][A[A[A


 48%|████▊     | 29/60 [00:12<00:11,  2.66it/s][A[A[A


 50%|█████     | 30/60 [00:12<00:11,  2.66it/s][A[A[A


 52%|█████▏    | 31/60 [00:13<00:10,  2.66it/s][A[A[A

 50%|█████     | 30/60 [00:12<00:11,  2.66it/s][A[A[A


 52%|█████▏    | 31/60 [00:13<00:10,  2.66it/s][A[A[A


 53%|█████▎    | 32/60 [00:13<00:10,  2.66it/s][A[A[A


 55%|█████▌    | 33/60 [00:14<00:10,  2.66it/s][A[A[A


 57%|█████▋    | 34/60 [00:14<00:09,  2.66it/s][A[A[A


 58%|█████▊    | 35/60 [00:14<00:09,  2.66it/s][A[A[A


 60%|██████    | 36/60 [00:15<00:09,  2.66it/s][A[A[A


 62%|██████▏   | 37/60 [00:15<00:08,  2.66it/s][A[A[A


 63%|██████▎   | 38/60 [00:15<00:08,  2.66it/s][A[A[A


 65%|██████▌   | 39/60 [00:16<00:07,  2.66it/s][A[A[A


 67%|██████▋   | 40/60 [00:16<00:07,  2.66it/s][A[A[A


 68%|██████▊   | 41/60 [00:17<00:07,  2.66it/s][A[A[A


 70%|███████   | 42/60 [00:17<00:06,  2.66it/s][A[A[A


 72%|███████▏  | 43/60 [00:17<00:06,  2.67it/s][A[A[A


 73%|███████▎  | 44/60 [00:18<00:06,  2.66it/s][A[A[A


 75%|███████▌  | 45/60 [00:18<00:05,  2.66it/s][A[A[A


 77%|███████▋  | 46/60 [00:18<00:05,  2.66it/s][A[A[A

 75%|███████▌  | 45/60 [00:18<00:05,  2.66it/s][A[A[A


 77%|███████▋  | 46/60 [00:18<00:05,  2.67it/s][A[A[A


 78%|███████▊  | 47/60 [00:19<00:04,  2.67it/s][A[A[A


 80%|████████  | 48/60 [00:19<00:04,  2.67it/s][A[A[A


 82%|████████▏ | 49/60 [00:20<00:04,  2.67it/s][A[A[A


 83%|████████▎ | 50/60 [00:20<00:03,  2.66it/s][A[A[A


 85%|████████▌ | 51/60 [00:20<00:03,  2.66it/s][A[A[A


 87%|████████▋ | 52/60 [00:21<00:03,  2.67it/s][A[A[A


 88%|████████▊ | 53/60 [00:21<00:02,  2.66it/s][A[A[A


 90%|█████████ | 54/60 [00:21<00:02,  2.67it/s][A[A[A


 92%|█████████▏| 55/60 [00:22<00:01,  2.67it/s][A[A[A


 93%|█████████▎| 56/60 [00:22<00:01,  2.67it/s][A[A[A


 95%|█████████▌| 57/60 [00:23<00:01,  2.66it/s][A[A[A


 97%|█████████▋| 58/60 [00:23<00:00,  2.67it/s][A[A[A


 98%|█████████▊| 59/60 [00:23<00:00,  2.66it/s][A[A[A


100%|██████████| 60/60 [00:24<00:00,  2.48it/s][A[A[A


 70%|███████   | 7/10 [03:29<01:29, 29.96s/it][A[A


 

100%|██████████| 60/60 [00:24<00:00,  2.47it/s][A[A[A


 90%|█████████ | 9/10 [04:29<00:30, 30.14s/it][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:02<02:03,  2.10s/it][A[A[A


  3%|▎         | 2/60 [00:02<01:31,  1.58s/it][A[A[A


  5%|▌         | 3/60 [00:02<01:09,  1.22s/it][A[A[A


  7%|▋         | 4/60 [00:03<00:54,  1.04it/s][A[A[A


  8%|▊         | 5/60 [00:03<00:43,  1.27it/s][A[A[A


 10%|█         | 6/60 [00:03<00:35,  1.50it/s][A[A[A


 12%|█▏        | 7/60 [00:04<00:30,  1.73it/s][A[A[A


 13%|█▎        | 8/60 [00:04<00:26,  1.93it/s][A[A[A


 15%|█▌        | 9/60 [00:05<00:24,  2.11it/s][A[A[A


 17%|█▋        | 10/60 [00:05<00:22,  2.25it/s][A[A[A


 18%|█▊        | 11/60 [00:05<00:20,  2.36it/s][A[A[A


 20%|██        | 12/60 [00:06<00:19,  2.44it/s][A[A[A


 22%|██▏       | 13/60 [00:06<00:18,  2.51it/s][A[A[A


 23%|██▎       | 14/60 [00:06<00:18,  2.55it/s][A[A[A


 25%|██▌       | 15

(476, 21)

In [46]:
# A models
ckpts = glob('../mnt/checkpoints/e129/state_dicts/*')
xlnet_a_preds = []
for i, ckpt in enumerate(tqdm(ckpts)):
    with open(ckpt, 'rb') as fin:
        state_dict = pickle.load(fin)
    model = XLNetModelForBinaryMultiLabelClassifier(9, '../mnt/datasets/model_configs/xlnet-model-base-cased-config.pkl', None, token_size=32006)
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.to(DEVICE)
    _, _, _, y_preds, _, qa_ids = test(model, None, a_test_loader, DEVICE, 'test')
    xlnet_a_preds.append(y_preds)
    del model
    gc.collect()
res_xlnet_a_pred = np.mean(xlnet_a_preds, axis=0)
res_xlnet_a_pred.shape



  0%|          | 0/10 [00:00<?, ?it/s][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:02<02:03,  2.09s/it][A[A[A


  3%|▎         | 2/60 [00:02<01:31,  1.58s/it][A[A[A


  5%|▌         | 3/60 [00:02<01:09,  1.22s/it][A[A[A


  7%|▋         | 4/60 [00:03<00:54,  1.04it/s][A[A[A


  8%|▊         | 5/60 [00:03<00:43,  1.27it/s][A[A[A


 10%|█         | 6/60 [00:03<00:35,  1.51it/s][A[A[A


 12%|█▏        | 7/60 [00:04<00:30,  1.73it/s][A[A[A


 13%|█▎        | 8/60 [00:04<00:26,  1.93it/s][A[A[A


 15%|█▌        | 9/60 [00:05<00:24,  2.11it/s][A[A[A


 17%|█▋        | 10/60 [00:05<00:22,  2.25it/s][A[A[A


 18%|█▊        | 11/60 [00:05<00:20,  2.36it/s][A[A[A


 20%|██        | 12/60 [00:06<00:19,  2.44it/s][A[A[A


 22%|██▏       | 13/60 [00:06<00:18,  2.50it/s][A[A[A


 23%|██▎       | 14/60 [00:06<00:18,  2.55it/s][A[A[A


 25%|██▌       | 15/60 [00:07<00:17,  2.58it/s][A[A[A


 27%|██▋       | 16/60 [0

 25%|██▌       | 15/60 [00:07<00:17,  2.58it/s][A[A[A


 27%|██▋       | 16/60 [00:07<00:16,  2.60it/s][A[A[A


 28%|██▊       | 17/60 [00:08<00:16,  2.62it/s][A[A[A


 30%|███       | 18/60 [00:08<00:15,  2.63it/s][A[A[A


 32%|███▏      | 19/60 [00:08<00:15,  2.64it/s][A[A[A


 33%|███▎      | 20/60 [00:09<00:15,  2.65it/s][A[A[A


 35%|███▌      | 21/60 [00:09<00:14,  2.65it/s][A[A[A


 37%|███▋      | 22/60 [00:09<00:14,  2.65it/s][A[A[A


 38%|███▊      | 23/60 [00:10<00:13,  2.66it/s][A[A[A


 40%|████      | 24/60 [00:10<00:13,  2.66it/s][A[A[A


 42%|████▏     | 25/60 [00:11<00:13,  2.66it/s][A[A[A


 43%|████▎     | 26/60 [00:11<00:12,  2.66it/s][A[A[A


 45%|████▌     | 27/60 [00:11<00:12,  2.66it/s][A[A[A


 47%|████▋     | 28/60 [00:12<00:12,  2.66it/s][A[A[A


 48%|████▊     | 29/60 [00:12<00:11,  2.66it/s][A[A[A


 50%|█████     | 30/60 [00:12<00:11,  2.65it/s][A[A[A


 52%|█████▏    | 31/60 [00:13<00:10,  2.65it/s][A[A[A

 50%|█████     | 30/60 [00:12<00:11,  2.67it/s][A[A[A


 52%|█████▏    | 31/60 [00:13<00:10,  2.67it/s][A[A[A


 53%|█████▎    | 32/60 [00:13<00:10,  2.67it/s][A[A[A


 55%|█████▌    | 33/60 [00:14<00:10,  2.67it/s][A[A[A


 57%|█████▋    | 34/60 [00:14<00:09,  2.67it/s][A[A[A


 58%|█████▊    | 35/60 [00:14<00:09,  2.67it/s][A[A[A


 60%|██████    | 36/60 [00:15<00:08,  2.67it/s][A[A[A


 62%|██████▏   | 37/60 [00:15<00:08,  2.67it/s][A[A[A


 63%|██████▎   | 38/60 [00:15<00:08,  2.67it/s][A[A[A


 65%|██████▌   | 39/60 [00:16<00:07,  2.67it/s][A[A[A


 67%|██████▋   | 40/60 [00:16<00:07,  2.67it/s][A[A[A


 68%|██████▊   | 41/60 [00:17<00:07,  2.67it/s][A[A[A


 70%|███████   | 42/60 [00:17<00:06,  2.67it/s][A[A[A


 72%|███████▏  | 43/60 [00:17<00:06,  2.67it/s][A[A[A


 73%|███████▎  | 44/60 [00:18<00:05,  2.67it/s][A[A[A


 75%|███████▌  | 45/60 [00:18<00:05,  2.67it/s][A[A[A


 77%|███████▋  | 46/60 [00:18<00:05,  2.67it/s][A[A[A

 75%|███████▌  | 45/60 [00:18<00:05,  2.67it/s][A[A[A


 77%|███████▋  | 46/60 [00:18<00:05,  2.67it/s][A[A[A


 78%|███████▊  | 47/60 [00:19<00:04,  2.67it/s][A[A[A


 80%|████████  | 48/60 [00:19<00:04,  2.67it/s][A[A[A


 82%|████████▏ | 49/60 [00:20<00:04,  2.67it/s][A[A[A


 83%|████████▎ | 50/60 [00:20<00:03,  2.67it/s][A[A[A


 85%|████████▌ | 51/60 [00:20<00:03,  2.67it/s][A[A[A


 87%|████████▋ | 52/60 [00:21<00:03,  2.66it/s][A[A[A


 88%|████████▊ | 53/60 [00:21<00:02,  2.66it/s][A[A[A


 90%|█████████ | 54/60 [00:21<00:02,  2.67it/s][A[A[A


 92%|█████████▏| 55/60 [00:22<00:01,  2.67it/s][A[A[A


 93%|█████████▎| 56/60 [00:22<00:01,  2.67it/s][A[A[A


 95%|█████████▌| 57/60 [00:23<00:01,  2.67it/s][A[A[A


 97%|█████████▋| 58/60 [00:23<00:00,  2.67it/s][A[A[A


 98%|█████████▊| 59/60 [00:23<00:00,  2.67it/s][A[A[A


100%|██████████| 60/60 [00:24<00:00,  2.48it/s][A[A[A


 70%|███████   | 7/10 [03:31<01:30, 30.21s/it][A[A


 

100%|██████████| 60/60 [00:24<00:00,  2.48it/s][A[A[A


 90%|█████████ | 9/10 [04:31<00:30, 30.19s/it][A[A


  0%|          | 0/60 [00:00<?, ?it/s][A[A[A


  2%|▏         | 1/60 [00:02<02:03,  2.09s/it][A[A[A


  3%|▎         | 2/60 [00:02<01:31,  1.58s/it][A[A[A


  5%|▌         | 3/60 [00:02<01:09,  1.22s/it][A[A[A


  7%|▋         | 4/60 [00:03<00:54,  1.04it/s][A[A[A


  8%|▊         | 5/60 [00:03<00:43,  1.27it/s][A[A[A


 10%|█         | 6/60 [00:03<00:35,  1.51it/s][A[A[A


 12%|█▏        | 7/60 [00:04<00:30,  1.73it/s][A[A[A


 13%|█▎        | 8/60 [00:04<00:26,  1.94it/s][A[A[A


 15%|█▌        | 9/60 [00:05<00:24,  2.11it/s][A[A[A


 17%|█▋        | 10/60 [00:05<00:22,  2.25it/s][A[A[A


 18%|█▊        | 11/60 [00:05<00:20,  2.36it/s][A[A[A


 20%|██        | 12/60 [00:06<00:19,  2.45it/s][A[A[A


 22%|██▏       | 13/60 [00:06<00:18,  2.51it/s][A[A[A


 23%|██▎       | 14/60 [00:06<00:18,  2.55it/s][A[A[A


 25%|██▌       | 15

(476, 9)

In [47]:
res_xlnet_pred = np.concatenate([res_xlnet_q_pred, res_xlnet_a_pred], axis=1)
res_xlnet_pred.shape

(476, 30)

In [48]:
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/res_xlnet_pred.pkl', 'wb') as fout:
    pickle.dump(res_xlnet_pred, fout)

## 最後に全部をブレンド

In [54]:
LABEL_COL = [
    'question_asker_intent_understanding',
    'question_body_critical',
    'question_conversational',
    'question_expect_short_answer',
    'question_fact_seeking',
    'question_has_commonly_accepted_answer',
    'question_interestingness_others',
    'question_interestingness_self',
    'question_multi_intent',
    'question_not_really_a_question',
    'question_opinion_seeking',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'question_type_instructions',
    'question_type_procedure',
    'question_type_reason_explanation',
    'question_type_spelling',
    'question_well_written',
    'answer_helpful',
    'answer_level_of_information',
    'answer_plausible',
    'answer_relevance',
    'answer_satisfaction',
    'answer_type_instructions',
    'answer_type_procedure',
    'answer_type_reason_explanation',
    'answer_well_written'
]

In [55]:
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/res_bert_pred.pkl', 'rb') as fin:
    res_bert_pred = pickle.load(fin)
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/res_roberta_pred.pkl', 'rb') as fin:
    res_roberta_pred = pickle.load(fin)
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/res_xlnet_pred.pkl', 'rb') as fin:
    res_xlnet_pred = pickle.load(fin)

In [56]:
with open('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/optRs.pkl', 'rb') as fin:
    optRs = pickle.load(fin)

In [57]:
res_pred = np.mean([res_bert_pred, res_roberta_pred, res_xlnet_pred], axis=0)
res_pred.shape

(476, 30)

In [58]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')
tst_df.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


In [59]:
for i, col in enumerate(LABEL_COL):
    tst_df[col] = res_pred[:, i]

In [60]:
tst_df.to_csv('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/raw_pseudo_tst_df.csv', index=False)

## 全 opt

In [61]:
final_prediction = []
for i in tqdm(list(range(30))):
    y_pred = res_pred[:, i]
    #if i not in [2,4,5,6,7,11,12,13,14,15,16,18,19,29]:
    # if i not in [2,3,4,5,6,7,8,9,11,12,13,14,15,16,19,20,22,23,24,25]:
    #     final_prediction.append(y_pred)
    #     continue
    
    optR = optRs[i]
    res = optR.predict(y_pred, optR.coefficients()).astype(float)

    final_prediction.append(res)

prediction = np.asarray(final_prediction).T
prediction.shape



100%|██████████| 30/30 [00:00<00:00, 849.79it/s]


(476, 30)

In [62]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')
tst_df.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


In [63]:
for i, col in enumerate(LABEL_COL):
    tst_df[col] = prediction[:, i]

In [64]:
tst_df.to_csv('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/opt_pseudo_tst_df.csv', index=False)

## half opt

In [65]:
final_prediction = []
for i in tqdm(list(range(30))):
    y_pred = res_pred[:, i]
    if i not in [1, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 19, 22, 23, 25]:
        final_prediction.append(y_pred)
        continue
    
    optR = optRs[i]
    res = optR.predict(y_pred, optR.coefficients()).astype(float)

    final_prediction.append(res)

prediction = np.asarray(final_prediction).T
prediction.shape



100%|██████████| 30/30 [00:00<00:00, 1364.08it/s]


(476, 30)

In [66]:
tst_df = pd.read_csv('../mnt/inputs/origin/test.csv')
tst_df.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


In [67]:
for i, col in enumerate(LABEL_COL):
    tst_df[col] = prediction[:, i]

In [68]:
tst_df.to_csv('../mnt/inputs/pseudos/top2_e121_e125_e126_e127_e128_e129/half_opt_pseudo_tst_df.csv', index=False)