In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

## About Data Set, Model and Cross-Validation Setup

For compeiting in this task, I focused on training distilled transformers for fast iterating, such as DistillBert and DistilRoBERTa. Also using Bert-base-uncase to validate my results. On this notebook, it only aims to analysis the prediction model produced on validation and test, and post processing.

The input sequences available for models are "question_title", "question_body", "answer". The max length for question and answer can be configured differently, here model used 384 for question and 512 for answer since their length difference spotted on data analysis.

Model archtitechure are using a shared tranmsformer embedding to ingest "question_title" + "question_body" for question, "question_title" + "answer" for answer. Meanwhile, a customized classification head is added on top of that.

The training stragegy consists of several part:

1) freeze the embedding weight to tune the classification head first,

2) unfreeze transformer weights using warm up scheduling to graduately increasing learning rate, and

3) use customized early stopping callback while perofmrance on the validation set stop imrpoving.

4) also try out some commonly augmentation tricks, such as truncated corpus, drop out words or label soften.


Regarding to cross-valation, Based on the well populating on duplicated questions, a good stregegy is to use `GroupKFold` to split data, to well split data further, `category` is also applied to generate group for data split into training and validation set. Meanwhile, the first fold of 5-fold cross-valiation is used to training model and validating the model's performance. 

In [2]:
# import essential modules
import os
import sys

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

import pandas as pd
import numpy as np

In [3]:
pd.set_option('display.max_rows', 1000)
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)

In [4]:
result_dir: str = "../input/distilroberta-base_q384_a512"
result_stats_filename: str = "model_stats.hdf5"

## To dive into prediction

To maximize the model performance on eval metrics, to thresholding the predictions from model. This is observed the label distribution in training set and understanding of the quest. The and therefore just  

In [5]:
from scipy.stats import spearmanr

def spearmanr_ignore_nan(trues: np.array, preds: np.array):
    return np.nanmean(
        [spearmanr(ta, pa).correlation for ta, pa in
         zip(np.transpose(trues), np.transpose(np.nan_to_num(preds)) + 1e-7)])

In [6]:
# open up the result file from training model
file_path = os.path.join(result_dir, result_stats_filename)
with pd.HDFStore(file_path, mode='r') as store:
    print(f"open {file_path} and found {len(store.keys())}:\n{store.keys()}")
    for k, v in store.items():
        var_name = k.split('/')[-1]
        df = store.get(k)
        vars()[var_name] = df
        print(f'read {k}: {df.shape}')

open ../input/distilroberta-base_q384_a512/model_stats.hdf5 and found 7:
['/test_preds', '/valid_breakdown_metrics', '/valid_group_score', '/valid_overall_metrics', '/valid_preds', '/valid_test_stats_diff', '/valid_trues']
read /test_preds: (476, 30)
read /valid_breakdown_metrics: (30, 5)
read /valid_group_score: (5, 3)
read /valid_overall_metrics: (8, 5)
read /valid_preds: (1216, 30)
read /valid_test_stats_diff: (30, 7)
read /valid_trues: (1216, 30)


In [7]:
valid_test_stats_diff

Unnamed: 0,test_mean,valid_mean,mean_diff,test_std,valid_std,ks_stats,p_value
question_type_entity,0.109976,0.144773,-0.034797,0.109976,0.144773,0.073813,0.045241
question_opinion_seeking,0.400333,0.432662,-0.032328,0.400333,0.432662,0.093764,0.0045
question_conversational,0.029604,0.04594,-0.016336,0.029604,0.04594,0.100149,0.001915
question_type_definition,0.027683,0.043981,-0.016298,0.027683,0.043981,0.072002,0.054286
question_well_written,0.790364,0.806607,-0.016243,0.790364,0.806607,0.098505,0.002399
question_interestingness_self,0.511383,0.52684,-0.015457,0.511383,0.52684,0.068733,0.074586
question_multi_intent,0.270178,0.284928,-0.01475,0.270178,0.284928,0.046212,0.442014
question_body_critical,0.629797,0.644542,-0.014744,0.629797,0.644542,0.078775,0.026821
question_type_compare,0.036706,0.050482,-0.013776,0.036706,0.050482,0.110882,0.000401
question_type_choice,0.226432,0.235489,-0.009056,0.226432,0.235489,0.048354,0.385617


In [8]:
valid_group_score  # this shows our current model performed bad on stackoverflow but great on life_art and science.

Unnamed: 0,overall,question,answer
CULTURE,0.352904,0.006812,0.277749
LIFE_ARTS,0.42325,0.018781,0.390112
SCIENCE,0.412098,0.037894,0.319552
STACKOVERFLOW,0.23155,-0.005066,0.22691
TECHNOLOGY,0.354609,0.026638,0.312946


In [9]:
valid_breakdown_metrics

Unnamed: 0,bias,mae,mape,pearson,spearman
question_type_spelling,-0.000117,0.000663,2.419421,0.015369,0.040741
question_not_really_a_question,-0.00024,0.008854,2.018759,0.027925,0.048685
answer_plausible,-0.005185,0.064072,0.06705,0.082279,0.092566
answer_relevance,0.003366,0.051478,0.053179,0.14911,0.143764
question_type_consequence,0.002892,0.01692,1.64595,0.137412,0.147791
answer_well_written,-0.00187,0.080743,0.088983,0.170488,0.1649
answer_helpful,-0.001719,0.090085,0.097709,0.190092,0.1989
question_expect_short_answer,-0.020726,0.279698,0.40562,0.270818,0.267898
answer_type_procedure,-0.003744,0.155941,1.344852,0.255036,0.274146
answer_satisfaction,0.008052,0.098185,0.114915,0.279763,0.280571


In [11]:
valid_trues.head()

Unnamed: 0_level_0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,question_opinion_seeking,question_type_choice,question_type_compare,question_type_consequence,question_type_definition,question_type_entity,question_type_instructions,question_type_procedure,question_type_reason_explanation,question_type_spelling,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
qa_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1
6,1.0,0.666667,0.0,0.5,1.0,1.0,0.444444,0.333333,0.0,0.0,0.5,0.0,0.0,0.0,0.0,0.0,1.0,0.5,0.0,0.0,0.833333,0.888889,0.666667,0.888889,1.0,0.733333,0.666667,0.666667,0.0,0.777778
11,1.0,0.333333,0.0,1.0,1.0,1.0,0.666667,0.555556,0.0,0.0,0.333333,0.333333,0.0,0.0,0.0,0.0,0.666667,0.0,0.333333,0.0,0.888889,0.666667,0.333333,0.666667,0.666667,0.266667,0.0,0.0,0.0,0.888889
17,0.888889,1.0,0.0,0.0,1.0,0.0,0.666667,0.333333,0.0,0.0,0.0,0.0,0.333333,0.0,0.0,0.0,0.0,0.0,0.666667,0.0,1.0,1.0,0.666667,1.0,1.0,1.0,0.0,0.0,1.0,1.0
24,0.777778,0.555556,0.0,1.0,0.666667,1.0,0.555556,0.333333,0.0,0.0,0.333333,1.0,0.0,0.0,0.0,0.0,0.666667,0.0,0.666667,0.0,0.888889,0.666667,0.666667,0.666667,0.888889,0.9,0.333333,0.333333,0.666667,1.0
41,0.888889,0.666667,0.0,0.333333,1.0,0.666667,0.555556,0.444444,1.0,0.0,0.0,0.333333,0.0,0.0,0.333333,0.333333,0.0,0.0,0.666667,0.0,1.0,0.888889,0.555556,1.0,1.0,0.8,0.0,0.0,0.333333,1.0


In [12]:
valid_preds.head()

Unnamed: 0_level_0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,question_opinion_seeking,question_type_choice,question_type_compare,question_type_consequence,question_type_definition,question_type_entity,question_type_instructions,question_type_procedure,question_type_reason_explanation,question_type_spelling,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
qa_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1
6,0.905412,0.565644,0.010187,0.695093,0.71958,0.796349,0.553067,0.401084,0.075334,0.005072,0.541578,0.043194,0.005633,0.00047,0.001188,0.087725,0.867587,0.278805,0.057896,1.8e-05,0.709242,0.894841,0.636292,0.960516,0.9581,0.800894,0.797965,0.112276,0.056069,0.894448
11,0.885103,0.570957,0.009143,0.788105,0.864973,0.894341,0.579482,0.441688,0.310563,0.003818,0.293361,0.323522,0.013668,0.008699,0.018006,0.274591,0.302266,0.111757,0.417412,0.001279,0.736912,0.859903,0.663103,0.934732,0.941381,0.771496,0.212178,0.104292,0.33751,0.865911
17,0.937903,0.777816,0.008757,0.380662,0.958779,0.90048,0.659576,0.560162,0.708315,0.000239,0.100381,0.036269,0.428866,0.00938,0.068765,0.059492,0.054747,0.080057,0.78586,6.8e-05,0.84531,0.955801,0.733329,0.983123,0.985105,0.921492,0.048606,0.114332,0.877092,0.935182
24,0.892733,0.49315,0.007811,0.782923,0.844781,0.853018,0.579164,0.42954,0.438672,0.000694,0.509196,0.80384,0.055598,0.006097,0.001714,0.0769,0.43716,0.146408,0.176069,5e-05,0.756427,0.976458,0.748032,0.988829,0.988576,0.915815,0.383622,0.154989,0.672255,0.944991
41,0.895552,0.612479,0.071921,0.455128,0.862793,0.679102,0.644888,0.591892,0.793374,0.002446,0.342298,0.299315,0.425462,0.027337,0.137157,0.048637,0.106113,0.117016,0.480605,0.001811,0.812502,0.883858,0.644396,0.948813,0.941258,0.819342,0.073691,0.092941,0.702519,0.876863


In [10]:
sys.path.append("../nlp_utils")

from nlp_utils import OptimalRounder

In [13]:
# training optimal rounder from the training distribution

df = pd.read_csv('../input/google-quest-challenge/train.csv')[valid_preds.columns]
opt = OptimalRounder(ref=df)
valid_preds_opt = opt.fit_transform(valid_trues, valid_preds)

fitting: question_asker_intent_understanding


  c /= stddev[:, None]
  c /= stddev[None, :]
  return (a < x) & (x < b)
  return (a < x) & (x < b)
  cond2 = cond0 & (x <= _a)


fitting: question_body_critical
fitting: question_conversational
fitting: question_expect_short_answer
fitting: question_fact_seeking
fitting: question_has_commonly_accepted_answer
fitting: question_interestingness_others
fitting: question_interestingness_self
fitting: question_multi_intent
fitting: question_not_really_a_question
fitting: question_opinion_seeking
fitting: question_type_choice
fitting: question_type_compare
fitting: question_type_consequence
fitting: question_type_definition
fitting: question_type_entity
fitting: question_type_instructions
fitting: question_type_procedure
fitting: question_type_reason_explanation
fitting: question_type_spelling
fitting: question_well_written
fitting: answer_helpful
fitting: answer_level_of_information
fitting: answer_plausible
fitting: answer_relevance
fitting: answer_satisfaction
fitting: answer_type_instructions
fitting: answer_type_procedure
fitting: answer_type_reason_explanation
fitting: answer_well_written


In [14]:
valid_scores_orig = valid_trues.apply(lambda x: x.corr(valid_preds[x.name], method='spearman'))
valid_scores_orig

question_asker_intent_understanding      0.379744
question_body_critical                   0.592324
question_conversational                  0.433244
question_expect_short_answer             0.267898
question_fact_seeking                    0.360641
question_has_commonly_accepted_answer    0.400180
question_interestingness_others          0.339203
question_interestingness_self            0.467470
question_multi_intent                    0.533345
question_not_really_a_question           0.048685
question_opinion_seeking                 0.463129
question_type_choice                     0.712711
question_type_compare                    0.371635
question_type_consequence                0.147791
question_type_definition                 0.383263
question_type_entity                     0.463998
question_type_instructions               0.757646
question_type_procedure                  0.319751
question_type_reason_explanation         0.592698
question_type_spelling                   0.040741


In [15]:
# score after post processing on validation set
valid_scores_opt = valid_trues.apply(lambda x: x.corr(valid_preds_opt[x.name], method='spearman'))
valid_scores_opt

question_asker_intent_understanding      0.372308
question_body_critical                   0.590384
question_conversational                  0.519570
question_expect_short_answer             0.286429
question_fact_seeking                    0.373448
question_has_commonly_accepted_answer    0.446220
question_interestingness_others          0.342085
question_interestingness_self            0.477320
question_multi_intent                    0.540809
question_not_really_a_question                NaN
question_opinion_seeking                 0.466753
question_type_choice                     0.719488
question_type_compare                    0.604133
question_type_consequence                     NaN
question_type_definition                 0.597205
question_type_entity                     0.599168
question_type_instructions               0.779938
question_type_procedure                  0.341829
question_type_reason_explanation         0.593394
question_type_spelling                        NaN


In [16]:
valid_preds_opt.apply(lambda x: x.nunique())  
# check the unique value counts in every index after post processing, only one unique value make scoring become NAN

question_asker_intent_understanding      5
question_body_critical                   7
question_conversational                  5
question_expect_short_answer             5
question_fact_seeking                    5
question_has_commonly_accepted_answer    5
question_interestingness_others          5
question_interestingness_self            5
question_multi_intent                    5
question_not_really_a_question           1
question_opinion_seeking                 5
question_type_choice                     5
question_type_compare                    5
question_type_consequence                1
question_type_definition                 4
question_type_entity                     5
question_type_instructions               5
question_type_procedure                  4
question_type_reason_explanation         5
question_type_spelling                   1
question_well_written                    8
answer_helpful                           5
answer_level_of_information              5
answer_plau

In [17]:
# eyeballing the improvement on every attribute
valid_scores_opt_diff = (valid_scores_opt - valid_scores_orig).sort_values(ascending=False)
valid_scores_opt_diff

question_type_compare                    0.232498
question_type_definition                 0.213943
question_type_entity                     0.135171
question_conversational                  0.086326
question_has_commonly_accepted_answer    0.046041
answer_relevance                         0.024595
question_type_instructions               0.022292
question_type_procedure                  0.022078
question_expect_short_answer             0.018531
answer_satisfaction                      0.014844
answer_plausible                         0.014596
answer_type_procedure                    0.013776
question_fact_seeking                    0.012807
question_interestingness_self            0.009850
answer_type_instructions                 0.007566
question_multi_intent                    0.007464
question_type_choice                     0.006776
answer_level_of_information              0.005415
question_opinion_seeking                 0.003624
answer_well_written                      0.003555


In [18]:
# apply useful columns only, has improvement, and not NaN in metrics
use_cols = valid_scores_opt_diff.loc[valid_scores_opt_diff > -.0010].dropna().index.tolist()
print(f"select {len(use_cols)} labels getting improve: {use_cols}")

select 24 labels getting improve: ['question_type_compare', 'question_type_definition', 'question_type_entity', 'question_conversational', 'question_has_commonly_accepted_answer', 'answer_relevance', 'question_type_instructions', 'question_type_procedure', 'question_expect_short_answer', 'answer_satisfaction', 'answer_plausible', 'answer_type_procedure', 'question_fact_seeking', 'question_interestingness_self', 'answer_type_instructions', 'question_multi_intent', 'question_type_choice', 'answer_level_of_information', 'question_opinion_seeking', 'answer_well_written', 'question_interestingness_others', 'answer_helpful', 'question_type_reason_explanation', 'answer_type_reason_explanation']


In [19]:
# calculate the lift from post processing
valid_preds_opt_final = valid_preds.copy()
valid_preds_opt_final[use_cols] = opt.transform(valid_preds[use_cols])

score_orig = spearmanr_ignore_nan(valid_trues.values, valid_preds.values)
score_opt = spearmanr_ignore_nan(valid_trues.values, valid_preds_opt_final.values)

print(f"orig score={score_orig:.3f}, optimized score={score_opt:.3f}, improve={score_opt-score_orig:.3f}")

orig score=0.384, optimized score=0.414, improve=0.030


In [20]:
# successfully apply the same post processing to test prediction
test_preds[use_cols] = opt.transform(test_preds[use_cols])
test_preds.head().T

qa_id,39,46,70,132,200
question_asker_intent_understanding,0.947565,0.867423,0.945147,0.879081,0.907718
question_body_critical,0.669461,0.515113,0.765663,0.377438,0.535705
question_conversational,1.0,0.0,0.0,0.0,0.0
question_expect_short_answer,0.333333,0.5,0.5,0.5,0.5
question_fact_seeking,0.0,0.666667,0.666667,0.5,0.666667
question_has_commonly_accepted_answer,0.333333,1.0,1.0,1.0,1.0
question_interestingness_others,0.888889,0.666667,0.833333,0.666667,0.777778
question_interestingness_self,0.888889,0.666667,0.777778,0.666667,0.777778
question_multi_intent,1.0,0.0,0.0,0.0,0.666667
question_not_really_a_question,0.001565,0.005538,0.00105,0.009642,0.0053
