# Imports

Fetch and summarize data in Python

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from alphaanalysis.plot.gam import get_sig_intervals

# Schemas

In [None]:
from alphacnn.database.dataset_schema import *

connect_to_database(
    dj_config_file=paths.CONFIG_FILE,
    create_tables=False, create_schema=False, schema_name=paths.SCHEMA_PREFIX + 'dataset')
dataset_schema

In [None]:
from alphacnn.database.pres_decoder_schema import *

connect_to_database(
    dj_config_file=paths.CONFIG_FILE,
    create_tables=True, create_schema=True, schema_name=paths.SCHEMA_PREFIX + 'decoder')
pres_decoder_schema

In [None]:
list(np.unique(PresDecoderPrediction().fetch('data_set_file')))

In [None]:
DataSet()

# Text information

In [None]:
n_frames = []
for split_kind in ['train', 'dev', 'test']:
    n_frames.append((PresDecoderPrediction() & dict(data_set_file='dataset_f002_f003_rot_1975_w_and_wo_nsl_bcnsmed.pkl', split_kind=split_kind, decoder_id='cnn_ensemble_10')).fetch1('d').shape[0])

print(n_frames)
print(np.sum(n_frames))
print(np.sum(n_frames) / 60)
print(np.around(100*(np.array(n_frames) / np.sum(n_frames)), 1))

# Plot

In [None]:
from alphaanalysis import plot as plota
plota.set_default_params(kind='paper')


FIGURE_FOLDER = 'figures'

In [None]:
# for key in (PresDecoderPrediction() & dict(split_kind='train')).proj().fetch(as_dict=True):
#     print(key)
#     PresDecoderPrediction().plot_loss(**key)
#     plt.show()

In [None]:
# for key in (PresDecoderPrediction()  & dict(split_kind='train')).proj().fetch(as_dict=True):
#     print(key)
#     PresDecoderPrediction().plot(**key)
#     plt.show()

# Data

In [None]:
PresDecoderPrediction() & dict(decoder_id='cnn_ensemble_10')

In [None]:
pdp_keys = (PresDecoderPrediction()  & dict(split_kind='train')).proj(train='split_kind').fetch(as_dict=True)

df_all = pd.DataFrame()

for pdp_key in pdp_keys:
    print(pdp_key)
    split_kinds = (PresDecoderPrediction & pdp_key).fetch('split_kind')
    for split_kind in split_kinds:
        y, d, p, p_pred, keys, key_idx = (PresDecoderPrediction & pdp_key & dict(split_kind=split_kind)).fetch1(
            'y', 'd', 'p', 'p_pred', 'keys', 'key_idx')
    
        pixel_size_video, video_width, video_height = (DataSet & pdp_key).fetch1('pixel_size_um', 'video_width', 'video_height')
        y_center, y_scale, d_min, d_max = (DataNorm & pdp_key).fetch1('y_center', 'y_scale', 'd_min', 'd_max')
        y_um = y * (y_scale * 0.5) * pixel_size_video
        d_cm = d * (d_max - d_min) + d_min
    
        assert np.mean(p) == 0.5
        
        dist_to_center_um = np.mean((y_um**2), axis=1)**0.5
        dist_to_border_um = np.min(np.array([(video_height * pixel_size_video) / 2, (video_width * pixel_size_video) / 2]) - np.abs(y_um), axis=1)
        
        df_split = pd.DataFrame({
            'd (cricket) [cm]': d_cm,
            'd (center) [um]': dist_to_center_um,
            'd (border) [um]': dist_to_border_um,
            'accuracy': (p>=0.5) == (p_pred>=0.5),
        })
     
        df_split['split_kind'] = split_kind
        df_split['data_set_file'] = pdp_key['data_set_file']
        df_split['split_id'] = pdp_key['split_id']
    
        df_all = pd.concat([df_all, df_split])

In [None]:
df_all

In [None]:
sns.scatterplot(data=df_all, x='d (center) [um]', y='d (border) [um]', s=0.2);

In [None]:
df_all['bc_noise'] = df_all.data_set_file.apply(lambda x: x.split('bcns')[1].split('.')[0].split('_')[0])
df_all['bc_noise'].unique()

In [None]:
bc_noise_map = {
    'med': 0.1,
}

df_all['bc_noise_num'] = df_all.bc_noise.apply(lambda x: bc_noise_map[x])
df_all['bc_noise_num'].unique()

In [None]:
df_all['pr_noise'] = df_all.data_set_file.apply(lambda x: int('prnoise' in x))
df_all['pr_noise'].unique()

In [None]:
df_all['noise'] = df_all['bc_noise_num'].astype(str) + ' ' + df_all['pr_noise'].astype(str)
df_all['noise'].unique()

In [None]:
df_all['rgc'] = df_all.data_set_file.apply(lambda x: x.split('w_and_wo_')[1].split('_bcns')[0])
df_all['rgc'].unique()

In [None]:
df_all.sort_values(['rgc', 'bc_noise_num', 'pr_noise'], inplace=True, ignore_index=True)
df_all = df_all.infer_objects()
df_all.head()

## Remove border cases

In [None]:
import seaborn as sns
sns.set_theme(style="ticks")

p = sns.JointGrid(data=df_all[
        (df_all.split_kind=='test') & 
        (df_all.rgc==df_all.rgc.unique()[0]) & 
        (df_all.bc_noise==df_all.bc_noise.unique()[0]) &
        (df_all.split_id==df_all.split_id.unique()[0]) &
        (df_all.pr_noise==df_all.pr_noise.unique()[0])
    ], x='d (cricket) [cm]', y='d (border) [um]',
    marginal_ticks=True
 )
cax = p.figure.add_axes([.9, .8, .02, .2])

p.plot_joint(
    sns.histplot, discrete=(False, False),
    cmap="light:#03012d", pmax=.8, cbar=True, cbar_ax=cax,
)
p.plot_marginals(sns.histplot, element="step", color="#03012d")
plt.show()

In [None]:
min_dist_border = 200
bc_noise = "med"
pr_noise = 0

df = df_all[(df_all['split_kind'] == 'test') & (df_all['d (border) [um]'] >= min_dist_border)].reset_index()
df.drop(['split_kind'], axis=1, inplace=True)
df = df.query(f'((rgc=="nsl") | (rgc=="tmp") | (rgc=="tmp_ss") | (rgc=="tmp_ws")) & (bc_noise=="{bc_noise}") & (pr_noise=={pr_noise})')
df.rename({'d (cricket) [cm]': "distance", 'd (center) [um]': "distance_center"}, inplace=True, axis=1)
df.head()

# Summary

In [None]:
from alphaanalysis import plot as plota

plota.set_default_params(kind='paper')

In [None]:
df_means = df.groupby(
    ['rgc', 'noise']).accuracy.mean().reset_index().rename(columns=dict(accuracy='Mean Accuracy'))

fig, ax = plt.subplots(1, 1, figsize=(12, 4))
sns.heatmap(
    df_means.pivot(columns="rgc", index="noise", values="Mean Accuracy"),
    square=True, cmap='viridis')
plt.show()

# Fit GAM in R

In [None]:
%load_ext rpy2.ipython

In [None]:
%%R

rm(list=ls()) 

In [None]:
%%R

library("IRdisplay")
library("dplyr")
library("parallel")
library("ggplot2")
library("nlme")
library("mgcv")
library("ggthemes")
library("itsadug")
library("png")
library("xtable")
library("tidymv")
library("cowplot")

In [None]:
%%R -i df

df$rgc <- factor(df$rgc)
df$bc_noise <- factor(df$bc_noise)
df$accuracy <-  as.logical(df$accuracy)
head(df)

## GAMs

### Contour size

In [None]:
%%R

m1 <- gam(accuracy ~ rgc + s(distance, k=4, bs="cr"), data=df, family=binomial)
m2 <- gam(accuracy ~ rgc + s(distance, by=rgc, k=4, bs="cr"), data=df, family=binomial)
m3 <- gam(accuracy ~ rgc + s(distance, by=rgc, k=8, bs="cr"), data=df, family=binomial)
m4 <- gam(accuracy ~ rgc + s(distance, by=rgc, k=12, bs="cr"), data=df, family=binomial)

AIC(m1, m2, m3, m4)

In [None]:
%%R

anova.gam(m1, m2, m3, m4, test = "Chisq")

In [None]:
%%R

options(repr.plot.width = 7, repr.plot.height = 7)
best_m <- m3
summary(best_m)
gam.check(best_m)

In [None]:
%%R

options(repr.plot.width = 7, repr.plot.height = 5)
p1 <- plot_smooths(model=best_m, series=distance, comparison=rgc)
plot(p1, ylim=c(0, 60))

In [None]:
%%R

options(repr.plot.width = 15, repr.plot.height = 5)
par(mfrow = c(1,4))
plot(best_m)

In [None]:
%%R -o df_pred

df_pred = predict_gam(best_m)
head(df_pred)

In [None]:
%%R

predict_gam(best_m) %>%
    ggplot(aes(distance, fit, col=rgc)) +
    geom_smooth_ci()

## Plot differences 

In [None]:
%%R -w 700 -h 250 -o tn_diff -o tw_diff -o ts_diff -o se

options(repr.plot.width = 15, repr.plot.height = 5)
par(mfrow=c(1,3), cex=1.0, tcl=-0.2)

n_tests <- 3
se <- qnorm((100 - (2.5/n_tests)) / 100)
print(se)
ylim <- NULL#c(-150, 150)
view <- "distance"
ylab <- 'Difference [logits]'
xlab <- 'Distance [um]'

tn_diff <- plot_diff(best_m, comp=list(rgc=c("tmp", "nsl")),    main=expression('t vs. n'),       view=view, se=se, ylab=ylab, xlab=xlab, ylim=ylim, hide.label=TRUE)
tw_diff <- plot_diff(best_m, comp=list(rgc=c("tmp", "tmp_ws")), main=expression('t vs. t'['WS']), view=view, se=se, ylab='', xlab=xlab, ylim=ylim, hide.label=TRUE)
ts_diff <- plot_diff(best_m, comp=list(rgc=c("tmp", "tmp_ss")), main=expression('t vs. t'['SS']), view=view, se=se, ylab='', xlab=xlab, ylim=ylim)

In [None]:
%%R

plot_diff(best_m, comp=list(rgc=c("tmp")),    main=expression('t vs. n'),       view=view, se=se, ylab=ylab, xlab=xlab, ylim=ylim, hide.label=TRUE)

# Go back to python for consistent plots

In [None]:
pairs_sig_regions = [
    (r"t$_\mathrm{mi}$ vs. n$_\mathrm{mi}$", get_sig_intervals(tn_diff, x='distance')),
    (r"t$_\mathrm{mi}$ vs. t$_\mathrm{wi}$", get_sig_intervals(tw_diff, x='distance')),
    (r"t$_\mathrm{mi}$ vs. t$_\mathrm{si}$", get_sig_intervals(ts_diff, x='distance')),
]
pairs_sig_regions

In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
from alphaanalysis import plot as plota

plota.set_default_params(kind='paper')

In [None]:
df_pred.head()

In [None]:
palette = plota.get_palette(indicator='calcium')

In [None]:
from scipy.special import expit

df_pred_acc = df_pred.copy()
df_pred_acc['lb.fit'] = (df_pred_acc["fit"] - se[0] * df_pred_acc["se.fit"]).apply(expit)
df_pred_acc['ub.fit'] = (df_pred_acc["fit"] + se[0] * df_pred_acc["se.fit"]).apply(expit)
df_pred_acc['fit'] = df_pred_acc['fit'].apply(expit)

In [None]:
from alphaanalysis.plot.gam import plot_gam_fits

axs = plot_gam_fits(
    df_pred_acc, x='distance', y='surround_index', group='rgc', f_se=2, pairs_sig_regions=pairs_sig_regions,
    side_groups=['tmp', 'tmp_ss', 'tmp_ws', 'nsl'], colors=[palette['t'], 'black', 'gray', palette['n']],
    figsize=(2.7, 2.0), height_ratios=(4, 1))

axs[0].set(ylabel='Accuracy [%]')
axs[1].set(xlabel='Distance [cm]')

axs[0].set_ylim(0.4, 1.)
axs[0].axhline(0.5, c='dimgray', ls='--', zorder=-10)

axs[1].tick_params(pad=35, axis='y')  
axs[1].set_yticklabels(axs[1].get_yticklabels(), ha='left')

label_dict = dict(
    nsl=r'n$_\mathrm{wi}$',
    tmp=r't$_\mathrm{mi}$',
    tmp_ws=r't$_\mathrm{wi}$',
    tmp_ss=r't$_\mathrm{si}$',
)
handles, labels = axs[0].get_legend_handles_labels()
axs[0].legend(handles, [label_dict.get(label, label) for label in labels], loc='upper right')

plt.tight_layout()
plt.savefig(f'figures/decoder_performance.pdf')
plt.show()

In [None]:
se = 2.39398

for rgc, df_pred_i in df_pred.groupby('rgc'):
    print(rgc)
    plt.plot(df_pred_i['distance'], (df_pred_i["fit"] - df_pred_i["se.fit"] * se) > 0, label='above baseline')
    plt.plot(df_pred_i['distance'], (df_pred_i["fit"] + df_pred_i["se.fit"] * se) < 0, label='below baseline')
    plt.legend()
    plt.show()