In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import os

In [3]:
PROJECT_DIR = os.path.abspath('.')
if PROJECT_DIR.endswith('final-nbs'):
    PROJECT_DIR = os.path.abspath('../')
    os.chdir(PROJECT_DIR)

In [4]:
import cfg
from src.data import get_features_path_from_metadata, join_dataframe_columns
from src import util
from src.data import setup_directories
util.setup_logging()

dirs = setup_directories(cfg.DATA_DIR, create_dirs=True)

In [5]:
raw_dir = Path(dirs['raw'])
train_dir = Path(dirs['train'])
cv_dir = Path(dirs['cv']['test'])

In [6]:
# read metadata
pd_metadata = pd.read_csv(raw_dir / "metadata.csv", index_col="sample_id")
pd_metadata.head()

Unnamed: 0_level_0,split,instrument_type,features_path,features_md5_hash
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
S0000,train,commercial,train_features/S0000.csv,017b9a71a702e81a828e6242aa15f049
S0001,train,commercial,train_features/S0001.csv,0d09840214054d254bd49436c6a6f315
S0002,train,commercial,train_features/S0002.csv,3f58b3c9b001bfed6ed4e4f757083e09
S0003,train,commercial,train_features/S0003.csv,e9a12f96114a2fda60b36f4c0f513fb1
S0004,train,commercial,train_features/S0004.csv,b67603d3931897bfa796ac42cc16de78


In [7]:
# read train labels
pd_train_target = pd.read_csv(raw_dir / 'train_labels.csv', index_col='sample_id')
# read train labels
pd_valid_target = pd.read_csv(raw_dir / 'val_labels.csv', index_col='sample_id')

In [8]:
# read train labels
pd_multclass_target = pd.read_csv(train_dir / 'multiclass.csv', index_col='sample_id')
pd_agg_features = pd.read_csv(train_dir / 'mz_agg_features_drop_correlated.csv', index_col='sample_id')
pd_cluster_features = pd.read_csv(train_dir / 'ae_clusters.csv', index_col='sample_id')
pd_sample_features = pd.read_csv(train_dir / 'sample_features.csv', index_col='sample_id')
pd_features = pd.concat((pd_sample_features, pd_agg_features, pd_cluster_features), axis=1)

In [9]:
valid_dir = Path(dirs['valid'])

# read train labels
pd_valid_multclass_target = pd.read_csv(valid_dir / 'multiclass.csv', index_col='sample_id')
pd_valid_agg_features = pd.read_csv(valid_dir / 'mz_agg_features_drop_correlated.csv', index_col='sample_id')
pd_valid_cluster_features = pd.read_csv(valid_dir / 'ae_clusters.csv', index_col='sample_id')
pd_valid_sample_features = pd.read_csv(valid_dir / 'sample_features.csv', index_col='sample_id')
pd_valid_features = pd.concat((pd_valid_sample_features, pd_valid_agg_features, pd_valid_cluster_features), axis=1)

In [10]:
pd_features = pd_features.append(pd_valid_features)

In [11]:
pd_train_target = pd_train_target.append(pd_valid_target)

In [12]:
pd_multclass_target = pd_multclass_target.append(pd_valid_multclass_target)

In [13]:
feature_names = pd_features.columns.to_list()

In [14]:
from src import util

In [15]:
data = pd.concat((pd_train_target, pd_multclass_target, pd_features), axis=1)

In [16]:
data.head()

Unnamed: 0_level_0,basalt,carbonate,chloride,iron_oxide,oxalate,oxychlorine,phyllosilicate,silicate,sulfate,sulfide,...,cluster_mz90,cluster_mz91,cluster_mz92,cluster_mz93,cluster_mz94,cluster_mz95,cluster_mz96,cluster_mz97,cluster_mz98,cluster_mz99
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0000,0,0,0,0,0,0,0,0,1,0,...,9.0,7.0,4.0,9.0,9.0,9.0,9.0,7.0,7.0,9.0
S0001,0,1,0,0,0,0,0,0,0,0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
S0002,0,0,0,0,0,1,0,0,0,0,...,6.0,7.0,6.0,7.0,6.0,6.0,7.0,8.0,8.0,1.0
S0003,0,1,0,1,0,0,0,0,1,0,...,6.0,6.0,8.0,6.0,0.0,7.0,8.0,7.0,7.0,8.0
S0004,0,0,0,1,0,1,1,0,0,0,...,7.0,7.0,7.0,6.0,7.0,7.0,6.0,8.0,6.0,8.0


In [17]:
from src.model_selection import get_train_test_tuple_from_split
from src.data import get_cv_paths
from sklearn.linear_model import LogisticRegression
from xgboost import XGBRFClassifier
from lightgbm import LGBMClassifier
from sklearn.svm import SVC

In [18]:
model_config = {'model': 'lgbm',
                'parameters':
               {'class_weight': 'balanced',
                'n_estimators': 50, 'colsample_bytree': 0.3}}

In [19]:
from src import train, inference

In [20]:
def train_one_vs_the_rest(data: pd.DataFrame, model_config, feature_names):
    
    models = {}
    
    for target_name in cfg.TARGETS:
        cv_paths = get_cv_paths(cv_dir, target_name)
        multiclass_target_name = f'{target_name}_multiclass'
        models[target_name] = train.train_cv_from_config(data, model_config, feature_names, multiclass_target_name, cv_paths=cv_paths)
    return models

In [21]:
models = train_one_vs_the_rest(data, model_config, feature_names)

2022-04-17 11:12:24 - src.train - INFO     [train.py:62] fold=1/33
2022-04-17 11:12:24 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/basalt/fold_0.csv
2022-04-17 11:12:24 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:12:24 - src.train - INFO     [train.py:62] fold=2/33
2022-04-17 11:12:24 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/basalt/fold_1.csv
2022-04-17 11:12:25 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:12:25 - src.train - INFO     [train.py:62] fold=3/33
2022-04-17 11:12:25 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/basalt/fold_10.csv
2022-04-17 11:12:26 - src.train - INFO     [train.py:28] elapsed training time: 0.009 min
2022-04-17 11:12:26 - src.train - INFO     [train.py:62] fold=4/33
2022-04-17 11:12:26 - src.train - INFO     [train.py:63] reading cv index from data/cv

2022-04-17 11:12:48 - src.train - INFO     [train.py:62] fold=30/33
2022-04-17 11:12:48 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/basalt/fold_6.csv
2022-04-17 11:12:49 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:12:49 - src.train - INFO     [train.py:62] fold=31/33
2022-04-17 11:12:49 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/basalt/fold_7.csv
2022-04-17 11:12:50 - src.train - INFO     [train.py:28] elapsed training time: 0.015 min
2022-04-17 11:12:50 - src.train - INFO     [train.py:62] fold=32/33
2022-04-17 11:12:50 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/basalt/fold_8.csv
2022-04-17 11:12:51 - src.train - INFO     [train.py:28] elapsed training time: 0.010 min
2022-04-17 11:12:51 - src.train - INFO     [train.py:62] fold=33/33
2022-04-17 11:12:51 - src.train - INFO     [train.py:63] reading cv index from data

2022-04-17 11:13:10 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:13:10 - src.train - INFO     [train.py:62] fold=26/33
2022-04-17 11:13:10 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/carbonate/fold_31.csv
2022-04-17 11:13:12 - src.train - INFO     [train.py:28] elapsed training time: 0.023 min
2022-04-17 11:13:12 - src.train - INFO     [train.py:62] fold=27/33
2022-04-17 11:13:12 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/carbonate/fold_32.csv
2022-04-17 11:13:13 - src.train - INFO     [train.py:28] elapsed training time: 0.015 min
2022-04-17 11:13:13 - src.train - INFO     [train.py:62] fold=28/33
2022-04-17 11:13:13 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/carbonate/fold_4.csv
2022-04-17 11:13:14 - src.train - INFO     [train.py:28] elapsed training time: 0.023 min
2022-04-17 11:13:14 - src.train - INFO     [train.

2022-04-17 11:13:45 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:13:45 - src.train - INFO     [train.py:62] fold=22/33
2022-04-17 11:13:45 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/chloride/fold_28.csv
2022-04-17 11:13:46 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:13:46 - src.train - INFO     [train.py:62] fold=23/33
2022-04-17 11:13:46 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/chloride/fold_29.csv
2022-04-17 11:13:47 - src.train - INFO     [train.py:28] elapsed training time: 0.013 min
2022-04-17 11:13:47 - src.train - INFO     [train.py:62] fold=24/33
2022-04-17 11:13:47 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/chloride/fold_3.csv
2022-04-17 11:13:47 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:13:47 - src.train - INFO     [train.py:

2022-04-17 11:14:10 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:14:10 - src.train - INFO     [train.py:62] fold=18/33
2022-04-17 11:14:10 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/iron_oxide/fold_24.csv
2022-04-17 11:14:11 - src.train - INFO     [train.py:28] elapsed training time: 0.022 min
2022-04-17 11:14:11 - src.train - INFO     [train.py:62] fold=19/33
2022-04-17 11:14:11 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/iron_oxide/fold_25.csv
2022-04-17 11:14:12 - src.train - INFO     [train.py:28] elapsed training time: 0.013 min
2022-04-17 11:14:12 - src.train - INFO     [train.py:62] fold=20/33
2022-04-17 11:14:12 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/iron_oxide/fold_26.csv
2022-04-17 11:14:13 - src.train - INFO     [train.py:28] elapsed training time: 0.013 min
2022-04-17 11:14:13 - src.train - INFO     [tr

2022-04-17 11:14:37 - src.train - INFO     [train.py:28] elapsed training time: 0.016 min
2022-04-17 11:14:37 - src.train - INFO     [train.py:62] fold=14/33
2022-04-17 11:14:37 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/oxalate/fold_20.csv
2022-04-17 11:14:38 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:14:38 - src.train - INFO     [train.py:62] fold=15/33
2022-04-17 11:14:38 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/oxalate/fold_21.csv
2022-04-17 11:14:39 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:14:39 - src.train - INFO     [train.py:62] fold=16/33
2022-04-17 11:14:39 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/oxalate/fold_22.csv
2022-04-17 11:14:40 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:14:40 - src.train - INFO     [train.py:62

2022-04-17 11:15:04 - src.train - INFO     [train.py:28] elapsed training time: 0.013 min
2022-04-17 11:15:04 - src.train - INFO     [train.py:62] fold=10/33
2022-04-17 11:15:04 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/oxychlorine/fold_17.csv
2022-04-17 11:15:05 - src.train - INFO     [train.py:28] elapsed training time: 0.013 min
2022-04-17 11:15:05 - src.train - INFO     [train.py:62] fold=11/33
2022-04-17 11:15:05 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/oxychlorine/fold_18.csv
2022-04-17 11:15:06 - src.train - INFO     [train.py:28] elapsed training time: 0.021 min
2022-04-17 11:15:06 - src.train - INFO     [train.py:62] fold=12/33
2022-04-17 11:15:06 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/oxychlorine/fold_19.csv
2022-04-17 11:15:08 - src.train - INFO     [train.py:28] elapsed training time: 0.016 min
2022-04-17 11:15:08 - src.train - INFO     

2022-04-17 11:15:31 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/phyllosilicate/fold_12.csv
2022-04-17 11:15:32 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:15:32 - src.train - INFO     [train.py:62] fold=6/33
2022-04-17 11:15:32 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/phyllosilicate/fold_13.csv
2022-04-17 11:15:33 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:15:33 - src.train - INFO     [train.py:62] fold=7/33
2022-04-17 11:15:33 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/phyllosilicate/fold_14.csv
2022-04-17 11:15:33 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:15:33 - src.train - INFO     [train.py:62] fold=8/33
2022-04-17 11:15:33 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/phyllosilicate/fold_

2022-04-17 11:15:57 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:15:57 - src.train - INFO     [train.py:62] fold=1/33
2022-04-17 11:15:57 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/silicate/fold_0.csv
2022-04-17 11:15:57 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:15:57 - src.train - INFO     [train.py:62] fold=2/33
2022-04-17 11:15:57 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/silicate/fold_1.csv
2022-04-17 11:15:58 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:15:58 - src.train - INFO     [train.py:62] fold=3/33
2022-04-17 11:15:58 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/silicate/fold_10.csv
2022-04-17 11:15:59 - src.train - INFO     [train.py:28] elapsed training time: 0.011 min
2022-04-17 11:15:59 - src.train - INFO     [train.py:62] 

2022-04-17 11:16:26 - src.train - INFO     [train.py:28] elapsed training time: 0.054 min
2022-04-17 11:16:26 - src.train - INFO     [train.py:62] fold=30/33
2022-04-17 11:16:26 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/silicate/fold_6.csv
2022-04-17 11:16:27 - src.train - INFO     [train.py:28] elapsed training time: 0.016 min
2022-04-17 11:16:27 - src.train - INFO     [train.py:62] fold=31/33
2022-04-17 11:16:27 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/silicate/fold_7.csv
2022-04-17 11:16:28 - src.train - INFO     [train.py:28] elapsed training time: 0.014 min
2022-04-17 11:16:28 - src.train - INFO     [train.py:62] fold=32/33
2022-04-17 11:16:28 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/silicate/fold_8.csv
2022-04-17 11:16:29 - src.train - INFO     [train.py:28] elapsed training time: 0.012 min
2022-04-17 11:16:29 - src.train - INFO     [train.py:62

2022-04-17 11:16:56 - src.train - INFO     [train.py:28] elapsed training time: 0.026 min
2022-04-17 11:16:56 - src.train - INFO     [train.py:62] fold=26/33
2022-04-17 11:16:56 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/sulfate/fold_31.csv
2022-04-17 11:16:57 - src.train - INFO     [train.py:28] elapsed training time: 0.015 min
2022-04-17 11:16:57 - src.train - INFO     [train.py:62] fold=27/33
2022-04-17 11:16:57 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/sulfate/fold_32.csv
2022-04-17 11:16:59 - src.train - INFO     [train.py:28] elapsed training time: 0.032 min
2022-04-17 11:16:59 - src.train - INFO     [train.py:62] fold=28/33
2022-04-17 11:16:59 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/sulfate/fold_4.csv
2022-04-17 11:17:00 - src.train - INFO     [train.py:28] elapsed training time: 0.022 min
2022-04-17 11:17:00 - src.train - INFO     [train.py:62]

2022-04-17 11:17:43 - src.train - INFO     [train.py:28] elapsed training time: 0.016 min
2022-04-17 11:17:43 - src.train - INFO     [train.py:62] fold=22/33
2022-04-17 11:17:43 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/sulfide/fold_28.csv
2022-04-17 11:17:44 - src.train - INFO     [train.py:28] elapsed training time: 0.024 min
2022-04-17 11:17:44 - src.train - INFO     [train.py:62] fold=23/33
2022-04-17 11:17:44 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/sulfide/fold_29.csv
2022-04-17 11:17:46 - src.train - INFO     [train.py:28] elapsed training time: 0.024 min
2022-04-17 11:17:46 - src.train - INFO     [train.py:62] fold=24/33
2022-04-17 11:17:46 - src.train - INFO     [train.py:63] reading cv index from data/cv_index/cv-model-test/sulfide/fold_3.csv
2022-04-17 11:17:47 - src.train - INFO     [train.py:28] elapsed training time: 0.020 min
2022-04-17 11:17:47 - src.train - INFO     [train.py:62]

In [22]:
def predict_multiclass_fn(model, test_data):
    probs = model.predict_proba(test_data)
    return probs[:, 2:].sum(axis=1)

In [23]:
scores = {}

for target_name in cfg.TARGETS:
    cv_paths = get_cv_paths(cv_dir, target_name)
    model = models[target_name]
    scores[target_name] = inference.cross_validation_inference(data, target_name, model, cv_paths, predict_multiclass_fn)
scores = pd.DataFrame(scores)

In [24]:
cv_scores = scores.mean()

In [25]:
cv_scores

basalt            0.112896
carbonate         0.075074
chloride          0.142640
iron_oxide        0.171367
oxalate           0.019858
oxychlorine       0.122891
phyllosilicate    0.203330
silicate          0.165229
sulfate           0.164459
sulfide           0.052669
dtype: float64

uniform distribution
basalt            0.148178
carbonate         0.105477
chloride          0.165927
iron_oxide        0.194785
oxalate           0.028967
oxychlorine       0.155488
phyllosilicate    0.253738
silicate          0.182060
sulfate           0.189164
sulfide           0.063096
dtype: float64

In [26]:
avg_loss = cv_scores.mean()

In [27]:
avg_loss

0.12304118681822954

In [28]:
test_dir = Path(dirs['test'])

In [29]:
pd_test_agg_features = pd.read_csv(test_dir / 'mz_agg_features.csv', index_col='sample_id')


pd_test_cluster_features = pd.read_csv(test_dir / 'ae_clusters.csv', index_col='sample_id')

test_sample_features = pd.read_csv(test_dir / 'sample_features.csv', index_col='sample_id')
pd_test_features = pd.concat((test_sample_features, pd_test_agg_features, pd_test_cluster_features), axis=1)

In [30]:
pd_test_features.head()

Unnamed: 0_level_0,sample_mol_ion_less99,sample_weighted_mass,sample_max_temp,sample_min_temp,sample_temp_range,sum_mz0,sum_mz1,sum_mz2,sum_mz3,sum_mz5,...,cluster_mz90,cluster_mz91,cluster_mz92,cluster_mz93,cluster_mz94,cluster_mz95,cluster_mz96,cluster_mz97,cluster_mz98,cluster_mz99
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0766,0,0.005432,0.501948,0.34025,0.505124,0.926832,0.920168,0.871866,0.897009,0.924933,...,6.0,6.0,7.0,7.0,6.0,8.0,7.0,7.0,6.0,8.0
S0767,0,0.065747,0.930877,0.220117,0.934335,0.875786,0.857193,0.809102,0.840202,0.874561,...,7.0,7.0,8.0,7.0,8.0,8.0,7.0,7.0,7.0,7.0
S0768,0,0.342944,0.84934,0.737245,0.847321,0.614612,0.576422,0.549181,0.580665,0.569098,...,7.0,8.0,7.0,0.0,7.0,7.0,8.0,6.0,7.0,6.0
S0769,0,0.864026,0.081626,0.962396,0.081617,0.145793,0.101693,0.068192,0.1117,0.105252,...,11.0,10.0,7.0,11.0,4.0,7.0,8.0,10.0,7.0,6.0
S0770,0,0.765547,0.674757,0.653061,0.675945,0.482291,0.446705,0.411472,0.443042,0.465338,...,8.0,0.0,9.0,7.0,7.0,8.0,6.0,6.0,7.0,8.0


In [31]:
from src.preprocessing import post_processing_prediction

In [32]:
test_yhat = {}

for target_name in cfg.TARGETS:
    target_models = models[target_name]
    target_yhat = np.mean([predict_multiclass_fn(model, pd_test_features) for model in target_models], axis=0)
    target_yhat = post_processing_prediction(pd_test_features, target_name, target_yhat)
    test_yhat[target_name] = target_yhat
test_yhat = pd.DataFrame(test_yhat, index=pd_test_features.index)

In [33]:
train_yhat = {}
for target_name in cfg.TARGETS:
    target_models = models[target_name]
    cv_paths = get_cv_paths(cv_dir, target_name)
    target_yhat = inference.compute_avg_prediction(data, target_models, cv_paths, predict_multiclass_fn)
    train_yhat[target_name] = target_yhat
train_yhat = pd.DataFrame(train_yhat)

In [34]:
sub_dir = Path(dirs['submission'])

In [35]:
sub_dir = sub_dir.joinpath('lgbm', 'test')

In [36]:
sub_dir.mkdir(exist_ok=True, parents=True)

In [37]:
test_yhat.to_csv(sub_dir / 'submission.csv', index=True)

In [38]:
train_yhat.to_csv(sub_dir / 'train.csv', index=True)