## Dependencies

In [1]:
import glob
import warnings
from tensorflow_hub import KerasLayer
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Dropout, GlobalAveragePooling1D, SpatialDropout1D, Concatenate
from googleqa_utilityscript import *
from googleqa_map_utilityscript import *
import bert_tokenization as tokenization


SEED = 0
seed_everything(SEED)
warnings.filterwarnings("ignore")

## Load data

In [2]:
BERT_PATH = '/kaggle/input/tf-hub-bert-base/bert_base_uncased'
VOCAB_PATH = BERT_PATH + '/assets/vocab.txt'
model_path_list = glob.glob('/kaggle/input/125-googleq-a-train-2fold-bert-base-unc-categoryv5/' + '*.h5')
model_path_list.sort()
print('Models to predict:', model_path_list)

test = pd.read_csv('/kaggle/input/google-quest-challenge/test.csv')

print('Test samples: %s' % len(test))
display(test.head())

Models to predict: ['/kaggle/input/125-googleq-a-train-2fold-bert-base-unc-categoryv5/model_fold_1.h5', '/kaggle/input/125-googleq-a-train-2fold-bert-base-unc-categoryv5/model_fold_2.h5']
Test samples: 476


Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


In [3]:
question_target_cols = ['question_asker_intent_understanding','question_body_critical', 'question_conversational', 
                        'question_expect_short_answer', 'question_fact_seeking', 'question_has_commonly_accepted_answer',
                        'question_interestingness_others', 'question_interestingness_self', 'question_multi_intent', 
                        'question_not_really_a_question', 'question_opinion_seeking', 'question_type_choice',
                        'question_type_compare', 'question_type_consequence', 'question_type_definition', 
                        'question_type_entity', 'question_type_instructions', 'question_type_procedure',
                        'question_type_reason_explanation', 'question_type_spelling', 'question_well_written']
answer_target_cols = ['answer_helpful', 'answer_level_of_information', 'answer_plausible', 'answer_relevance',
                      'answer_satisfaction', 'answer_type_instructions', 'answer_type_procedure', 
                      'answer_type_reason_explanation', 'answer_well_written']
target_cols = question_target_cols + answer_target_cols

## Pre-process data

In [4]:
text_features = ['question_title', 'question_body', 'answer']
    
# for feature in text_features:
#     # Lower
#     test[feature] = test[feature].apply(lambda x: x.lower())
#     # Map misspellings
#     test[feature] = test[feature].apply(lambda x: map_misspellings(x))
#     # Map contractions
#     test[feature] = test[feature].apply(lambda x: map_contraction(x))
#     # Trim text
#     test[feature] = test[feature].apply(lambda x: x.strip())

# Model parameters

In [5]:
N_CLASS = len(target_cols)
MAX_SEQUENCE_LENGTH = 512
N_CLASS_CAT = test['category'].nunique()

## Test set

In [6]:
tokenizer = tokenization.FullTokenizer(VOCAB_PATH, do_lower_case=True)

# Test features
X_test = compute_input_arays(test, text_features, tokenizer, MAX_SEQUENCE_LENGTH)

# Model

In [7]:
def model_fn():
    input_word_ids = Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='input_word_ids')
    input_masks = Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='input_masks')
    segment_ids = Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='segment_ids')

    bert_layer = KerasLayer(BERT_PATH, trainable=True)
    pooled_output, sequence_output = bert_layer([input_word_ids, input_masks, segment_ids])

    # Sequence output
    seq_branch = SpatialDropout1D(0.3)(sequence_output)
    seq_branch = GlobalAveragePooling1D()(seq_branch)
    output_seq = Dense(N_CLASS, activation="sigmoid", name="output_seq")(seq_branch)
    
    # Class output
    class_branch = Dropout(0.3)(pooled_output)
    output_class = Dense(N_CLASS_CAT, activation="softmax", name="output_class")(class_branch)

    model = Model(inputs=[input_word_ids, input_masks, segment_ids], outputs=[output_seq, output_class])
    
    return model

# Make predictions

In [8]:
Y_test = np.zeros((len(test), N_CLASS))

for model_path in model_path_list:
    model = model_fn()
    model.load_weights(model_path)
    Y_test += model.predict(X_test)[0] / len(model_path_list)

In [9]:
submission = pd.read_csv('/kaggle/input/google-quest-challenge/sample_submission.csv')
submission[target_cols] = Y_test
submission.to_csv("submission.csv", index=False)
display(submission.head())
display(submission.describe())

Unnamed: 0,qa_id,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,39,0.940991,0.650606,0.196062,0.49948,0.586706,0.480348,0.693087,0.618637,0.438965,...,0.921307,0.930857,0.593457,0.962335,0.957972,0.837815,0.081368,0.049898,0.769465,0.91845
1,46,0.888936,0.515788,0.005378,0.785756,0.830355,0.914835,0.56499,0.453957,0.055756,...,0.746931,0.949585,0.698319,0.96559,0.975874,0.875894,0.910859,0.118255,0.074409,0.86638
2,70,0.906965,0.703399,0.02549,0.822056,0.8686,0.919062,0.587727,0.511749,0.048106,...,0.872015,0.947577,0.653075,0.965819,0.970831,0.881634,0.215414,0.079586,0.784638,0.918626
3,132,0.874633,0.422964,0.008524,0.646797,0.771172,0.90694,0.502879,0.436196,0.14336,...,0.710097,0.944513,0.643613,0.968891,0.97653,0.885792,0.863449,0.190791,0.265618,0.901787
4,200,0.908559,0.364501,0.030876,0.827418,0.802791,0.857078,0.625598,0.568623,0.046743,...,0.669165,0.896762,0.607168,0.954185,0.934135,0.82643,0.153061,0.106723,0.681053,0.91282


Unnamed: 0,qa_id,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
count,476.0,476.0,476.0,476.0,476.0,476.0,476.0,476.0,476.0,476.0,...,476.0,476.0,476.0,476.0,476.0,476.0,476.0,476.0,476.0,476.0
mean,5029.186975,0.8899,0.58077,0.034292,0.709581,0.802982,0.848066,0.579516,0.475682,0.216375,...,0.794639,0.929468,0.660605,0.961451,0.971223,0.862055,0.547313,0.129516,0.478519,0.905943
std,2812.67006,0.044802,0.139714,0.063375,0.11379,0.112103,0.134133,0.053492,0.086279,0.203015,...,0.087977,0.02058,0.047879,0.011247,0.011333,0.033847,0.324394,0.06259,0.274453,0.021181
min,39.0,0.759376,0.325794,0.002532,0.268697,0.332183,0.180326,0.463988,0.339261,0.003576,...,0.572485,0.857759,0.506974,0.920169,0.912674,0.742523,0.006929,0.008117,0.054992,0.828353
25%,2572.0,0.859436,0.45553,0.006535,0.646952,0.757532,0.828774,0.538701,0.412764,0.066709,...,0.721414,0.918475,0.629626,0.955195,0.965786,0.8434,0.202021,0.077331,0.231352,0.891553
50%,5093.0,0.890781,0.573722,0.011167,0.705195,0.81547,0.894027,0.571269,0.450785,0.133282,...,0.795242,0.932142,0.659746,0.963137,0.973948,0.866157,0.648369,0.133005,0.428269,0.906209
75%,7482.0,0.92589,0.690934,0.027698,0.780856,0.880194,0.929361,0.615949,0.523179,0.307807,...,0.874031,0.944779,0.690695,0.968914,0.979325,0.885766,0.845618,0.173852,0.721527,0.920811
max,9640.0,0.974978,0.896942,0.464771,0.973659,0.982035,0.986197,0.729986,0.771042,0.854584,...,0.964542,0.973138,0.810142,0.985487,0.989047,0.945926,0.937761,0.280196,0.986245,0.955265
