In [1]:
import numpy as np
import polars as pl
import pandas as pd
from sklearn.base import clone
from copy import deepcopy
import optuna
from scipy.optimize import minimize
import os
import matplotlib.pyplot as plt
import seaborn as sns

import re
from colorama import Fore, Style

from tqdm import tqdm
from IPython.display import clear_output
from concurrent.futures import ThreadPoolExecutor

import warnings
warnings.filterwarnings('ignore')
pd.options.display.max_columns = None

import lightgbm as lgb
from catboost import CatBoostRegressor, CatBoostClassifier
from xgboost import XGBRegressor
from sklearn.ensemble import VotingRegressor
from sklearn.model_selection import *
from sklearn.metrics import *

n_splits = 5
SEED = 42

# Data Loading

## Load tabular dataset

In [2]:
train = pd.read_csv('/kaggle/input/child-mind-institute-problematic-internet-use/train.csv')
test = pd.read_csv('/kaggle/input/child-mind-institute-problematic-internet-use/test.csv')
sample = pd.read_csv('/kaggle/input/child-mind-institute-problematic-internet-use/sample_submission.csv')

In [3]:
train.describe()

Unnamed: 0,Basic_Demos-Age,Basic_Demos-Sex,CGAS-CGAS_Score,Physical-BMI,Physical-Height,Physical-Weight,Physical-Waist_Circumference,Physical-Diastolic_BP,Physical-HeartRate,Physical-Systolic_BP,Fitness_Endurance-Max_Stage,Fitness_Endurance-Time_Mins,Fitness_Endurance-Time_Sec,FGC-FGC_CU,FGC-FGC_CU_Zone,FGC-FGC_GSND,FGC-FGC_GSND_Zone,FGC-FGC_GSD,FGC-FGC_GSD_Zone,FGC-FGC_PU,FGC-FGC_PU_Zone,FGC-FGC_SRL,FGC-FGC_SRL_Zone,FGC-FGC_SRR,FGC-FGC_SRR_Zone,FGC-FGC_TL,FGC-FGC_TL_Zone,BIA-BIA_Activity_Level_num,BIA-BIA_BMC,BIA-BIA_BMI,BIA-BIA_BMR,BIA-BIA_DEE,BIA-BIA_ECW,BIA-BIA_FFM,BIA-BIA_FFMI,BIA-BIA_FMI,BIA-BIA_Fat,BIA-BIA_Frame_num,BIA-BIA_ICW,BIA-BIA_LDM,BIA-BIA_LST,BIA-BIA_SMM,BIA-BIA_TBW,PAQ_A-PAQ_A_Total,PAQ_C-PAQ_C_Total,PCIAT-PCIAT_01,PCIAT-PCIAT_02,PCIAT-PCIAT_03,PCIAT-PCIAT_04,PCIAT-PCIAT_05,PCIAT-PCIAT_06,PCIAT-PCIAT_07,PCIAT-PCIAT_08,PCIAT-PCIAT_09,PCIAT-PCIAT_10,PCIAT-PCIAT_11,PCIAT-PCIAT_12,PCIAT-PCIAT_13,PCIAT-PCIAT_14,PCIAT-PCIAT_15,PCIAT-PCIAT_16,PCIAT-PCIAT_17,PCIAT-PCIAT_18,PCIAT-PCIAT_19,PCIAT-PCIAT_20,PCIAT-PCIAT_Total,SDS-SDS_Total_Raw,SDS-SDS_Total_T,PreInt_EduHx-computerinternet_hoursday,sii
count,3960.0,3960.0,2421.0,3022.0,3027.0,3076.0,898.0,2954.0,2967.0,2954.0,743.0,740.0,740.0,2322.0,2282.0,1074.0,1062.0,1074.0,1063.0,2310.0,2271.0,2305.0,2267.0,2307.0,2269.0,2324.0,2285.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,1991.0,475.0,1721.0,2733.0,2734.0,2731.0,2731.0,2729.0,2732.0,2729.0,2730.0,2730.0,2733.0,2734.0,2731.0,2729.0,2732.0,2730.0,2728.0,2725.0,2728.0,2730.0,2733.0,2736.0,2609.0,2606.0,3301.0,2736.0
mean,10.433586,0.372727,65.454771,19.331929,55.946713,89.038615,27.278508,69.648951,81.597236,116.983074,4.989233,7.37027,27.581081,11.25969,0.476337,22.420438,1.829567,23.518622,1.904045,5.579654,0.330251,8.694924,0.61888,8.805635,0.620097,9.252775,0.785558,2.651431,6.719826,19.367048,1237.018187,2064.693747,20.825346,74.021708,15.030554,4.336495,16.85502,1.745354,33.17338,20.02299,67.301883,34.389466,53.998726,2.178853,2.58955,2.370655,2.177762,2.399854,0.839253,2.297545,1.06369,0.586295,1.24652,1.062637,1.304793,1.685443,0.244599,1.340051,1.035505,1.499634,1.452346,1.62789,1.613636,1.158974,0.943652,27.896199,41.088923,57.763622,1.060588,0.580409
std,3.574648,0.483591,22.341862,5.113934,7.473764,44.56904,5.567287,13.611226,13.665196,17.061225,2.014072,3.189662,17.707751,11.807781,0.499549,10.833995,0.612585,11.148951,0.612344,7.390161,0.470407,3.429301,0.485769,3.422167,0.485469,2.988863,0.410525,1.028267,92.586325,5.047848,1872.383246,2836.246272,73.266287,199.433753,5.792505,6.356402,199.372119,0.680635,56.272346,70.21561,108.705918,84.050607,129.362539,0.849476,0.783937,1.673312,1.697117,1.588807,1.195601,1.705218,1.268282,1.049355,1.342582,1.258797,1.331715,1.543074,0.522956,1.411156,1.301712,1.492929,1.4956,1.445622,1.529178,1.343661,1.18546,20.338853,10.427433,13.196091,1.094875,0.771122
min,5.0,0.0,25.0,0.0,33.0,0.0,18.0,0.0,27.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,-7.78961,0.048267,813.397,1073.45,1.78945,28.9004,7.86485,-194.163,-8745.08,1.0,14.489,4.63581,23.6201,4.65573,20.5892,0.66,0.58,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,17.0,38.0,0.0,0.0
25%,8.0,0.0,59.0,15.86935,50.0,57.2,23.0,61.0,72.0,107.0,4.0,6.0,12.75,3.0,0.0,15.1,1.0,16.2,2.0,0.0,0.0,7.0,0.0,7.0,0.0,7.0,1.0,2.0,2.966905,15.9136,1004.71,1605.785,11.10955,49.2781,13.408,2.306915,8.602395,1.0,24.4635,12.98315,45.2041,21.14155,35.887,1.49,2.02,1.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,12.0,33.0,47.0,0.0,0.0
50%,10.0,0.0,65.0,17.937682,55.0,77.0,26.0,68.0,81.0,114.0,5.0,7.0,28.0,9.0,0.0,20.05,2.0,21.2,2.0,3.0,0.0,9.0,1.0,9.0,1.0,10.0,1.0,3.0,3.92272,17.9665,1115.38,1863.98,15.928,61.0662,14.0925,3.69863,16.1746,2.0,28.8558,16.4388,56.9964,27.4151,44.987,2.01,2.54,2.0,2.0,2.0,0.0,2.0,1.0,0.0,1.0,1.0,1.0,2.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,26.0,39.0,55.0,1.0,0.0
75%,13.0,1.0,75.0,21.571244,62.0,113.8,30.0,76.0,90.5,125.0,6.0,9.0,43.0,15.75,1.0,26.6,2.0,28.175,2.0,9.0,1.0,11.0,1.0,11.0,1.0,12.0,1.0,3.0,5.460925,21.4611,1310.36,2218.145,25.1622,81.8338,15.43095,5.98769,30.2731,2.0,35.4757,22.1676,77.10565,38.1794,60.27105,2.78,3.16,4.0,4.0,4.0,1.0,4.0,2.0,1.0,2.0,2.0,2.0,3.0,0.0,2.0,2.0,2.0,2.0,3.0,2.0,2.0,1.0,41.0,46.0,64.0,2.0,1.0
max,22.0,1.0,999.0,59.132048,78.5,315.0,50.0,179.0,138.0,203.0,28.0,20.0,59.0,115.0,1.0,124.0,3.0,123.8,3.0,51.0,1.0,21.7,1.0,21.0,1.0,22.0,1.0,5.0,4115.36,53.9243,83152.2,124728.0,3233.0,8799.08,217.771,28.2515,153.82,3.0,2457.91,3108.17,4683.71,3607.69,5690.91,4.71,4.79,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,93.0,96.0,100.0,3.0,3.0


In [4]:
train['id'].head()

0    00008ff9
1    000fd460
2    00105258
3    00115b9f
4    0016bb22
Name: id, dtype: object

## Load timeseries data

1. **process_file**: This function process file timeseries, extract general information in the file like count, mean, std, min, 25%, 50%, 75% and max of each features and then the features matrix is flattened to a vector to represent the data in the file
2. **load_time_series** Format and load all timeseries files after processed in a folder.

In [5]:
def process_file(filename, dirname):
    df = pd.read_parquet(os.path.join(dirname, filename, 'part-0.parquet'))
    df.drop('step', axis=1, inplace=True)
    return df.describe().values.reshape(-1), filename.split('=')[1]

def load_time_series(dirname) -> pd.DataFrame:
    ids = os.listdir(dirname)
    
    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(lambda fname: process_file(fname, dirname), ids), total=len(ids)))
    
    stats, indexes = zip(*results)
    
    df = pd.DataFrame(stats, columns=[f"Stat_{i}" for i in range(len(stats[0]))])
    df['id'] = indexes
    
    return df

In [6]:
train_ts = load_time_series("/kaggle/input/child-mind-institute-problematic-internet-use/series_train.parquet")
test_ts = load_time_series("/kaggle/input/child-mind-institute-problematic-internet-use/series_test.parquet")
time_series_cols = train_ts.columns.tolist()
time_series_cols.remove("id")

100%|██████████| 996/996 [01:23<00:00, 11.99it/s]
100%|██████████| 2/2 [00:00<00:00,  7.86it/s]


In [7]:
train_ts.head(5)

Unnamed: 0,Stat_0,Stat_1,Stat_2,Stat_3,Stat_4,Stat_5,Stat_6,Stat_7,Stat_8,Stat_9,Stat_10,Stat_11,Stat_12,Stat_13,Stat_14,Stat_15,Stat_16,Stat_17,Stat_18,Stat_19,Stat_20,Stat_21,Stat_22,Stat_23,Stat_24,Stat_25,Stat_26,Stat_27,Stat_28,Stat_29,Stat_30,Stat_31,Stat_32,Stat_33,Stat_34,Stat_35,Stat_36,Stat_37,Stat_38,Stat_39,Stat_40,Stat_41,Stat_42,Stat_43,Stat_44,Stat_45,Stat_46,Stat_47,Stat_48,Stat_49,Stat_50,Stat_51,Stat_52,Stat_53,Stat_54,Stat_55,Stat_56,Stat_57,Stat_58,Stat_59,Stat_60,Stat_61,Stat_62,Stat_63,Stat_64,Stat_65,Stat_66,Stat_67,Stat_68,Stat_69,Stat_70,Stat_71,Stat_72,Stat_73,Stat_74,Stat_75,Stat_76,Stat_77,Stat_78,Stat_79,Stat_80,Stat_81,Stat_82,Stat_83,Stat_84,Stat_85,Stat_86,Stat_87,Stat_88,Stat_89,Stat_90,Stat_91,Stat_92,Stat_93,Stat_94,Stat_95,id
0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,50458.0,-0.054638,-0.163923,-0.114302,0.045252,-7.805897,0.0,46.009533,4027.514893,54154750000000.0,4.43886,2.0,30.202068,0.633126,0.513286,0.500372,0.132576,34.917873,0.0,205.862213,108.451317,18769760000000.0,1.825557,0.0,11.773107,-1.812031,-2.63138,-1.798073,0.0,-89.987045,0.0,0.0,3829.0,0.0,1.0,2.0,15.0,-0.70166,-0.619076,-0.536432,0.007953,-32.948602,0.0,2.520257,3958.0,43251250000000.0,3.0,2.0,17.0,0.015846,-0.14181,-0.104193,0.019257,-6.358004,0.0,8.230733,4029.0,56305000000000.0,5.0,2.0,28.0,0.437897,0.148919,0.22377,0.036048,13.09575,0.0,24.75,4146.0,69780000000000.0,6.0,2.0,38.0,1.850391,3.580182,1.738203,5.314874,89.422226,0.0,2626.199951,4187.0,86395000000000.0,7.0,2.0,57.0,0745c390
1,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,340584.0,0.113277,0.093139,-0.106038,0.02896,-6.065619,0.046508,56.437958,3829.466064,43311490000000.0,3.840885,2.0,232.909103,0.507897,0.541129,0.603787,0.096825,44.034721,0.208482,206.625092,167.600983,25091360000000.0,1.957999,0.0,5.701968,-1.807955,-2.887664,-1.004992,0.0,-89.654587,0.0,0.0,3098.166748,0.0,1.0,2.0,223.0,-0.231743,-0.2576,-0.595426,0.000367,-37.326844,0.0,4.0,3724.0,21285000000000.0,2.0,2.0,228.0,0.094074,0.068143,-0.2285,0.005257,-13.454103,0.0,10.05048,3812.0,43605000000000.0,4.0,2.0,233.0,0.517859,0.542323,0.312333,0.020598,18.462269,0.0,27.490936,3958.0,65110000000000.0,5.0,2.0,238.0,1.928769,3.234613,2.475326,3.966906,89.08033,1.0,2628.199951,4146.0,86395000000000.0,7.0,2.0,243.0,eaab7a96
2,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,40003.0,-0.499738,0.046381,-0.181152,0.056544,-11.934993,0.0,77.30513,4106.425781,44816770000000.0,3.148264,3.0,100.144516,0.454021,0.510668,0.412588,0.140594,27.367514,0.0,274.848145,50.734318,20381560000000.0,1.169176,0.0,5.653936,-1.903281,-3.150104,-1.020313,0.0,-89.540176,0.0,0.0,3853.0,45000000000.0,1.0,3.0,97.0,-0.873151,-0.255299,-0.485521,0.005643,-30.154542,0.0,2.918126,4089.625,28885000000000.0,3.0,3.0,98.0,-0.644505,0.088542,-0.191693,0.018467,-11.570901,0.0,7.863636,4111.0,47270000000000.0,3.0,3.0,99.0,-0.242422,0.381953,0.088555,0.048282,5.009753,0.0,21.022933,4140.0,60945000000000.0,4.0,3.0,100.0,1.02151,1.016589,1.746797,5.066334,86.987267,0.0,2618.199951,4183.0,86365000000000.0,7.0,3.0,134.0,8ec2cc63
3,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,223915.0,0.00743,0.007583,-0.19651,0.053544,-12.847143,0.0,9.369678,3958.604492,48366420000000.0,4.273992,2.303057,60.025017,0.5861,0.542189,0.474437,0.103401,32.552841,0.0,54.104408,122.706802,18687730000000.0,2.023705,1.487018,7.396456,-1.684624,-2.405738,-1.023798,0.0,-89.968369,0.0,0.0,3468.0,0.0,1.0,1.0,48.0,-0.530198,-0.412805,-0.556091,0.009947,-34.965618,0.0,0.893617,3841.0,35260000000000.0,3.0,1.0,53.0,0.022344,0.009674,-0.245181,0.027653,-15.000056,0.0,2.340206,3947.0,48810000000000.0,4.0,1.0,60.0,0.536801,0.443383,0.084469,0.057278,4.816339,0.0,6.2,4064.0,63300000000000.0,6.0,4.0,67.0,5.908,2.083693,1.269051,6.134459,89.976074,0.0,2502.0,6000.0,86395000000000.0,7.0,4.0,72.0,b2987a65
4,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,15420.0,0.086653,-0.115162,-0.138969,0.040399,-11.009835,0.0,5.049157,3992.347656,58338950000000.0,4.541829,4.0,46.192024,0.509845,0.494897,0.639449,0.090201,47.933723,0.0,15.590773,126.12159,21462060000000.0,2.081796,0.0,18.615358,-1.675859,-1.071042,-1.012266,0.0,-89.770241,0.0,0.0,3815.083252,35000000000.0,1.0,4.0,20.0,-0.224805,-0.444297,-0.685736,0.005364,-46.348264,0.0,1.438378,3837.333252,51613750000000.0,3.0,4.0,32.0,0.053034,-0.087422,-0.22543,0.024135,-13.665493,0.0,2.897436,4000.0,64270000000000.0,4.0,4.0,42.0,0.544297,0.153125,0.347474,0.04369,20.726226,0.0,4.942201,4087.0,73936250000000.0,7.0,4.0,69.0,3.231563,1.03362,1.071875,2.774382,89.300034,0.0,1046.800049,4199.0,86015000000000.0,7.0,4.0,76.0,7b8842c3


Time series data is then merged to the tabular data

In [8]:
train = pd.merge(train, train_ts, how="left", on='id')
test = pd.merge(test, test_ts, how="left", on='id')

# Here the GA(GANs) to handle missing value 

In [9]:
import tensorflow as tf
from tensorflow.keras import layers

# Tạo một lớp GAN đơn giản để dự đoán giá trị bị thiếu
class GANImputer:
    def __init__(self, latent_dim=100, epochs=1000, batch_size=32):
        self.latent_dim = latent_dim
        self.epochs = epochs
        self.batch_size = batch_size
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        self.gan = self.build_gan()

    def build_generator(self):
        model = tf.keras.Sequential([
            layers.Dense(128, activation='relu', input_dim=self.latent_dim),
            layers.Dense(256, activation='relu'),
            layers.Dense(512, activation='relu'),
            layers.Dense(1, activation='linear')
        ])
        return model

    def build_discriminator(self):
        model = tf.keras.Sequential([
            layers.Dense(512, activation='relu', input_dim=1),
            layers.Dense(256, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(1, activation='sigmoid')
        ])
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
        return model

    def build_gan(self):
        self.discriminator.trainable = False
        model = tf.keras.Sequential([
            self.generator,
            self.discriminator
        ])
        model.compile(optimizer='adam', loss='binary_crossentropy')
        return model

    def train(self, data):
        if data.shape[0] < self.batch_size:
            raise ValueError("Not enough data to train the GAN. Ensure the batch size is smaller than the number of samples.")
        
        valid = np.ones((self.batch_size, 1))
        fake = np.zeros((self.batch_size, 1))

        for epoch in range(self.epochs):
            # Lấy mẫu dữ liệu thật
            idx = np.random.randint(0, data.shape[0], self.batch_size)
            real_data = data[idx]

            # Sinh dữ liệu giả
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            generated_data = self.generator.predict(noise)

            # Huấn luyện Discriminator
            d_loss_real = self.discriminator.train_on_batch(real_data, valid)
            d_loss_fake = self.discriminator.train_on_batch(generated_data, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Huấn luyện Generator
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            g_loss = self.gan.train_on_batch(noise, valid)

            if epoch % 100 == 0:
                print(f"Epoch {epoch}/{self.epochs} - D Loss: {d_loss[0]:.4f}, G Loss: {g_loss:.4f}")

    def impute(self, data):
        missing_idx = np.isnan(data)
        complete_data = data.copy()
        
        for i in range(data.shape[1]):  # Với từng cột
            missing_rows = missing_idx[:, i]
            if np.any(missing_rows):
                # Lấy noise để dự đoán giá trị thiếu
                noise = np.random.normal(0, 1, (missing_rows.sum(), self.latent_dim))
                generated_values = self.generator.predict(noise)
                complete_data[missing_rows, i] = generated_values[:, 0]
        return complete_data

# Tiền xử lý dữ liệu
numeric_cols = train.select_dtypes(include=['float64', 'int64']).columns
numeric_data = train[numeric_cols].values

# Loại bỏ hàng chứa toàn giá trị NaN
non_nan_rows = ~np.isnan(numeric_data).all(axis=1)
numeric_data = numeric_data[non_nan_rows]

# Chuẩn bị GANs
gan_imputer = GANImputer(epochs=500, batch_size=64)

# Điền giá trị bị thiếu
imputed_data = gan_imputer.impute(numeric_data)

# Chuyển thành DataFrame
train_imputed = pd.DataFrame(imputed_data, columns=numeric_cols)

# Làm tròn cột 'sii'
if 'sii' in train_imputed.columns:
    train_imputed['sii'] = train_imputed['sii'].round().astype(int)

# Gán lại các cột không phải số
for col in train.columns:
    if col not in numeric_cols:
        train_imputed[col] = train[col]

train = train_imputed


[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m91/91[0m [32m━━━━━━━━━━━━━━━━━━━━[

In [10]:
train = train.drop('id', axis=1)
test = test.drop('id', axis=1)

In [11]:
train.head()

Unnamed: 0,Basic_Demos-Age,Basic_Demos-Sex,CGAS-CGAS_Score,Physical-BMI,Physical-Height,Physical-Weight,Physical-Waist_Circumference,Physical-Diastolic_BP,Physical-HeartRate,Physical-Systolic_BP,Fitness_Endurance-Max_Stage,Fitness_Endurance-Time_Mins,Fitness_Endurance-Time_Sec,FGC-FGC_CU,FGC-FGC_CU_Zone,FGC-FGC_GSND,FGC-FGC_GSND_Zone,FGC-FGC_GSD,FGC-FGC_GSD_Zone,FGC-FGC_PU,FGC-FGC_PU_Zone,FGC-FGC_SRL,FGC-FGC_SRL_Zone,FGC-FGC_SRR,FGC-FGC_SRR_Zone,FGC-FGC_TL,FGC-FGC_TL_Zone,BIA-BIA_Activity_Level_num,BIA-BIA_BMC,BIA-BIA_BMI,BIA-BIA_BMR,BIA-BIA_DEE,BIA-BIA_ECW,BIA-BIA_FFM,BIA-BIA_FFMI,BIA-BIA_FMI,BIA-BIA_Fat,BIA-BIA_Frame_num,BIA-BIA_ICW,BIA-BIA_LDM,BIA-BIA_LST,BIA-BIA_SMM,BIA-BIA_TBW,PAQ_A-PAQ_A_Total,PAQ_C-PAQ_C_Total,PCIAT-PCIAT_01,PCIAT-PCIAT_02,PCIAT-PCIAT_03,PCIAT-PCIAT_04,PCIAT-PCIAT_05,PCIAT-PCIAT_06,PCIAT-PCIAT_07,PCIAT-PCIAT_08,PCIAT-PCIAT_09,PCIAT-PCIAT_10,PCIAT-PCIAT_11,PCIAT-PCIAT_12,PCIAT-PCIAT_13,PCIAT-PCIAT_14,PCIAT-PCIAT_15,PCIAT-PCIAT_16,PCIAT-PCIAT_17,PCIAT-PCIAT_18,PCIAT-PCIAT_19,PCIAT-PCIAT_20,PCIAT-PCIAT_Total,SDS-SDS_Total_Raw,SDS-SDS_Total_T,PreInt_EduHx-computerinternet_hoursday,sii,Stat_0,Stat_1,Stat_2,Stat_3,Stat_4,Stat_5,Stat_6,Stat_7,Stat_8,Stat_9,Stat_10,Stat_11,Stat_12,Stat_13,Stat_14,Stat_15,Stat_16,Stat_17,Stat_18,Stat_19,Stat_20,Stat_21,Stat_22,Stat_23,Stat_24,Stat_25,Stat_26,Stat_27,Stat_28,Stat_29,Stat_30,Stat_31,Stat_32,Stat_33,Stat_34,Stat_35,Stat_36,Stat_37,Stat_38,Stat_39,Stat_40,Stat_41,Stat_42,Stat_43,Stat_44,Stat_45,Stat_46,Stat_47,Stat_48,Stat_49,Stat_50,Stat_51,Stat_52,Stat_53,Stat_54,Stat_55,Stat_56,Stat_57,Stat_58,Stat_59,Stat_60,Stat_61,Stat_62,Stat_63,Stat_64,Stat_65,Stat_66,Stat_67,Stat_68,Stat_69,Stat_70,Stat_71,Stat_72,Stat_73,Stat_74,Stat_75,Stat_76,Stat_77,Stat_78,Stat_79,Stat_80,Stat_81,Stat_82,Stat_83,Stat_84,Stat_85,Stat_86,Stat_87,Stat_88,Stat_89,Stat_90,Stat_91,Stat_92,Stat_93,Stat_94,Stat_95,Basic_Demos-Enroll_Season,CGAS-Season,Physical-Season,Fitness_Endurance-Season,FGC-Season,BIA-Season,PAQ_A-Season,PAQ_C-Season,PCIAT-Season,SDS-Season,PreInt_EduHx-Season
0,5.0,0.0,51.0,16.877316,46.0,50.8,-0.154064,0.118792,-0.179445,-0.155679,-0.152848,0.007912,-0.12252,0.0,0.0,0.116686,0.074379,0.046487,0.067392,0.0,0.0,7.0,0.0,6.0,0.0,6.0,1.0,2.0,2.66855,16.8792,932.498,1492.0,8.25598,41.5862,13.8177,3.06143,9.21377,1.0,24.4349,8.89536,38.9177,19.5413,32.6909,-0.209154,0.155368,5.0,4.0,4.0,0.0,4.0,0.0,0.0,4.0,0.0,0.0,4.0,0.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,4.0,55.0,-0.062785,-0.249771,3.0,2,0.373241,0.289002,0.202279,0.016655,-0.129149,0.130674,-0.036775,0.130338,-0.089939,-0.098153,-0.021684,-0.025672,-0.110934,0.227276,0.102979,-0.088788,-0.155662,-0.160307,-0.148536,-0.123719,-0.0581628,-0.034341,0.152515,0.375503,0.196253,0.10046,-0.049118,-0.097126,-0.181682,0.129153,0.592024,-0.082748,-0.291428,-0.206699,-0.009084,0.245418,-0.240748,0.19824,-0.016956,0.137855,0.115685,-0.086519,-0.00215,0.054817,-0.3147832,-0.251763,-0.128998,-0.235171,0.006193,0.020703,-0.14183,-0.089135,0.164077,0.102115,-0.31228,0.063128,-0.1429681,0.050361,0.227214,-0.102565,-0.427599,-0.209742,-0.230028,0.110077,0.000809,0.481299,0.083569,-0.44333,0.06927192,0.360862,-0.150935,0.058627,-0.515985,-0.159533,0.250685,0.230257,-0.189195,0.298147,-0.094588,0.136663,0.3003571,0.400883,-0.208567,-0.151631,0.052621,-0.007136,-0.085346,-0.220448,-0.416644,-0.071484,0.21969,-0.097079,-0.02836639,-0.274512,0.011742,-0.060631,Fall,Winter,Fall,,Fall,Fall,,,Fall,,Fall
1,9.0,0.0,0.111098,14.03559,48.0,46.0,22.0,75.0,70.0,122.0,0.248451,-0.111683,0.253281,3.0,0.0,-0.053693,0.085005,-0.424165,-0.219452,5.0,0.0,11.0,1.0,11.0,1.0,3.0,0.0,2.0,2.57949,14.0371,936.656,1498.65,6.01993,42.0291,12.8254,1.21172,3.97085,1.0,21.0352,14.974,39.4497,15.4107,27.0552,-0.05114,2.34,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,46.0,64.0,0.0,0,0.149736,0.121664,-0.026019,0.157084,0.385274,-0.143335,-0.267048,0.442075,0.130279,0.321569,0.136117,-0.001031,0.087139,0.135488,-0.082363,0.006746,-0.249495,0.018895,0.129138,-0.18301,-0.1009916,-0.005584,-0.136259,-0.440091,0.009793,0.269016,-0.051625,0.016377,-0.224567,-0.04811,-0.079611,-0.17564,0.1579596,-0.079938,-0.180506,0.068062,-0.004663,0.019804,-0.091586,0.059809,-0.032098,-0.205951,-0.259167,-0.068215,-0.00969404,-0.076273,-0.040543,0.361664,-0.059096,-0.09854,0.063532,-0.016562,-0.226004,-0.044651,-0.084755,0.046531,-0.1982391,-0.026478,-0.2876,0.24054,-0.210579,-0.049129,-0.007372,0.011423,-0.059365,0.057984,0.026228,-0.158474,-0.07591943,-0.050091,-0.016198,-0.009005,0.028435,0.105841,-0.407039,0.201467,-0.286448,-0.039902,-0.187489,0.097585,0.2784588,0.173933,-0.036986,-0.133328,-0.212758,-0.43607,0.02387,0.279588,0.015037,0.062325,-0.106941,0.023251,-0.05872407,0.058401,0.048831,-0.043989,Summer,,Fall,,Fall,Winter,,Fall,Fall,Fall,Summer
2,10.0,1.0,71.0,16.648696,56.5,75.6,-0.218841,65.0,94.0,117.0,5.0,7.0,33.0,20.0,1.0,10.2,1.0,14.7,2.0,7.0,1.0,10.0,1.0,10.0,1.0,5.0,0.0,0.104477,-0.297142,0.174724,0.083522,0.021778,-0.344502,-0.217101,-0.319782,0.135576,0.184098,-0.04318,-0.283192,-0.145384,-0.230623,0.151023,-0.39341,0.45818,2.17,5.0,2.0,2.0,1.0,2.0,1.0,1.0,2.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,2.0,2.0,1.0,1.0,28.0,38.0,54.0,2.0,0,-0.196864,-0.241204,0.040846,-0.110933,0.073238,0.133504,-0.326363,0.068324,0.010892,-0.440738,0.137047,-0.057938,-0.152985,-0.135459,-0.105569,0.287972,-0.003899,0.130881,0.433591,0.139819,0.06920992,0.062679,-0.247691,-0.09527,-0.03573,-0.143616,0.069428,0.064776,0.44595,0.286904,-0.181611,-0.094492,-0.09811944,-0.121294,-0.187951,0.20889,-0.561118,-0.185013,0.025149,0.183368,0.078928,0.126396,0.072283,-0.214366,0.3186095,-0.031028,0.106868,-0.26099,0.156615,0.189989,-0.08109,0.194367,0.031424,-0.06586,0.045728,0.206752,0.09848626,-0.121122,-0.015072,-0.081818,-0.199642,-0.32298,0.179811,-0.259431,0.121225,-0.135975,0.029856,0.117951,-0.1209691,0.213279,-0.0894,-0.349239,-0.01934,-0.319676,0.132141,-0.245438,-0.143809,-0.096489,0.108262,-0.062116,-0.02991761,0.022421,-0.24696,-0.441375,-0.23577,0.065267,-0.421742,-0.362294,-0.523501,-0.124694,0.103444,-0.243985,-0.1281529,-0.081943,-0.019832,0.101446,Summer,Fall,Fall,Fall,Fall,,,Summer,Fall,Fall,Summer
3,9.0,0.0,71.0,18.292347,56.0,81.6,-0.208838,60.0,97.0,117.0,6.0,9.0,37.0,18.0,1.0,-0.125438,0.06917,0.192009,-0.073777,5.0,0.0,7.0,0.0,7.0,0.0,7.0,1.0,3.0,3.84191,18.2943,1131.43,1923.44,15.5925,62.7757,14.074,4.22033,18.8243,2.0,30.4041,16.779,58.9338,26.4798,45.9966,-0.12202,2.451,4.0,2.0,4.0,0.0,5.0,1.0,0.0,3.0,2.0,2.0,3.0,0.0,3.0,0.0,0.0,3.0,4.0,3.0,4.0,1.0,44.0,31.0,45.0,0.0,1,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,-0.316384,0.016009,-0.16789,0.047388,-10.580416,0.0,42.29631,4053.579102,50462150000000.0,4.470182,3.0,53.201683,0.453665,0.502702,0.58571,0.106351,42.94717,0.0,208.168976,112.404045,19428420000000.0,1.931421,0.0,14.244914,-1.746094,-2.905339,-1.048372,0.0,-89.833092,0.0,0.0,3824.0,55000000000.0,1.0,3.0,41.0,-0.68418,-0.309863,-0.649974,0.006432,-41.541863,0.0,2.392969,4028.666748,36890000000000.0,3.0,3.0,42.0,-0.366849,0.024974,-0.245378,0.023637,-15.086617,0.0,6.926828,4070.0,53477500000000.0,5.0,3.0,50.0,-0.010677,0.400677,0.204727,0.04142,12.220764,0.0,15.0,4147.0,66408750000000.0,6.0,3.0,53.0,1.507865,1.666354,1.546979,4.004276,89.751656,0.0,2633.25,4188.5,86110000000000.0,7.0,3.0,85.0,Winter,Fall,Summer,Summer,Summer,Summer,,Winter,Summer,Summer,Winter
4,18.0,1.0,-0.178201,-0.023806,0.103824,0.039332,0.334921,-0.050506,0.022204,-0.385004,0.139939,0.001325,-0.4605,0.269659,-0.088131,0.015555,0.040107,-0.103649,0.134727,-0.015768,-0.03687,-0.029575,-0.206977,0.106131,-0.09589,0.307428,0.059726,-0.13657,0.15152,0.001998,0.088034,-0.003269,-0.00182,0.064564,0.147794,-0.337159,0.046568,-0.046357,-0.033089,-0.397515,-0.310134,-0.041897,-0.082382,1.04,-0.063808,0.301544,-0.002114,0.335112,0.077993,-0.108909,-0.175843,0.07313,-0.011015,0.443055,0.110335,-0.303601,-0.217882,-0.594706,-0.456694,0.071706,-0.093584,-0.049534,0.183711,0.13472,0.005874,0.194074,-0.320399,-0.505314,-0.048344,0,0.026037,-0.074391,-0.174369,-0.075133,-0.229504,0.00719,0.13523,-0.071977,0.159332,0.021706,0.370155,0.151499,0.2647,-0.164902,0.218573,0.09824,0.195661,-0.011592,-0.035602,0.18849,-0.1124921,0.32645,0.013488,0.17606,-0.185237,-0.077988,0.211697,0.011398,-0.151226,0.327706,0.068223,-0.096587,-0.02224067,-0.391335,0.260173,-0.200784,0.531572,-0.296789,0.112476,0.061734,-0.081523,-0.208388,-0.0863,-0.120484,0.3002845,0.029091,-0.136006,-0.257753,-0.044188,0.186163,-0.028362,-0.069998,-0.4039,-0.031814,-0.164591,0.005512,0.02000909,-0.291887,0.140225,-0.028794,0.104028,0.249424,-0.28284,-0.28129,0.050977,-0.297205,0.091797,-0.057484,0.1067495,-0.128449,0.035262,0.11354,0.091221,0.36,-0.126656,0.026567,-0.098362,-0.194729,-0.075744,-0.166821,0.08817163,-0.304603,-0.049464,-0.43084,-0.003995,0.34861,-0.215579,-0.218892,0.145464,-0.041306,-0.071206,-0.082548,-0.1090113,-0.21411,-0.041213,-0.075291,Spring,Summer,,,,,Summer,,,,


# Data Filtering

Select the columns which is present in test data to train

In [12]:
featuresCols = ['Basic_Demos-Enroll_Season', 'Basic_Demos-Age', 'Basic_Demos-Sex',
                'CGAS-Season', 'CGAS-CGAS_Score', 'Physical-Season', 'Physical-BMI',
                'Physical-Height', 'Physical-Weight', 'Physical-Waist_Circumference',
                'Physical-Diastolic_BP', 'Physical-HeartRate', 'Physical-Systolic_BP',
                'Fitness_Endurance-Season', 'Fitness_Endurance-Max_Stage',
                'Fitness_Endurance-Time_Mins', 'Fitness_Endurance-Time_Sec',
                'FGC-Season', 'FGC-FGC_CU', 'FGC-FGC_CU_Zone', 'FGC-FGC_GSND',
                'FGC-FGC_GSND_Zone', 'FGC-FGC_GSD', 'FGC-FGC_GSD_Zone', 'FGC-FGC_PU',
                'FGC-FGC_PU_Zone', 'FGC-FGC_SRL', 'FGC-FGC_SRL_Zone', 'FGC-FGC_SRR',
                'FGC-FGC_SRR_Zone', 'FGC-FGC_TL', 'FGC-FGC_TL_Zone', 'BIA-Season',
                'BIA-BIA_Activity_Level_num', 'BIA-BIA_BMC', 'BIA-BIA_BMI',
                'BIA-BIA_BMR', 'BIA-BIA_DEE', 'BIA-BIA_ECW', 'BIA-BIA_FFM',
                'BIA-BIA_FFMI', 'BIA-BIA_FMI', 'BIA-BIA_Fat', 'BIA-BIA_Frame_num',
                'BIA-BIA_ICW', 'BIA-BIA_LDM', 'BIA-BIA_LST', 'BIA-BIA_SMM',
                'BIA-BIA_TBW', 'PAQ_A-Season', 'PAQ_A-PAQ_A_Total', 'PAQ_C-Season',
                'PAQ_C-PAQ_C_Total', 'SDS-Season', 'SDS-SDS_Total_Raw',
                'SDS-SDS_Total_T', 'PreInt_EduHx-Season',
                'PreInt_EduHx-computerinternet_hoursday', 'sii']

featuresCols += time_series_cols

train = train[featuresCols]

Drop the NaN sii value

In [13]:
train = train.dropna(subset='sii')

Fill the missing categorical data with "Missing" 

In [14]:
cat_c = ['Basic_Demos-Enroll_Season', 'CGAS-Season', 'Physical-Season', 'Fitness_Endurance-Season', 
          'FGC-Season', 'BIA-Season', 'PAQ_A-Season', 'PAQ_C-Season', 'SDS-Season', 'PreInt_EduHx-Season']

def update(df):
    for c in cat_c: 
        df[c] = df[c].fillna('Missing')
        df[c] = df[c].astype('category')
    return df
        
train = update(train)
test = update(test)

Create a mapping from string to integer to push data to the model (Use one hot encode instead)

In [15]:
def create_mapping(column, dataset):
    unique_values = dataset[column].unique()
    return {value: idx for idx, value in enumerate(unique_values)}

for col in cat_c:
    mapping_train = create_mapping(col, train)
    mapping_test = create_mapping(col, test)
    
    train[col] = train[col].replace(mapping_train).astype(int)
    test[col] = test[col].replace(mapping_test).astype(int)

print(f'Train Shape : {train.shape} || Test Shape : {test.shape}')

Train Shape : (3960, 155) || Test Shape : (20, 154)


In [16]:
test.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20 entries, 0 to 19
Columns: 154 entries, Basic_Demos-Enroll_Season to Stat_95
dtypes: float64(142), int64(12)
memory usage: 24.2 KB


In [17]:
train.head()

Unnamed: 0,Basic_Demos-Enroll_Season,Basic_Demos-Age,Basic_Demos-Sex,CGAS-Season,CGAS-CGAS_Score,Physical-Season,Physical-BMI,Physical-Height,Physical-Weight,Physical-Waist_Circumference,Physical-Diastolic_BP,Physical-HeartRate,Physical-Systolic_BP,Fitness_Endurance-Season,Fitness_Endurance-Max_Stage,Fitness_Endurance-Time_Mins,Fitness_Endurance-Time_Sec,FGC-Season,FGC-FGC_CU,FGC-FGC_CU_Zone,FGC-FGC_GSND,FGC-FGC_GSND_Zone,FGC-FGC_GSD,FGC-FGC_GSD_Zone,FGC-FGC_PU,FGC-FGC_PU_Zone,FGC-FGC_SRL,FGC-FGC_SRL_Zone,FGC-FGC_SRR,FGC-FGC_SRR_Zone,FGC-FGC_TL,FGC-FGC_TL_Zone,BIA-Season,BIA-BIA_Activity_Level_num,BIA-BIA_BMC,BIA-BIA_BMI,BIA-BIA_BMR,BIA-BIA_DEE,BIA-BIA_ECW,BIA-BIA_FFM,BIA-BIA_FFMI,BIA-BIA_FMI,BIA-BIA_Fat,BIA-BIA_Frame_num,BIA-BIA_ICW,BIA-BIA_LDM,BIA-BIA_LST,BIA-BIA_SMM,BIA-BIA_TBW,PAQ_A-Season,PAQ_A-PAQ_A_Total,PAQ_C-Season,PAQ_C-PAQ_C_Total,SDS-Season,SDS-SDS_Total_Raw,SDS-SDS_Total_T,PreInt_EduHx-Season,PreInt_EduHx-computerinternet_hoursday,sii,Stat_0,Stat_1,Stat_2,Stat_3,Stat_4,Stat_5,Stat_6,Stat_7,Stat_8,Stat_9,Stat_10,Stat_11,Stat_12,Stat_13,Stat_14,Stat_15,Stat_16,Stat_17,Stat_18,Stat_19,Stat_20,Stat_21,Stat_22,Stat_23,Stat_24,Stat_25,Stat_26,Stat_27,Stat_28,Stat_29,Stat_30,Stat_31,Stat_32,Stat_33,Stat_34,Stat_35,Stat_36,Stat_37,Stat_38,Stat_39,Stat_40,Stat_41,Stat_42,Stat_43,Stat_44,Stat_45,Stat_46,Stat_47,Stat_48,Stat_49,Stat_50,Stat_51,Stat_52,Stat_53,Stat_54,Stat_55,Stat_56,Stat_57,Stat_58,Stat_59,Stat_60,Stat_61,Stat_62,Stat_63,Stat_64,Stat_65,Stat_66,Stat_67,Stat_68,Stat_69,Stat_70,Stat_71,Stat_72,Stat_73,Stat_74,Stat_75,Stat_76,Stat_77,Stat_78,Stat_79,Stat_80,Stat_81,Stat_82,Stat_83,Stat_84,Stat_85,Stat_86,Stat_87,Stat_88,Stat_89,Stat_90,Stat_91,Stat_92,Stat_93,Stat_94,Stat_95
0,0,5.0,0.0,0,51.0,0,16.877316,46.0,50.8,-0.154064,0.118792,-0.179445,-0.155679,0,-0.152848,0.007912,-0.12252,0,0.0,0.0,0.116686,0.074379,0.046487,0.067392,0.0,0.0,7.0,0.0,6.0,0.0,6.0,1.0,0,2.0,2.66855,16.8792,932.498,1492.0,8.25598,41.5862,13.8177,3.06143,9.21377,1.0,24.4349,8.89536,38.9177,19.5413,32.6909,0,-0.209154,0,0.155368,0,-0.062785,-0.249771,0,3.0,2,0.373241,0.289002,0.202279,0.016655,-0.129149,0.130674,-0.036775,0.130338,-0.089939,-0.098153,-0.021684,-0.025672,-0.110934,0.227276,0.102979,-0.088788,-0.155662,-0.160307,-0.148536,-0.123719,-0.0581628,-0.034341,0.152515,0.375503,0.196253,0.10046,-0.049118,-0.097126,-0.181682,0.129153,0.592024,-0.082748,-0.291428,-0.206699,-0.009084,0.245418,-0.240748,0.19824,-0.016956,0.137855,0.115685,-0.086519,-0.00215,0.054817,-0.3147832,-0.251763,-0.128998,-0.235171,0.006193,0.020703,-0.14183,-0.089135,0.164077,0.102115,-0.31228,0.063128,-0.1429681,0.050361,0.227214,-0.102565,-0.427599,-0.209742,-0.230028,0.110077,0.000809,0.481299,0.083569,-0.44333,0.06927192,0.360862,-0.150935,0.058627,-0.515985,-0.159533,0.250685,0.230257,-0.189195,0.298147,-0.094588,0.136663,0.3003571,0.400883,-0.208567,-0.151631,0.052621,-0.007136,-0.085346,-0.220448,-0.416644,-0.071484,0.21969,-0.097079,-0.02836639,-0.274512,0.011742,-0.060631
1,1,9.0,0.0,1,0.111098,0,14.03559,48.0,46.0,22.0,75.0,70.0,122.0,0,0.248451,-0.111683,0.253281,0,3.0,0.0,-0.053693,0.085005,-0.424165,-0.219452,5.0,0.0,11.0,1.0,11.0,1.0,3.0,0.0,1,2.0,2.57949,14.0371,936.656,1498.65,6.01993,42.0291,12.8254,1.21172,3.97085,1.0,21.0352,14.974,39.4497,15.4107,27.0552,0,-0.05114,1,2.34,1,46.0,64.0,1,0.0,0,0.149736,0.121664,-0.026019,0.157084,0.385274,-0.143335,-0.267048,0.442075,0.130279,0.321569,0.136117,-0.001031,0.087139,0.135488,-0.082363,0.006746,-0.249495,0.018895,0.129138,-0.18301,-0.1009916,-0.005584,-0.136259,-0.440091,0.009793,0.269016,-0.051625,0.016377,-0.224567,-0.04811,-0.079611,-0.17564,0.1579596,-0.079938,-0.180506,0.068062,-0.004663,0.019804,-0.091586,0.059809,-0.032098,-0.205951,-0.259167,-0.068215,-0.00969404,-0.076273,-0.040543,0.361664,-0.059096,-0.09854,0.063532,-0.016562,-0.226004,-0.044651,-0.084755,0.046531,-0.1982391,-0.026478,-0.2876,0.24054,-0.210579,-0.049129,-0.007372,0.011423,-0.059365,0.057984,0.026228,-0.158474,-0.07591943,-0.050091,-0.016198,-0.009005,0.028435,0.105841,-0.407039,0.201467,-0.286448,-0.039902,-0.187489,0.097585,0.2784588,0.173933,-0.036986,-0.133328,-0.212758,-0.43607,0.02387,0.279588,0.015037,0.062325,-0.106941,0.023251,-0.05872407,0.058401,0.048831,-0.043989
2,1,10.0,1.0,2,71.0,0,16.648696,56.5,75.6,-0.218841,65.0,94.0,117.0,1,5.0,7.0,33.0,0,20.0,1.0,10.2,1.0,14.7,2.0,7.0,1.0,10.0,1.0,10.0,1.0,5.0,0.0,2,0.104477,-0.297142,0.174724,0.083522,0.021778,-0.344502,-0.217101,-0.319782,0.135576,0.184098,-0.04318,-0.283192,-0.145384,-0.230623,0.151023,-0.39341,0,0.45818,2,2.17,1,38.0,54.0,1,2.0,0,-0.196864,-0.241204,0.040846,-0.110933,0.073238,0.133504,-0.326363,0.068324,0.010892,-0.440738,0.137047,-0.057938,-0.152985,-0.135459,-0.105569,0.287972,-0.003899,0.130881,0.433591,0.139819,0.06920992,0.062679,-0.247691,-0.09527,-0.03573,-0.143616,0.069428,0.064776,0.44595,0.286904,-0.181611,-0.094492,-0.09811944,-0.121294,-0.187951,0.20889,-0.561118,-0.185013,0.025149,0.183368,0.078928,0.126396,0.072283,-0.214366,0.3186095,-0.031028,0.106868,-0.26099,0.156615,0.189989,-0.08109,0.194367,0.031424,-0.06586,0.045728,0.206752,0.09848626,-0.121122,-0.015072,-0.081818,-0.199642,-0.32298,0.179811,-0.259431,0.121225,-0.135975,0.029856,0.117951,-0.1209691,0.213279,-0.0894,-0.349239,-0.01934,-0.319676,0.132141,-0.245438,-0.143809,-0.096489,0.108262,-0.062116,-0.02991761,0.022421,-0.24696,-0.441375,-0.23577,0.065267,-0.421742,-0.362294,-0.523501,-0.124694,0.103444,-0.243985,-0.1281529,-0.081943,-0.019832,0.101446
3,2,9.0,0.0,2,71.0,1,18.292347,56.0,81.6,-0.208838,60.0,97.0,117.0,2,6.0,9.0,37.0,1,18.0,1.0,-0.125438,0.06917,0.192009,-0.073777,5.0,0.0,7.0,0.0,7.0,0.0,7.0,1.0,3,3.0,3.84191,18.2943,1131.43,1923.44,15.5925,62.7757,14.074,4.22033,18.8243,2.0,30.4041,16.779,58.9338,26.4798,45.9966,0,-0.12202,3,2.451,2,31.0,45.0,2,0.0,1,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,43330.0,-0.316384,0.016009,-0.16789,0.047388,-10.580416,0.0,42.29631,4053.579102,50462150000000.0,4.470182,3.0,53.201683,0.453665,0.502702,0.58571,0.106351,42.94717,0.0,208.168976,112.404045,19428420000000.0,1.931421,0.0,14.244914,-1.746094,-2.905339,-1.048372,0.0,-89.833092,0.0,0.0,3824.0,55000000000.0,1.0,3.0,41.0,-0.68418,-0.309863,-0.649974,0.006432,-41.541863,0.0,2.392969,4028.666748,36890000000000.0,3.0,3.0,42.0,-0.366849,0.024974,-0.245378,0.023637,-15.086617,0.0,6.926828,4070.0,53477500000000.0,5.0,3.0,50.0,-0.010677,0.400677,0.204727,0.04142,12.220764,0.0,15.0,4147.0,66408750000000.0,6.0,3.0,53.0,1.507865,1.666354,1.546979,4.004276,89.751656,0.0,2633.25,4188.5,86110000000000.0,7.0,3.0,85.0
4,3,18.0,1.0,3,-0.178201,2,-0.023806,0.103824,0.039332,0.334921,-0.050506,0.022204,-0.385004,0,0.139939,0.001325,-0.4605,2,0.269659,-0.088131,0.015555,0.040107,-0.103649,0.134727,-0.015768,-0.03687,-0.029575,-0.206977,0.106131,-0.09589,0.307428,0.059726,2,-0.13657,0.15152,0.001998,0.088034,-0.003269,-0.00182,0.064564,0.147794,-0.337159,0.046568,-0.046357,-0.033089,-0.397515,-0.310134,-0.041897,-0.082382,1,1.04,0,-0.063808,0,-0.320399,-0.505314,3,-0.048344,0,0.026037,-0.074391,-0.174369,-0.075133,-0.229504,0.00719,0.13523,-0.071977,0.159332,0.021706,0.370155,0.151499,0.2647,-0.164902,0.218573,0.09824,0.195661,-0.011592,-0.035602,0.18849,-0.1124921,0.32645,0.013488,0.17606,-0.185237,-0.077988,0.211697,0.011398,-0.151226,0.327706,0.068223,-0.096587,-0.02224067,-0.391335,0.260173,-0.200784,0.531572,-0.296789,0.112476,0.061734,-0.081523,-0.208388,-0.0863,-0.120484,0.3002845,0.029091,-0.136006,-0.257753,-0.044188,0.186163,-0.028362,-0.069998,-0.4039,-0.031814,-0.164591,0.005512,0.02000909,-0.291887,0.140225,-0.028794,0.104028,0.249424,-0.28284,-0.28129,0.050977,-0.297205,0.091797,-0.057484,0.1067495,-0.128449,0.035262,0.11354,0.091221,0.36,-0.126656,0.026567,-0.098362,-0.194729,-0.075744,-0.166821,0.08817163,-0.304603,-0.049464,-0.43084,-0.003995,0.34861,-0.215579,-0.218892,0.145464,-0.041306,-0.071206,-0.082548,-0.1090113,-0.21411,-0.041213,-0.075291


In [18]:
train['sii'].head()

0    2
1    0
2    0
3    1
4    0
Name: sii, dtype: int64

# Training Function

**quadratic_weighted_kappa**: calculate QWK value

In [19]:
def quadratic_weighted_kappa(y_true, y_pred):
    return cohen_kappa_score(y_true, y_pred, weights='quadratic')


**threshold_Rounder**: Turn the sii from PCIAT_Total to categorical 

In [20]:
def threshold_Rounder(oof_non_rounded, thresholds):
    return np.where(oof_non_rounded < thresholds[0], 0,
                    np.where(oof_non_rounded < thresholds[1], 1,
                             np.where(oof_non_rounded < thresholds[2], 2, 3)))

**evaluate_predictions**: this function evaluate the prediction of the model by first turn integer prediction values to categorical values and then calculate QWK from it and the true labels.

In [21]:
def evaluate_predictions(thresholds, y_true, oof_non_rounded):
    rounded_p = threshold_Rounder(oof_non_rounded, thresholds)
    return -quadratic_weighted_kappa(y_true, rounded_p)


**TrainML**: Train the model using K-Fold, The model is regression model, predict a real value represent how bad the patient was. The value may not explicitly different, so we re-define the threshold to make it split more accurate

In [22]:
def TrainML(model_class, test_data):
    
    X = train.drop(['sii'], axis=1)
    y = train['sii']

    # Apply K-Fold
    SKF = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    
    train_S = []
    test_S = []
    
    oof_non_rounded = np.zeros(len(y), dtype=float) 
    oof_rounded = np.zeros(len(y), dtype=int) 
    test_preds = np.zeros((len(test_data), n_splits))

    for fold, (train_idx, test_idx) in enumerate(tqdm(SKF.split(X, y), desc="Training Folds", total=n_splits)):
        # Train model
        X_train, X_val = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_val = y.iloc[train_idx], y.iloc[test_idx]

        model = clone(model_class)
        model.fit(X_train, y_train)

        y_train_pred = model.predict(X_train)
        y_val_pred = model.predict(X_val)

        # Round to integer values
        oof_non_rounded[test_idx] = y_val_pred
        y_val_pred_rounded = y_val_pred.round(0).astype(int)
        oof_rounded[test_idx] = y_val_pred_rounded

        train_kappa = quadratic_weighted_kappa(y_train, y_train_pred.round(0).astype(int))
        val_kappa = quadratic_weighted_kappa(y_val, y_val_pred_rounded)

        train_S.append(train_kappa)
        test_S.append(val_kappa)

        #Predict with test dataset
        test_preds[:, fold] = model.predict(test_data)
        
        print(f"Fold {fold+1} - Train QWK: {train_kappa:.4f}, Validation QWK: {val_kappa:.4f}")
        clear_output(wait=True)

    print(f"Mean Train QWK --> {np.mean(train_S):.4f}")
    print(f"Mean Validation QWK ---> {np.mean(test_S):.4f}")

    # Using optimizer to find the best threshold
    KappaOPtimizer = minimize(evaluate_predictions,
                              x0=[0.5, 1.5, 2.5], args=(y, oof_non_rounded), 
                              method='Nelder-Mead') # Nelder-Mead | # Powell
    assert KappaOPtimizer.success, "Optimization did not converge."

    # Use the threshold retrive from the optimizer to predict again to evaluate
    oof_tuned = threshold_Rounder(oof_non_rounded, KappaOPtimizer.x)
    tKappa = quadratic_weighted_kappa(y, oof_tuned)

    print(f"----> || Optimized QWK SCORE :: {Fore.CYAN}{Style.BRIGHT} {tKappa:.3f}{Style.RESET_ALL}")

    # Use the threshold retrive from the optimizer to predict test
    tpm = test_preds.mean(axis=1)
    tpTuned = threshold_Rounder(tpm, KappaOPtimizer.x)

    # Create submition
    submission = pd.DataFrame({
        'id': sample['id'],
        'sii': tpTuned
    })

    return submission,model

# Create model and train the model

We first try to train a model to take it as baseline. The baseline here is LGBM, which is an Gradient Boosting framework

In [23]:
Params = {
    'learning_rate': 0.046,
    'max_depth': 12,
    'num_leaves': 478,
    'min_data_in_leaf': 13,
    'feature_fraction': 0.893,
    'bagging_fraction': 0.784,
    'bagging_freq': 4,
    'lambda_l1': 10,  # Increased from 6.59
    'lambda_l2': 0.01,  # Increased from 2.68e-06
}

Light = lgb.LGBMRegressor(**Params, verbose=-1, n_estimators=200, random_state=SEED)
Submission,model = TrainML(Light,test)

Training Folds: 100%|██████████| 5/5 [00:17<00:00,  3.48s/it]

Mean Train QWK --> 0.8685
Mean Validation QWK ---> 0.4532





----> || Optimized QWK SCORE :: [36m[1m 0.500[0m


# Submit model

In [24]:
Submission.to_csv('submission.csv', index=False)
print(Submission['sii'].value_counts())

sii
0    15
1     5
Name: count, dtype: int64
