In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## やったこと

- テキストのクリーニング処理 -> 改善
- host, categoryのカテゴリカル変数をエンベディングして入力 -> 改善

- epochs=20で、early-stoppingはあまり良くならなかった -> とりあえず速く数を回したいので、epochs=4でやっている


- batch_size=8以上にすると、out_of_memoryになる

- MSELossを使用 -> 悪化
- titleは分けて、別のエンベディングとして入力 -> 悪化



- BERTを2つ使う -> gpu不足
- クラス分類問題にする（30*num_class） -> 学習が安定しない（nan）
- 30個の目的変数それぞれ独立に予測するモデル -> 約30時間必要、あまり精度が出ないように見える -> 関連する目的変数だけをグルーピングしてモデルを分ける必要？

In [0]:
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
import os, sys, gc, random, multiprocessing, glob, time

DATA_DIR = '/content/drive/My Drive/Colab Notebooks/GoogleQuest/input/google-quest-challenge'
# DATA_DIR = '../input/google-quest-challenge'
# DATA_DIR = 'D:/project/ICF_AutoCapsule_disabled/kaggle/google-quest-challenge'
# BERT_DIR = 'D:/project/ICF_AutoCapsule_disabled/BERT'

In [0]:
# !pip install ../input/sacremoses/sacremoses-master/
# !pip install ../input/transformers/transformers-master/

In [0]:
!pip install transformers
!pip install flashtext



In [0]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader, Dataset

#from ml_stratifiers import MultilabelStratifiedShuffleSplit, MultilabelStratifiedKFold
from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold
from sklearn.utils import shuffle
from sklearn.preprocessing import LabelEncoder

from scipy.stats import spearmanr

import transformers
from transformers import (
    BertTokenizer, BertModel, BertForSequenceClassification, BertConfig,
    WEIGHTS_NAME, CONFIG_NAME, AdamW, get_linear_schedule_with_warmup, 
    get_cosine_schedule_with_warmup,
)

from tqdm import tqdm
print(transformers.__version__)

2.4.1


In [0]:
## Make results reproducible .Else noone will believe you .
import random

def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [0]:
class PipeLineConfig:
    def __init__(self, lr, warmup, epochs, patience, batch_size, seed, name, question_weight,answer_weight,fold,train):
        self.lr = lr
        self.warmup = warmup
        self.epochs = epochs
        self.patience = patience
        self.batch_size = batch_size
        self.seed = seed
        self.name = name
        self.question_weight = question_weight
        self.answer_weight =answer_weight
        self.fold = fold
        self.train = train

In [0]:
config = PipeLineConfig(lr=1e-5, \
                        warmup=0.01, \
                        epochs=4, \
                        patience=3, \
                        batch_size=4, \
                        seed=42, \
                        name='10div_cased', \
                        question_weight=0.5, \
                        answer_weight=0.5, \
                        fold=5, \
                        train=True
                       )

In [0]:
seed_everything(config.seed)

In [0]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(device)

cuda


In [0]:
sub = pd.read_csv(f'{DATA_DIR}/sample_submission.csv')
sub.head()

Unnamed: 0,qa_id,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,question_opinion_seeking,question_type_choice,question_type_compare,question_type_consequence,question_type_definition,question_type_entity,question_type_instructions,question_type_procedure,question_type_reason_explanation,question_type_spelling,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,39,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308
1,46,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448
2,70,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673
3,132,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401
4,200,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074


In [0]:
target_columns = sub.columns.values[1:].tolist()
target_columns

['question_asker_intent_understanding',
 'question_body_critical',
 'question_conversational',
 'question_expect_short_answer',
 'question_fact_seeking',
 'question_has_commonly_accepted_answer',
 'question_interestingness_others',
 'question_interestingness_self',
 'question_multi_intent',
 'question_not_really_a_question',
 'question_opinion_seeking',
 'question_type_choice',
 'question_type_compare',
 'question_type_consequence',
 'question_type_definition',
 'question_type_entity',
 'question_type_instructions',
 'question_type_procedure',
 'question_type_reason_explanation',
 'question_type_spelling',
 'question_well_written',
 'answer_helpful',
 'answer_level_of_information',
 'answer_plausible',
 'answer_relevance',
 'answer_satisfaction',
 'answer_type_instructions',
 'answer_type_procedure',
 'answer_type_reason_explanation',
 'answer_well_written']

In [0]:
train = pd.read_csv(f'{DATA_DIR}/train.csv')
train.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,question_opinion_seeking,question_type_choice,question_type_compare,question_type_consequence,question_type_definition,question_type_entity,question_type_instructions,question_type_procedure,question_type_reason_explanation,question_type_spelling,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0,What am I losing when using extension tubes in...,After playing around with macro photography on...,ysap,https://photo.stackexchange.com/users/1024,"I just got extension tubes, so here's the skin...",rfusca,https://photo.stackexchange.com/users/1917,http://photo.stackexchange.com/questions/9169/...,LIFE_ARTS,photo.stackexchange.com,1.0,0.333333,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.666667,1.0,1.0,0.8,1.0,0.0,0.0,1.0
1,1,What is the distinction between a city and a s...,I am trying to understand what kinds of places...,russellpierce,https://rpg.stackexchange.com/users/8774,It might be helpful to look into the definitio...,Erik Schmidt,https://rpg.stackexchange.com/users/1871,http://rpg.stackexchange.com/questions/47820/w...,CULTURE,rpg.stackexchange.com,1.0,1.0,0.0,0.5,1.0,1.0,0.444444,0.444444,0.666667,0.0,0.0,0.666667,0.666667,0.0,0.333333,0.0,0.0,0.0,0.333333,0.0,0.888889,0.888889,0.555556,0.888889,0.888889,0.666667,0.0,0.0,0.666667,0.888889
2,2,Maximum protusion length for through-hole comp...,I'm working on a PCB that has through-hole com...,Joe Baker,https://electronics.stackexchange.com/users/10157,Do you even need grooves? We make several pro...,Dwayne Reid,https://electronics.stackexchange.com/users/64754,http://electronics.stackexchange.com/questions...,SCIENCE,electronics.stackexchange.com,0.888889,0.666667,0.0,1.0,1.0,1.0,0.666667,0.444444,0.333333,0.0,0.333333,0.0,0.0,0.0,0.0,0.0,1.0,0.333333,0.333333,0.0,0.777778,0.777778,0.555556,1.0,1.0,0.666667,0.0,0.333333,1.0,0.888889
3,3,Can an affidavit be used in Beit Din?,"An affidavit, from what i understand, is basic...",Scimonster,https://judaism.stackexchange.com/users/5151,"Sending an ""affidavit"" it is a dispute between...",Y e z,https://judaism.stackexchange.com/users/4794,http://judaism.stackexchange.com/questions/551...,CULTURE,judaism.stackexchange.com,0.888889,0.666667,0.666667,1.0,1.0,1.0,0.444444,0.444444,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.888889,0.833333,0.333333,0.833333,1.0,0.8,0.0,0.0,1.0,1.0
4,5,How do you make a binary image in Photoshop?,I am trying to make a binary image. I want mor...,leigero,https://graphicdesign.stackexchange.com/users/...,Check out Image Trace in Adobe Illustrator. \n...,q2ra,https://graphicdesign.stackexchange.com/users/...,http://graphicdesign.stackexchange.com/questio...,LIFE_ARTS,graphicdesign.stackexchange.com,1.0,0.666667,0.0,1.0,1.0,1.0,0.666667,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.666667,1.0,1.0,0.8,1.0,0.0,1.0,1.0


In [0]:
test = pd.read_csv(f'{DATA_DIR}/test.csv')
test.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


## Preprocessing

In [0]:
import re
from flashtext import KeywordProcessor

In [0]:
PUNCTS = {
            '》', '〞', '¢', '‹', '╦', '║', '♪', 'Ø', '╩', '\\', '★', '＋', 'ï', '<', '?', '％', '+', '„', 'α', '*', '〰', '｟', '¹', '●', '〗', ']', '▾', '■', '〙', '↓', '´', '【', 'ᴵ',
            '"', '）', '｀', '│', '¤', '²', '‡', '¿', '–', '」', '╔', '〾', '%', '¾', '←', '〔', '＿', '’', '-', ':', '‧', '｛', 'β', '（', '─', 'à', 'â', '､', '•', '；', '☆', '／', 'π',
            'é', '╗', '＾', '▪', ',', '►', '/', '〚', '¶', '♦', '™', '}', '″', '＂', '『', '▬', '±', '«', '“', '÷', '×', '^', '!', '╣', '▲', '・', '░', '′', '〝', '‛', '√', ';', '】', '▼',
            '.', '~', '`', '。', 'ə', '］', '，', '{', '～', '！', '†', '‘', '﹏', '═', '｣', '〕', '〜', '＼', '▒', '＄', '♥', '〛', '≤', '∞', '_', '[', '＆', '→', '»', '－', '＝', '§', '⋅', 
            '▓', '&', 'Â', '＞', '〃', '|', '¦', '—', '╚', '〖', '―', '¸', '³', '®', '｠', '¨', '‟', '＊', '£', '#', 'Ã', "'", '▀', '·', '？', '、', '█', '”', '＃', '⊕', '=', '〟', '½', '』',
            '［', '$', ')', 'θ', '@', '›', '＠', '｝', '¬', '…', '¼', '：', '¥', '❤', '€', '−', '＜', '(', '〘', '▄', '＇', '>', '₤', '₹', '∅', 'è', '〿', '「', '©', '｢', '∙', '°', '｜', '¡', 
            '↑', 'º', '¯', '♫', '#'
          }


mispell_dict = {"aren't" : "are not", "can't" : "cannot", "couldn't" : "could not",
"couldnt" : "could not", "didn't" : "did not", "doesn't" : "does not",
"doesnt" : "does not", "don't" : "do not", "hadn't" : "had not", "hasn't" : "has not",
"haven't" : "have not", "havent" : "have not", "he'd" : "he would", "he'll" : "he will", "he's" : "he is", "i'd" : "I would",
"i'd" : "I had", "i'll" : "I will", "i'm" : "I am", "isn't" : "is not", "it's" : "it is",
"it'll":"it will", "i've" : "I have", "let's" : "let us", "mightn't" : "might not", "mustn't" : "must not", 
"shan't" : "shall not", "she'd" : "she would", "she'll" : "she will", "she's" : "she is", "shouldn't" : "should not", "shouldnt" : "should not",
"that's" : "that is", "thats" : "that is", "there's" : "there is", "theres" : "there is", "they'd" : "they would", "they'll" : "they will",
"they're" : "they are", "theyre":  "they are", "they've" : "they have", "we'd" : "we would", "we're" : "we are", "weren't" : "were not",
"we've" : "we have", "what'll" : "what will", "what're" : "what are", "what's" : "what is", "what've" : "what have", "where's" : "where is",
"who'd" : "who would", "who'll" : "who will", "who're" : "who are", "who's" : "who is", "who've" : "who have", "won't" : "will not", "wouldn't" : "would not", "you'd" : "you would",
"you'll" : "you will", "you're" : "you are", "you've" : "you have", "'re": " are", "wasn't": "was not", "we'll":" will", "didn't": "did not", "tryin'":"trying"}


kp = KeywordProcessor(case_sensitive=True)
for k, v in mispell_dict.items():
    kp.add_keyword(k, v)

def clean_punct(text):
    text = str(text)
    for punct in PUNCTS:
        text = text.replace(punct, ' {} '.format(punct))
    return text


def preprocessing(text):
    text = text.lower()
    text = re.sub(r'(\&lt)|(\&gt)', ' ', text)
    
    text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', ' url ', text)
    text = kp.replace_keywords(text)
    text = clean_punct(text)
    text = re.sub(r'\n\r', ' ', text)
    text = re.sub(r'\s{2,}', ' ', text)
    
    return text

## Dataset

In [0]:
MAX_LEN = 512

class QuestDataset(torch.utils.data.Dataset):
    def __init__(self, df, train_mode=True, labeled=True):
        self.df = df
        self.train_mode = train_mode
        self.labeled = labeled
        #self.tokenizer = BertTokenizer.from_pretrained(BERT_DIR+'/bert-base-uncased')
        #self.tokenizer = BertTokenizer.from_pretrained('../input/bert-base-uncased/')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]

        question_title = row.question_title
        question_body = row.question_body
        answer_text = row.answer


        inputs_q = self.tokenizer.encode_plus(
            question_title + " " + question_body,
            add_special_tokens=True,
            max_length=MAX_LEN,
        )

        inputs_a = self.tokenizer.encode_plus(
            answer_text,
            add_special_tokens=True,
            max_length=MAX_LEN,
        )

        ids_q = inputs_q["input_ids"]
        token_type_ids_q = inputs_q["token_type_ids"]
        mask_q = inputs_q["attention_mask"]

        ids_a = inputs_a["input_ids"]
        token_type_ids_a = inputs_a["token_type_ids"]
        mask_a = inputs_a["attention_mask"]
        
        padding_length_q = MAX_LEN - len(ids_q)
        padding_length_a = MAX_LEN - len(ids_a)
        
        ids_q = ids_q + ([0] * padding_length_q)
        mask_q = mask_q + ([0] * padding_length_q)
        token_type_ids_q = token_type_ids_q + ([0] * padding_length_q)

        ids_a = ids_a + ([0] * padding_length_a)
        mask_a = mask_a + ([0] * padding_length_a)
        token_type_ids_a = token_type_ids_a + ([0] * padding_length_a)
        
        if self.labeled:
            labels = self.get_label(row)
            return {
                'ids_q': torch.tensor(ids_q, dtype=torch.long),
                'mask_q': torch.tensor(mask_q, dtype=torch.long),
                'token_type_ids_q': torch.tensor(token_type_ids_q, dtype=torch.long),
                'ids_a': torch.tensor(ids_a, dtype=torch.long),
                'mask_a': torch.tensor(mask_a, dtype=torch.long),
                'token_type_ids_a': torch.tensor(token_type_ids_a, dtype=torch.long),
                'labels': labels, 
            }
        else:
            return {
                'ids_q': torch.tensor(ids_q, dtype=torch.long),
                'mask_q': torch.tensor(mask_q, dtype=torch.long),
                'token_type_ids_q': torch.tensor(token_type_ids_q, dtype=torch.long),
                'ids_a': torch.tensor(ids_a, dtype=torch.long),
                'mask_a': torch.tensor(mask_a, dtype=torch.long),
                'token_type_ids_a': torch.tensor(token_type_ids_a, dtype=torch.long)
            }


    def get_label(self, row):
        return torch.tensor(row[target_columns].values.astype(np.float32))

In [0]:
def get_train_val_loaders(batch_size=4, val_batch_size=4, ifold=0):
    df = pd.read_csv(f'{DATA_DIR}/train.csv')

    # cleaning
    df['question_title'] = df['question_title'].apply(lambda x : preprocessing(x))
    df['question_body'] = df['question_body'].apply(lambda x : preprocessing(x))
    df['answer'] = df['answer'].apply(lambda x : preprocessing(x))
    

    df = shuffle(df, random_state=1234)
    gkf = GroupKFold(n_splits=5).split(X=df.question_body, groups=df.question_body)
    for fold, (train_idx, valid_idx) in enumerate(gkf):
        if fold == ifold:
            df_train = df.iloc[train_idx]
            df_val = df.iloc[valid_idx]
            break

    print('train', df_train.shape)
    print('val', df_val.shape)

    ds_train = QuestDataset(df_train, train_mode=True)
    train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
    train_loader.num = len(df_train)

    ds_val = QuestDataset(df_val, train_mode=False)
    val_loader = torch.utils.data.DataLoader(ds_val, batch_size=val_batch_size, shuffle=False, num_workers=0, drop_last=False)
    val_loader.num = len(df_val)
    val_loader.df = df_val

    return train_loader, val_loader, df_val.shape[0], valid_idx


def get_test_loader(batch_size=4):
    df = pd.read_csv(f'{DATA_DIR}/test.csv')

    # cleaning
    df['question_title'] = df['question_title'].apply(lambda x : preprocessing(x))
    df['question_body'] = df['question_body'].apply(lambda x : preprocessing(x))
    df['answer'] = df['answer'].apply(lambda x : preprocessing(x))
    
    ds_test = QuestDataset(df, train_mode=False, labeled=False)
    loader = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
    loader.num = len(df)
    
    return loader

In [0]:
# class QuestModel(nn.Module):
#     def __init__(self, n_classes=30):
#         super(QuestModel, self).__init__()
#         self.model_name = 'QuestModel'
#         #self.bert_model = BertModel.from_pretrained(BERT_DIR+'/bert-base-uncased/')
#         #self.bert_model = BertModel.from_pretrained('../input/bert-base-uncased/')
#         self.bert_model_q = BertModel.from_pretrained('bert-base-uncased')
#         self.bert_model_a = BertModel.from_pretrained('bert-base-uncased')

#         # self.lstm_q = nn.LSTM(768, 300)
#         # self.lstm_a = nn.LSTM(768, 300)
        
#         self.fc_q1 = nn.Linear(768*2, 12)
#         self.fc_q2 = nn.Linear(768*2, 9)
        
#         self.fc_a = nn.Linear(768*2+21, 9)

#     def forward(self, ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a):
#         layers_q, pool_out_q = self.bert_model_q(input_ids=ids_q, token_type_ids=token_type_ids_q, attention_mask=mask_q)
#         layers_a, pool_out_a = self.bert_model_a(input_ids=ids_a, token_type_ids=token_type_ids_a, attention_mask=mask_a)
        
#         # #print(layers_q.shape)
#         # layers_q, _ = self.lstm_q(layers_q)
#         # layers_a, _ = self.lstm_a(layers_a)

#         #print(layers_q.shape)
#         out_q = F.avg_pool1d(layers_q.transpose(1,2), kernel_size=layers_q.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
#         out_a = F.avg_pool1d(layers_a.transpose(1,2), kernel_size=layers_a.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
#         out = torch.cat([out_q, out_a], dim=-1)
#         out = F.dropout(out, p=0.2, training=self.training)

#         logit_q1 = self.fc_q1(out)
#         logit_q2 = self.fc_q2(out)
        
#         out_a = torch.cat([out, F.relu(logit_q1)], dim=-1)
#         out_a = torch.cat([out_a, F.relu(logit_q2)], dim=-1)
#         logit_a = self.fc_a(out_a)

#         logit = torch.cat([logit_q1, logit_q2], dim=-1)
#         logit = torch.cat([logit, logit_a], dim=-1)

#         logit = logit[:, [0,1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,12,20,21,22,23,24,25,26,27,28,29]]

#         return logit
    

In [0]:
class QuestModel(nn.Module):
    def __init__(self, n_classes=30):
        super(QuestModel, self).__init__()
        self.model_name = 'QuestModel'
        #self.bert_model = BertModel.from_pretrained(BERT_DIR+'/bert-base-uncased/')
        #self.bert_model = BertModel.from_pretrained('../input/bert-base-uncased/')
        self.bert_model = BertModel.from_pretrained('bert-base-cased')
        
        self.fc_q = nn.Linear(768*10, 21)
        self.fc_a = nn.Linear(768*10+21, 9)

    def forward(self, ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a):
        layers_q, pool_out_q = self.bert_model(input_ids=ids_q, token_type_ids=token_type_ids_q, attention_mask=mask_q)
        layers_a, pool_out_a = self.bert_model(input_ids=ids_a, token_type_ids=token_type_ids_a, attention_mask=mask_a)
        

        # [batch_size, 512, 768]
        layers_q1,layers_q2,layers_q3,layers_q4,layers_q5,layers_q6,layers_q7,layers_q8,layers_q9,layers_q10 = layers_q[:, :50, :], layers_q[:, 50:100, :], layers_q[:,100:150, :], layers_q[:, 150:200, :], layers_q[:, 200:250, :], layers_q[:, 250:300, :], layers_q[:, 300:350, :], layers_q[:, 350:400, :], layers_q[:, 400:450, :], layers_q[:, 450:512, :]
        layers_a1,layers_a2,layers_a3,layers_a4,layers_a5,layers_a6,layers_a7,layers_a8,layers_a9,layers_a10 = layers_a[:, :50, :], layers_a[:, 50:100, :], layers_a[:,100:150, :], layers_a[:, 150:200, :], layers_a[:, 200:250, :], layers_a[:, 250:300, :], layers_a[:, 300:350, :], layers_a[:, 350:400, :], layers_a[:, 400:450, :], layers_a[:, 450:512, :]
        out_q1 = F.avg_pool1d(layers_q1.transpose(1,2), kernel_size=layers_q1.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a1 = F.avg_pool1d(layers_a1.transpose(1,2), kernel_size=layers_a1.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q2 = F.avg_pool1d(layers_q2.transpose(1,2), kernel_size=layers_q2.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a2 = F.avg_pool1d(layers_a2.transpose(1,2), kernel_size=layers_a2.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q3 = F.avg_pool1d(layers_q3.transpose(1,2), kernel_size=layers_q3.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a3 = F.avg_pool1d(layers_a3.transpose(1,2), kernel_size=layers_a3.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q4 = F.avg_pool1d(layers_q4.transpose(1,2), kernel_size=layers_q4.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a4 = F.avg_pool1d(layers_a4.transpose(1,2), kernel_size=layers_a4.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q5 = F.avg_pool1d(layers_q5.transpose(1,2), kernel_size=layers_q5.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a5 = F.avg_pool1d(layers_a5.transpose(1,2), kernel_size=layers_a5.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q6 = F.avg_pool1d(layers_q6.transpose(1,2), kernel_size=layers_q6.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a6 = F.avg_pool1d(layers_a6.transpose(1,2), kernel_size=layers_a6.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q7 = F.avg_pool1d(layers_q7.transpose(1,2), kernel_size=layers_q7.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a7 = F.avg_pool1d(layers_a7.transpose(1,2), kernel_size=layers_a7.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q8 = F.avg_pool1d(layers_q8.transpose(1,2), kernel_size=layers_q8.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a8 = F.avg_pool1d(layers_a8.transpose(1,2), kernel_size=layers_a8.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q9 = F.avg_pool1d(layers_q9.transpose(1,2), kernel_size=layers_q9.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a9 = F.avg_pool1d(layers_a9.transpose(1,2), kernel_size=layers_a9.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_q10 = F.avg_pool1d(layers_q10.transpose(1,2), kernel_size=layers_q10.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a10 = F.avg_pool1d(layers_a10.transpose(1,2), kernel_size=layers_a10.size()[1]).squeeze()  # sequence方向は中央値だけ抽出

        out_q = torch.cat([out_q1, out_q2], dim=-1)
        out_q = torch.cat([out_q, out_q3], dim=-1)
        out_q = torch.cat([out_q, out_q4], dim=-1)
        out_q = torch.cat([out_q, out_q5], dim=-1)
        out_q = torch.cat([out_q, out_q6], dim=-1)
        out_q = torch.cat([out_q, out_q7], dim=-1)
        out_q = torch.cat([out_q, out_q8], dim=-1)
        out_q = torch.cat([out_q, out_q9], dim=-1)
        out_q = torch.cat([out_q, out_q10], dim=-1)
        out_q = F.dropout(out_q, p=0.4, training=self.training)
        logit_q = self.fc_q(out_q)
        

        out_a = torch.cat([out_a1, out_a2], dim=-1)
        out_a = torch.cat([out_a, out_a3], dim=-1)
        out_a = torch.cat([out_a, out_a4], dim=-1)
        out_a = torch.cat([out_a, out_a5], dim=-1)
        out_a = torch.cat([out_a, out_a6], dim=-1)
        out_a = torch.cat([out_a, out_a7], dim=-1)
        out_a = torch.cat([out_a, out_a8], dim=-1)
        out_a = torch.cat([out_a, out_a9], dim=-1)
        out_a = torch.cat([out_a, out_a10], dim=-1)
        out_a = F.dropout(out_a, p=0.4, training=self.training)
        out_a = torch.cat([out_a, F.relu(logit_q)], dim=-1)
        logit_a = self.fc_a(out_a)

        logit = torch.cat([logit_q, logit_a], dim=-1)

        return logit

In [0]:
def train_model(train_loader, optimizer, criterion, scheduler):
    model.train()
    avg_loss = 0.    
    for idx, batch in enumerate(tqdm(train_loader)):
        ids_q = batch['ids_q'].to(device)
        mask_q = batch['mask_q'].to(device)
        token_type_ids_q = batch['token_type_ids_q'].to(device)
        ids_a = batch['ids_a'].to(device)
        mask_a = batch['mask_a'].to(device)
        token_type_ids_a = batch['token_type_ids_a'].to(device)
        labels = batch['labels'].to(device)
        
        logits = model(ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a)        
        loss = criterion(logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        avg_loss += loss.item() / len(train_loader)
        del ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss

def val_model(val_loader, val_length, batch_size=8):
    model.eval() # eval mode  
    avg_val_loss = 0.
    
    valid_preds = np.zeros((val_length, len(target_columns)))
    original = np.zeros((val_length, len(target_columns)))
    
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(val_loader)):
            ids_q = batch['ids_q'].to(device)
            mask_q = batch['mask_q'].to(device)
            token_type_ids_q = batch['token_type_ids_q'].to(device)
            ids_a = batch['ids_a'].to(device)
            mask_a = batch['mask_a'].to(device)
            token_type_ids_a = batch['token_type_ids_a'].to(device)
            labels = batch['labels'].to(device)

            logits = torch.sigmoid(model(ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a))
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        rho_val = np.mean([spearmanr(original[:, i], valid_preds[:,i]).correlation for i in range(valid_preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        score = 0
        for i in range(len(target_columns)):
            print(i, spearmanr(original[:,i], valid_preds[:,i]))
            score += np.nan_to_num(spearmanr(original[:, i], valid_preds[:, i]).correlation)
    
    return avg_val_loss, score/len(target_columns)

In [21]:
cv_list = []
for fold in range(config.fold):
    print('---%d-Fold---'%(fold+1))
    
    patience = 0
    best_loss   = 100.0
    best_score      = -1.
    best_preds = 0
    best_param_loss = None
    best_param_score = None

    model = QuestModel(n_classes=30).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, eps=4e-5)
    criterion = nn.BCEWithLogitsLoss()
    
    for epoch in range(config.epochs):
        
        torch.cuda.empty_cache()
        start_time   = time.time()
        
        train_loader, val_loader, val_length, val_idx = get_train_val_loaders(batch_size=config.batch_size, val_batch_size=config.batch_size, ifold=fold)
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=config.epochs*len(train_loader))
        #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=config.epochs*len(train_loader))

        loss_train = train_model(train_loader, optimizer, criterion, scheduler)
        loss_val, score_val = val_model(val_loader, val_length, batch_size=config.batch_size)
        print(f'Epoch {(epoch+1)}, train_loss: {loss_train}, val_loss: {loss_val}, score_val: {score_val}, time: {(time.time()-start_time)}')

        if score_val > best_score:
            best_score = score_val
            best_param_score = model.state_dict()
            print('best_param_score_{}_{}.pt'.format(config.name ,fold+1))
            torch.save(best_param_score, '/content/drive/My Drive/Colab Notebooks/GoogleQuest/best_param_score_{}_{}.pt'.format(config.name ,fold+1))
        else:
            patience += 1
            if patience >= config.patience:
                del train_loader, val_loader, loss_train, loss_val, score_val
                torch.cuda.empty_cache()
                gc.collect()
                break
    
        del train_loader, val_loader, loss_train, loss_val, score_val
        torch.cuda.empty_cache()
        gc.collect()
        
    model.load_state_dict(best_param_score)
    print('best_param_score_{}_{}.pt'.format(config.name ,fold+1))
    torch.save(best_param_score, '/content/drive/My Drive/Colab Notebooks/GoogleQuest/best_param_score_{}_{}.pt'.format(config.name ,fold+1))   
    cv_list.append(best_score)

    torch.cuda.empty_cache()
    gc.collect()
    
print('CV_score: ', np.mean(cv_list))

---1-Fold---
train (4863, 41)
val (1216, 41)


HBox(children=(IntProgress(value=0, description='Downloading', max=213450, style=ProgressStyle(description_wid…

  0%|          | 0/1215 [00:00<?, ?it/s]




100%|██████████| 1215/1215 [19:50<00:00,  1.03it/s]
100%|██████████| 304/304 [01:44<00:00,  2.91it/s]


 val_spearman-rho: 0.37754                                                                                                    
0 SpearmanrResult(correlation=0.36608229686499166, pvalue=7.2828092067544845e-40)
1 SpearmanrResult(correlation=0.6306456639737682, pvalue=7.93811424489788e-136)
2 SpearmanrResult(correlation=0.3881123707470349, pvalue=5.373353046821705e-45)
3 SpearmanrResult(correlation=0.28731482431413463, pvalue=1.5313538112393279e-24)
4 SpearmanrResult(correlation=0.37144644666428556, pvalue=4.4562878607150995e-41)
5 SpearmanrResult(correlation=0.40140887586106405, pvalue=2.7369213663539695e-48)
6 SpearmanrResult(correlation=0.30499550463346814, pvalue=1.3588973408275295e-27)
7 SpearmanrResult(correlation=0.41941499604436633, pvalue=5.399933409504066e-53)
8 SpearmanrResult(correlation=0.5593320708204396, pvalue=5.0196623047808774e-101)
9 SpearmanrResult(correlation=0.06192397554093177, pvalue=0.030833537110387287)
10 SpearmanrResult(correlation=0.44595347606866054, pvalue=

100%|██████████| 1215/1215 [19:58<00:00,  1.02it/s]
100%|██████████| 304/304 [01:43<00:00,  2.93it/s]


 val_spearman-rho: 0.38988                                                                                                    
0 SpearmanrResult(correlation=0.37883229142505226, pvalue=8.718073015153543e-43)
1 SpearmanrResult(correlation=0.6521607018496539, pvalue=3.3073665238816985e-148)
2 SpearmanrResult(correlation=0.39746235630179944, pvalue=2.69410075853475e-47)
3 SpearmanrResult(correlation=0.261907777022717, pvalue=1.5987741791720605e-20)
4 SpearmanrResult(correlation=0.35382078924285815, pvalue=3.55134087084063e-37)
5 SpearmanrResult(correlation=0.3920793422498503, pvalue=5.801406126295771e-46)
6 SpearmanrResult(correlation=0.3276180995548226, pvalue=8.124258342116224e-32)
7 SpearmanrResult(correlation=0.45357813131035446, pvalue=9.57983175718296e-63)
8 SpearmanrResult(correlation=0.5941666417032256, pvalue=6.1691680601605926e-117)
9 SpearmanrResult(correlation=0.09731363202662746, pvalue=0.0006787546316121253)
10 SpearmanrResult(correlation=0.4420443275976211, pvalue=2.505242

100%|██████████| 1215/1215 [20:00<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.92it/s]


 val_spearman-rho: 0.39203                                                                                                    
0 SpearmanrResult(correlation=0.390028691522398, pvalue=1.8403979850501105e-45)
1 SpearmanrResult(correlation=0.6479512498430232, pvalue=1.0491489618795224e-145)
2 SpearmanrResult(correlation=0.39695749275352765, pvalue=3.601608644590356e-47)
3 SpearmanrResult(correlation=0.2720614978030534, pvalue=4.452604465232237e-22)
4 SpearmanrResult(correlation=0.3725891816748541, pvalue=2.440707364576454e-41)
5 SpearmanrResult(correlation=0.3955453831450578, pvalue=8.09080819614925e-47)
6 SpearmanrResult(correlation=0.33227554391932024, pvalue=9.875382920470715e-33)
7 SpearmanrResult(correlation=0.4468987455368133, pvalue=9.472884555285832e-61)
8 SpearmanrResult(correlation=0.5827310168432068, pvalue=1.6837356887727145e-111)
9 SpearmanrResult(correlation=0.10531988293560517, pvalue=0.0002340685949174488)
10 SpearmanrResult(correlation=0.465091840906901, pvalue=2.7313029

100%|██████████| 1215/1215 [19:59<00:00,  1.00it/s]
100%|██████████| 304/304 [01:44<00:00,  2.92it/s]


 val_spearman-rho: 0.38596                                                                                                    
0 SpearmanrResult(correlation=0.37753662127751536, pvalue=1.7512720881880738e-42)
1 SpearmanrResult(correlation=0.6549063311425237, pvalue=7.350319094089782e-150)
2 SpearmanrResult(correlation=0.3995814416163708, pvalue=7.92157690260733e-48)
3 SpearmanrResult(correlation=0.2549124143692678, pvalue=1.7240337783007187e-19)
4 SpearmanrResult(correlation=0.3492379747218389, pvalue=3.3505478372033455e-36)
5 SpearmanrResult(correlation=0.4095305880422941, pvalue=2.242587831657802e-50)
6 SpearmanrResult(correlation=0.3232539684661765, pvalue=5.662716989844267e-31)
7 SpearmanrResult(correlation=0.4307663584238748, pvalue=4.125989427072136e-56)
8 SpearmanrResult(correlation=0.5857193995433968, pvalue=6.707780240991883e-113)
9 SpearmanrResult(correlation=0.09880200034182121, pvalue=0.0005600780385185998)
10 SpearmanrResult(correlation=0.4250824383675998, pvalue=1.552817

100%|██████████| 1215/1215 [19:59<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.83it/s]


 val_spearman-rho: 0.38962                                                                                                    
0 SpearmanrResult(correlation=0.409289569517574, pvalue=2.591217685496046e-50)
1 SpearmanrResult(correlation=0.5918204328416019, pvalue=8.381276421559533e-116)
2 SpearmanrResult(correlation=0.41268688884284876, pvalue=3.343285097240086e-51)
3 SpearmanrResult(correlation=0.2984612300415778, pvalue=1.9321360530630207e-26)
4 SpearmanrResult(correlation=0.3799112793323522, pvalue=4.865187914839452e-43)
5 SpearmanrResult(correlation=0.4052130980161849, pvalue=2.9318030030437506e-49)
6 SpearmanrResult(correlation=0.31478742417970434, pvalue=2.2377997383826263e-29)
7 SpearmanrResult(correlation=0.4053712066538003, pvalue=2.6701854311734665e-49)
8 SpearmanrResult(correlation=0.5249069670232961, pvalue=4.699862640695931e-87)
9 SpearmanrResult(correlation=0.09873309169307994, pvalue=0.0005651162715625579)
10 SpearmanrResult(correlation=0.4736934950394995, pvalue=4.98978

100%|██████████| 1215/1215 [20:00<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.83it/s]


 val_spearman-rho: 0.39745                                                                                                    
0 SpearmanrResult(correlation=0.4392537607192173, pvalue=1.6073713295519804e-58)
1 SpearmanrResult(correlation=0.5953207290534458, pvalue=1.6960365433657974e-117)
2 SpearmanrResult(correlation=0.4083273329072679, pvalue=4.608194656005586e-50)
3 SpearmanrResult(correlation=0.34215339335201905, pvalue=1.0009914566287618e-34)
4 SpearmanrResult(correlation=0.38873307874883445, pvalue=3.800657185094787e-45)
5 SpearmanrResult(correlation=0.40222606200790634, pvalue=1.6979769645530414e-48)
6 SpearmanrResult(correlation=0.2797174899962554, pvalue=2.7003186397196295e-23)
7 SpearmanrResult(correlation=0.42929201927581057, pvalue=1.0644244692990514e-55)
8 SpearmanrResult(correlation=0.5511827531508632, pvalue=1.4330501901547074e-97)
9 SpearmanrResult(correlation=0.10833810776935436, pvalue=0.00015362252204131858)
10 SpearmanrResult(correlation=0.47641631813008967, pvalue

100%|██████████| 1215/1215 [20:00<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.81it/s]


 val_spearman-rho: 0.40203                                                                                                    
0 SpearmanrResult(correlation=0.43290049653956064, pvalue=1.0378278616627257e-56)
1 SpearmanrResult(correlation=0.6259345393015407, pvalue=3.029801108069241e-133)
2 SpearmanrResult(correlation=0.41546123309661814, pvalue=6.1698335011865555e-52)
3 SpearmanrResult(correlation=0.3185458803579706, pvalue=4.4392113978175456e-30)
4 SpearmanrResult(correlation=0.3984972396765203, pvalue=1.4834802731035864e-47)
5 SpearmanrResult(correlation=0.39605375942883386, pvalue=6.048475494196458e-47)
6 SpearmanrResult(correlation=0.34258619412179897, pvalue=8.154381422616602e-35)
7 SpearmanrResult(correlation=0.43663400068926395, pvalue=9.060285128693768e-58)
8 SpearmanrResult(correlation=0.589279708222986, pvalue=1.3799884284501318e-114)
9 SpearmanrResult(correlation=0.12136077420532053, pvalue=2.2022174096408033e-05)
10 SpearmanrResult(correlation=0.4707293959077129, pvalue=4

100%|██████████| 1215/1215 [19:59<00:00,  1.02it/s]
100%|██████████| 304/304 [01:44<00:00,  2.84it/s]


 val_spearman-rho: 0.39282                                                                                                    
0 SpearmanrResult(correlation=0.4178084082633301, pvalue=1.4586852780432614e-52)
1 SpearmanrResult(correlation=0.618613226391354, pvalue=2.5499778921549462e-129)
2 SpearmanrResult(correlation=0.40822766855527826, pvalue=4.89080954621109e-50)
3 SpearmanrResult(correlation=0.2995172405146403, pvalue=1.2639278563522558e-26)
4 SpearmanrResult(correlation=0.35854768805934334, pvalue=3.3731718856876574e-38)
5 SpearmanrResult(correlation=0.36116425171874666, pvalue=9.008090601631337e-39)
6 SpearmanrResult(correlation=0.30795751347377776, pvalue=3.9886665406434703e-28)
7 SpearmanrResult(correlation=0.4151277708249087, pvalue=7.565716158334352e-52)
8 SpearmanrResult(correlation=0.5729527176203043, pvalue=5.084104655897311e-107)
9 SpearmanrResult(correlation=0.12720461409484282, pvalue=8.615061984990704e-06)
10 SpearmanrResult(correlation=0.44933566808689335, pvalue=1.7

100%|██████████| 1215/1215 [19:58<00:00,  1.00s/it]
100%|██████████| 304/304 [01:44<00:00,  2.91it/s]


 val_spearman-rho: 0.37231                                                                                                    
0 SpearmanrResult(correlation=0.326029980334948, pvalue=1.6528799254482765e-31)
1 SpearmanrResult(correlation=0.5915572033964502, pvalue=1.1216420039261956e-115)
2 SpearmanrResult(correlation=0.3886311443045467, pvalue=4.0232573225433746e-45)
3 SpearmanrResult(correlation=0.22942370510376883, pvalue=5.480921333137103e-16)
4 SpearmanrResult(correlation=0.270485997020467, pvalue=7.840031971179847e-22)
5 SpearmanrResult(correlation=0.40237584443911756, pvalue=1.5554898326245156e-48)
6 SpearmanrResult(correlation=0.31965867505823364, pvalue=2.737610938235636e-30)
7 SpearmanrResult(correlation=0.44441478584736255, pvalue=5.095356933230572e-60)
8 SpearmanrResult(correlation=0.5602530108807608, pvalue=2.0142229587010974e-101)
9 SpearmanrResult(correlation=0.06655269772525312, pvalue=0.020289077407687175)
10 SpearmanrResult(correlation=0.40175803629660384, pvalue=2.23

100%|██████████| 1215/1215 [19:59<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.92it/s]


 val_spearman-rho: 0.3817                                                                                                    
0 SpearmanrResult(correlation=0.3319880602146387, pvalue=1.1259206667993923e-32)
1 SpearmanrResult(correlation=0.6165702886136976, pvalue=3.0433639910488034e-128)
2 SpearmanrResult(correlation=0.38867099822664325, pvalue=3.934723951915924e-45)
3 SpearmanrResult(correlation=0.26024950716298323, pvalue=2.827738875879483e-20)
4 SpearmanrResult(correlation=0.2968144822001555, pvalue=3.7319084156342914e-26)
5 SpearmanrResult(correlation=0.38077196092994353, pvalue=3.0503285380300785e-43)
6 SpearmanrResult(correlation=0.35409082687987675, pvalue=3.10780456655177e-37)
7 SpearmanrResult(correlation=0.4816274419956893, pvalue=1.2651207467864253e-71)
8 SpearmanrResult(correlation=0.581715808968973, pvalue=4.994338280613574e-111)
9 SpearmanrResult(correlation=0.05615635047773414, pvalue=0.05025701750398893)
10 SpearmanrResult(correlation=0.4113844407120793, pvalue=7.35080

100%|██████████| 1215/1215 [19:57<00:00,  1.01s/it]
100%|██████████| 304/304 [01:44<00:00,  2.93it/s]


 val_spearman-rho: 0.38553                                                                                                    
0 SpearmanrResult(correlation=0.3383227866710727, pvalue=6.058562411131832e-34)
1 SpearmanrResult(correlation=0.6209938148702007, pvalue=1.3857138339209462e-130)
2 SpearmanrResult(correlation=0.40235803784585394, pvalue=1.5717860730939837e-48)
3 SpearmanrResult(correlation=0.28772366981948855, pvalue=1.3089039532003871e-24)
4 SpearmanrResult(correlation=0.2854983979193084, pvalue=3.066008394484852e-24)
5 SpearmanrResult(correlation=0.39499092749660225, pvalue=1.1105503367579943e-46)
6 SpearmanrResult(correlation=0.33562495457497593, pvalue=2.1208224284975725e-33)
7 SpearmanrResult(correlation=0.4720991215228407, pvalue=1.627648386223175e-68)
8 SpearmanrResult(correlation=0.5925332783089495, pvalue=3.802178005705695e-116)
9 SpearmanrResult(correlation=0.05604727902634235, pvalue=0.05070578504998748)
10 SpearmanrResult(correlation=0.40116812279776587, pvalue=3.1

100%|██████████| 1215/1215 [19:58<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.92it/s]


 val_spearman-rho: 0.38391                                                                                                    
0 SpearmanrResult(correlation=0.3412648372191852, pvalue=1.523316703905377e-34)
1 SpearmanrResult(correlation=0.6213509699904766, pvalue=8.932355839560071e-131)
2 SpearmanrResult(correlation=0.3901415376286888, pvalue=1.7274816287349303e-45)
3 SpearmanrResult(correlation=0.2692569649599231, pvalue=1.2158024218228344e-21)
4 SpearmanrResult(correlation=0.25166075611813626, pvalue=5.081421498503411e-19)
5 SpearmanrResult(correlation=0.39800274686976195, pvalue=1.9734054247747918e-47)
6 SpearmanrResult(correlation=0.3231002809814335, pvalue=6.0599498947805075e-31)
7 SpearmanrResult(correlation=0.46324416834084975, pvalue=1.033791555490628e-65)
8 SpearmanrResult(correlation=0.5804092601365064, pvalue=2.01257170970605e-110)
9 SpearmanrResult(correlation=0.07074417274664449, pvalue=0.0136064590911734)
10 SpearmanrResult(correlation=0.37185448325981285, pvalue=3.59530

100%|██████████| 1215/1215 [20:00<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.86it/s]


 val_spearman-rho: 0.36971                                                                                                    
0 SpearmanrResult(correlation=0.295324877820991, pvalue=6.744374089818216e-26)
1 SpearmanrResult(correlation=0.527232844581888, pvalue=6.000886849408552e-88)
2 SpearmanrResult(correlation=0.38373292740505605, pvalue=6.05564265247401e-44)
3 SpearmanrResult(correlation=0.2946493369708242, pvalue=8.810426815923197e-26)
4 SpearmanrResult(correlation=0.3279976198380922, pvalue=6.851699187922235e-32)
5 SpearmanrResult(correlation=0.4250707134929949, pvalue=1.5643699772262873e-54)
6 SpearmanrResult(correlation=0.17402278211032726, pvalue=1.0037741607690578e-09)
7 SpearmanrResult(correlation=0.4623353752964395, pvalue=1.983622510006272e-65)
8 SpearmanrResult(correlation=0.5314647743589487, pvalue=1.3610210108769782e-89)
9 SpearmanrResult(correlation=0.06362712740377834, pvalue=0.026505761716313844)
10 SpearmanrResult(correlation=0.4529452029243809, pvalue=1.4870407706

100%|██████████| 1215/1215 [20:01<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  2.90it/s]


 val_spearman-rho: 0.38583                                                                                                    
0 SpearmanrResult(correlation=0.2997518631671526, pvalue=1.1499187733198333e-26)
1 SpearmanrResult(correlation=0.540641163166012, pvalue=3.069660313039523e-93)
2 SpearmanrResult(correlation=0.3822869963028183, pvalue=1.336485085535146e-43)
3 SpearmanrResult(correlation=0.3086940807670365, pvalue=2.9342214269188263e-28)
4 SpearmanrResult(correlation=0.33070802081698775, pvalue=2.015398673081901e-32)
5 SpearmanrResult(correlation=0.41986979475350533, pvalue=4.0718403483972485e-53)
6 SpearmanrResult(correlation=0.3341438881155004, pvalue=4.196907690641114e-33)
7 SpearmanrResult(correlation=0.46875440724444306, pvalue=1.9052948835351895e-67)
8 SpearmanrResult(correlation=0.571201874995972, pvalue=3.1083588665899704e-106)
9 SpearmanrResult(correlation=0.04665119534989319, pvalue=0.10395071706473796)
10 SpearmanrResult(correlation=0.45519255449207413, pvalue=3.10766

100%|██████████| 1215/1215 [20:00<00:00,  1.02it/s]
100%|██████████| 304/304 [01:44<00:00,  2.87it/s]


 val_spearman-rho: 0.3957                                                                                                    
0 SpearmanrResult(correlation=0.31120612166372086, pvalue=1.0230430012222964e-28)
1 SpearmanrResult(correlation=0.5553296922841138, pvalue=2.5690564714498056e-99)
2 SpearmanrResult(correlation=0.42060980028610334, pvalue=2.5698529060275163e-53)
3 SpearmanrResult(correlation=0.33191065592631286, pvalue=1.1663585744033172e-32)
4 SpearmanrResult(correlation=0.3592216227385519, pvalue=2.403595630984674e-38)
5 SpearmanrResult(correlation=0.4440739625019458, pvalue=6.411516792028058e-60)
6 SpearmanrResult(correlation=0.31914899011282327, pvalue=3.416931816298623e-30)
7 SpearmanrResult(correlation=0.47683387039401204, pvalue=4.772279850804641e-70)
8 SpearmanrResult(correlation=0.578069006180038, pvalue=2.4047838566957845e-109)
9 SpearmanrResult(correlation=0.0678082640903921, pvalue=0.018037601098764917)
10 SpearmanrResult(correlation=0.48382452408131205, pvalue=2.350

100%|██████████| 1215/1215 [20:04<00:00,  1.03it/s]
100%|██████████| 304/304 [01:45<00:00,  2.84it/s]


 val_spearman-rho: 0.38587                                                                                                    
0 SpearmanrResult(correlation=0.29396366550381836, pvalue=1.1547133556628218e-25)
1 SpearmanrResult(correlation=0.5403338965060612, pvalue=4.083474457536811e-93)
2 SpearmanrResult(correlation=0.4143977701321194, pvalue=1.1814322450222242e-51)
3 SpearmanrResult(correlation=0.32913075503335004, pvalue=4.1140944629436354e-32)
4 SpearmanrResult(correlation=0.3538570729955108, pvalue=3.4882746692451123e-37)
5 SpearmanrResult(correlation=0.4301656786804014, pvalue=6.073852104141427e-56)
6 SpearmanrResult(correlation=0.32038418990952355, pvalue=1.9953630731503483e-30)
7 SpearmanrResult(correlation=0.4838335538863134, pvalue=2.333906195459941e-72)
8 SpearmanrResult(correlation=0.5706408998575642, pvalue=5.539487233533365e-106)
9 SpearmanrResult(correlation=0.06903344747263258, pvalue=0.016054442875062535)
10 SpearmanrResult(correlation=0.48059961584929983, pvalue=2.76

100%|██████████| 1216/1216 [20:03<00:00,  1.00it/s]
100%|██████████| 304/304 [01:45<00:00,  3.15it/s]


 val_spearman-rho: 0.36344                                                                                                    
0 SpearmanrResult(correlation=0.31587604510696826, pvalue=1.4803925179995758e-29)
1 SpearmanrResult(correlation=0.5774419746843641, pvalue=5.70830123869443e-109)
2 SpearmanrResult(correlation=0.37824936039849416, pvalue=1.2899926044950138e-42)
3 SpearmanrResult(correlation=0.218656235915822, pvalue=1.2853776354242412e-14)
4 SpearmanrResult(correlation=0.3197252776829794, pvalue=2.8078320521849288e-30)
5 SpearmanrResult(correlation=0.432242817093104, pvalue=1.7635583495102205e-56)
6 SpearmanrResult(correlation=0.27780475216335465, pvalue=5.71157726882849e-23)
7 SpearmanrResult(correlation=0.48934368009951423, pvalue=3.721042550760569e-74)
8 SpearmanrResult(correlation=0.5508118037868667, pvalue=2.4545862984803915e-97)
9 SpearmanrResult(correlation=0.10102277036245778, pvalue=0.00042071485176956616)
10 SpearmanrResult(correlation=0.452918802556308, pvalue=1.6994

100%|██████████| 1216/1216 [20:03<00:00,  1.02it/s]
100%|██████████| 304/304 [01:44<00:00,  3.17it/s]


 val_spearman-rho: 0.37691                                                                                                    
0 SpearmanrResult(correlation=0.35738289134556805, pvalue=6.477787092774316e-38)
1 SpearmanrResult(correlation=0.5709699486245008, pvalue=4.810462583861604e-106)
2 SpearmanrResult(correlation=0.3860061484474293, pvalue=1.8765275068717836e-44)
3 SpearmanrResult(correlation=0.24429373684973119, pvalue=5.733278522421272e-18)
4 SpearmanrResult(correlation=0.31804013316404933, pvalue=5.8311907771954815e-30)
5 SpearmanrResult(correlation=0.41086070287715787, pvalue=1.1061765098780746e-50)
6 SpearmanrResult(correlation=0.3089326054060342, pvalue=2.793769130180916e-28)
7 SpearmanrResult(correlation=0.5047579686287242, pvalue=1.570845729996928e-79)
8 SpearmanrResult(correlation=0.5560447971123881, pvalue=1.5367773225562667e-99)
9 SpearmanrResult(correlation=0.11925575083451873, pvalue=3.079834293931972e-05)
10 SpearmanrResult(correlation=0.4444369996308076, pvalue=5.60

100%|██████████| 1216/1216 [19:59<00:00,  1.01it/s]
100%|██████████| 304/304 [01:44<00:00,  3.21it/s]


 val_spearman-rho: 0.37698                                                                                                    
0 SpearmanrResult(correlation=0.3652074648798827, pvalue=1.2281895704833424e-39)
1 SpearmanrResult(correlation=0.6166022119036495, pvalue=3.720834366219026e-128)
2 SpearmanrResult(correlation=0.37434155217790593, pvalue=1.0410888619360108e-41)
3 SpearmanrResult(correlation=0.23889831160303296, pvalue=3.141280862876334e-17)
4 SpearmanrResult(correlation=0.318691042384513, pvalue=4.399500425966071e-30)
5 SpearmanrResult(correlation=0.4255687532956886, pvalue=1.2619826976067286e-54)
6 SpearmanrResult(correlation=0.3072216262643766, pvalue=5.693305722100955e-28)
7 SpearmanrResult(correlation=0.48426415085239144, pvalue=1.9159774155026715e-72)
8 SpearmanrResult(correlation=0.5834508324526255, pvalue=9.571912931018957e-112)
9 SpearmanrResult(correlation=0.12080734080830079, pvalue=2.420356261657416e-05)
10 SpearmanrResult(correlation=0.44517000119069006, pvalue=3.41

100%|██████████| 1216/1216 [20:01<00:00,  1.00it/s]
100%|██████████| 304/304 [01:44<00:00,  3.19it/s]


 val_spearman-rho: 0.3721                                                                                                    
0 SpearmanrResult(correlation=0.367200902593132, pvalue=4.3934235079065204e-40)
1 SpearmanrResult(correlation=0.5869304659834667, pvalue=2.2236703796276925e-113)
2 SpearmanrResult(correlation=0.3784572083098036, pvalue=1.1534604220011172e-42)
3 SpearmanrResult(correlation=0.23533363051425388, pvalue=9.44622078777049e-17)
4 SpearmanrResult(correlation=0.2897921102044524, pvalue=6.159244487450099e-25)
5 SpearmanrResult(correlation=0.40697846370026897, pvalue=1.1275662352021821e-49)
6 SpearmanrResult(correlation=0.30499170545574683, pvalue=1.4296929768587198e-27)
7 SpearmanrResult(correlation=0.4671908373938597, pvalue=6.745409598908319e-67)
8 SpearmanrResult(correlation=0.5697741249710387, pvalue=1.6430911311929787e-105)
9 SpearmanrResult(correlation=0.11456293391509441, pvalue=6.270934155672909e-05)
10 SpearmanrResult(correlation=0.4381746926647588, pvalue=3.653