In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## やったこと

- テキストのクリーニング処理 -> 改善
- host, categoryのカテゴリカル変数をエンベディングして入力 -> 改善

- epochs=20で、early-stoppingはあまり良くならなかった -> とりあえず速く数を回したいので、epochs=4でやっている


- batch_size=8以上にすると、out_of_memoryになる

- MSELossを使用 -> 悪化
- titleは分けて、別のエンベディングとして入力 -> 悪化



- BERTを2つ使う -> gpu不足
- クラス分類問題にする（30*num_class） -> 学習が安定しない（nan）
- 30個の目的変数それぞれ独立に予測するモデル -> 約30時間必要、あまり精度が出ないように見える -> 関連する目的変数だけをグルーピングしてモデルを分ける必要？

In [0]:
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
import os, sys, gc, random, multiprocessing, glob, time

DATA_DIR = '/content/drive/My Drive/Colab Notebooks/GoogleQuest/input/google-quest-challenge'
# DATA_DIR = '../input/google-quest-challenge'
# DATA_DIR = 'D:/project/ICF_AutoCapsule_disabled/kaggle/google-quest-challenge'
# BERT_DIR = 'D:/project/ICF_AutoCapsule_disabled/BERT'

In [0]:
# !pip install ../input/sacremoses/sacremoses-master/
# !pip install ../input/transformers/transformers-master/

In [0]:
!pip install transformers
!pip install flashtext



In [0]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader, Dataset

#from ml_stratifiers import MultilabelStratifiedShuffleSplit, MultilabelStratifiedKFold
from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold
from sklearn.utils import shuffle
from sklearn.preprocessing import LabelEncoder

from scipy.stats import spearmanr

import transformers
from transformers import (
    BertTokenizer, BertModel, BertForSequenceClassification, BertConfig,
    WEIGHTS_NAME, CONFIG_NAME, AdamW, get_linear_schedule_with_warmup, 
    get_cosine_schedule_with_warmup,
)

from tqdm import tqdm
print(transformers.__version__)

2.4.1


In [0]:
## Make results reproducible .Else noone will believe you .
import random

def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [0]:
class PipeLineConfig:
    def __init__(self, lr, warmup, epochs, patience, batch_size, seed, name, question_weight,answer_weight,fold,train):
        self.lr = lr
        self.warmup = warmup
        self.epochs = epochs
        self.patience = patience
        self.batch_size = batch_size
        self.seed = seed
        self.name = name
        self.question_weight = question_weight
        self.answer_weight =answer_weight
        self.fold = fold
        self.train = train

In [0]:
config = PipeLineConfig(lr=1e-5, \
                        warmup=0.01, \
                        epochs=20, \
                        patience=3, \
                        batch_size=4, \
                        seed=42, \
                        name='twoBERT', \
                        question_weight=0.5, \
                        answer_weight=0.5, \
                        fold=5, \
                        train=True
                       )

In [0]:
seed_everything(config.seed)

In [0]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(device)

cuda


In [0]:
sub = pd.read_csv(f'{DATA_DIR}/sample_submission.csv')
sub.head()

Unnamed: 0,qa_id,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,question_opinion_seeking,question_type_choice,question_type_compare,question_type_consequence,question_type_definition,question_type_entity,question_type_instructions,question_type_procedure,question_type_reason_explanation,question_type_spelling,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,39,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308,0.00308
1,46,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448,0.00448
2,70,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673,0.00673
3,132,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401,0.01401
4,200,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074,0.02074


In [0]:
target_columns = sub.columns.values[1:].tolist()
target_columns

['question_asker_intent_understanding',
 'question_body_critical',
 'question_conversational',
 'question_expect_short_answer',
 'question_fact_seeking',
 'question_has_commonly_accepted_answer',
 'question_interestingness_others',
 'question_interestingness_self',
 'question_multi_intent',
 'question_not_really_a_question',
 'question_opinion_seeking',
 'question_type_choice',
 'question_type_compare',
 'question_type_consequence',
 'question_type_definition',
 'question_type_entity',
 'question_type_instructions',
 'question_type_procedure',
 'question_type_reason_explanation',
 'question_type_spelling',
 'question_well_written',
 'answer_helpful',
 'answer_level_of_information',
 'answer_plausible',
 'answer_relevance',
 'answer_satisfaction',
 'answer_type_instructions',
 'answer_type_procedure',
 'answer_type_reason_explanation',
 'answer_well_written']

In [0]:
train = pd.read_csv(f'{DATA_DIR}/train.csv')
train.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,question_opinion_seeking,question_type_choice,question_type_compare,question_type_consequence,question_type_definition,question_type_entity,question_type_instructions,question_type_procedure,question_type_reason_explanation,question_type_spelling,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0,What am I losing when using extension tubes in...,After playing around with macro photography on...,ysap,https://photo.stackexchange.com/users/1024,"I just got extension tubes, so here's the skin...",rfusca,https://photo.stackexchange.com/users/1917,http://photo.stackexchange.com/questions/9169/...,LIFE_ARTS,photo.stackexchange.com,1.0,0.333333,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.666667,1.0,1.0,0.8,1.0,0.0,0.0,1.0
1,1,What is the distinction between a city and a s...,I am trying to understand what kinds of places...,russellpierce,https://rpg.stackexchange.com/users/8774,It might be helpful to look into the definitio...,Erik Schmidt,https://rpg.stackexchange.com/users/1871,http://rpg.stackexchange.com/questions/47820/w...,CULTURE,rpg.stackexchange.com,1.0,1.0,0.0,0.5,1.0,1.0,0.444444,0.444444,0.666667,0.0,0.0,0.666667,0.666667,0.0,0.333333,0.0,0.0,0.0,0.333333,0.0,0.888889,0.888889,0.555556,0.888889,0.888889,0.666667,0.0,0.0,0.666667,0.888889
2,2,Maximum protusion length for through-hole comp...,I'm working on a PCB that has through-hole com...,Joe Baker,https://electronics.stackexchange.com/users/10157,Do you even need grooves? We make several pro...,Dwayne Reid,https://electronics.stackexchange.com/users/64754,http://electronics.stackexchange.com/questions...,SCIENCE,electronics.stackexchange.com,0.888889,0.666667,0.0,1.0,1.0,1.0,0.666667,0.444444,0.333333,0.0,0.333333,0.0,0.0,0.0,0.0,0.0,1.0,0.333333,0.333333,0.0,0.777778,0.777778,0.555556,1.0,1.0,0.666667,0.0,0.333333,1.0,0.888889
3,3,Can an affidavit be used in Beit Din?,"An affidavit, from what i understand, is basic...",Scimonster,https://judaism.stackexchange.com/users/5151,"Sending an ""affidavit"" it is a dispute between...",Y e z,https://judaism.stackexchange.com/users/4794,http://judaism.stackexchange.com/questions/551...,CULTURE,judaism.stackexchange.com,0.888889,0.666667,0.666667,1.0,1.0,1.0,0.444444,0.444444,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.888889,0.833333,0.333333,0.833333,1.0,0.8,0.0,0.0,1.0,1.0
4,5,How do you make a binary image in Photoshop?,I am trying to make a binary image. I want mor...,leigero,https://graphicdesign.stackexchange.com/users/...,Check out Image Trace in Adobe Illustrator. \n...,q2ra,https://graphicdesign.stackexchange.com/users/...,http://graphicdesign.stackexchange.com/questio...,LIFE_ARTS,graphicdesign.stackexchange.com,1.0,0.666667,0.0,1.0,1.0,1.0,0.666667,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.666667,1.0,1.0,0.8,1.0,0.0,1.0,1.0


In [0]:
test = pd.read_csv(f'{DATA_DIR}/test.csv')
test.head()

Unnamed: 0,qa_id,question_title,question_body,question_user_name,question_user_page,answer,answer_user_name,answer_user_page,url,category,host
0,39,Will leaving corpses lying around upset my pri...,I see questions/information online about how t...,Dylan,https://gaming.stackexchange.com/users/64471,There is no consequence for leaving corpses an...,Nelson868,https://gaming.stackexchange.com/users/97324,http://gaming.stackexchange.com/questions/1979...,CULTURE,gaming.stackexchange.com
1,46,Url link to feature image in the portfolio,I am new to Wordpress. i have issue with Featu...,Anu,https://wordpress.stackexchange.com/users/72927,I think it is possible with custom fields.\n\n...,Irina,https://wordpress.stackexchange.com/users/27233,http://wordpress.stackexchange.com/questions/1...,TECHNOLOGY,wordpress.stackexchange.com
2,70,"Is accuracy, recoil or bullet spread affected ...","To experiment I started a bot game, toggled in...",Konsta,https://gaming.stackexchange.com/users/37545,You do not have armour in the screenshots. Thi...,Damon Smithies,https://gaming.stackexchange.com/users/70641,http://gaming.stackexchange.com/questions/2154...,CULTURE,gaming.stackexchange.com
3,132,Suddenly got an I/O error from my external HDD,I have used my Raspberry Pi as a torrent-serve...,robbannn,https://raspberrypi.stackexchange.com/users/17341,Your Western Digital hard drive is disappearin...,HeatfanJohn,https://raspberrypi.stackexchange.com/users/1311,http://raspberrypi.stackexchange.com/questions...,TECHNOLOGY,raspberrypi.stackexchange.com
4,200,Passenger Name - Flight Booking Passenger only...,I have bought Delhi-London return flights for ...,Amit,https://travel.stackexchange.com/users/29089,I called two persons who work for Saudia (tick...,Nean Der Thal,https://travel.stackexchange.com/users/10051,http://travel.stackexchange.com/questions/4704...,CULTURE,travel.stackexchange.com


## Preprocessing

In [0]:
import re
from flashtext import KeywordProcessor

In [0]:
PUNCTS = {
            '》', '〞', '¢', '‹', '╦', '║', '♪', 'Ø', '╩', '\\', '★', '＋', 'ï', '<', '?', '％', '+', '„', 'α', '*', '〰', '｟', '¹', '●', '〗', ']', '▾', '■', '〙', '↓', '´', '【', 'ᴵ',
            '"', '）', '｀', '│', '¤', '²', '‡', '¿', '–', '」', '╔', '〾', '%', '¾', '←', '〔', '＿', '’', '-', ':', '‧', '｛', 'β', '（', '─', 'à', 'â', '､', '•', '；', '☆', '／', 'π',
            'é', '╗', '＾', '▪', ',', '►', '/', '〚', '¶', '♦', '™', '}', '″', '＂', '『', '▬', '±', '«', '“', '÷', '×', '^', '!', '╣', '▲', '・', '░', '′', '〝', '‛', '√', ';', '】', '▼',
            '.', '~', '`', '。', 'ə', '］', '，', '{', '～', '！', '†', '‘', '﹏', '═', '｣', '〕', '〜', '＼', '▒', '＄', '♥', '〛', '≤', '∞', '_', '[', '＆', '→', '»', '－', '＝', '§', '⋅', 
            '▓', '&', 'Â', '＞', '〃', '|', '¦', '—', '╚', '〖', '―', '¸', '³', '®', '｠', '¨', '‟', '＊', '£', '#', 'Ã', "'", '▀', '·', '？', '、', '█', '”', '＃', '⊕', '=', '〟', '½', '』',
            '［', '$', ')', 'θ', '@', '›', '＠', '｝', '¬', '…', '¼', '：', '¥', '❤', '€', '−', '＜', '(', '〘', '▄', '＇', '>', '₤', '₹', '∅', 'è', '〿', '「', '©', '｢', '∙', '°', '｜', '¡', 
            '↑', 'º', '¯', '♫', '#'
          }


mispell_dict = {"aren't" : "are not", "can't" : "cannot", "couldn't" : "could not",
"couldnt" : "could not", "didn't" : "did not", "doesn't" : "does not",
"doesnt" : "does not", "don't" : "do not", "hadn't" : "had not", "hasn't" : "has not",
"haven't" : "have not", "havent" : "have not", "he'd" : "he would", "he'll" : "he will", "he's" : "he is", "i'd" : "I would",
"i'd" : "I had", "i'll" : "I will", "i'm" : "I am", "isn't" : "is not", "it's" : "it is",
"it'll":"it will", "i've" : "I have", "let's" : "let us", "mightn't" : "might not", "mustn't" : "must not", 
"shan't" : "shall not", "she'd" : "she would", "she'll" : "she will", "she's" : "she is", "shouldn't" : "should not", "shouldnt" : "should not",
"that's" : "that is", "thats" : "that is", "there's" : "there is", "theres" : "there is", "they'd" : "they would", "they'll" : "they will",
"they're" : "they are", "theyre":  "they are", "they've" : "they have", "we'd" : "we would", "we're" : "we are", "weren't" : "were not",
"we've" : "we have", "what'll" : "what will", "what're" : "what are", "what's" : "what is", "what've" : "what have", "where's" : "where is",
"who'd" : "who would", "who'll" : "who will", "who're" : "who are", "who's" : "who is", "who've" : "who have", "won't" : "will not", "wouldn't" : "would not", "you'd" : "you would",
"you'll" : "you will", "you're" : "you are", "you've" : "you have", "'re": " are", "wasn't": "was not", "we'll":" will", "didn't": "did not", "tryin'":"trying"}


kp = KeywordProcessor(case_sensitive=True)
for k, v in mispell_dict.items():
    kp.add_keyword(k, v)

def clean_punct(text):
    text = str(text)
    for punct in PUNCTS:
        text = text.replace(punct, ' {} '.format(punct))
    return text


def preprocessing(text):
    text = text.lower()
    text = re.sub(r'(\&lt)|(\&gt)', ' ', text)
    
    text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', ' url ', text)
    text = kp.replace_keywords(text)
    text = clean_punct(text)
    text = re.sub(r'\n\r', ' ', text)
    text = re.sub(r'\s{2,}', ' ', text)
    
    return text

## Dataset

In [0]:
MAX_LEN = 512

class QuestDataset(torch.utils.data.Dataset):
    def __init__(self, df, train_mode=True, labeled=True):
        self.df = df
        self.train_mode = train_mode
        self.labeled = labeled
        #self.tokenizer = BertTokenizer.from_pretrained(BERT_DIR+'/bert-base-uncased')
        #self.tokenizer = BertTokenizer.from_pretrained('../input/bert-base-uncased/')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]

        question_title = row.question_title
        question_body = row.question_body
        answer_text = row.answer


        inputs_q = self.tokenizer.encode_plus(
            question_title + " " + question_body,
            add_special_tokens=True,
            max_length=MAX_LEN,
        )

        inputs_a = self.tokenizer.encode_plus(
            answer_text,
            add_special_tokens=True,
            max_length=MAX_LEN,
        )

        ids_q = inputs_q["input_ids"]
        token_type_ids_q = inputs_q["token_type_ids"]
        mask_q = inputs_q["attention_mask"]

        ids_a = inputs_a["input_ids"]
        token_type_ids_a = inputs_a["token_type_ids"]
        mask_a = inputs_a["attention_mask"]
        
        padding_length_q = MAX_LEN - len(ids_q)
        padding_length_a = MAX_LEN - len(ids_a)
        
        ids_q = ids_q + ([0] * padding_length_q)
        mask_q = mask_q + ([0] * padding_length_q)
        token_type_ids_q = token_type_ids_q + ([0] * padding_length_q)

        ids_a = ids_a + ([0] * padding_length_a)
        mask_a = mask_a + ([0] * padding_length_a)
        token_type_ids_a = token_type_ids_a + ([0] * padding_length_a)
        
        if self.labeled:
            labels = self.get_label(row)
            return {
                'ids_q': torch.tensor(ids_q, dtype=torch.long),
                'mask_q': torch.tensor(mask_q, dtype=torch.long),
                'token_type_ids_q': torch.tensor(token_type_ids_q, dtype=torch.long),
                'ids_a': torch.tensor(ids_a, dtype=torch.long),
                'mask_a': torch.tensor(mask_a, dtype=torch.long),
                'token_type_ids_a': torch.tensor(token_type_ids_a, dtype=torch.long),
                'labels': labels, 
            }
        else:
            return {
                'ids_q': torch.tensor(ids_q, dtype=torch.long),
                'mask_q': torch.tensor(mask_q, dtype=torch.long),
                'token_type_ids_q': torch.tensor(token_type_ids_q, dtype=torch.long),
                'ids_a': torch.tensor(ids_a, dtype=torch.long),
                'mask_a': torch.tensor(mask_a, dtype=torch.long),
                'token_type_ids_a': torch.tensor(token_type_ids_a, dtype=torch.long)
            }


    def get_label(self, row):
        return torch.tensor(row[target_columns].values.astype(np.float32))

In [0]:
def get_train_val_loaders(batch_size=4, val_batch_size=4, ifold=0):
    df = pd.read_csv(f'{DATA_DIR}/train.csv')

    # cleaning
    df['question_title'] = df['question_title'].apply(lambda x : preprocessing(x))
    df['question_body'] = df['question_body'].apply(lambda x : preprocessing(x))
    df['answer'] = df['answer'].apply(lambda x : preprocessing(x))
    

    df = shuffle(df, random_state=1234)
    gkf = GroupKFold(n_splits=5).split(X=df.question_body, groups=df.question_body)
    for fold, (train_idx, valid_idx) in enumerate(gkf):
        if fold == ifold:
            df_train = df.iloc[train_idx]
            df_val = df.iloc[valid_idx]
            break

    print('train', df_train.shape)
    print('val', df_val.shape)

    ds_train = QuestDataset(df_train, train_mode=True)
    train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
    train_loader.num = len(df_train)

    ds_val = QuestDataset(df_val, train_mode=False)
    val_loader = torch.utils.data.DataLoader(ds_val, batch_size=val_batch_size, shuffle=False, num_workers=0, drop_last=False)
    val_loader.num = len(df_val)
    val_loader.df = df_val

    return train_loader, val_loader, df_val.shape[0], valid_idx


def get_test_loader(batch_size=4):
    df = pd.read_csv(f'{DATA_DIR}/test.csv')

    # cleaning
    df['question_title'] = df['question_title'].apply(lambda x : preprocessing(x))
    df['question_body'] = df['question_body'].apply(lambda x : preprocessing(x))
    df['answer'] = df['answer'].apply(lambda x : preprocessing(x))
    
    ds_test = QuestDataset(df, train_mode=False, labeled=False)
    loader = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
    loader.num = len(df)
    
    return loader

In [0]:
class QuestModel(nn.Module):
    def __init__(self, n_classes=30):
        super(QuestModel, self).__init__()
        self.model_name = 'QuestModel'
        #self.bert_model = BertModel.from_pretrained(BERT_DIR+'/bert-base-uncased/')
        #self.bert_model = BertModel.from_pretrained('../input/bert-base-uncased/')
        self.bert_model_q = BertModel.from_pretrained('bert-base-uncased')
        self.bert_model_a = BertModel.from_pretrained('bert-base-uncased')
        
        # self.fc_q = nn.Linear(768*2, 21)
        # self.fc_a = nn.Linear(768*2+21, 9)
        self.fc = nn.Linear(768*2, 30)

    def forward(self, ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a):
        layers_q, pool_out_q = self.bert_model_q(input_ids=ids_q, token_type_ids=token_type_ids_q, attention_mask=mask_q)
        layers_a, pool_out_a = self.bert_model_a(input_ids=ids_a, token_type_ids=token_type_ids_a, attention_mask=mask_a)
        
        out_q = F.avg_pool1d(layers_q.transpose(1,2), kernel_size=layers_q.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out_a = F.avg_pool1d(layers_a.transpose(1,2), kernel_size=layers_a.size()[1]).squeeze()  # sequence方向は中央値だけ抽出
        out = torch.cat([out_q, out_a], dim=-1)

        out = F.dropout(out, p=0.2, training=self.training)
        # logit_q = self.fc_q(out)

        # out_a = torch.cat([out, F.relu(logit_q)], dim=-1)
        # logit_a = self.fc_a(out_a)

        # logit = torch.cat([logit_q, logit_a], dim=-1)
        logit = self.fc(out)

        return logit
    

In [0]:
def train_model(train_loader, optimizer, criterion, scheduler):
    model.train()
    avg_loss = 0.    
    for idx, batch in enumerate(tqdm(train_loader)):
        ids_q = batch['ids_q'].to(device)
        mask_q = batch['mask_q'].to(device)
        token_type_ids_q = batch['token_type_ids_q'].to(device)
        ids_a = batch['ids_a'].to(device)
        mask_a = batch['mask_a'].to(device)
        token_type_ids_a = batch['token_type_ids_a'].to(device)
        labels = batch['labels'].to(device)
        
        logits = model(ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a)        
        loss = criterion(logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        avg_loss += loss.item() / len(train_loader)
        del ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss

def val_model(val_loader, val_length, batch_size=8):
    model.eval() # eval mode  
    avg_val_loss = 0.
    
    valid_preds = np.zeros((val_length, len(target_columns)))
    original = np.zeros((val_length, len(target_columns)))
    
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(val_loader)):
            ids_q = batch['ids_q'].to(device)
            mask_q = batch['mask_q'].to(device)
            token_type_ids_q = batch['token_type_ids_q'].to(device)
            ids_a = batch['ids_a'].to(device)
            mask_a = batch['mask_a'].to(device)
            token_type_ids_a = batch['token_type_ids_a'].to(device)
            labels = batch['labels'].to(device)

            logits = torch.sigmoid(model(ids_q, mask_q, token_type_ids_q, ids_a, mask_a, token_type_ids_a))
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        rho_val = np.mean([spearmanr(original[:, i], valid_preds[:,i]).correlation for i in range(valid_preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        score = 0
        for i in range(len(target_columns)):
            print(i, spearmanr(original[:,i], valid_preds[:,i]))
            score += np.nan_to_num(spearmanr(original[:, i], valid_preds[:, i]).correlation)
    
    return avg_val_loss, score/len(target_columns)

In [21]:
cv_list = []
for fold in range(config.fold):
    print('---%d-Fold---'%(fold+1))
    
    patience = 0
    best_loss   = 100.0
    best_score      = -1.
    best_preds = 0
    best_param_loss = None
    best_param_score = None

    model = QuestModel(n_classes=30).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, eps=4e-5)
    criterion = nn.BCEWithLogitsLoss()
    
    for epoch in range(config.epochs):
        
        torch.cuda.empty_cache()
        start_time   = time.time()
        
        train_loader, val_loader, val_length, val_idx = get_train_val_loaders(batch_size=config.batch_size, val_batch_size=config.batch_size, ifold=fold)
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=config.epochs*len(train_loader))
        
        loss_train = train_model(train_loader, optimizer, criterion, scheduler)
        loss_val, score_val = val_model(val_loader, val_length, batch_size=config.batch_size)
        print(f'Epoch {(epoch+1)}, train_loss: {loss_train}, val_loss: {loss_val}, score_val: {score_val}, time: {(time.time()-start_time)}')

        if score_val > best_score:
            best_score = score_val
            best_param_score = model.state_dict()
            print('best_param_score_{}_{}.pt'.format(config.name ,fold+1))
            torch.save(best_param_score, '/content/drive/My Drive/Colab Notebooks/GoogleQuest/best_param_score_{}_{}.pt'.format(config.name ,fold+1))
        else:
            patience += 1
            if patience >= config.patience:
                del train_loader, val_loader, loss_train, loss_val, score_val
                torch.cuda.empty_cache()
                gc.collect()
                break
    
        del train_loader, val_loader, loss_train, loss_val, score_val
        torch.cuda.empty_cache()
        gc.collect()
        
    model.load_state_dict(best_param_score)
    print('best_param_score_{}_{}.pt'.format(config.name ,fold+1))
    torch.save(best_param_score, '/content/drive/My Drive/Colab Notebooks/GoogleQuest/best_param_score_{}_{}.pt'.format(config.name ,fold+1))   
    cv_list.append(best_score)

    torch.cuda.empty_cache()
    gc.collect()
    
print('CV_score: ', np.mean(cv_list))

---1-Fold---
train (4863, 41)
val (1216, 41)


100%|██████████| 1215/1215 [11:12<00:00,  1.79it/s]
100%|██████████| 304/304 [00:57<00:00,  5.26it/s]


 val_spearman-rho: 0.36865                                                                                                    
0 SpearmanrResult(correlation=0.3692463087591003, pvalue=1.410616119269644e-40)
1 SpearmanrResult(correlation=0.5387376288678389, pvalue=1.789962880690129e-92)
2 SpearmanrResult(correlation=0.39797552337143743, pvalue=2.0046253842034616e-47)
3 SpearmanrResult(correlation=0.2778968455682508, pvalue=5.301264377562336e-23)
4 SpearmanrResult(correlation=0.3431034710008025, pvalue=6.379513974104636e-35)
5 SpearmanrResult(correlation=0.40435235618641513, pvalue=4.872375286378862e-49)
6 SpearmanrResult(correlation=0.33668101954180263, pvalue=1.3006110730311633e-33)
7 SpearmanrResult(correlation=0.4396561593642691, pvalue=1.230759249379538e-58)
8 SpearmanrResult(correlation=0.5603562964092339, pvalue=1.81783297942471e-101)
9 SpearmanrResult(correlation=-0.021019317771957195, pvalue=0.46398808074045517)
10 SpearmanrResult(correlation=0.4307425742873143, pvalue=4.189710

  0%|          | 0/1215 [00:00<?, ?it/s]

train (4863, 41)
val (1216, 41)


100%|██████████| 1215/1215 [11:14<00:00,  1.76it/s]
100%|██████████| 304/304 [00:57<00:00,  5.16it/s]


 val_spearman-rho: 0.39218                                                                                                    
0 SpearmanrResult(correlation=0.3813879486308659, pvalue=2.18203957545614e-43)
1 SpearmanrResult(correlation=0.6193100081394164, pvalue=1.0900515365755497e-129)
2 SpearmanrResult(correlation=0.4136607846340386, pvalue=1.8506920419899157e-51)
3 SpearmanrResult(correlation=0.3026342349010222, pvalue=3.574362712393396e-27)
4 SpearmanrResult(correlation=0.3622390195344285, pvalue=5.2185261073794286e-39)
5 SpearmanrResult(correlation=0.39791891391100764, pvalue=2.0711258712183725e-47)
6 SpearmanrResult(correlation=0.34571666749749647, pvalue=1.8329813586328094e-35)
7 SpearmanrResult(correlation=0.4667893889267525, pvalue=7.982505518121545e-67)
8 SpearmanrResult(correlation=0.5661896344356845, pvalue=5.219125378367117e-104)
9 SpearmanrResult(correlation=0.06213422091165362, pvalue=0.030268452222064764)
10 SpearmanrResult(correlation=0.45451479513820514, pvalue=4.988

100%|██████████| 1215/1215 [11:13<00:00,  1.81it/s]
100%|██████████| 304/304 [00:57<00:00,  5.31it/s]


 val_spearman-rho: 0.39649                                                                                                    
0 SpearmanrResult(correlation=0.3830998693972851, pvalue=8.568355528830184e-44)
1 SpearmanrResult(correlation=0.6356253484685496, pvalue=1.3238302775351785e-138)
2 SpearmanrResult(correlation=0.4111223012828798, pvalue=8.610291933596229e-51)
3 SpearmanrResult(correlation=0.2780529219288825, pvalue=5.004386938281453e-23)
4 SpearmanrResult(correlation=0.32771644616061213, pvalue=7.773590575055924e-32)
5 SpearmanrResult(correlation=0.4052808587318423, pvalue=2.816698386802108e-49)
6 SpearmanrResult(correlation=0.3695271343605882, pvalue=1.2182861375483635e-40)
7 SpearmanrResult(correlation=0.47232717738230495, pvalue=1.3749282222643427e-68)
8 SpearmanrResult(correlation=0.5947579965051123, pvalue=3.18537818649622e-117)
9 SpearmanrResult(correlation=0.04732483223846312, pvalue=0.0990451796390843)
10 SpearmanrResult(correlation=0.44714507440322504, pvalue=8.0110678

100%|██████████| 1215/1215 [11:14<00:00,  1.75it/s]
100%|██████████| 304/304 [00:57<00:00,  5.18it/s]


 val_spearman-rho: 0.3924                                                                                                    
0 SpearmanrResult(correlation=0.35879132626099347, pvalue=2.9845065041473096e-38)
1 SpearmanrResult(correlation=0.6303499283583996, pvalue=1.1563650924481958e-135)
2 SpearmanrResult(correlation=0.4126737305742508, pvalue=3.3700608919796445e-51)
3 SpearmanrResult(correlation=0.2757063995953725, pvalue=1.1856141281292375e-22)
4 SpearmanrResult(correlation=0.3291767112275042, pvalue=4.029676552650607e-32)
5 SpearmanrResult(correlation=0.38245373679523614, pvalue=1.220129355264726e-43)
6 SpearmanrResult(correlation=0.34340140108078776, pvalue=5.537285803722892e-35)
7 SpearmanrResult(correlation=0.4609445949448013, pvalue=5.357218485718915e-65)
8 SpearmanrResult(correlation=0.5969663137429541, pvalue=2.66605873255049e-118)
9 SpearmanrResult(correlation=0.07563497369391628, pvalue=0.008326159939836899)
10 SpearmanrResult(correlation=0.4327178456073382, pvalue=1.16840

100%|██████████| 1215/1215 [11:14<00:00,  1.81it/s]
100%|██████████| 304/304 [00:57<00:00,  5.25it/s]


 val_spearman-rho: 0.39011                                                                                                    
0 SpearmanrResult(correlation=0.36280505273572705, pvalue=3.911292615499151e-39)
1 SpearmanrResult(correlation=0.6576221303022944, pvalue=1.6366994487377852e-151)
2 SpearmanrResult(correlation=0.40152764411570235, pvalue=2.553673115182069e-48)
3 SpearmanrResult(correlation=0.28181808039377026, pvalue=1.2321659367155737e-23)
4 SpearmanrResult(correlation=0.32167744793327113, pvalue=1.1330605572741709e-30)
5 SpearmanrResult(correlation=0.38105681930601193, pvalue=2.612808589163336e-43)
6 SpearmanrResult(correlation=0.3179580750113738, pvalue=5.725857614335922e-30)
7 SpearmanrResult(correlation=0.465441383408704, pvalue=2.1213502479352892e-66)
8 SpearmanrResult(correlation=0.5977552416853239, pvalue=1.0938783859165812e-118)
9 SpearmanrResult(correlation=0.0707250491535093, pvalue=0.013631893464305626)
10 SpearmanrResult(correlation=0.4110026739564892, pvalue=9.25

100%|██████████| 1215/1215 [11:13<00:00,  1.81it/s]
100%|██████████| 304/304 [00:57<00:00,  5.19it/s]


 val_spearman-rho: 0.38307                                                                                                    
0 SpearmanrResult(correlation=0.35050511912238047, pvalue=1.808078788027495e-36)
1 SpearmanrResult(correlation=0.6486869855475588, pvalue=3.8593307803239084e-146)
2 SpearmanrResult(correlation=0.40683218498670126, pvalue=1.123070273309361e-49)
3 SpearmanrResult(correlation=0.2619630535593254, pvalue=1.5685612175054425e-20)
4 SpearmanrResult(correlation=0.2912682238607908, pvalue=3.3203745922908264e-25)
5 SpearmanrResult(correlation=0.35620249257378445, pvalue=1.0900205403861595e-37)
6 SpearmanrResult(correlation=0.3230554937766384, pvalue=6.180823885883262e-31)
7 SpearmanrResult(correlation=0.4623983540674623, pvalue=1.896151578499619e-65)
8 SpearmanrResult(correlation=0.5815806196626288, pvalue=5.770762358378747e-111)
9 SpearmanrResult(correlation=0.08282667191745018, pvalue=0.0038492136898617964)
10 SpearmanrResult(correlation=0.39958877993383407, pvalue=7.8

100%|██████████| 1215/1215 [11:12<00:00,  1.82it/s]
100%|██████████| 304/304 [00:57<00:00,  4.83it/s]


 val_spearman-rho: 0.37808                                                                                                    
0 SpearmanrResult(correlation=0.3962748843664207, pvalue=5.328698088937486e-47)
1 SpearmanrResult(correlation=0.5141396472208689, pvalue=5.258346134681016e-83)
2 SpearmanrResult(correlation=0.4003869338425634, pvalue=4.96280366351322e-48)
3 SpearmanrResult(correlation=0.3091451714208364, pvalue=2.43023225838449e-28)
4 SpearmanrResult(correlation=0.3556936740252103, pvalue=1.4040742574781576e-37)
5 SpearmanrResult(correlation=0.4036466555901042, pvalue=7.38119717304058e-49)
6 SpearmanrResult(correlation=0.31471172611456744, pvalue=2.311357559510403e-29)
7 SpearmanrResult(correlation=0.4244515392793283, pvalue=2.3128912202370316e-54)
8 SpearmanrResult(correlation=0.5269267147100816, pvalue=7.875235281950374e-88)
9 SpearmanrResult(correlation=0.03350883912631734, pvalue=0.24296035862993678)
10 SpearmanrResult(correlation=0.4538291642978498, pvalue=8.0446199889095

100%|██████████| 1215/1215 [11:13<00:00,  1.83it/s]
100%|██████████| 304/304 [00:58<00:00,  4.72it/s]


 val_spearman-rho: 0.39903                                                                                                    
0 SpearmanrResult(correlation=0.41254038922464203, pvalue=3.653721145805476e-51)
1 SpearmanrResult(correlation=0.5684743006686472, pvalue=5.10643108384518e-105)
2 SpearmanrResult(correlation=0.4124299089440473, pvalue=3.9066418923696676e-51)
3 SpearmanrResult(correlation=0.2505233271203181, pvalue=7.389482514435299e-19)
4 SpearmanrResult(correlation=0.3792405749624574, pvalue=6.993223912995296e-43)
5 SpearmanrResult(correlation=0.37127917213588246, pvalue=4.865835395177921e-41)
6 SpearmanrResult(correlation=0.3028062036208745, pvalue=3.332270563331317e-27)
7 SpearmanrResult(correlation=0.45613662107322567, pvalue=1.6043734741270373e-63)
8 SpearmanrResult(correlation=0.5610201356766567, pvalue=9.393413558610894e-102)
9 SpearmanrResult(correlation=0.08055428959524354, pvalue=0.00494332766420915)
10 SpearmanrResult(correlation=0.4723579644920061, pvalue=1.3439497

100%|██████████| 1215/1215 [11:15<00:00,  1.81it/s]
100%|██████████| 304/304 [00:58<00:00,  4.75it/s]


 val_spearman-rho: 0.40649                                                                                                    
0 SpearmanrResult(correlation=0.4404672325091172, pvalue=7.178005049517197e-59)
1 SpearmanrResult(correlation=0.5910238997087597, pvalue=2.0224401769683061e-115)
2 SpearmanrResult(correlation=0.4119927270914565, pvalue=5.0900562306966514e-51)
3 SpearmanrResult(correlation=0.3204584091926621, pvalue=1.931745537023183e-30)
4 SpearmanrResult(correlation=0.3846869547234987, pvalue=3.5840112950069405e-44)
5 SpearmanrResult(correlation=0.4157657574625546, pvalue=5.120307550053727e-52)
6 SpearmanrResult(correlation=0.3503919805315227, pvalue=1.910675620710835e-36)
7 SpearmanrResult(correlation=0.4573469434687205, pvalue=6.852936835259339e-64)
8 SpearmanrResult(correlation=0.5567377411271068, pvalue=6.473787453895931e-100)
9 SpearmanrResult(correlation=0.10273431032710734, pvalue=0.0003328649371306859)
10 SpearmanrResult(correlation=0.487063100299687, pvalue=1.9220842

100%|██████████| 1215/1215 [11:15<00:00,  1.78it/s]
100%|██████████| 304/304 [00:57<00:00,  4.79it/s]


 val_spearman-rho: 0.40192                                                                                                    
0 SpearmanrResult(correlation=0.4118636149042219, pvalue=5.503370921203933e-51)
1 SpearmanrResult(correlation=0.5695835179355982, pvalue=1.6411581391258787e-105)
2 SpearmanrResult(correlation=0.41412689740913855, pvalue=1.393505226919125e-51)
3 SpearmanrResult(correlation=0.28943314499642997, pvalue=6.771148576974678e-25)
4 SpearmanrResult(correlation=0.35576758074631865, pvalue=1.3534161062853458e-37)
5 SpearmanrResult(correlation=0.3995217691181528, pvalue=8.200385364186607e-48)
6 SpearmanrResult(correlation=0.3360458970404227, pvalue=1.745657647728648e-33)
7 SpearmanrResult(correlation=0.44157317553055314, pvalue=3.433014785413141e-59)
8 SpearmanrResult(correlation=0.5580197124728417, pvalue=1.8350393421330496e-100)
9 SpearmanrResult(correlation=0.0812928091424155, pvalue=0.004560277313099552)
10 SpearmanrResult(correlation=0.4812672580784298, pvalue=1.6652

100%|██████████| 1215/1215 [11:16<00:00,  1.75it/s]
100%|██████████| 304/304 [01:00<00:00,  4.63it/s]


 val_spearman-rho: 0.39081                                                                                                    
0 SpearmanrResult(correlation=0.4277431981656509, pvalue=2.8661772277089165e-55)
1 SpearmanrResult(correlation=0.5968025815802519, pvalue=3.2065015775933682e-118)
2 SpearmanrResult(correlation=0.40512264384755414, pvalue=3.092775916182001e-49)
3 SpearmanrResult(correlation=0.27836910429826245, pvalue=4.452440240296844e-23)
4 SpearmanrResult(correlation=0.3131628009710115, pvalue=4.4707565277809753e-29)
5 SpearmanrResult(correlation=0.3904184090977443, pvalue=1.4787678812172487e-45)
6 SpearmanrResult(correlation=0.3227211667061448, pvalue=7.161906362848981e-31)
7 SpearmanrResult(correlation=0.4431060419323896, pvalue=1.2295034090718606e-59)
8 SpearmanrResult(correlation=0.5670833808836939, pvalue=2.106864038530494e-104)
9 SpearmanrResult(correlation=0.08152355360992214, pvalue=0.004446224154942064)
10 SpearmanrResult(correlation=0.4625860184510047, pvalue=1.657

100%|██████████| 1215/1215 [11:31<00:00,  1.73it/s]
100%|██████████| 304/304 [01:00<00:00,  4.53it/s]


 val_spearman-rho: 0.39185                                                                                                    
0 SpearmanrResult(correlation=0.4098270365144589, pvalue=1.8771199148014358e-50)
1 SpearmanrResult(correlation=0.5999204838950601, pvalue=9.366055016773017e-120)
2 SpearmanrResult(correlation=0.4047226144847912, pvalue=3.9167326327065616e-49)
3 SpearmanrResult(correlation=0.27042643100265973, pvalue=8.008948250697301e-22)
4 SpearmanrResult(correlation=0.33901509884505454, pvalue=4.383904300930744e-34)
5 SpearmanrResult(correlation=0.3817792904409442, pvalue=1.7630843580227745e-43)
6 SpearmanrResult(correlation=0.32917264000408664, pvalue=4.037085193266523e-32)
7 SpearmanrResult(correlation=0.430457911242417, pvalue=5.03274711178603e-56)
8 SpearmanrResult(correlation=0.5760163866334811, pvalue=2.0837185004851342e-108)
9 SpearmanrResult(correlation=0.0989238391358655, pvalue=0.0005512719512993237)
10 SpearmanrResult(correlation=0.44548771398449116, pvalue=2.4677

100%|██████████| 1215/1215 [11:31<00:00,  1.77it/s]
100%|██████████| 304/304 [00:59<00:00,  5.09it/s]


 val_spearman-rho: 0.37089                                                                                                    
0 SpearmanrResult(correlation=0.29346524512573846, pvalue=1.40497389611726e-25)
1 SpearmanrResult(correlation=0.5365383511112433, pvalue=1.353586558963224e-91)
2 SpearmanrResult(correlation=0.40421229377002577, pvalue=5.2914789309816645e-49)
3 SpearmanrResult(correlation=0.22467154014335391, pvalue=2.225689974139718e-15)
4 SpearmanrResult(correlation=0.27316266455749916, pvalue=2.991708236753221e-22)
5 SpearmanrResult(correlation=0.40676336760785664, pvalue=1.1699473606678032e-49)
6 SpearmanrResult(correlation=0.35480274317298155, pvalue=2.1849249551981977e-37)
7 SpearmanrResult(correlation=0.46637869823058503, pvalue=1.075629796222617e-66)
8 SpearmanrResult(correlation=0.5585054758180925, pvalue=1.1364687566248014e-100)
9 SpearmanrResult(correlation=0.034150402153451415, pvalue=0.2340512772565566)
10 SpearmanrResult(correlation=0.4214800553938996, pvalue=1.49

100%|██████████| 1215/1215 [11:28<00:00,  1.79it/s]
100%|██████████| 304/304 [01:00<00:00,  5.06it/s]


 val_spearman-rho: 0.38784                                                                                                    
0 SpearmanrResult(correlation=0.3125840818236064, pvalue=5.714711350284387e-29)
1 SpearmanrResult(correlation=0.595037909349627, pvalue=2.328481362592947e-117)
2 SpearmanrResult(correlation=0.4109270064556512, pvalue=9.686059326555526e-51)
3 SpearmanrResult(correlation=0.26082009193609046, pvalue=2.3250132746580636e-20)
4 SpearmanrResult(correlation=0.31848755516215727, pvalue=4.552864367476136e-30)
5 SpearmanrResult(correlation=0.4269107649347721, pvalue=4.870669947999743e-55)
6 SpearmanrResult(correlation=0.3573886961419913, pvalue=6.029984348389463e-38)
7 SpearmanrResult(correlation=0.47812904649667887, pvalue=1.7997559937260272e-70)
8 SpearmanrResult(correlation=0.5760006949401933, pvalue=2.1182753161749122e-108)
9 SpearmanrResult(correlation=0.06143769356196658, pvalue=0.032175291794296075)
10 SpearmanrResult(correlation=0.453305791411036, pvalue=1.157648

100%|██████████| 1215/1215 [11:27<00:00,  1.77it/s]
100%|██████████| 304/304 [00:59<00:00,  5.03it/s]


 val_spearman-rho: 0.38633                                                                                                    
0 SpearmanrResult(correlation=0.31835404636149334, pvalue=4.823993140894155e-30)
1 SpearmanrResult(correlation=0.5744380392674363, pvalue=1.0849747688179307e-107)
2 SpearmanrResult(correlation=0.4085283309842343, pvalue=4.086685948404952e-50)
3 SpearmanrResult(correlation=0.27906348664376807, pvalue=3.442747840277689e-23)
4 SpearmanrResult(correlation=0.30864536749419336, pvalue=2.9944895186094612e-28)
5 SpearmanrResult(correlation=0.4007561195701897, pvalue=4.003674203374601e-48)
6 SpearmanrResult(correlation=0.34836508948989575, pvalue=5.116190742047547e-36)
7 SpearmanrResult(correlation=0.47256848533721296, pvalue=1.1499548937031691e-68)
8 SpearmanrResult(correlation=0.5594097948007295, pvalue=4.6478676821382716e-101)
9 SpearmanrResult(correlation=0.08420520431929296, pvalue=0.003297663049723595)
10 SpearmanrResult(correlation=0.4192810113286603, pvalue=5.8

100%|██████████| 1215/1215 [11:26<00:00,  1.81it/s]
100%|██████████| 304/304 [01:01<00:00,  4.98it/s]


 val_spearman-rho: 0.38272                                                                                                    
0 SpearmanrResult(correlation=0.33143152818626154, pvalue=1.450743053669281e-32)
1 SpearmanrResult(correlation=0.5969839593212193, pvalue=2.613532117666912e-118)
2 SpearmanrResult(correlation=0.4158604577120542, pvalue=4.831665462169108e-52)
3 SpearmanrResult(correlation=0.28072296107626094, pvalue=1.8564334994734837e-23)
4 SpearmanrResult(correlation=0.28196435651405555, pvalue=1.1663576431519534e-23)
5 SpearmanrResult(correlation=0.41072747957297356, pvalue=1.0923185463285597e-50)
6 SpearmanrResult(correlation=0.35674551862207465, pvalue=8.315060754381449e-38)
7 SpearmanrResult(correlation=0.46887344140834475, pvalue=1.7463942765051016e-67)
8 SpearmanrResult(correlation=0.5690020353056789, pvalue=2.9772238842680114e-105)
9 SpearmanrResult(correlation=0.07019465403050533, pvalue=0.014354236399801622)
10 SpearmanrResult(correlation=0.41720223404521983, pvalue=

100%|██████████| 1215/1215 [11:28<00:00,  1.75it/s]
100%|██████████| 304/304 [00:59<00:00,  5.04it/s]


 val_spearman-rho: 0.38138                                                                                                    
0 SpearmanrResult(correlation=0.3251304057664884, pvalue=2.4668965050351173e-31)
1 SpearmanrResult(correlation=0.6038681142674485, pvalue=1.0099763652038347e-121)
2 SpearmanrResult(correlation=0.4087489881761443, pvalue=3.5815333489646934e-50)
3 SpearmanrResult(correlation=0.276874162911129, pvalue=7.7264389396657e-23)
4 SpearmanrResult(correlation=0.2670507282750352, pvalue=2.6573124011337606e-21)
5 SpearmanrResult(correlation=0.38618378735560344, pvalue=1.5683728656350377e-44)
6 SpearmanrResult(correlation=0.3532295297856277, pvalue=4.753956001188708e-37)
7 SpearmanrResult(correlation=0.4764249273919428, pvalue=6.487383589347695e-70)
8 SpearmanrResult(correlation=0.5531998937789173, pvalue=2.0408415874766908e-98)
9 SpearmanrResult(correlation=0.06497640895045308, pvalue=0.02345939009206961)
10 SpearmanrResult(correlation=0.39872575233728896, pvalue=1.2999859

100%|██████████| 1215/1215 [11:26<00:00,  1.79it/s]
100%|██████████| 304/304 [00:59<00:00,  4.94it/s]


 val_spearman-rho: 0.36519                                                                                                    
0 SpearmanrResult(correlation=0.26260332939488945, pvalue=1.2571417369855814e-20)
1 SpearmanrResult(correlation=0.4420184545553175, pvalue=2.54899598912442e-59)
2 SpearmanrResult(correlation=0.3944780877174857, pvalue=1.4877409218213058e-46)
3 SpearmanrResult(correlation=0.25925245291457055, pvalue=3.976429479338254e-20)
4 SpearmanrResult(correlation=0.32067900973061436, pvalue=1.7542867918498227e-30)
5 SpearmanrResult(correlation=0.406417435241799, pvalue=1.436734204433298e-49)
6 SpearmanrResult(correlation=0.2904607682299488, pvalue=4.546116767295751e-25)
7 SpearmanrResult(correlation=0.4493877444806242, pvalue=1.730645964939662e-61)
8 SpearmanrResult(correlation=0.5265039536961591, pvalue=1.1457441518650474e-87)
9 SpearmanrResult(correlation=0.05191123293928087, pvalue=0.07036327444900367)
10 SpearmanrResult(correlation=0.4527113005371539, pvalue=1.74901411

100%|██████████| 1215/1215 [11:22<00:00,  1.75it/s]
100%|██████████| 304/304 [00:58<00:00,  5.05it/s]


 val_spearman-rho: 0.39102                                                                                                    
0 SpearmanrResult(correlation=0.30793043186606933, pvalue=4.033878860205028e-28)
1 SpearmanrResult(correlation=0.4929653527180386, pvalue=1.8700391277812638e-75)
2 SpearmanrResult(correlation=0.4064611926170213, pvalue=1.3999030394358378e-49)
3 SpearmanrResult(correlation=0.3132315940460912, pvalue=4.3420219323054137e-29)
4 SpearmanrResult(correlation=0.33103456360354255, pvalue=1.73768686043776e-32)
5 SpearmanrResult(correlation=0.4379529006761124, pvalue=3.8007260124388796e-58)
6 SpearmanrResult(correlation=0.3122373506101842, pvalue=6.618456396540368e-29)
7 SpearmanrResult(correlation=0.46747490412099335, pvalue=4.8478335799194576e-67)
8 SpearmanrResult(correlation=0.5397297326537537, pvalue=7.150768136041292e-93)
9 SpearmanrResult(correlation=0.055395806285857316, pvalue=0.05345657983883761)
10 SpearmanrResult(correlation=0.4778026335730827, pvalue=2.30207

100%|██████████| 1215/1215 [11:19<00:00,  1.77it/s]
100%|██████████| 304/304 [00:58<00:00,  5.07it/s]


 val_spearman-rho: 0.39217                                                                                                    
0 SpearmanrResult(correlation=0.30098106646133593, pvalue=6.997717200388557e-27)
1 SpearmanrResult(correlation=0.5472019004847307, pvalue=6.454676484793301e-96)
2 SpearmanrResult(correlation=0.40304975358284156, pvalue=1.0480050125266503e-48)
3 SpearmanrResult(correlation=0.3394892485459431, pvalue=3.510939314652928e-34)
4 SpearmanrResult(correlation=0.32458672907674524, pvalue=3.1403086339626224e-31)
5 SpearmanrResult(correlation=0.45803400428286367, pvalue=4.221835065915096e-64)
6 SpearmanrResult(correlation=0.3030224145463983, pvalue=3.050821608158272e-27)
7 SpearmanrResult(correlation=0.46341258351401154, pvalue=9.159872138051926e-66)
8 SpearmanrResult(correlation=0.554263671465097, pvalue=7.262377997366604e-99)
9 SpearmanrResult(correlation=0.05247309443813363, pvalue=0.06737299183132929)
10 SpearmanrResult(correlation=0.488617663905161, pvalue=5.72370710

100%|██████████| 1215/1215 [11:20<00:00,  1.79it/s]
100%|██████████| 304/304 [00:58<00:00,  5.01it/s]


 val_spearman-rho: 0.39684                                                                                                    
0 SpearmanrResult(correlation=0.305223379550666, pvalue=1.237223672527795e-27)
1 SpearmanrResult(correlation=0.5550455459094797, pvalue=3.3902223529412042e-99)
2 SpearmanrResult(correlation=0.4039878494885816, pvalue=6.039035699128722e-49)
3 SpearmanrResult(correlation=0.35158073519787303, pvalue=1.0686772433782002e-36)
4 SpearmanrResult(correlation=0.30805171032196943, pvalue=3.835281347250556e-28)
5 SpearmanrResult(correlation=0.433905387477502, pvalue=5.400045056768772e-57)
6 SpearmanrResult(correlation=0.29623152562441707, pvalue=4.706444367722954e-26)
7 SpearmanrResult(correlation=0.4654553907405047, pvalue=2.0999617063136367e-66)
8 SpearmanrResult(correlation=0.5398139382154118, pvalue=6.614075574154241e-93)
9 SpearmanrResult(correlation=0.039949075989586705, pvalue=0.16386337514784544)
10 SpearmanrResult(correlation=0.4823247820948162, pvalue=7.42477550

100%|██████████| 1215/1215 [11:19<00:00,  1.80it/s]
100%|██████████| 304/304 [00:58<00:00,  5.02it/s]


 val_spearman-rho: 0.39056                                                                                                    
0 SpearmanrResult(correlation=0.30239059524568307, pvalue=3.947443788642707e-27)
1 SpearmanrResult(correlation=0.5657550576250103, pvalue=8.104437706021316e-104)
2 SpearmanrResult(correlation=0.40444698416078567, pvalue=4.608074936527294e-49)
3 SpearmanrResult(correlation=0.34524891818385106, pvalue=2.2934352523849004e-35)
4 SpearmanrResult(correlation=0.3313934644205363, pvalue=1.476084574307731e-32)
5 SpearmanrResult(correlation=0.44484959909682326, pvalue=3.799313458866269e-60)
6 SpearmanrResult(correlation=0.27867415537183904, pvalue=3.9771258360413347e-23)
7 SpearmanrResult(correlation=0.46437059850326595, pvalue=4.5966357046041756e-66)
8 SpearmanrResult(correlation=0.5491508081784082, pvalue=1.0071575150379845e-96)
9 SpearmanrResult(correlation=0.04364168594840204, pvalue=0.12826055631330371)
10 SpearmanrResult(correlation=0.4868327955302531, pvalue=2.29

100%|██████████| 1215/1215 [11:19<00:00,  1.81it/s]
100%|██████████| 304/304 [00:58<00:00,  5.05it/s]


 val_spearman-rho: 0.38304                                                                                                    
0 SpearmanrResult(correlation=0.2873217879337308, pvalue=1.527268608431983e-24)
1 SpearmanrResult(correlation=0.5702131159124663, pvalue=8.60009741435816e-106)
2 SpearmanrResult(correlation=0.3942387353148122, pvalue=1.704978470284195e-46)
3 SpearmanrResult(correlation=0.3090067446616176, pvalue=2.5750097237341387e-28)
4 SpearmanrResult(correlation=0.30676665816893933, pvalue=6.54026532220492e-28)
5 SpearmanrResult(correlation=0.4124651062276673, pvalue=3.8242309569249433e-51)
6 SpearmanrResult(correlation=0.2763894657829891, pvalue=9.231572030632858e-23)
7 SpearmanrResult(correlation=0.44753588200008254, pvalue=6.138769649272093e-61)
8 SpearmanrResult(correlation=0.5616953915389978, pvalue=4.791758986189841e-102)
9 SpearmanrResult(correlation=0.04811763377564192, pvalue=0.09351020464141264)
10 SpearmanrResult(correlation=0.4731584647807896, pvalue=7.424972325

100%|██████████| 1215/1215 [11:20<00:00,  1.76it/s]
100%|██████████| 304/304 [00:58<00:00,  5.04it/s]


 val_spearman-rho: 0.37337                                                                                                    
0 SpearmanrResult(correlation=0.3006051675619807, pvalue=8.147660043311056e-27)
1 SpearmanrResult(correlation=0.5740135606228466, pvalue=1.6883845319077106e-107)
2 SpearmanrResult(correlation=0.3932194507477222, pvalue=3.042574066122163e-46)
3 SpearmanrResult(correlation=0.28792933563302864, pvalue=1.2094114509099176e-24)
4 SpearmanrResult(correlation=0.29609699216820473, pvalue=4.964934362591776e-26)
5 SpearmanrResult(correlation=0.40870442698404924, pvalue=3.678277782551036e-50)
6 SpearmanrResult(correlation=0.27976933280363436, pvalue=2.648748173669691e-23)
7 SpearmanrResult(correlation=0.45920411720002074, pvalue=1.84547760409961e-64)
8 SpearmanrResult(correlation=0.5471179989747419, pvalue=6.990165352798825e-96)
9 SpearmanrResult(correlation=0.03463014398417962, pvalue=0.2275421607542069)
10 SpearmanrResult(correlation=0.44598902928093925, pvalue=1.757080

100%|██████████| 1216/1216 [11:19<00:00,  1.79it/s]
100%|██████████| 304/304 [00:58<00:00,  5.66it/s]


 val_spearman-rho: 0.36101                                                                                                    
0 SpearmanrResult(correlation=0.3040578198426015, pvalue=2.097405877441598e-27)
1 SpearmanrResult(correlation=0.5247222163173652, pvalue=6.499785592266028e-87)
2 SpearmanrResult(correlation=0.3659437252833779, pvalue=8.408825535148676e-40)
3 SpearmanrResult(correlation=0.1972987457558199, pvalue=3.964400760019966e-12)
4 SpearmanrResult(correlation=0.3307898023894031, pvalue=2.058651014207698e-32)
5 SpearmanrResult(correlation=0.42782333964111363, pvalue=3.0142119776900086e-55)
6 SpearmanrResult(correlation=0.3161883035582211, pvalue=1.2947912282601468e-29)
7 SpearmanrResult(correlation=0.4693921094953091, pvalue=1.3533094335395747e-67)
8 SpearmanrResult(correlation=0.5645765278326301, pvalue=3.229469815480534e-103)
9 SpearmanrResult(correlation=0.015139143875536614, pvalue=0.5980617086578353)
10 SpearmanrResult(correlation=0.4560388881795564, pvalue=1.93143075

100%|██████████| 1216/1216 [11:19<00:00,  1.82it/s]
100%|██████████| 304/304 [00:58<00:00,  5.77it/s]


 val_spearman-rho: 0.38984                                                                                                    
0 SpearmanrResult(correlation=0.3666609413888331, pvalue=5.808280895711346e-40)
1 SpearmanrResult(correlation=0.5854059064227551, pvalue=1.1625000140022065e-112)
2 SpearmanrResult(correlation=0.3796033423208047, pvalue=6.21525068106205e-43)
3 SpearmanrResult(correlation=0.23495369643259817, pvalue=1.0611055687399377e-16)
4 SpearmanrResult(correlation=0.3493830758458091, pvalue=3.3338605989489686e-36)
5 SpearmanrResult(correlation=0.4309263153636536, pvalue=4.126058849932353e-56)
6 SpearmanrResult(correlation=0.35413550955638573, pvalue=3.251893892142614e-37)
7 SpearmanrResult(correlation=0.4995937287889762, pvalue=1.065312408912321e-77)
8 SpearmanrResult(correlation=0.5654056536771266, pvalue=1.39960692218659e-103)
9 SpearmanrResult(correlation=0.0958401385665961, pvalue=0.0008229766494457857)
10 SpearmanrResult(correlation=0.4659684250124291, pvalue=1.6376103

  0%|          | 0/1216 [00:00<?, ?it/s]

train (4864, 41)
val (1215, 41)


100%|██████████| 1216/1216 [11:19<00:00,  1.80it/s]
100%|██████████| 304/304 [00:58<00:00,  5.69it/s]


 val_spearman-rho: 0.38794                                                                                                    
0 SpearmanrResult(correlation=0.3467699179785238, pvalue=1.178661910727419e-35)
1 SpearmanrResult(correlation=0.5925251836984678, pvalue=4.764530266222005e-116)
2 SpearmanrResult(correlation=0.3881874417409141, pvalue=5.59394464621818e-45)
3 SpearmanrResult(correlation=0.23485857259240356, pvalue=1.0924170284902985e-16)
4 SpearmanrResult(correlation=0.3285215199576691, pvalue=5.734166851655963e-32)
5 SpearmanrResult(correlation=0.4383706350195921, pvalue=3.210272596818777e-58)
6 SpearmanrResult(correlation=0.3653755371944369, pvalue=1.1265347396954706e-39)
7 SpearmanrResult(correlation=0.47005140765805653, pvalue=8.345710281122805e-68)
8 SpearmanrResult(correlation=0.561435422891187, pvalue=7.507900437460167e-102)
9 SpearmanrResult(correlation=0.11010048968963478, pvalue=0.00012030574942646971)
10 SpearmanrResult(correlation=0.4523188157178894, pvalue=2.575142

100%|██████████| 1216/1216 [11:19<00:00,  1.79it/s]
100%|██████████| 304/304 [00:58<00:00,  5.75it/s]


 val_spearman-rho: 0.38784                                                                                                    
0 SpearmanrResult(correlation=0.3546653070900817, pvalue=2.5025082543047155e-37)
1 SpearmanrResult(correlation=0.5665034096172941, pvalue=4.6093903941490445e-104)
2 SpearmanrResult(correlation=0.3879172699286949, pvalue=6.501874139301247e-45)
3 SpearmanrResult(correlation=0.20267538872636412, pvalue=9.92866805044869e-13)
4 SpearmanrResult(correlation=0.3322412288400212, pvalue=1.0639732133780912e-32)
5 SpearmanrResult(correlation=0.4233098496481444, pvalue=5.240898115406623e-54)
6 SpearmanrResult(correlation=0.3431391189302758, pvalue=6.68056463350576e-35)
7 SpearmanrResult(correlation=0.4801965064928589, pvalue=4.290308691045289e-71)
8 SpearmanrResult(correlation=0.5830489784568031, pvalue=1.4737585083965735e-111)
9 SpearmanrResult(correlation=0.13691981391839594, pvalue=1.6656752699242624e-06)
10 SpearmanrResult(correlation=0.4533626011356611, pvalue=1.24903

100%|██████████| 1216/1216 [11:20<00:00,  1.79it/s]
100%|██████████| 304/304 [00:58<00:00,  5.64it/s]


 val_spearman-rho: 0.37608                                                                                                    
0 SpearmanrResult(correlation=0.34896291075922486, pvalue=4.0877223131478e-36)
1 SpearmanrResult(correlation=0.5882364076619753, pvalue=5.354393004497862e-114)
2 SpearmanrResult(correlation=0.3872136751450999, pvalue=9.61307936652715e-45)
3 SpearmanrResult(correlation=0.20341967921516627, pvalue=8.171852965243166e-13)
4 SpearmanrResult(correlation=0.3104442311893123, pvalue=1.48365396560385e-28)
5 SpearmanrResult(correlation=0.3728307523658077, pvalue=2.316252167759451e-41)
6 SpearmanrResult(correlation=0.33481947042443266, pvalue=3.2652263243758475e-33)
7 SpearmanrResult(correlation=0.45831274302472114, pvalue=3.902954299929125e-64)
8 SpearmanrResult(correlation=0.5767619994206571, pvalue=1.1667259006554111e-108)
9 SpearmanrResult(correlation=0.12512386156629562, pvalue=1.2191499441334126e-05)
10 SpearmanrResult(correlation=0.4254914016644989, pvalue=1.325273