In [None]:
import datajoint as dj
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

schema = dj.schema("mburg_dnrebuttal_insilico", locals())
dj.config["schema_name"] = "mburg_dnrebuttal_insilico"
dj.config["display.limit"] = 20
dj.config["enable_python_native_blobs"] = True

import divisivenormalization.utils as helpers
from divisivenormalization.analysis import compute_confidence_interval
from divisivenormalization.insilico import *

helpers.config_ipython()

sns.set()
sns.set_style("ticks")
font_size = 8
rc_dict = {
    "font.size": font_size,
    "axes.titlesize": font_size,
    "axes.labelsize": font_size,
    "xtick.labelsize": font_size,
    "ytick.labelsize": font_size,
    "legend.fontsize": font_size,
    "figure.figsize": (helpers.cm2inch(8), helpers.cm2inch(8)),
    "figure.dpi": 300,
    "pdf.fonttype": 42,
    "savefig.transparent": True,
    "savefig.bbox_inches": "tight",
}
sns.set_context("paper", rc=rc_dict)


class args:
    num_best = 10
    canvas_size = [40, 40]
    blue = "xkcd:Blue"
    dark_blue = helpers.darken_color(3, 67, 223, 0.8)
    sample_neuron_idx = 107



 ### Plot tuning curve for example neuron and histograms of inhibition index

In [None]:
def subplot_tc(ax, model_type, unit_id, xscale=None, more_labels=False, max_line=False):
    contrasts = OrthPlaidsContrastParams().contrasts(dict())
    contrasts_perc = contrasts / contrasts[-1] * 50

    num_colors = contrasts.shape[0] + 1
    colors = plt.cm.gist_earth(np.linspace(0, 1, num_colors))
    colors = np.flipud(colors)[1:]

    constraint = dict(model_type=model_type, run_no=get_best_run_nums(model_type)[0], unit_id=unit_id)
    t = (OrthPlaidsContrast.Unit() & constraint).fetch("tuning_curve", order_by="unit_id")
    t = np.array([np.array(a) for a in t]).mean(axis=0)  # mean across phases
    t = t / np.max(t)

    for ti, ci in zip(t.T, colors):
        ax.plot(contrasts_perc, ti, marker=".", linestyle="-", color=ci)

    ax.set_xscale(xscale)
    ax.set_xlabel("Orth. mask contrast (%)")
    ax.set_ylabel("Prediction (normalized)")
    ax.set_yticks(np.arange(0, 1.2, 0.2))

    if xscale == "log":
        raise NotImplementedError
    else:
        contr_perc_labels = np.round(contrasts_perc, 0).astype(int)
        ax.set_xticks(contr_perc_labels)
        if more_labels is True:
            ax.set_xticklabels(["0", "", "", ""] + list(contr_perc_labels[4:].astype(str)))
        else:
            ax.set_xticklabels(
                ["0", "", "", "", ""]
                + list(contr_perc_labels[5].astype(str))
                + [""]
                + list(contr_perc_labels[7:].astype(str))
            )

    if max_line is True:
        ax.axhline(1, color="black", ls="--")

    sns.despine(ax=ax, offset=5, trim=True)


def subplot_hist(ax, model_type, bin_width=10, color="grey", distplot=True):
    inhib_percent_lst = []
    for best_idx in range(args.num_best):
        run_no = get_best_run_nums(model_type)[best_idx]
        constraint = dict(model_type=model_type, run_no=run_no)
        inhib_percent = (OrthPlaidsContrastInhibPercentPhaseAvg().Unit() & constraint).fetch(
            "inhib_percent", order_by="unit_id"
        )
        inhib_percent_lst.append(inhib_percent)
    inhib_percents = np.hstack([i for i in inhib_percent_lst])
    inhib_percents *= 100

    bins = np.arange(0, 100 + bin_width, bin_width)
    if distplot:
        data = np.array(inhib_percents).flatten()
        sns.histplot(data=data, ax=ax, bins=bins, kde=True, stat="probability")
        ax.set_xlim(left=0)
        ax.set_ylim(top=1)
        ax.set_yticks(np.arange(0, 1.20, 0.20))
        ax.set_yticklabels(np.arange(0, 120, 20))
    else:
        norm_weights = 100 / len(inhib_percents) * np.ones(len(inhib_percents))
        ax.hist(inhib_percents, bins=bins, weights=norm_weights, color=color)
        ax.set_yticks(np.arange(0, 120, 20))

    ax.axvline(10, color="black", ls=":")

    ax.set_xlabel("Inhibition index (%)")
    ax.set_ylabel("Fraction of units (%)")
    sns.despine(ax=ax, offset=5, trim=True)


ncol = 3
nrow = 2
_, axes = plt.subplots(nrow, ncol, figsize=(helpers.cm2inch(15.8), nrow / ncol * helpers.cm2inch(15.8)))

uid = args.sample_neuron_idx
xscale = "linear"
subplot_tc(axes[0, 0], "dn", uid, xscale)
subplot_tc(axes[0, 1], "cnn3", uid, xscale)
subplot_tc(axes[0, 2], "subunit", uid, xscale)

bin_width = 10
subplot_hist(axes[1, 0], "dn", bin_width)
subplot_hist(axes[1, 1], "cnn3", bin_width)
subplot_hist(axes[1, 2], "subunit", bin_width)

plt.tight_layout()
plt.show()



 ### Compute inhibition index statistics across model types for the best 10 models on validation set

In [None]:
inhib_percent_dict = {}
fev_dict = {}

for model_type in ["subunit", "dn", "cnn3"]:
    inhib_percent_dict[model_type] = []
    fev_dict[model_type] = []
    for best_idx in range(args.num_best):
        run_no = get_best_run_nums(model_type)[best_idx]
        constraint = dict(model_type=model_type, run_no=run_no)
        inhib_percent = (OrthPlaidsContrastInhibPercentPhaseAvg().Unit() & constraint).fetch(
            "inhib_percent", order_by="unit_id"
        )
        inhib_percent_dict[model_type].append(inhib_percent)

crit_percent = 0.1
for model_name, inhib_values in inhib_percent_dict.items():

    inhib_below_lst = []
    inhib_above_lst = []
    inhib_above_perc_lst = []
    for inhib in inhib_values:
        inhib_above = inhib[inhib >= crit_percent]
        inhib_above_perc = len(inhib_above) / (len(inhib))
        inhib_above_perc_lst.append(inhib_above_perc)

    print(model_name)
    mean = np.average(inhib_above_perc_lst)
    conf_int = compute_confidence_interval(inhib_above_perc_lst)
    print("Percentage of cells inhibited by more than 10% (mean across models):", np.round(mean, 2))
    print("Confidence interval:", np.round(conf_int, 2))
    print("Plus/minus:", np.round(mean - conf_int[0], 2))
    print()
