In [61]:
import torch
import joblib
from tqdm import tqdm
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.svm import SVR
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, pearsonr

In [8]:
x_train = torch.load('../data/X_tensor_APSIPA.pt')
y_train = torch.load('../data/y_tensor_APSIPA.pt')
ref_names = joblib.load('../data/ref_names_APSIPA.pkl')

In [9]:
X_train = []
for x in x_train:
    vector = [v.detach().numpy() for v in x]
    X_train.append(vector)

In [11]:
refs = list(set(ref_names))

In [30]:
refs

['head_00039_vox9',
 'romanoillamp_vox10',
 'amphoriskos_vox10',
 'loot_vox10_1200',
 'biplane_vox10',
 'longdress_vox10_1300',
 'the20smaria_00600_vox10',
 'soldier_vox10_0690']

In [51]:
# the key is the reference that is excluded from the group
groups = {}
for ref in refs:
    xtrain, ytrain = [], []
    xtest, ytest = [], []
    for i, ref_name in enumerate(ref_names):
        if ref_name == ref:
            xtest.append(X_train[i])
            ytest.append(y_train[i])
            continue
        xtrain.append(X_train[i])
        ytrain.append(y_train[i])
    groups[ref] = [xtrain, ytrain, xtest, ytest]

In [34]:
etr_model = ExtraTreesRegressor(
    n_estimators=37,
    min_samples_split=15,
    min_samples_leaf=4,
    max_features='log2',  # type: ignore
    max_depth=7
)

In [36]:
svr_model = SVR(
    kernel='rbf',
    gamma=1,  # type: ignore
    epsilon=0.01,
    degree=2,
    C=5
)

In [37]:
lgbm_model = lgb.LGBMRegressor(
    subsample_for_bin=140000,
    reg_lambda=0.1,
    reg_alpha=1.0,
    num_leaves=100,
    n_estimators=166,
    min_split_gain=1,
    min_child_weight=0.0001,
    min_child_samples=20,
    learning_rate=0.1,
    colsample_bytree=1.0,
    boosting_type='dart'
)

In [38]:
models = {
    'lgbm': lgbm_model,
    'svr': svr_model,
    'etr': etr_model
}

In [72]:
results = []
for ref_out, xy in tqdm(groups.items()):
    result = {'group_out': ref_out}
    xtrain, ytrain = xy[0], xy[1]
    xtest, ytest = xy[2], xy[3]
    for model_name, model in models.items():
        model.fit(xtrain, ytrain)
        ypred = model.predict(xtest)
        result[f'{model_name}-pearson'] = pearsonr(ytest, ypred)[0]
        result[f'{model_name}-spearman'] = spearmanr(ytest, ypred)[0]
        result[f'{model_name}-mse'] = mean_squared_error(ytest, ypred)
    results.append(result)

100%|██████████| 8/8 [00:02<00:00,  2.73it/s]


In [73]:
df_results = pd.DataFrame(results)

In [76]:
to_concat = {
    'group_out': 'mean',
    'lgbm-pearson': df_results.loc[:, 'lgbm-pearson'].mean(),
    'lgbm-spearman': df_results.loc[:, 'lgbm-spearman'].mean(),
    'lgbm-mse': df_results.loc[:, 'lgbm-mse'].mean(),
    'svr-pearson': df_results.loc[:, 'svr-pearson'].mean(),
    'svr-spearman': df_results.loc[:, 'svr-spearman'].mean(),
    'svr-mse': df_results.loc[:, 'svr-mse'].mean(),
    'etr-pearson': df_results.loc[:, 'etr-pearson'].mean(),
    'etr-spearman': df_results.loc[:, 'etr-spearman'].mean(),
    'etr-mse': df_results.loc[:, 'etr-mse'].mean()
}

In [82]:
df_conc = pd.DataFrame([to_concat])
df_results = pd.concat([df_results, df_conc])

In [83]:
df_results

Unnamed: 0,group_out,lgbm-pearson,lgbm-spearman,lgbm-mse,svr-pearson,svr-spearman,svr-mse,etr-pearson,etr-spearman,etr-mse
0,head_00039_vox9,0.954444,0.973553,0.693719,0.958956,0.972627,0.486603,0.946839,0.985944,0.477254
1,romanoillamp_vox10,0.954732,0.964774,0.65014,0.943044,0.94968,0.457908,0.943354,0.953874,0.440951
2,amphoriskos_vox10,0.90541,0.941721,0.317693,0.924268,0.952076,0.282869,0.922278,0.953308,0.30656
3,loot_vox10_1200,0.938353,0.972065,0.194798,0.940321,0.972716,0.307531,0.921873,0.970247,0.433148
4,biplane_vox10,0.949477,0.981732,0.951615,0.939075,0.965874,0.696461,0.946222,0.977082,0.620164
5,longdress_vox10_1300,0.886687,0.961844,0.467367,0.933535,0.974565,0.40905,0.926261,0.97098,0.366517
6,the20smaria_00600_vox10,0.95141,0.97028,0.153863,0.954043,0.977645,0.234751,0.960962,0.974927,0.205289
7,soldier_vox10_0690,0.956258,0.973613,0.144662,0.950766,0.973577,0.22772,0.95269,0.975553,0.245573
0,mean,0.937096,0.967448,0.446732,0.943001,0.967345,0.387862,0.94006,0.970239,0.386932
