In [None]:
# Imports

from plotting import plot_heatmap_noise,plot_metrics,plot_violin,plot_histogram,plot_ou_tau,plot_timecourse
from statistical_analysis import normality,equal_variance,comparison
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy as sp
from scipy.io import loadmat
import copy

In [None]:
# Upload data

ornstein_no_adap  = loadmat('results_ornstein_no_adap.mat')
pink_no_adap      = loadmat('results_pink_no_adap.mat')
white_no_adap     = loadmat('results_white_no_adap.mat')

ornstein_div_adap = loadmat('results_ornstein_div_adap.mat')
pink_div_adap     = loadmat('results_pink_div_adap.mat')
white_div_adap    = loadmat('results_white_div_adap.mat')

ornstein_sub_adap = loadmat('results_ornstein_sub_adap.mat')
pink_sub_adap     = loadmat('results_pink_sub_adap.mat')
white_sub_adap    = loadmat('results_white_sub_adap.mat')

ornstein_tau_no_adap  = loadmat('results_tau_ornstein_no_adap.mat')
ornstein_tau_div_adap = loadmat('results_tau_ornstein_div_adap.mat')
ornstein_tau_sub_adap = loadmat('results_tau_ornstein_sub_adap.mat')

ou_time    = loadmat('timecourse_ornstein.mat')
pink_time  = loadmat('timecourse_pink.mat')
white_time = loadmat('timecourse_white.mat')

In [None]:
# Build data structures

list_matrices = ['DurMatrix','DurMixMatrix','PercMatrix','WTAmatrix','CVmatrix']
list_sd       = ['DurSDmatrix','DurMixSDmatrix','PercSDmatrix','WTAsdMatrix','CVsdMatrix']
new_keys      = ['DomDur','MixDur','PercTime','DomIndex','CV']

main_keys     = ['OU_cs','Pink_cs','White_cs','OU_st']

def transform_data(old_dicts):
    data = {}
    sd_data = {}
    for i in range(0,4):
      old_dict = old_dicts[i]

      var = {new_keys[j]: old_dict[list_matrices[j]] for j in range(0,5)}
      data[main_keys[i]] = var

      var = {new_keys[j]: old_dict[list_sd[j]] for j in range(0,5)}
      sd_data[main_keys[i]] = var
    return data, sd_data

def make_heatmap(experiment):
    heat = copy.deepcopy(experiment)
    for case,value in experiment.items():
        for metric in value:
            heat[case][metric] = np.flipud(experiment[case][metric])
    return heat

def make_histogram(experiment):
    hist = copy.deepcopy(experiment)
    for case,value in experiment.items():
        for metric in value:
            series = experiment[case][metric].ravel()
            hist[case][metric] = np.nan_to_num(series[np.isfinite(series)])
    return hist

def make_mask(data,case,metric,signal,threshold):
    mask_matrix = data[case][metric]
    if signal == '>':
        mask = mask_matrix >= threshold
    else:
        mask = mask_matrix <= threshold
    return mask

def make_histogram_mask(experiment,metric_threshold,signal_threshold,threshold,alt_threshold):
    hist = copy.deepcopy(experiment)
    #      for each type of noise
    for case,value in experiment.items():
        #      calculate mask based on specific metrics
        total_mask = np.zeros(np.shape(experiment[case][metric_threshold[0]]))
        for m,s,t in zip(metric_threshold,signal_threshold,threshold):
            mask = make_mask(experiment,case,m,s,t)
            if mask.sum() == 0:
                mask = make_mask(experiment,case,m,s,alt_threshold)
            total_mask += mask*1
        total_mask = total_mask == total_mask.max()
        #      apply mask to all metrics
        for metric in value:
            matrix = experiment[case][metric]
            series = matrix[total_mask].ravel()
            hist[case][metric] = np.nan_to_num(series[np.isfinite(series)])
    return hist

data_no_adap,sd_no_adap   = transform_data([ornstein_no_adap,pink_no_adap,white_no_adap,ornstein_tau_no_adap])
data_div_adap,sd_div_adap = transform_data([ornstein_div_adap,pink_div_adap,white_div_adap,ornstein_tau_div_adap])
data_sub_adap,sd_sub_adap = transform_data([ornstein_sub_adap,pink_sub_adap,white_sub_adap,ornstein_tau_sub_adap])

heat_no_adap  = make_heatmap(data_no_adap)
heat_div_adap = make_heatmap(data_div_adap)
heat_sub_adap = make_heatmap(data_sub_adap)

heat_no_adap_sd  = make_heatmap(sd_no_adap)
heat_div_adap_sd = make_heatmap(sd_div_adap)
heat_sub_adap_sd = make_heatmap(sd_sub_adap)

hist_no_adap  = make_histogram(heat_no_adap)
hist_div_adap = make_histogram(heat_div_adap)
hist_sub_adap = make_histogram(heat_sub_adap)

hist_no_adap_mask  = make_histogram_mask(heat_no_adap,['PercTime','PercTime'],['>','<'],[50,100],[20])
hist_div_adap_mask = make_histogram_mask(heat_div_adap,['PercTime','PercTime'],['>','<'],[50,100],[20])
hist_sub_adap_mask = make_histogram_mask(heat_sub_adap,['PercTime','PercTime'],['>','<'],[50,100],[20])

In [None]:
# Plots

title = 'Dominance Time (%)'
title1 = 'Without adaptation'
title2 = 'With divisive adaptation'
title3 = 'With subtractive adaptation'
size_titles = 32
size_labels = 32
size_pad = 25
ax1 = plot_heatmap_noise(heat_no_adap,'PercTime',0,100,title,title1,size_titles,size_labels,size_pad,'fig1_no_adap')
ax2 = plot_heatmap_noise(heat_div_adap,'PercTime',0,100,title,title2,size_titles,size_labels,size_pad,'fig1_div_adap')
ax3 = plot_heatmap_noise(heat_sub_adap,'PercTime',0,100,title,title3,size_titles,size_labels,size_pad,'fig1_sub_adap')

ax = plot_metrics([hist_no_adap_mask,hist_div_adap_mask,hist_sub_adap_mask],32,30,'fig2')

ax = plot_violin(hist_sub_adap_mask,['OU_cs','Pink_cs','White_cs'],size_labels=24,save='fig3_metrics',max_y=4.5)

ax = plot_histogram(hist_sub_adap,size_labels=24,size_pad=25,save='fig3_histogram')

title1 = 'Without adaptation'
title3 = 'With subtractive adaptation'
size_labels = 20
ax1 = plot_ou_tau(heat_no_adap,sd_no_adap,size_labels=size_labels,main_title=title1,adap=False,save='fig4_no_adap')
ax2 = plot_ou_tau(heat_sub_adap,sd_sub_adap,size_labels=size_labels,main_title=title3,adap=True,save='fig4_sub_adap')

ax = plot_timecourse([ou_time,pink_time,white_time],28,28,'fig_model')

In [None]:
# Statistical tests

norm = normality(hist_no_adap_mask,hist_div_adap_mask,hist_sub_adap_mask)
var = equal_variance(hist_no_adap_mask,hist_div_adap_mask,hist_sub_adap_mask)

c1 = comparison('DomDur',hist_no_adap_mask,hist_div_adap_mask,hist_sub_adap_mask)
c2 = comparison('CV',hist_no_adap_mask,hist_div_adap_mask,hist_sub_adap_mask)
c3 = comparison('MixDur',hist_no_adap_mask,hist_div_adap_mask,hist_sub_adap_mask)
display(c1,c2,c3)