In [None]:
import datajoint as dj
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

schema = dj.schema("mburg_dnrebuttal_insilico", locals())
dj.config["schema_name"] = "mburg_dnrebuttal_insilico"
dj.config["display.limit"] = 20
dj.config["enable_python_native_blobs"] = True

import divisivenormalization.utils as helpers
from divisivenormalization.analysis import compute_confidence_interval
from divisivenormalization.insilico import (
    get_best_run_nums,
    CenterSurroundResponses,
    OptimalGabor,
    CenterSurroundParams,
    CenterSurroundStats,
)

helpers.config_ipython()

sns.set()
sns.set_style("ticks")
font_size = 8
rc_dict = {
    "font.size": font_size,
    "axes.titlesize": font_size,
    "axes.labelsize": font_size,
    "xtick.labelsize": font_size,
    "ytick.labelsize": font_size,
    "legend.fontsize": font_size,
    "figure.figsize": (helpers.cm2inch(8), helpers.cm2inch(8)),
    "figure.dpi": 300,
    "pdf.fonttype": 42,
    "savefig.transparent": True,
    "savefig.bbox_inches": "tight",
}
sns.set_context("paper", rc=rc_dict)


class args:
    num_best = 10
    canvas_size = [40, 40]
    blue = "xkcd:Blue"
    dark_blue = helpers.darken_color(3, 67, 223, 0.8)



 ### Size-tuning curves

In [None]:
model_type = "dn"
run_no = get_best_run_nums(model_type)[0]
constraint = dict(model_type=model_type, run_no=run_no)

tc_center = (CenterSurroundResponses().Unit() & constraint).fetch("tuning_curve_center", order_by="unit_id")
tc_lst = [t for t in tc_center]
means = np.mean(tc_lst, axis=0)
means_norm = means / np.max(means)

# get sizes in deg of visual field
key = (OptimalGabor.Unit() & constraint & dict(unit_id=0)).fetch1(dj.key)
loc, _, sf, _, ori, phase = OptimalGabor.Unit().params(key)
center_set = CenterSurroundParams().center_set(key, args.canvas_size, loc, sf, ori, phase)
sizes_px = np.array(center_set.sizes_center) * 2 * center_set.sizes_total[0]
sizes = sizes_px / 35  # 35 ppd
sizes[0] = 0  # was -0.01 internally for stimulus generation, here we need it for x-axis. Set to correct value.

plt.figure(figsize=(helpers.cm2inch(6), helpers.cm2inch(6)))
for tc in tc_center:
    t = tc
    t = t * 1 / np.max(t)
    plt.plot(sizes, t, color="xkcd:blue", alpha=0.15)
plt.plot(sizes, means_norm, linewidth=2, markersize=4, color=args.dark_blue, linestyle="-")

sizes_labels = [0, "", "", "", "", "", 0.5, "", "", 0.9, "", "", 1.8, "", 2.8, 3.4]
plt.xticks(ticks=sizes, labels=sizes_labels)
plt.xlabel("Stimulus diameter (deg)")
plt.ylabel("Prediction (normalized)")
sns.despine(trim=True, offset=5)
plt.tight_layout()
plt.show()



 ### Suppression indices for top 10 models on validation set

In [None]:
model_type = "dn"
si_lst = []
mean_lst = []
for best_idx in range(args.num_best):
    run_no = get_best_run_nums(model_type)[best_idx]
    key = dict(model_type=model_type, run_no=run_no)
    si = (CenterSurroundStats.Unit() & key).fetch("suppression_index", order_by="unit_id")
    si_lst.append(np.array(si))
sis = np.array(si_lst)

model_mean = np.mean(sis, 1)
mean = np.mean(sis)
conf_int = compute_confidence_interval(model_mean)
print("Mean suppression index", np.round(mean, 3))
print("Confidence interval", np.round(conf_int, 3))
print("Plus/minus", np.round(mean - conf_int[0], 3))

plt.figure(figsize=(helpers.cm2inch(8), helpers.cm2inch(8 / 8 * 6)))
bins = np.arange(0, 1 + 0.05, 0.05)
norm_weights = 1 / len(sis.flatten()) * np.ones_like(sis.flatten())
plt.hist(sis.flatten(), bins=bins, weights=norm_weights, color=args.blue, edgecolor="w", linewidth=0)

plt.xlim(left=0)
plt.yticks(np.arange(0, 0.9 + 0.3, 0.3))
plt.xlabel("Suppression index")
plt.ylabel("Proportion (%)")
sns.despine(trim=True, offset=5)
plt.tight_layout()
plt.show()



 ### Grating summation field (GSF) diameter across the best ten DN models on validation set

In [None]:
model_type = "dn"
gsf_lst = []
gsf_global_max_lst = []
for best_idx in range(args.num_best):
    run_no = get_best_run_nums(model_type)[best_idx]
    key = dict(model_type=model_type, run_no=run_no)
    gsfs = (CenterSurroundStats.Unit() & key).fetch("gsf_pixel")
    gsfs_global_max = (CenterSurroundStats.Unit() & key).fetch("gsf_global_max_pixel")
    gsf_lst.append(np.array(gsfs))
    gsf_global_max_lst.append(np.array(gsfs_global_max))
gsf_arr = 2 * np.array(gsf_lst) / 35  # diameter instead of radius, convert to deg of visual field (35 ppd)
gsf_global_max_arr = 2 * np.array(gsf_global_max_lst) / 35

print("Mean grating summation field diameter:", np.round(gsf_arr.mean(), 2), "deg")

