In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from matplotlib.pyplot import figure
from matplotlib import ticker
import seaborn as sns



In [None]:
os.getcwd()
os.chdir('..')


# Introductory plots & data visualization:

In [None]:
X = np.load('data/processed/No_stratification/X_train.npy')

In [None]:
def plot_arena(X: np.array, labels = True, show = True):
    plt.scatter(X[:, 1], X[:, 2], c = ((X[:, 0]) + 2) * 24 , s = 0.5)
    if labels:
        plt.title('Aphid trajectory\n on an arabidopsis leaf', fontsize = 25)
        plt.colorbar().set_label('Min.', fontsize = 15)
        plt.ylabel('y', fontsize = 25)
        plt.xlabel('x', fontsize = 25)

    if show:
        plt.show()

In [None]:
figure(figsize=(13, 10), dpi= 100)
plot_arena(X[120])

# Missing data plots

In [None]:
num_miss_values = np.load('data/interim/missing_data_info/nan_register.npy', allow_pickle= True)[:, 1]

In [None]:
figure(figsize=(13, 10), dpi= 100)
n, bins, patches  = plt.hist(num_miss_values, 200, range = (1, 1000));
for i in range(len(patches)):
    patches[i].set_facecolor(plt.cm.viridis(n[i]/max(n)))
plt.title('Number of missing steps', fontsize = 25)
plt.xlabel('Steps missing', fontsize = 25)
plt.ylabel('Number of arenas', fontsize = 25)
plt.xticks(fontsize = 25)
plt.yticks(fontsize = 25)


In [None]:
sum(num_miss_values > 0)

In [None]:
sum(num_miss_values > 1000)

# Images

In [None]:
from src.features.process_data import convert_to_image

X_train = convert_to_image(np.load('data/processed/No_stratification/X_train.npy')[:, :, [0, 1, 2, 9]],
        resolution = (128, 128))


In [None]:
plt.figure(figsize=(10, 10), dpi=100)

for i in range(9):
    # Takes around 1 min
    plt.subplot(3, 3, 1 + i)
    plt.imshow(X_train[i*123], cmap= 'jet')
    plt.axis('off')
plt.subplots_adjust(wspace=0.1, hspace=0)
plt.show()

# Time series interim example

In [None]:
X = np.load('data/processed/No_stratification/X_train.npy')

In [None]:
X[123].shape

In [None]:
plt.figure(figsize=(50, 15), dpi=400)
n_images = 21 
names = ['time', 'x', 'y', 'dx', 'dy', 'mm_dx', 'mstd_dx', 'mm_dy', 'mstd_dy',
        'speed', 'mm_speed', 'mstd_speed', 'acceleration', 'mm_acceleration',
        'mstd_acceleration', 'angular_speed', 'mm_angular_speed', 'mstd_angular_speed',
        'trav_distance', 'mm_trav_distance', 'mstd_trav_distance']
for i in range(21):
    # Takes around 1 min
    plt.subplot(3, 7, 1 + i)
    plt.plot(X[123][:, i], c = 'g', linewidth= 0.8, alpha = 1)
    #plt.scatter(range(X[123][:, i].shape[0]), X[123][:, i], c = 'black', alpha = 1, s = 0.01)
    plt.title(names[i], fontsize = 30)
    plt.locator_params(axis = 'x', nbins=3)
    plt.xticks(fontsize = 25)
    plt.yticks(fontsize = 25)
plt.subplots_adjust(left=0.1,
                    bottom=0.1, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.4, 
                    hspace=0.4)

plt.savefig('images\sample_features.png')
plt.show()

# $H^2$ & $R^2$ Plots

In [None]:
# Load paths
handcrafted_heritability = pd.read_csv('results\Hand_features\Handcrafted_dimensions_repeatability.csv', index_col= False)
handcrafted_varaince_explained = pd.read_csv('results\Hand_features\handcrafted_dimensions_R2_86001.csv', index_col= False)

supervised_heritability = pd.read_csv('results\Supervised\IT_C0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_GS1_snp86001_BS20_repeatability.csv', index_col= False)
supervised_variance_explained =  pd.read_csv('results/Supervised/IT_C0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_GS1_snp86001_BS20_R2_86001.csv', index_col= False)

self_supervised_heritability = pd.read_csv('results\Self-Supervised\MINIMAL_VAE_C0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_GS1_BS32_repeatability.csv', index_col= False)
self_supervised_variance_explained =  pd.read_csv('results\Self-Supervised\MINIMAL_VAE_C0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_GS1_BS32_R2_86001.csv', index_col= False)

contrastive_heritability = pd.read_csv('results\Contrastive\MINIMAL_Contrastive_C1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_BS128_repeatability.csv', index_col= False)
contrastive_variance_explained =  pd.read_csv('results\Contrastive\MINIMAL_Contrastive_C1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_BS128_R2_86001.csv', index_col= False)


H2s = [ handcrafted_heritability,
        supervised_heritability,
        self_supervised_heritability,
        contrastive_heritability]

R2s = [ handcrafted_varaince_explained,
        supervised_variance_explained,
        self_supervised_variance_explained,
        contrastive_variance_explained]

titles = [  'Trait quality for Handcrafted features.\nBroad sense heritability and variance explained (for SNP 86001)',
            'Trait quality for Inception Time features .\nBroad sense heritability and variance explained (for SNP 86001)',
            'Trait quality for VAE .\nBroad sense heritability and variance explained (for SNP 86001)',
            'Trait quality for Contrastive learning .\nBroad sense heritability and variance explained (for SNP 86001)',
        ]

for approach in range(4):

    figure(figsize=(20, 20), dpi=80)
    plt.errorbar(   y = H2s[approach]['Unnamed: 0'],
                    x = H2s[approach]['repeatability'],   
                    xerr = (H2s[approach]['repeatability'] - H2s[approach]['lower'], 
                    H2s[approach]['upper']- H2s[approach]['repeatability']), 
                    fmt='o', 
                    capsize=5, 
                    color = 'black', 
                    ecolor= 'red', 
                    label = r'$H^2$')

    plt.scatter(    y = R2s[approach]['Unnamed: 0'],
                    x = R2s[approach]['R^2'],
                    color = 'green',
                    marker = "*",
                    s = 200,
                    label = r'$R^2$')

    plt.title(titles[approach], fontsize=25)
    plt.legend(prop={'size': 20})
    plt.yticks(fontsize=20)
    plt.xticks(fontsize=20)
    plt.grid()
    plt.show()

# Reverse-GWAS plots

In [None]:
Handcrafted_SNP_perf = pd.read_csv('results\Hand_features\handcrafted_dimensions_reverseGwas_214k.csv')
Supervised_SNP_perf = pd.read_csv('results\Supervised\IT_C0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_GS1_snp86001_BS20_reverseGwas_214k.csv')
Self_supervised_SNP_perf = pd.read_csv('results\Self-Supervised\MINIMAL_VAE_C0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_GS1_BS32_reverse_GWAS_reverseGwas_214k.csv')
Contrastive_SNP_perf = pd.read_csv('results\Contrastive\MINIMAL_Contrastive_C1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_BS128_reverseGwas_214k.csv')

titles = [  'SVM on handcrafted features',
            'SVM on traits extracted via supervised learning',
            'SVM on latent dimensions\nVariational Autoencoder',
            'SVM on dimensions generated by\nsupervised contrastive learning']

SNP_perf = [ Handcrafted_SNP_perf,
             Supervised_SNP_perf,
             Self_supervised_SNP_perf,
             Contrastive_SNP_perf]

for approach in range(4):
    test_perf = np.array(SNP_perf[approach]['Test MCC'])
    test_perf[np.where(test_perf < 0)] = 0
    chromosome = pd.read_csv('data/raw/raw_data/Chromosome_info/Crhomosome.csv')
    my_map = {1 : 'black', 2: 'grey', 3 : 'black', 4: 'grey', 5 : 'black'}
    figure(figsize=(20, 15), dpi=60)
    plt.scatter(SNP_perf[approach]['SNP'], test_perf, s = 45, c= chromosome['x'].map(my_map))
    plt.ylabel('MCC', fontsize = 20)
    plt.xlabel('SNP', fontsize = 20)
    plt.title(titles[approach], fontsize = 20)
    #plt.axvline(86000, c ='red')
    #plt.axvline(131996, c ='red', linestyle = '--', alpha = 0.5, label = 'SNP: 131996')
    #plt.margins(0.01, 0.01)
    #plt.legend(fontsize = 15)

# PREDICTIVE PERFORMANCE ON SNP 86001 SUPERVISED APPROACHES

In [None]:
results = {'name': [], 'mean_acc':[], 'sd_acc': [], 'mean_MCC':[], 'sd_MCC': []}

for res in os.listdir('models/Results/supervised'):
    if 'csv' not in res:
        continue
    results['name'].append(res)
    current_r = pd.read_csv('models/Results/supervised/' + res)
    results['mean_acc'].append(current_r['acc'].mean())
    results['sd_acc'].append(current_r['acc'].std())
    results['mean_MCC'].append(current_r['mcc'].mean())
    results['sd_MCC'].append(current_r['mcc'].std())

results['name'][0] = 'IT full no str.'
results['name'][1] = 'IT full'
results['name'][2] = 'IT small no str.'
results['name'][3] = 'IT small'
results['name'][4] = 'Trans. no str'
results['name'][5] = 'Trans.'
results['name'][6] = 'Xception. no str'
results['name'][7] = 'Xception'

results = pd.DataFrame(results)

In [None]:
colnames = ['IT full no str.', 'IT full', 'IT full anti-str.', 'IT small no str.', 'IT small', 'IT small anti-str',
            'Trans. no str', 'Trans.', 'Trans. anti-str', 'Xception. no str', 'Xception', 'Xception anti-str']

results_accuracy = np.zeros(shape = (5, len(colnames)))
results_MCC = np.zeros(shape = (5, len(colnames)))
i = 0

for res in os.listdir('models/Results/supervised'):
    print(res)
    if 'csv' not in res:
        continue
    current_r = pd.read_csv('models/Results/supervised/' + res)

    results_accuracy[:, i] = current_r['acc']
    results_MCC[:, i] = current_r['mcc']
    i += 1



results_accuracy = pd.DataFrame(results_accuracy)
results_MCC = pd.DataFrame(results_MCC)

results_accuracy.columns = colnames
results_MCC.columns = colnames

print(results_accuracy)
print(results_MCC)

In [None]:
boxprops = dict(linewidth=3)
whiskerprops  = dict(linewidth=3)
capprops  = dict(linewidth=3)
medianprops  = dict(linewidth=5, c = 'orange')
flierprops = dict(marker='o', markerfacecolor='black', markersize=12, markeredgecolor='none')

#General stuff
fig, (ax0, ax1) = plt.subplots(ncols=2) 
fig.set_size_inches(20, 10)
fig.suptitle('Genotype stratification performance', fontsize=20)
ax0.set_title('Accuracy in test set', fontsize = 20)
ax1.set_title('MCC. in test set', fontsize = 20)
ax0.yaxis.set_tick_params(labelsize=20)
ax1.yaxis.set_tick_params(labelsize=20)
ax1.xaxis.set_tick_params(labelsize=20)
ax0.xaxis.set_tick_params(labelsize=20)

#ax0
ax0.boxplot(results_accuracy[['IT full', 'IT small', 'Trans.', 'Xception']], boxprops=boxprops, whiskerprops = whiskerprops, capprops= capprops, medianprops = medianprops, flierprops = flierprops)
ax0.set_xticklabels(['IT full', 'IT small', 'Trans.', 'Xception'])
#plt.title('Accuracy\nGenotype stratification', fontsize = 20)
ax0.axhline(0.5891238670694864, xmax=5, xmin=0, ls = '--', alpha = 1, linewidth = 6, color = 'green', label = '% 1s in test set')
ax0.grid()
ax0.legend(fontsize = 20)

#ax1
ax1.boxplot(results_MCC[['IT full', 'IT small', 'Trans.', 'Xception']], boxprops=boxprops, whiskerprops = whiskerprops, capprops= capprops, medianprops = medianprops, flierprops = flierprops)
ax1.set_xticklabels(['IT full', 'IT small', 'Trans.', 'Xception'])
ax1.grid()
ax1.axhline(0, xmax=5, xmin=0, alpha = 1, linewidth = 6, color = 'black')

In [None]:
fig, (ax0, ax1) = plt.subplots(ncols=2) 
fig.set_size_inches(20, 10)
fig.suptitle('No stratification performance', fontsize=20)
ax0.set_title('Accuracy in test set', fontsize = 20)
ax1.set_title('MCC. in test set', fontsize = 20)
ax0.yaxis.set_tick_params(labelsize=20)
ax1.yaxis.set_tick_params(labelsize=20)
ax1.xaxis.set_tick_params(labelsize=20)
ax0.xaxis.set_tick_params(labelsize=20)
#fig.autofmt_xdate(rotation=20)

#ax0
ax0.boxplot(results_accuracy[['IT full no str.', 'IT small no str.', 'Trans. no str', 'Xception. no str']], boxprops=boxprops, whiskerprops = whiskerprops, capprops= capprops, medianprops = medianprops, flierprops = flierprops)
ax0.set_xticklabels(['IT full', 'IT small', 'Trans.', 'Xception'])
#plt.title('Accuracy\nGenotype stratification', fontsize = 20)
ax0.axhline(0.5891238670694864, xmax=5, xmin=0, ls = '--', alpha = 1, linewidth = 6, color = 'green', label = '% 1s in test set')
ax0.grid()
ax0.legend(fontsize = 20)

#ax1
ax1.boxplot(results_MCC[['IT full no str.', 'IT small no str.', 'Trans. no str', 'Xception. no str']], boxprops=boxprops, whiskerprops = whiskerprops, capprops= capprops, medianprops = medianprops, flierprops = flierprops)
ax1.set_xticklabels(['IT full', 'IT small', 'Trans.', 'Xception'])
ax1.grid()
ax1.axhline(0, xmax=5, xmin=0, alpha = 1, linewidth = 6, color = 'black')

In [None]:
fig, (ax0, ax1) = plt.subplots(ncols=2) 
fig.set_size_inches(20, 10)
fig.suptitle('Anti-stratification performance', fontsize=20)
ax0.set_title('Accuracy in test set', fontsize = 20)
ax1.set_title('MCC. in test set', fontsize = 20)
ax0.yaxis.set_tick_params(labelsize=20)
ax1.yaxis.set_tick_params(labelsize=20)
ax1.xaxis.set_tick_params(labelsize=20)
ax0.xaxis.set_tick_params(labelsize=20)
#fig.autofmt_xdate(rotation=20)

#ax0
ax0.boxplot(results_accuracy[['IT full anti-str.', 'IT small anti-str', 'Trans. anti-str', 'Xception anti-str']], boxprops=boxprops, whiskerprops = whiskerprops, capprops= capprops, medianprops = medianprops, flierprops = flierprops)
ax0.set_xticklabels(['IT full', 'IT small', 'Trans.', 'Xception'])
#plt.title('Accuracy\nGenotype stratification', fontsize = 20)
ax0.axhline(0.5891238670694864, xmax=5, xmin=0, ls = '--', alpha = 1, linewidth = 6, color = 'green', label = '% 1s in test set')
ax0.grid()
ax0.legend(fontsize = 20)

#ax1
ax1.boxplot(results_MCC[['IT full anti-str.', 'IT small anti-str', 'Trans. anti-str', 'Xception anti-str']], boxprops=boxprops, whiskerprops = whiskerprops, capprops= capprops, medianprops = medianprops, flierprops = flierprops)
ax1.set_xticklabels(['IT full', 'IT small', 'Trans.', 'Xception'])
ax1.grid()
ax1.axhline(0, xmax=5, xmin=0, alpha = 1, linewidth = 6, color = 'black')

## Correlation with last layer weights:


In [None]:
weights_last_layer = pd.read_csv('results\Supervised\IT_C0-1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_GS1_snp86001_BS20_ll_weights.csv', index_col= False)


In [None]:
print(weights_last_layer['0'].corr(supervised_heritability['repeatability']))
print(weights_last_layer['1'].corr(supervised_heritability['repeatability']))
print(weights_last_layer['0'].corr(supervised_variance_explained['R^2']))
print(weights_last_layer['1'].corr(supervised_variance_explained['R^2']))

# Some comparisons between approaches' performance

In [None]:
from scipy import stats
# R^2 self-supervised vs Supervised
print(stats.ttest_ind(self_supervised_variance_explained['R^2'], supervised_variance_explained['R^2'], equal_var= False))
# R^2 self-supervised vs handcrafted
print(stats.ttest_ind(self_supervised_variance_explained['R^2'], handcrafted_varaince_explained['R^2'], equal_var= False))
# R^2 Supervisedd vs handcrafted
print(stats.ttest_ind(supervised_variance_explained['R^2'], handcrafted_varaince_explained['R^2'], equal_var= False))


# H^2 self-supervised vs Supervised
print(stats.ttest_ind(self_supervised_heritability['repeatability'], supervised_heritability['repeatability'], equal_var= False))
# H^2 self-supervised vs handcrafted
print(stats.ttest_ind(self_supervised_heritability['repeatability'], handcrafted_heritability['repeatability'], equal_var= False))
# H^2 Supervisedd vs handcrafted
print(stats.ttest_ind(supervised_heritability['repeatability'], handcrafted_heritability['repeatability'], equal_var= False))



# Interpretability Contrastive Dimensions

In [None]:
import pandas as pd
import numpy as np
Heritability_all = pd.read_csv('results\Contrastive\MINIMAL_Contrastive_C1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_BS128_repeatability.csv')
Heritability_all_test = pd.read_csv('results\Contrastive\MINIMAL_Contrastive_C1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_BS128_test_repeatability.csv')

In [None]:
X_train = np.load('data/processed/Genotype_stratified/X_train.npy')
X_val = np.load('data/processed/Genotype_stratified/X_val.npy')
X_test = np.load('data/processed/Genotype_stratified/X_test.npy')

In [None]:
X_all = np.vstack((X_train, X_val, X_test))
HD_contrastive = pd.read_csv('data\processed\Hidden_representations\Contrastive\MINIMAL_Contrastive_C1-2-3-4-5-6-7-8-9-10-11-12-13-14-15-16-17-18-19-20_D1_BS128.csv')

In [None]:
plt.figure(figsize=(50, 50), dpi=400)

#for n, i in enumerate(['Dim_0', 'Dim_1', 'Dim_2', 'Dim_3', 'Dim_4', 'Dim_5',
#            'Dim_6', 'Dim_7', 'Dim_8', 'Dim_9']):
for i in ['Dim_0', 'Dim_1', 'Dim_2', 'Dim_3', 'Dim_4', 'Dim_5',
            'Dim_6', 'Dim_7', 'Dim_8', 'Dim_9']:
        fig, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 1.3]})
        fig.set_figheight(15)
        fig.set_figwidth(25)
        fig.set_dpi(200)
        fig.tight_layout()

        ax1.scatter(    X_all[HD_contrastive[i].argmax()][:, 1],
                        X_all[HD_contrastive[i].argmax()][:, 2],
                        c = ((X_all[HD_contrastive[i].argmax()][:, 0]) + 2) * 24,
                        s = 1.5, )
        ax1.set_title(f'{i}: maximum value', fontsize = 20)
        ax1.set_xlabel('x', fontsize = 20)
        ax1.set_ylabel('y', fontsize = 20)

        im = ax2.scatter(    X_all[HD_contrastive[i].argmin()][:, 1],
                        X_all[HD_contrastive[i].argmin()][:, 2],
                        c =  (( X_all[HD_contrastive[i].argmin()][:, 0]) + 2) * 24,
                        s = 1.5)
        ax2.set_xlabel('x', fontsize = 20)
       #ax2.set_ylabel('y', fontsize = 20)

        ax2.set_title(f'{i}: minimum value', fontsize = 20)
        
        fig.colorbar(im, shrink=0.50).set_label('Minute', fontsize = 15)
        fig.suptitle(f'Aphid trajectory comparison for {i}',  fontsize = 22)
        fig.tight_layout()
        fig.subplots_adjust(top=0.90)
        plt.show()