In [1]:
import torch
import joblib
from tqdm import tqdm
import pandas as pd
import numpy as np
from sklearn.linear_model import PoissonRegressor
from sklearn.svm import SVR, NuSVR
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, pearsonr

In [2]:
x_train = torch.load('../data/X_tensor_WPC_cpu.pt')
y_train = torch.load('../data/y_tensor_WPC.pt')
ref_names = joblib.load('../data/ref_names_WPC.pkl')
codec_names = joblib.load('../data/degradations_WPC.pkl')

In [3]:
X_train = [[v.detach().numpy() for v in x] for x in x_train]

In [4]:
refs = list(set(ref_names))
codecs = list(set(codec_names))

In [5]:
refs

['biscuits',
 'banana',
 'puer_tea',
 'glasses_case',
 'litchi',
 'pen_container',
 'house',
 'pineapple',
 'tool_box',
 'stone',
 'statue',
 'ping-pong_bat',
 'cauliflower',
 'honeydew_melon',
 'ship',
 'pumpkin',
 'bag',
 'mushroom',
 'flowerpot',
 'cake']

In [6]:
codecs

['G-PCC (T)/S-PCC',
 'V-PCC',
 'Downsampling',
 'G-PCC (O)/L-PCC',
 'Gaussian noise']

In [7]:
# the key is the reference that is excluded from the group
groups_by_ref = {}
for ref in refs:
    xtrain, ytrain = [], []
    xtest, ytest = [], []
    for i, ref_name in enumerate(ref_names):
        if ref_name == ref:
            xtest.append(X_train[i])
            ytest.append(y_train[i])
            continue
        xtrain.append(X_train[i])
        ytrain.append(y_train[i])
    groups_by_ref[ref] = [xtrain, ytrain, xtest, ytest]

In [8]:
# the key is the reference that is excluded from the group
groups_by_codec = {}
for codec in codecs:
    xtrain, ytrain = [], []
    xtest, ytest = [], []
    for i, codec_name in enumerate(codec_names):
        if codec_name == codec:
            xtest.append(X_train[i])
            ytest.append(y_train[i])
            continue
        xtrain.append(X_train[i])
        ytrain.append(y_train[i])
    groups_by_codec[codec] = [xtrain, ytrain, xtest, ytest]

In [9]:
def get_nusvr_model():
    return NuSVR(
    nu=0.42857142857142855,
    kernel='rbf',
    gamma=1, # type: ignore
    degree=2,
    C=50
)

In [10]:
def get_svr_model():
    return SVR(
        kernel='rbf',
        gamma=1,  # type: ignore
        epsilon=1,
        degree=2,
        C=5
    )

In [11]:
def get_pr_model():
    return PoissonRegressor(
        solver='lbfgs',
        max_iter=10,
        fit_intercept=True,
        alpha=0.01
    )

In [12]:
models = ['nusvr', 'svr', 'pr']

In [24]:
results_by_ref = []
for ref_out, xy in tqdm(groups_by_ref.items()):
    result = {'group_out': ref_out}
    xtrain, ytrain = xy[0], xy[1]
    xtest, ytest = xy[2], xy[3]
    for model_name in models:
        if model_name == 'nusvr':
            model = get_nusvr_model()
        if model_name == 'svr':
            model = get_svr_model()
        if model_name == 'pr':
            model = get_pr_model()
        model.fit(xtrain, ytrain)
        ypred = model.predict(xtest)
        result[f'{model_name}-pearson'] = pearsonr(ytest, ypred)[0]
        result[f'{model_name}-spearman'] = spearmanr(ytest, ypred)[0]
        result[f'{model_name}-rmse'] = np.sqrt(mean_squared_error(ytest, ypred))
    results_by_ref.append(result)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATION

In [25]:
results_by_codec = []
for codec_out, xy in tqdm(groups_by_codec.items()):
    result = {'group_out': codec_out}
    xtrain, ytrain = xy[0], xy[1]
    xtest, ytest = xy[2], xy[3]
    for model_name in models:
        if model_name == 'nusvr':
            model = get_nusvr_model()
        if model_name == 'svr':
            model = get_svr_model()
        if model_name == 'pr':
            model = get_pr_model()
        model.fit(xtrain, ytrain)
        ypred = model.predict(xtest)
        result[f'{model_name}-pearson'] = pearsonr(ytest, ypred)[0]
        result[f'{model_name}-spearman'] = spearmanr(ytest, ypred)[0]
        result[f'{model_name}-rmse'] = np.sqrt(mean_squared_error(ytest, ypred))
    results_by_codec.append(result)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATION

In [27]:
df_results_by_ref = pd.DataFrame(results_by_ref)
df_results_by_codec = pd.DataFrame(results_by_codec)

In [29]:
to_concat_by_ref = {
    'group_out': 'mean',
    'nusvr-pearson': df_results_by_ref.loc[:, 'nusvr-pearson'].mean(),
    'nusvr-spearman': df_results_by_ref.loc[:, 'nusvr-spearman'].mean(),
    'nusvr-rmse': df_results_by_ref.loc[:, 'nusvr-rmse'].mean(),
    'svr-pearson': df_results_by_ref.loc[:, 'svr-pearson'].mean(),
    'svr-spearman': df_results_by_ref.loc[:, 'svr-spearman'].mean(),
    'svr-rmse': df_results_by_ref.loc[:, 'svr-rmse'].mean(),
    'pr-pearson': df_results_by_ref.loc[:, 'pr-pearson'].mean(),
    'pr-spearman': df_results_by_ref.loc[:, 'pr-spearman'].mean(),
    'pr-rmse': df_results_by_ref.loc[:, 'pr-rmse'].mean()
}

In [30]:
to_concat_by_codec = {
    'group_out': 'mean',
    'nusvr-pearson': df_results_by_codec.loc[:, 'nusvr-pearson'].mean(),
    'nusvr-spearman': df_results_by_codec.loc[:, 'nusvr-spearman'].mean(),
    'nusvr-rmse': df_results_by_codec.loc[:, 'nusvr-rmse'].mean(),
    'svr-pearson': df_results_by_codec.loc[:, 'svr-pearson'].mean(),
    'svr-spearman': df_results_by_codec.loc[:, 'svr-spearman'].mean(),
    'svr-rmse': df_results_by_codec.loc[:, 'svr-rmse'].mean(),
    'pr-pearson': df_results_by_codec.loc[:, 'pr-pearson'].mean(),
    'pr-spearman': df_results_by_codec.loc[:, 'pr-spearman'].mean(),
    'pr-rmse': df_results_by_codec.loc[:, 'pr-rmse'].mean()
}

In [39]:
df_conc_by_ref = pd.DataFrame([to_concat_by_ref])
df_results_by_ref = pd.concat([df_results_by_ref, df_conc_by_ref])

In [40]:
df_conc_by_codec = pd.DataFrame([to_concat_by_codec])
df_results_by_codec = pd.concat([df_results_by_codec, df_conc_by_codec])

In [41]:
df_results_by_ref

Unnamed: 0,group_out,nusvr-pearson,nusvr-spearman,nusvr-rmse,svr-pearson,svr-spearman,svr-rmse,pr-pearson,pr-spearman,pr-rmse
0,biscuits,0.82419,0.840446,13.675668,0.827944,0.83926,12.991666,0.803648,0.839972,14.330817
1,banana,0.606233,0.552394,17.685598,0.62644,0.578947,17.34932,0.622866,0.569938,17.164066
2,puer_tea,0.625697,0.671645,25.361042,0.638017,0.670934,24.720671,0.639954,0.650545,24.509574
3,glasses_case,0.804486,0.811759,16.023422,0.802631,0.808677,16.463656,0.798169,0.810574,15.299373
4,litchi,0.782992,0.752252,16.868621,0.790947,0.761498,16.530225,0.776801,0.750593,17.163536
5,pen_container,0.899217,0.922475,16.800718,0.904089,0.923898,17.417051,0.906997,0.922949,15.074298
6,house,0.821931,0.824561,13.793459,0.813883,0.825747,14.013349,0.812125,0.830488,14.527902
7,pineapple,0.776881,0.790185,13.525366,0.778988,0.789474,13.979558,0.778381,0.787814,13.120743
8,tool_box,0.85202,0.859886,11.128779,0.829753,0.852063,11.889417,0.848932,0.874348,11.363182
9,stone,0.752124,0.762447,14.258058,0.753521,0.773826,14.101727,0.747199,0.771693,14.296634


In [42]:
df_results_by_codec

Unnamed: 0,group_out,nusvr-pearson,nusvr-spearman,nusvr-rmse,svr-pearson,svr-spearman,svr-rmse,pr-pearson,pr-spearman,pr-rmse
0,G-PCC (T)/S-PCC,0.858,0.863,14.14,0.857,0.858,14.47,0.851,0.86,15.16
1,V-PCC,0.553,0.565,14.49,0.55,0.559,14.44,0.574,0.572,14.85
2,Downsampling,0.787,0.858,35.3,0.79,0.862,35.64,0.833,0.856,35.83
3,G-PCC (O)/L-PCC,0.917,0.908,11.78,0.915,0.905,12.51,0.888,0.901,13.31
4,Gaussian noise,0.89,0.879,10.0,0.887,0.879,10.67,0.908,0.873,9.17
0,mean,0.801,0.815,17.14,0.8,0.812,17.55,0.811,0.812,17.66
0,mean,0.801074,0.814909,17.141028,0.799581,0.812489,17.54645,0.810585,0.812365,17.663145
