In [1]:
import torch
import joblib
from tqdm import tqdm
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.svm import SVR
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, pearsonr

In [2]:
x_train = torch.load('../data/X_tensor_APSIPA.pt')
y_train = torch.load('../data/y_tensor_APSIPA.pt')
ref_names = joblib.load('../data/ref_names_APSIPA.pkl')

In [3]:
X_train = []
for x in x_train:
    vector = [v.detach().numpy() for v in x]
    X_train.append(vector)

In [4]:
refs = list(set(ref_names))

In [5]:
refs

['romanoillamp_vox10',
 'loot_vox10_1200',
 'head_00039_vox9',
 'the20smaria_00600_vox10',
 'soldier_vox10_0690',
 'amphoriskos_vox10',
 'longdress_vox10_1300',
 'biplane_vox10']

In [6]:
# the key is the reference that is excluded from the group
groups = {}
for ref in refs:
    xtrain, ytrain = [], []
    xtest, ytest = [], []
    for i, ref_name in enumerate(ref_names):
        if ref_name == ref:
            xtest.append(X_train[i])
            ytest.append(y_train[i])
            continue
        xtrain.append(X_train[i])
        ytrain.append(y_train[i])
    groups[ref] = [xtrain, ytrain, xtest, ytest]

In [7]:
def get_etr_model():
    return ExtraTreesRegressor(
    n_estimators=37,
    min_samples_split=15,
    min_samples_leaf=4,
    max_features='log2',  # type: ignore
    max_depth=7
)

In [8]:
def get_svr_model():
    return SVR(
        kernel='rbf',
        gamma=1,  # type: ignore
        epsilon=0.01,
        degree=2,
        C=5
    )

In [9]:
def get_lgbm_model():
    return lgb.LGBMRegressor(
        subsample_for_bin=140000,
        reg_lambda=0.1,
        reg_alpha=1.0,
        num_leaves=100,
        n_estimators=166,
        min_split_gain=1,
        min_child_weight=0.0001,
        min_child_samples=20,
        learning_rate=0.1,
        colsample_bytree=1.0,
        boosting_type='dart'
    )

In [10]:
models = ['lgbm', 'svr', 'etr']

In [11]:
results = []
for ref_out, xy in tqdm(groups.items()):
    result = {'group_out': ref_out}
    xtrain, ytrain = xy[0], xy[1]
    xtest, ytest = xy[2], xy[3]
    for model_name in models:
        if model_name == 'lgbm':
            model = get_lgbm_model()
        if model_name == 'svr':
            model = get_svr_model()
        if model_name == 'etr':
            model = get_etr_model()
        model.fit(xtrain, ytrain)
        ypred = model.predict(xtest)
        result[f'{model_name}-pearson'] = pearsonr(ytest, ypred)[0]
        result[f'{model_name}-spearman'] = spearmanr(ytest, ypred)[0]
        result[f'{model_name}-mse'] = mean_squared_error(ytest, ypred)
    results.append(result)

100%|██████████| 8/8 [00:01<00:00,  5.13it/s]


In [12]:
df_results = pd.DataFrame(results)

In [13]:
to_concat = {
    'group_out': 'mean',
    'lgbm-pearson': df_results.loc[:, 'lgbm-pearson'].mean(),
    'lgbm-spearman': df_results.loc[:, 'lgbm-spearman'].mean(),
    'lgbm-mse': df_results.loc[:, 'lgbm-mse'].mean(),
    'svr-pearson': df_results.loc[:, 'svr-pearson'].mean(),
    'svr-spearman': df_results.loc[:, 'svr-spearman'].mean(),
    'svr-mse': df_results.loc[:, 'svr-mse'].mean(),
    'etr-pearson': df_results.loc[:, 'etr-pearson'].mean(),
    'etr-spearman': df_results.loc[:, 'etr-spearman'].mean(),
    'etr-mse': df_results.loc[:, 'etr-mse'].mean()
}

In [14]:
df_conc = pd.DataFrame([to_concat])
df_results = pd.concat([df_results, df_conc])

In [15]:
df_results

Unnamed: 0,group_out,lgbm-pearson,lgbm-spearman,lgbm-mse,svr-pearson,svr-spearman,svr-mse,etr-pearson,etr-spearman,etr-mse
0,romanoillamp_vox10,0.954732,0.964774,0.65014,0.943044,0.94968,0.457908,0.942447,0.95264,0.45749
1,loot_vox10_1200,0.938353,0.972065,0.194798,0.940321,0.972716,0.307531,0.945365,0.972716,0.333405
2,head_00039_vox9,0.954444,0.973553,0.693719,0.958956,0.972627,0.486603,0.947881,0.979532,0.517254
3,the20smaria_00600_vox10,0.95141,0.97028,0.153863,0.954043,0.977645,0.234751,0.959052,0.971221,0.234188
4,soldier_vox10_0690,0.956258,0.973613,0.144662,0.950766,0.973577,0.22772,0.951414,0.973083,0.240087
5,amphoriskos_vox10,0.90541,0.941721,0.317693,0.924268,0.952076,0.282869,0.924621,0.954047,0.309589
6,longdress_vox10_1300,0.886687,0.961844,0.467367,0.933535,0.974565,0.40905,0.907957,0.971602,0.5002
7,biplane_vox10,0.949477,0.981732,0.951615,0.939075,0.965874,0.696461,0.933021,0.976223,0.70331
0,mean,0.937096,0.967448,0.446732,0.943001,0.967345,0.387862,0.93897,0.968883,0.41194
