In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import os

In [3]:
PROJECT_DIR = os.path.abspath('.')
if PROJECT_DIR.endswith('final-nbs'):
    PROJECT_DIR = os.path.abspath('../')
    os.chdir(PROJECT_DIR)

In [4]:
import cfg
from src.data import get_features_path_from_metadata, join_dataframe_columns
from src import util
from src.data import setup_directories
util.setup_logging()

dirs = setup_directories(cfg.DATA_DIR, create_dirs=True)

In [5]:
raw_dir = Path(dirs['raw'])
train_dir = Path(dirs['train'])
cv_dir = Path(dirs['cv']['final-validation'])

In [6]:
# read metadata
pd_metadata = pd.read_csv(raw_dir / "metadata.csv", index_col="sample_id")
pd_metadata.head()

Unnamed: 0_level_0,split,instrument_type,features_path,features_md5_hash
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
S0000,train,commercial,train_features/S0000.csv,017b9a71a702e81a828e6242aa15f049
S0001,train,commercial,train_features/S0001.csv,0d09840214054d254bd49436c6a6f315
S0002,train,commercial,train_features/S0002.csv,3f58b3c9b001bfed6ed4e4f757083e09
S0003,train,commercial,train_features/S0003.csv,e9a12f96114a2fda60b36f4c0f513fb1
S0004,train,commercial,train_features/S0004.csv,b67603d3931897bfa796ac42cc16de78


In [7]:
# read train labels
pd_train_target = pd.read_csv(raw_dir / 'train_labels.csv', index_col='sample_id')

In [8]:
# read train labels
pd_multclass_target = pd.read_csv(train_dir / 'multiclass.csv', index_col='sample_id')

In [9]:
pd_agg_features = pd.read_csv(train_dir / 'mz_agg_features_drop_correlated.csv', index_col='sample_id')

In [10]:
pd_agg_features.head()

Unnamed: 0_level_0,sum_mz0,sum_mz1,sum_mz2,sum_mz3,sum_mz6,sum_mz7,sum_mz12,sum_mz13,sum_mz14,sum_mz15,...,temp_peak_mz90,temp_peak_mz91,temp_peak_mz92,temp_peak_mz93,temp_peak_mz94,temp_peak_mz95,temp_peak_mz96,temp_peak_mz97,temp_peak_mz98,temp_peak_mz99
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0000,0.34002,0.199854,0.166877,0.200833,0.23883,0.438316,0.766372,0.868359,0.890791,0.863902,...,0.089245,0.229957,0.081612,0.124142,0.203451,0.103749,0.139557,0.104356,0.081609,0.08987
S0001,0.0,0.913373,0.977375,0.979591,0.956738,0.99092,0.980405,0.96303,0.979561,0.965443,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
S0002,0.410304,0.380667,0.326579,0.354215,0.186832,0.408898,0.452858,0.06905,0.259936,0.569445,...,0.927101,0.769674,0.494778,0.535845,0.238448,0.89256,0.6348,0.596683,0.238674,0.839972
S0003,0.244877,0.219661,0.181182,0.234663,0.299011,0.648236,0.836375,0.547981,0.601476,0.124742,...,0.763939,0.835352,0.85759,0.81636,0.334546,0.69962,0.231548,0.182711,0.918799,0.079698
S0004,0.501639,0.469028,0.42088,0.452981,0.224536,0.251611,0.563269,0.440885,0.481131,0.533506,...,0.783448,0.937656,0.763923,0.347323,0.456205,0.36314,0.413243,0.762212,0.614587,0.16442


In [11]:
pd_cluster_features = pd.read_csv(train_dir / 'ae_clusters.csv', index_col='sample_id')

In [12]:
pd_cluster_features.head()

Unnamed: 0_level_0,cluster_mz0,cluster_mz1,cluster_mz2,cluster_mz3,cluster_mz5,cluster_mz6,cluster_mz7,cluster_mz8,cluster_mz9,cluster_mz10,...,cluster_mz90,cluster_mz91,cluster_mz92,cluster_mz93,cluster_mz94,cluster_mz95,cluster_mz96,cluster_mz97,cluster_mz98,cluster_mz99
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0000,9.0,9.0,9.0,9.0,9.0,7.0,7.0,7.0,7.0,7.0,...,9.0,7.0,4.0,9.0,9.0,9.0,9.0,7.0,7.0,9.0
S0001,-1.0,4.0,7.0,9.0,4.0,7.0,7.0,4.0,9.0,9.0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
S0002,9.0,9.0,9.0,9.0,9.0,7.0,6.0,6.0,6.0,7.0,...,6.0,7.0,6.0,7.0,6.0,6.0,7.0,8.0,8.0,1.0
S0003,9.0,9.0,3.0,9.0,9.0,6.0,7.0,10.0,6.0,11.0,...,6.0,6.0,8.0,6.0,0.0,7.0,8.0,7.0,7.0,8.0
S0004,9.0,9.0,9.0,9.0,9.0,9.0,9.0,6.0,9.0,0.0,...,7.0,7.0,7.0,6.0,7.0,7.0,6.0,8.0,6.0,8.0


In [13]:
pd_sample_features = pd.read_csv(train_dir / 'sample_features.csv', index_col='sample_id')

In [14]:
pd_features = pd.concat((pd_sample_features, pd_agg_features, pd_cluster_features), axis=1)

In [15]:
feature_names = pd_features.columns.to_list()

In [16]:
from src import util

In [17]:
data = pd.concat((pd_train_target, pd_multclass_target, pd_features), axis=1)

In [18]:
data.head()

Unnamed: 0_level_0,basalt,carbonate,chloride,iron_oxide,oxalate,oxychlorine,phyllosilicate,silicate,sulfate,sulfide,...,cluster_mz90,cluster_mz91,cluster_mz92,cluster_mz93,cluster_mz94,cluster_mz95,cluster_mz96,cluster_mz97,cluster_mz98,cluster_mz99
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0000,0,0,0,0,0,0,0,0,1,0,...,9.0,7.0,4.0,9.0,9.0,9.0,9.0,7.0,7.0,9.0
S0001,0,1,0,0,0,0,0,0,0,0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
S0002,0,0,0,0,0,1,0,0,0,0,...,6.0,7.0,6.0,7.0,6.0,6.0,7.0,8.0,8.0,1.0
S0003,0,1,0,1,0,0,0,0,1,0,...,6.0,6.0,8.0,6.0,0.0,7.0,8.0,7.0,7.0,8.0
S0004,0,0,0,1,0,1,1,0,0,0,...,7.0,7.0,7.0,6.0,7.0,7.0,6.0,8.0,6.0,8.0


In [19]:
from src.model_selection import get_train_test_tuple_from_split
from src.data import get_cv_paths
from sklearn.linear_model import LogisticRegression
from xgboost import XGBRFClassifier
from lightgbm import LGBMClassifier
from sklearn.svm import SVC

In [20]:
model_config = {'model': 'lgbm',
                'parameters':
               {'class_weight': 'balanced',
                'n_estimators': 50, 'colsample_bytree': 0.3}}

In [21]:
from src import train, inference

In [22]:
def train_one_vs_the_rest(data: pd.DataFrame, model_config, feature_names):
    
    models = {}
    
    for target_name in cfg.TARGETS:
        cv_paths = get_cv_paths(cv_dir, target_name)
        multiclass_target_name = f'{target_name}_multiclass'
        models[target_name] = train.train_cv_from_config(data, model_config, feature_names, multiclass_target_name, cv_paths=cv_paths)
    return models

In [23]:
models = train_one_vs_the_rest(data, model_config, feature_names)

2022-04-17 11:06:07 - src.train - INFO     [train.py:62] fold=1/24
2022-04-17 11:06:07 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/basalt/fold_0.csv
2022-04-17 11:06:08 - src.train - INFO     [train.py:28] elapsed training time: 0.018 min
2022-04-17 11:06:08 - src.train - INFO     [train.py:62] fold=2/24
2022-04-17 11:06:09 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/basalt/fold_1.csv
2022-04-17 11:06:10 - src.train - INFO     [train.py:28] elapsed training time: 0.019 min
2022-04-17 11:06:10 - src.train - INFO     [train.py:62] fold=3/24
2022-04-17 11:06:10 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/basalt/fold_10.csv
2022-04-17 11:06:11 - src.train - INFO     [train.py:28] elapsed training time: 0.020 min
2022-04-17 11:06:11 - src.train - INFO     [train.py:62] fold=4/24
2022-04-17 11:06:11 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model

2022-04-17 11:06:48 - src.train - INFO     [train.py:28] elapsed training time: 0.018 min
2022-04-17 11:06:48 - src.train - INFO     [train.py:62] fold=7/24
2022-04-17 11:06:48 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/carbonate/fold_14.csv
2022-04-17 11:06:51 - src.train - INFO     [train.py:28] elapsed training time: 0.045 min
2022-04-17 11:06:51 - src.train - INFO     [train.py:62] fold=8/24
2022-04-17 11:06:51 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/carbonate/fold_15.csv
2022-04-17 11:06:53 - src.train - INFO     [train.py:28] elapsed training time: 0.034 min
2022-04-17 11:06:53 - src.train - INFO     [train.py:62] fold=9/24
2022-04-17 11:06:53 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/carbonate/fold_16.csv
2022-04-17 11:06:54 - src.train - INFO     [train.py:28] elapsed training time: 0.014 min
2022-04-17 11:06:54 - src.train - INFO     [train.py:62] fold=10/24

2022-04-17 11:07:19 - src.train - INFO     [train.py:62] fold=12/24
2022-04-17 11:07:19 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/chloride/fold_19.csv
2022-04-17 11:07:20 - src.train - INFO     [train.py:28] elapsed training time: 0.014 min
2022-04-17 11:07:20 - src.train - INFO     [train.py:62] fold=13/24
2022-04-17 11:07:20 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/chloride/fold_2.csv
2022-04-17 11:07:21 - src.train - INFO     [train.py:28] elapsed training time: 0.017 min
2022-04-17 11:07:21 - src.train - INFO     [train.py:62] fold=14/24
2022-04-17 11:07:21 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/chloride/fold_20.csv
2022-04-17 11:07:22 - src.train - INFO     [train.py:28] elapsed training time: 0.015 min
2022-04-17 11:07:22 - src.train - INFO     [train.py:62] fold=15/24
2022-04-17 11:07:22 - src.train - INFO     [train.py:63] reading cv index from data/cv_ind

2022-04-17 11:08:03 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/iron_oxide/fold_23.csv
2022-04-17 11:08:05 - src.train - INFO     [train.py:28] elapsed training time: 0.033 min
2022-04-17 11:08:05 - src.train - INFO     [train.py:62] fold=18/24
2022-04-17 11:08:05 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/iron_oxide/fold_3.csv
2022-04-17 11:08:07 - src.train - INFO     [train.py:28] elapsed training time: 0.023 min
2022-04-17 11:08:07 - src.train - INFO     [train.py:62] fold=19/24
2022-04-17 11:08:07 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/iron_oxide/fold_4.csv
2022-04-17 11:08:07 - src.train - INFO     [train.py:28] elapsed training time: 0.009 min
2022-04-17 11:08:07 - src.train - INFO     [train.py:62] fold=20/24
2022-04-17 11:08:07 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/iron_oxide/fold_5.csv
2022-04-17 11:08:08 - src.tra

2022-04-17 11:08:33 - src.train - INFO     [train.py:28] elapsed training time: 0.009 min
2022-04-17 11:08:33 - src.train - INFO     [train.py:62] fold=23/24
2022-04-17 11:08:33 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/oxalate/fold_8.csv
2022-04-17 11:08:34 - src.train - INFO     [train.py:28] elapsed training time: 0.009 min
2022-04-17 11:08:34 - src.train - INFO     [train.py:62] fold=24/24
2022-04-17 11:08:34 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/oxalate/fold_9.csv
2022-04-17 11:08:35 - src.train - INFO     [train.py:28] elapsed training time: 0.008 min
2022-04-17 11:08:35 - src.train - INFO     [train.py:62] fold=1/24
2022-04-17 11:08:35 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/oxychlorine/fold_0.csv
2022-04-17 11:08:35 - src.train - INFO     [train.py:28] elapsed training time: 0.010 min
2022-04-17 11:08:35 - src.train - INFO     [train.py:62] fold=2/24
202

2022-04-17 11:08:57 - src.train - INFO     [train.py:28] elapsed training time: 0.015 min
2022-04-17 11:08:57 - src.train - INFO     [train.py:62] fold=4/24
2022-04-17 11:08:57 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/phyllosilicate/fold_11.csv
2022-04-17 11:08:58 - src.train - INFO     [train.py:28] elapsed training time: 0.018 min
2022-04-17 11:08:58 - src.train - INFO     [train.py:62] fold=5/24
2022-04-17 11:08:58 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/phyllosilicate/fold_12.csv
2022-04-17 11:08:59 - src.train - INFO     [train.py:28] elapsed training time: 0.010 min
2022-04-17 11:08:59 - src.train - INFO     [train.py:62] fold=6/24
2022-04-17 11:08:59 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/phyllosilicate/fold_13.csv
2022-04-17 11:09:00 - src.train - INFO     [train.py:28] elapsed training time: 0.013 min
2022-04-17 11:09:00 - src.train - INFO     [train.py

2022-04-17 11:09:26 - src.train - INFO     [train.py:28] elapsed training time: 0.029 min
2022-04-17 11:09:26 - src.train - INFO     [train.py:62] fold=9/24
2022-04-17 11:09:26 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/silicate/fold_16.csv
2022-04-17 11:09:27 - src.train - INFO     [train.py:28] elapsed training time: 0.023 min
2022-04-17 11:09:27 - src.train - INFO     [train.py:62] fold=10/24
2022-04-17 11:09:27 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/silicate/fold_17.csv
2022-04-17 11:09:28 - src.train - INFO     [train.py:28] elapsed training time: 0.017 min
2022-04-17 11:09:28 - src.train - INFO     [train.py:62] fold=11/24
2022-04-17 11:09:28 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/silicate/fold_18.csv
2022-04-17 11:09:30 - src.train - INFO     [train.py:28] elapsed training time: 0.017 min
2022-04-17 11:09:30 - src.train - INFO     [train.py:62] fold=12/24


2022-04-17 11:09:51 - src.train - INFO     [train.py:62] fold=14/24
2022-04-17 11:09:51 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/sulfate/fold_20.csv
2022-04-17 11:09:52 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:09:52 - src.train - INFO     [train.py:62] fold=15/24
2022-04-17 11:09:52 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/sulfate/fold_21.csv
2022-04-17 11:09:53 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:09:53 - src.train - INFO     [train.py:62] fold=16/24
2022-04-17 11:09:53 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/sulfate/fold_22.csv
2022-04-17 11:09:53 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:09:53 - src.train - INFO     [train.py:62] fold=17/24
2022-04-17 11:09:53 - src.train - INFO     [train.py:63] reading cv index from data/cv_index

2022-04-17 11:10:18 - src.train - INFO     [train.py:28] elapsed training time: 0.010 min
2022-04-17 11:10:18 - src.train - INFO     [train.py:62] fold=20/24
2022-04-17 11:10:18 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/sulfide/fold_5.csv
2022-04-17 11:10:19 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:10:19 - src.train - INFO     [train.py:62] fold=21/24
2022-04-17 11:10:19 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/sulfide/fold_6.csv
2022-04-17 11:10:19 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:10:19 - src.train - INFO     [train.py:62] fold=22/24
2022-04-17 11:10:19 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model/sulfide/fold_7.csv
2022-04-17 11:10:20 - src.train - INFO     [train.py:28] elapsed training time: 0.010 min
2022-04-17 11:10:20 - src.train - INFO     [train.py:62] fold=23/24
2022-

In [24]:
def predict_multiclass_fn(model, test_data):
    probs = model.predict_proba(test_data)
    return probs[:, 2:].sum(axis=1)

In [25]:
scores = {}

for target_name in cfg.TARGETS:
    cv_paths = get_cv_paths(cv_dir, target_name)
    model = models[target_name]
    scores[target_name] = inference.cross_validation_inference(data, target_name, model, cv_paths, predict_multiclass_fn)
scores = pd.DataFrame(scores)

In [26]:
cv_scores = scores.mean()

In [27]:
cv_scores

basalt            0.148178
carbonate         0.105477
chloride          0.165927
iron_oxide        0.194785
oxalate           0.028967
oxychlorine       0.155488
phyllosilicate    0.253738
silicate          0.182060
sulfate           0.189164
sulfide           0.063096
dtype: float64

uniform distribution
basalt            0.148178
carbonate         0.105477
chloride          0.165927
iron_oxide        0.194785
oxalate           0.028967
oxychlorine       0.155488
phyllosilicate    0.253738
silicate          0.182060
sulfate           0.189164
sulfide           0.063096
dtype: float64

In [28]:
avg_loss = cv_scores.mean()

In [29]:
avg_loss

0.14868795489350956

In [30]:
test_dir = Path(dirs['test'])

In [31]:
pd_test_agg_features = pd.read_csv(test_dir / 'mz_agg_features.csv', index_col='sample_id')


pd_test_cluster_features = pd.read_csv(test_dir / 'ae_clusters.csv', index_col='sample_id')

test_sample_features = pd.read_csv(test_dir / 'sample_features.csv', index_col='sample_id')
pd_test_features = pd.concat((test_sample_features, pd_test_agg_features, pd_test_cluster_features), axis=1)

In [32]:
pd_test_features.head()

Unnamed: 0_level_0,sample_mol_ion_less99,sample_weighted_mass,sample_max_temp,sample_min_temp,sample_temp_range,sum_mz0,sum_mz1,sum_mz2,sum_mz3,sum_mz5,...,cluster_mz90,cluster_mz91,cluster_mz92,cluster_mz93,cluster_mz94,cluster_mz95,cluster_mz96,cluster_mz97,cluster_mz98,cluster_mz99
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0766,0,0.005432,0.501948,0.34025,0.505124,0.926832,0.920168,0.871866,0.897009,0.924933,...,6.0,6.0,7.0,7.0,6.0,8.0,7.0,7.0,6.0,8.0
S0767,0,0.065747,0.930877,0.220117,0.934335,0.875786,0.857193,0.809102,0.840202,0.874561,...,7.0,7.0,8.0,7.0,8.0,8.0,7.0,7.0,7.0,7.0
S0768,0,0.342944,0.84934,0.737245,0.847321,0.614612,0.576422,0.549181,0.580665,0.569098,...,7.0,8.0,7.0,0.0,7.0,7.0,8.0,6.0,7.0,6.0
S0769,0,0.864026,0.081626,0.962396,0.081617,0.145793,0.101693,0.068192,0.1117,0.105252,...,11.0,10.0,7.0,11.0,4.0,7.0,8.0,10.0,7.0,6.0
S0770,0,0.765547,0.674757,0.653061,0.675945,0.482291,0.446705,0.411472,0.443042,0.465338,...,8.0,0.0,9.0,7.0,7.0,8.0,6.0,6.0,7.0,8.0


In [33]:
from src.preprocessing import post_processing_prediction

In [34]:
test_yhat = {}

for target_name in cfg.TARGETS:
    target_models = models[target_name]
    target_yhat = np.mean([predict_multiclass_fn(model, pd_test_features) for model in target_models], axis=0)
    target_yhat = post_processing_prediction(pd_test_features, target_name, target_yhat)
    test_yhat[target_name] = target_yhat
test_yhat = pd.DataFrame(test_yhat, index=pd_test_features.index)

In [35]:
train_yhat = {}
for target_name in cfg.TARGETS:
    target_models = models[target_name]
    cv_paths = get_cv_paths(cv_dir, target_name)
    target_yhat = inference.compute_avg_prediction(data, target_models, cv_paths, predict_multiclass_fn)
    train_yhat[target_name] = target_yhat
train_yhat = pd.DataFrame(train_yhat)

In [36]:
sub_dir = Path(dirs['submission'])

In [37]:
sub_dir = sub_dir.joinpath('lgbm', 'validation')

In [38]:
sub_dir.mkdir(exist_ok=True, parents=True)

In [39]:
test_yhat.to_csv(sub_dir / 'submission.csv', index=True)

In [40]:
train_yhat.to_csv(sub_dir / 'train.csv', index=True)