In [1]:
import os
import sys
import math
import numpy as np
import pandas as pd
import warnings
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import GroupShuffleSplit

sys.path.append(os.getcwd()+'/../scripts/')

from Training import Training

%load_ext autoreload
%autoreload 2

In [2]:
path = '../data/preprocessed/'
files = sorted(os.listdir(path))

In [4]:
atts = {'color': ['L', 'C', 'H'], 'fresh mass': 'FRESH MASS' , 'firmness': 'FIRMNESS', \
        'color_p': ['LP', 'CP', 'HP'], 'sst': 'SST', \
        'acidity': 'TOTAL ACIDITY', 'dry mass': 'DRY MASS'}

In [8]:
atts_n = 1

In [7]:
results_cols = ['att', 'file', 'lv', 'R2C', 'RMSEC', 'R2P', 'RMSEP']

In [10]:
warnings.filterwarnings("ignore")

In [20]:
def data_without_nan(df, att):
    wvls = list(df.columns[atts_n+13:].values)
    if type(att) == list:
        var = att + wvls
    else:
        var = [att] + wvls
    return df[var].dropna().reset_index(drop=True)

In [67]:
atts = {'firmness': 'FIRMNESS', \
       'fress mass': 'FRESH MASS', \
        'dry mass': 'DRY MASS'}

In [6]:
atts = {'firmness': 'FIRMNESS'}

In [12]:
atts = {'sst': 'SST'}

In [68]:
atts = {'dry mass': 'DRY MASS'}

# PLS only

In [23]:
def pls():
    for att in list(atts.keys()):
        results = pd.DataFrame(columns=results_cols)
        y_vars = atts[att]
        df = pd.read_csv(path+files[0], index_col=0)
        df = data_without_nan(df, y_vars)
        X_train, X_test, Y_train, Y_test = train_test_split(df.drop(y_vars, axis=1), df[y_vars], test_size=0.33, \
                                                                random_state=0)
        X_train_ind = X_train.index.values
        X_test_ind = X_test.index.values
        Y_train_ind = Y_train.index.values
        Y_test_ind = Y_test.index.values

        for file in files:
            df = pd.read_csv(path+file, index_col=0)
            df = data_without_nan(df, y_vars)

            trn = Training(max_iter_pls=100, n_components=10, k=10)
            trn.optimal_lv_num(df.drop(columns=y_vars), df[y_vars])
            trn.pls_model()

    #         X_train, X_test, Y_train, Y_test = train_test_split(df.drop(y_vars, axis=1), df[y_vars], test_size=0.33, \
    #                                                             random_state=0)

            X = df.drop(y_vars, axis=1)
            Y = df[y_vars]
#             print(X.iloc[X_train_ind])
            cv_metrics = trn.cv(trn.pls, X.iloc[X_train_ind], Y.iloc[Y_train_ind])
            print(cv_metrics)
            model, c_metrics = trn.train(trn.pls, X.iloc[X_train_ind], Y.iloc[Y_train_ind])
            p_metrics = trn.predict(model, X.iloc[X_test_ind], Y.iloc[Y_test_ind])

            row = pd.Series(data=[att, file, trn.lv_number, c_metrics['R2'], \
                            c_metrics['RMSEC'], p_metrics['R2'], \
                            p_metrics['RMSEP']], index=results_cols)

            results = results.append(row, ignore_index=True)
        print(results)
        results.to_csv('../results/pls_%s_s.csv' % att)

In [24]:
pls()

            R2     RMSECV
1     0.038314  18.808264
2    -0.026640  17.175823
3    -0.021163  23.104806
4    -0.061996  19.947324
5     0.012485  18.839957
6    -0.117845  19.994364
7    -0.077699  22.342992
8    -0.066737  20.141839
9    -0.090647  17.091857
10   -0.091258  20.239416
mean -0.050319  19.768664
            R2     RMSECV
1     0.032240  18.867568
2    -0.015800  17.084908
3    -0.017923  23.068121
4    -0.055751  19.888590
5    -0.000027  18.958935
6    -0.110780  19.931083
7    -0.097633  22.548683
8    -0.066650  20.141018
9    -0.078441  16.995943
10   -0.094156  20.266271
mean -0.050492  19.775112
            R2     RMSECV
1    -0.005633  19.233210
2    -0.067651  17.515530
3    -0.040787  23.325752
4    -0.058820  19.917471
5    -0.019668  19.144212
6    -0.094559  19.785010
7    -0.081647  22.383879
8    -0.097229  20.427685
9    -0.090447  17.090284
10   -0.123637  20.537487
mean -0.068008  19.936052
            R2     RMSECV
1     0.029844  18.890910
2    -0.0358

            R2     RMSECV
1    -0.054974  19.699399
2    -0.201845  18.583727
3    -0.113957  24.131753
4    -0.156589  20.816739
5    -0.070791  19.618262
6    -0.174158  20.491794
7    -0.112339  22.699236
8    -0.398959  23.066038
9    -0.347694  18.999522
10   -0.210765  21.318870
mean -0.184207  20.942534
            R2     RMSECV
1    -0.046237  19.617657
2    -0.164825  18.295274
3    -0.088267  23.851876
4    -0.161800  20.863581
5    -0.032451  19.263841
6    -0.147824  20.260698
7    -0.096147  22.533421
8    -0.309970  22.320354
9    -0.275018  18.480136
10   -0.214943  21.355619
mean -0.153748  20.684246
            R2     RMSECV
1     0.010127  19.081908
2    -0.056052  17.420119
3    -0.040254  23.319775
4    -0.171511  20.950594
5    -0.009280  19.046441
6    -0.158386  20.353704
7    -0.118174  22.758694
8    -0.296023  22.201215
9    -0.200981  17.935565
10   -0.136610  20.655704
mean -0.117714  20.372372
            R2     RMSECV
1    -0.009877  19.273752
2    -0.1447

            R2     RMSECV
1     0.028278  18.906145
2    -0.052891  17.394032
3    -0.025322  23.151812
4    -0.071254  20.034083
5     0.003371  18.926702
6    -0.157020  20.341702
7    -0.093423  22.505402
8    -0.093359  20.391624
9    -0.129001  17.389788
10   -0.100012  20.320438
mean -0.069063  19.936173
            R2     RMSECV
1     0.033392  18.856329
2    -0.035744  17.251810
3    -0.018247  23.071797
4    -0.048475  19.819935
5    -0.000516  18.963571
6    -0.113490  19.955378
7    -0.094087  22.512234
8    -0.072798  20.198983
9    -0.077558  16.988985
10   -0.086618  20.196344
mean -0.051414  19.781537
            R2     RMSECV
1     0.021690  18.970133
2    -0.106223  17.829117
3    -0.026090  23.160476
4    -0.048574  19.820872
5    -0.006739  19.022452
6    -0.107889  19.905125
7    -0.092616  22.497091
8    -0.154786  20.956615
9    -0.133493  17.424349
10   -0.085283  20.183937
mean -0.074000  19.977017
            R2     RMSECV
1     0.032639  18.863677
2    -0.0430

            R2     RMSECV
1     0.040739  18.784530
2    -0.056424  17.423191
3    -0.015204  23.037295
4    -0.060958  19.937576
5    -0.008207  19.036317
6    -0.175855  20.506601
7    -0.103268  22.606493
8    -0.092626  20.384787
9    -0.086424  17.058732
10   -0.079631  20.131306
mean -0.063786  19.890683
            R2     RMSECV
1     0.044164  18.750974
2    -0.023757  17.151692
3    -0.018841  23.078521
4    -0.032546  19.668804
5    -0.044719  19.377952
6    -0.162712  20.391669
7    -0.111647  22.692174
8    -0.254464  21.842357
9    -0.134532  17.432329
10   -0.081375  20.147564
mean -0.082043  20.053404
            R2     RMSECV
1     0.058149  18.613295
2     0.002715  16.928489
3    -0.008836  22.964923
4    -0.025866  19.605077
5    -0.053211  19.456547
6    -0.168683  20.443965
7    -0.116528  22.741939
8    -0.278279  22.048711
9    -0.120065  17.320831
10   -0.065607  20.000134
mean -0.077621  20.012391
            R2     RMSECV
1     0.032220  18.867765
2    -0.0881

            R2     RMSECV
1    -0.040093  19.559964
2    -0.110809  17.866036
3    -0.090131  23.872290
4    -0.135183  20.623202
5    -0.125017  20.108870
6    -0.191932  20.646314
7    -0.139821  22.977935
8    -0.315200  22.364871
9    -0.319748  18.801503
10   -0.213306  21.341229
mean -0.168124  20.816221
            R2     RMSECV
1    -0.008795  19.263422
2    -0.034124  17.238315
3    -0.049728  23.425733
4    -0.144070  20.703775
5    -0.109595  19.970566
6    -0.192886  20.654573
7    -0.147953  23.059760
8    -0.376473  22.879906
9    -0.296064  18.632035
10   -0.170329  20.959855
mean -0.153002  20.678794
            R2     RMSECV
1    -0.026104  19.427980
2    -0.108196  17.845012
3    -0.096839  23.945630
4    -0.155078  20.803142
5    -0.035648  19.293641
6    -0.147677  20.259400
7    -0.105082  22.625065
8    -0.259329  21.884671
9    -0.259824  18.369696
10   -0.181703  21.061459
mean -0.137548  20.551570
            R2     RMSECV
1     0.000432  19.175128
2    -0.0985

            R2     RMSECV
1     0.031803  18.871827
2    -0.010069  17.036648
3     0.005334  22.803076
4    -0.060177  19.930230
5    -0.042291  19.355419
6    -0.119641  20.010419
7    -0.097425  22.546548
8    -0.187264  21.249277
9    -0.073794  16.959284
10   -0.112309  20.433700
mean -0.066583  19.919643
            R2     RMSECV
1     0.021779  18.969267
2    -0.046398  17.340315
3    -0.020445  23.096683
4    -0.102394  20.323175
5    -0.041808  19.350936
6    -0.150523  20.284503
7    -0.102367  22.597259
8    -0.237477  21.693968
9    -0.101196  17.174314
10   -0.112058  20.431396
mean -0.089289  20.126182
            R2     RMSECV
1     0.023104  18.956419
2    -0.036348  17.256847
3    -0.016841  23.055859
4    -0.069500  20.017672
5    -0.040777  19.341356
6    -0.137097  20.165800
7    -0.085111  22.419696
8    -0.187794  21.254017
9    -0.089672  17.084214
10   -0.141450  20.699640
mean -0.078149  20.025152
            R2     RMSECV
1    -0.038946  19.549179
2    -0.0309

            R2     RMSECV
1    -0.030294  19.467612
2     0.031961  16.678419
3    -0.040247  23.319696
4    -0.158644  20.835222
5    -0.029988  19.240851
6    -0.125196  20.060001
7    -0.086598  22.435051
8    -0.104805  20.498083
9    -0.064337  16.884439
10   -0.088851  20.217084
mean -0.069700  19.963646
            R2     RMSECV
1    -0.050152  19.654324
2     0.019364  16.786587
3    -0.040100  23.318058
4    -0.161479  20.860696
5    -0.039765  19.331949
6    -0.122638  20.037180
7    -0.098561  22.558212
8    -0.106931  20.517795
9    -0.069416  16.924677
10   -0.111364  20.425016
mean -0.078104  20.041449
            R2     RMSECV
1    -0.033359  19.496547
2     0.032888  16.670436
3    -0.043950  23.361173
4    -0.154249  20.795675
5    -0.041609  19.349087
6    -0.129209  20.095736
7    -0.082150  22.389087
8    -0.102859  20.480020
9    -0.070130  16.930327
10   -0.087787  20.207204
mean -0.071241  19.977529
            R2     RMSECV
1    -0.047649  19.630885
2     0.0227

            R2     RMSECV
1     0.020625  18.980452
2    -0.014332  17.072559
3    -0.024704  23.144828
4    -0.072853  20.049031
5    -0.045361  19.383900
6    -0.198755  20.705318
7    -0.122423  22.801898
8    -0.150048  20.913581
9    -0.183300  17.803053
10   -0.172009  20.974891
mean -0.096316  20.182951
            R2     RMSECV
1     0.005680  19.124722
2    -0.035835  17.252570
3    -0.030216  23.206996
4    -0.081598  20.130577
5    -0.013120  19.082640
6    -0.180920  20.550721
7    -0.128777  22.866344
8    -0.185379  21.232401
9    -0.102687  17.185933
10   -0.132060  20.614325
mean -0.088491  20.124723
            R2     RMSECV
1     0.023297  18.954542
2    -0.060905  17.460102
3    -0.024703  23.144820
4    -0.089976  20.208392
5    -0.013896  19.089954
6    -0.199083  20.708151
7    -0.108308  22.658068
8    -0.160410  21.007590
9    -0.156890  17.603262
10   -0.172021  20.974997
mean -0.096290  20.180988
            R2     RMSECV
1     0.024744  18.940496
2    -0.0163

            R2     RMSECV
1     0.030426  18.885236
2    -0.025556  17.166760
3    -0.006305  22.936101
4    -0.046518  19.801429
5     0.013233  18.832825
6    -0.127895  20.084044
7    -0.102368  22.597269
8    -0.149794  20.911275
9    -0.124214  17.352881
10   -0.191903  21.152159
mean -0.073089  19.971998
            R2     RMSECV
1     0.074463  18.451384
2    -0.020494  17.124335
3    -0.023767  23.134244
4    -0.038769  19.727987
5     0.025310  18.717217
6    -0.135099  20.148076
7    -0.069341  22.256183
8    -0.133595  20.763444
9    -0.102154  17.181783
10   -0.183070  21.073637
mean -0.060651  19.857829
            R2     RMSECV
1    -0.034464  19.506964
2    -0.031334  17.215045
3    -0.018516  23.074839
4    -0.039110  19.731224
5    -0.003051  18.987581
6    -0.137133  20.166126
7    -0.110984  22.685408
8    -0.131958  20.748447
9    -0.190941  17.860444
10   -0.198293  21.208789
mean -0.089579  20.118487
            R2     RMSECV
1     0.013939  19.045133
2    -0.0148