# Model paper: Ganglion Cells in the Retina
- **Author**: Javier Cruz
- **Contact**: https://github.com/sisyphvs
- **Last Modification**: January 25, 2024
- **Description**:

## Introduction

### Importing Libraries

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import pickle
import lmfit
from pyret import filtertools
from matplotlib.patches import Ellipse

In [None]:
sys.path.append("../..")

In [None]:
from scripts import load_yaml_config, truncate_float, plot_hist, plot_hist2d

### Paths and configuration

In [None]:
plt.style.use("seaborn-v0_8-darkgrid")

In [None]:
configPath = "../../config/"
config = load_yaml_config(configPath + "general_config.yml")

DATA_RFS_PATH = "../.." + config["paths"]["data_cache"]["DATA_RFS"]
S_DATA_STA_PATH = "../.." + config["paths"]["data_cache"]["S_DATA_STA"]
T_DATA_STA_PATH = "../.." + config["paths"]["data_cache"]["T_DATA_STA"]

### Loading data

In [None]:
with open(DATA_RFS_PATH, "rb") as file:
    DATA = pickle.load(file)
with open(S_DATA_STA_PATH, "rb") as file:
    S_DATA = pickle.load(file)
with open(T_DATA_STA_PATH, "rb") as file:
    T_DATA = pickle.load(file)

### Constants

In [None]:
FPS = config["params"]["FPS"]
FRAME_TIME = 1 / FPS
TIME = T_DATA["time"]
TIME_CONT = np.linspace(0, 26, 260) * FRAME_TIME
FRAMES = len(T_DATA["contrast"]["temp_100"])

## Peaks

In [None]:
# find data peaks:
def find_peaks(t, std_mult=1.75):
    peaks = {}
    std = np.std(t)
    max = np.max(t)
    min = np.min(t)

    if max > std_mult * std:
        position = np.where(t == max)[0][0]
        peaks["max"] = (position, position * FRAME_TIME)
    else:
        peaks["max"] = None

    if min < -std_mult * std:
        position = np.where(t == min)[0][0]
        peaks["min"] = (position, position * FRAME_TIME)
    else:
        peaks["min"] = None

    return peaks

In [None]:
PEAKS_DATA = {}

for c in DATA:
    PEAKS_DATA[c] = find_peaks(T_DATA["contrast"][c][::-1])

## Function Fitting

In [None]:
STA_LEN = 26

In [None]:
FUNC_RESIDUALS = {}

for c in DATA:
    FUNC_RESIDUALS[c] = None

In [None]:
def fit_temp(cell, values=[9e8, 4e10], report=False):
    # get data
    dataAmp = T_DATA["contrast"][cell].squeeze()

    try:
        dataAmp = dataAmp[::-1]
    except:
        return 0

    tau1 = dataAmp.argmin()
    tau2 = dataAmp.argmax()

    if tau2 > 15:
        tau1 = tau1 + 2
        tau2 = tau1 + 0.01

    if np.abs(dataAmp[int(tau1)]) - np.abs(dataAmp[int(tau2)]) < 0:
        tau1 = dataAmp.argmax() + 1
        tau2 = dataAmp.argmin() + 1

    tau1 = tau1 + 1
    tau2 = tau2 + 1

    x = np.linspace(0, STA_LEN - 1, STA_LEN)
    x_fit = np.linspace(0, STA_LEN - 1, STA_LEN * 10)

    # model function
    def fmodel(params, x):
        amp1 = params["amp1"]
        amp2 = params["amp2"]
        T1 = params["tau1"]
        T2 = params["tau2"]
        n = params["n"]
        model = amp1 * ((x / T1) ** n) * np.exp(-1 * n * ((x / T1) - 1)) + amp2 * (
            (x / T2) ** n
        ) * np.exp(-1 * n * ((x / T2) - 1))
        return model

    # function to be minimized
    def fcn2min(params, x, data):
        model = fmodel(params, x)
        return model - data

    # set initial parameters
    params = lmfit.Parameters()
    params.add("amp1", value=values[0], max=np.inf)
    params.add("amp2", value=values[1], max=np.inf)
    params.add("tau1", value=tau1, min=1, max=20)
    params.add("tau2", value=tau2, min=1, max=20)
    params.add("n", value=3, max=30)

    # do fit, here with leastsq model
    minner = lmfit.Minimizer(fcn2min, params, fcn_args=(x, dataAmp))
    result = minner.minimize()

    # save chi-sqr data
    FUNC_RESIDUALS[cell] = result.chisqr

    # write error report
    if report:
        lmfit.report_fit(result)

    return result, x_fit, fmodel(result.params, x_fit)

In [None]:
FIT_FUNC = {}
for c in DATA:
    FIT_FUNC[c] = None

### Attempt 1

In [None]:
thrsh = 0.4  # arbitrary parameter

# temp with problems
PROBLEMS = []

for c in DATA:
    try:
        result, x_fit, t1 = fit_temp(c, [9e7, 4e9])
        if result.chisqr >= thrsh:  # probably not a good fitting
            PROBLEMS.append(c)
            continue
        FIT_FUNC[c] = result, x_fit, t1
    except:
        PROBLEMS.append(c)  # not fitted func
        continue

In [None]:
print(len(PROBLEMS))

### Attempt 2

In [None]:
# temp with problems
PROBLEMS2 = []

for c in PROBLEMS:
    try:
        result, x_fit, t1 = fit_temp(c)
        if result.chisqr >= thrsh:  # probably not a good fitting
            PROBLEMS2.append(c)
            continue
        FIT_FUNC[c] = result, x_fit, t1
    except:
        PROBLEMS2.append(c)  # not fitted func
        continue

In [None]:
print(len(PROBLEMS2))

### Attempt 3

In [None]:
# temp with problems
PROBLEMS3 = []

for c in PROBLEMS2:
    try:
        result, x_fit, t1 = fit_temp(c, [9e5, 4e7])
        if result.chisqr >= thrsh:  # probably not a good fitting
            PROBLEMS3.append(c)
            continue
        FIT_FUNC[c] = result, x_fit, t1
    except:
        PROBLEMS3.append(c)  # not fitted func
        continue

In [None]:
print(len(PROBLEMS3))

### Attempt 4

In [None]:
# temp with problems
PROBLEMS4 = []

for c in PROBLEMS3:
    try:
        result, x_fit, t1 = fit_temp(c, [9e3, 4e5])
        if result.chisqr >= thrsh:  # probably not a good fitting
            PROBLEMS4.append(c)
            continue
        FIT_FUNC[c] = result, x_fit, t1
    except:
        PROBLEMS4.append(c)  # not fitted func
        continue

In [None]:
print(len(PROBLEMS4))

### Attempt 5

In [None]:
# temp with problems
PROBLEMS5 = []

for c in PROBLEMS4:
    try:
        result, x_fit, t1 = fit_temp(c, [9e9, 4e10])
        if result.chisqr >= thrsh:  # probably not a good fitting
            PROBLEMS5.append(c)
            continue
        FIT_FUNC[c] = result, x_fit, t1
    except:
        PROBLEMS5.append(c)  # not fitted func
        continue

In [None]:
print(len(PROBLEMS4))

## Function Data

### Function values

In [None]:
FUNC_VALUES = {}

for c in DATA:
    if FIT_FUNC[c] is not None:
        result, x_fit, t1 = FIT_FUNC[c]
        FUNC_VALUES[c] = t1
    else:
        FUNC_VALUES[c] = None

### Function Peaks

In [None]:
FUNC_PEAKS = {}
for c in DATA:
    FUNC_PEAKS[c] = {"max": None, "min": None}

In [None]:
def find_peaks_func(c):
    if FIT_FUNC[c] is not None:
        result, x_fit, t1 = FIT_FUNC[c]
        peaks = PEAKS_DATA[c]

        if peaks["min"] is not None:
            func_min = np.min(t1)
            position = np.argmin(t1)
            func_argmin = TIME_CONT[position]
            FUNC_PEAKS[c]["min"] = (func_argmin, func_min, position)

        if peaks["max"] is not None:
            func_max = np.max(t1)
            position = np.argmax(t1)
            func_argmax = TIME_CONT[position]
            FUNC_PEAKS[c]["max"] = (func_argmax, func_max, position)

In [None]:
for c in DATA:
    find_peaks_func(c)

### Function HWHH

In [None]:
FUNC_HWHH = {}
for c in DATA:
    FUNC_HWHH[c] = None

In [None]:
def find_func_hwhh(c):
    if FIT_FUNC[c] is None:
        return

    t1 = FIT_FUNC[c][2]
    func_peaks = FUNC_PEAKS[c]

    if func_peaks["max"] is not None:
        high = func_peaks["max"][1] / 2
        position_array = func_peaks["max"][2]

    if func_peaks["min"] is not None:
        high = func_peaks["min"][1] / 2
        position_array = func_peaks["min"][2]

    points_near = np.where(np.isclose(t1, high, atol=1e-2))[0]

    try:
        max = np.max([x for x in points_near if x < position_array])
        min = np.min([x for x in points_near if x > position_array])
        FUNC_HWHH[c] = ((TIME_CONT[max] + TIME_CONT[min]) / 2, high)
    except:
        return

In [None]:
for c in DATA:
    find_func_hwhh(c)

### Function Roots

In [None]:
FUNC_ROOTS = {}
for c in DATA:
    FUNC_ROOTS[c] = {"First": None, "ZC": None}

In [None]:
std_mult = 0.01

for c in DATA:
    if FIT_FUNC[c] is not None:
        result, x_fit, t1 = FIT_FUNC[c]
    else:
        continue

    zero_crossings = np.where(np.diff(np.signbit(t1)))[0]
    near_zero = np.where(np.abs(t1) < np.std(t1) * std_mult)[0]

    # data peaks
    peaks = FUNC_PEAKS[c]

    v1 = (
        peaks["min"] is not None
        and peaks["max"] is not None
        and peaks["min"] <= peaks["max"]
    )
    v2 = peaks["min"] is not None and peaks["max"] is None
    v3 = (
        peaks["min"] is not None
        and peaks["max"] is not None
        and peaks["min"] >= peaks["max"]
    )
    v4 = peaks["min"] is None and peaks["max"] is not None

    if v1 or v2:
        value = peaks["min"][2]
    if v3 or v4:
        value = peaks["max"][2]

    try:
        points_after = [x for x in zero_crossings if x > value]
        zc_position = np.min(points_after)
        zc = TIME_CONT[zc_position]
        FUNC_ROOTS[c]["ZC"] = zc

        points_before = [x for x in near_zero if x < value]
        first_position = np.max(points_before)
        first = TIME_CONT[first_position]
        FUNC_ROOTS[c]["First"] = first
    except:
        continue

### Function Bandwidth

In [None]:
FUNC_BANDWIDTH = {}
for c in DATA:
    if FUNC_ROOTS[c]["First"] is not None and FUNC_ROOTS[c]["ZC"] is not None:
        FUNC_BANDWIDTH[c] = FUNC_ROOTS[c]["ZC"] - FUNC_ROOTS[c]["First"]
    else:
        FUNC_BANDWIDTH[c] = None

## Export data

In [None]:
TO_SAVE = {
    "FUNC_VALUES": FUNC_VALUES,
    "FUNC_PEAKS": FUNC_PEAKS,
    "FUNC_HWHH": FUNC_HWHH,
    "FUNC_ROOTS": FUNC_ROOTS,
    "FUNC_BANDWIDTH": FUNC_BANDWIDTH,
}

for elem in TO_SAVE.keys():
    with open("../../" + config["paths"]["data_cache"][elem], "wb") as output:
        pickle.dump(TO_SAVE[elem], output)

___

## Plots

In [None]:
def plot_sta(cell, thrsh=0.05):
    # get fitted function
    if FIT_FUNC[c] is None:
        return
    else:
        result, x_fit, t1 = FIT_FUNC[c]

    # plot
    fig, ax = plt.subplots(1, 2)
    fig.suptitle("Cell: " + cell, weight="bold")
    fig.tight_layout()
    fig.set_size_inches(10, 3)

    ax[1].axhline(y=0, color="black", linestyle="-", lw=1)  # horizontal axis

    # get data
    t = T_DATA["contrast"][c][::-1]
    max = np.max(t)
    min = np.min(t)

    # plot data
    ax[1].scatter(TIME, t, c="grey", label="Data")

    # find data peaks
    peaks = PEAKS_DATA[c]
    if peaks["max"] is not None:
        ax[1].axvline(
            x=peaks["max"][1], color="grey", label="Data Max", linestyle="--", lw=1
        )
    if peaks["min"] is not None:
        ax[1].axvline(
            x=peaks["min"][1], color="grey", label="Data Min", linestyle="--", lw=1
        )

    # plot function model
    if result.chisqr < thrsh:
        ax[1].plot(TIME_CONT, t1, label="Function ")
    else:
        ax[1].plot(TIME_CONT, t1, label="Function", color="red")

    ax[1].text(
        0.0,
        min,
        "Chi-Square: " + str(truncate_float(result.chisqr, 4)),
        fontsize=10,
        bbox=dict(facecolor="white", alpha=1),
    )

    # get function roots
    if FUNC_ROOTS[c]["ZC"] is not None:
        zc = FUNC_ROOTS[c]["ZC"]
        ax[1].scatter(zc, 0, color="orange", marker=".")
        ax[1].axvline(
            x=zc,
            color="orange",
            label="Function Zero Crossing Prediction",
            linestyle=":",
            lw=1,
        )

    # function model peaks prediction
    if peaks["max"] is not None:
        func_max = FUNC_PEAKS[c]["max"][1]
        func_argmax = FUNC_PEAKS[c]["max"][0]
        ax[1].scatter(func_argmax, func_max, color="purple", marker=".")
        ax[1].axvline(
            x=func_argmax,
            color="purple",
            label="Function Max Prediction",
            linestyle="--",
            lw=1,
        )
    if peaks["min"] is not None:
        func_min = FUNC_PEAKS[c]["min"][1]
        func_argmin = FUNC_PEAKS[c]["min"][0]
        ax[1].scatter(func_argmin, func_min, color="purple", marker=".")
        ax[1].axvline(
            x=func_argmin,
            color="purple",
            label="Function Min Prediction",
            linestyle="--",
            lw=1,
        )

    # function hwhh prediction
    if FUNC_HWHH[c] is not None:
        ax[1].scatter(
            FUNC_HWHH[c][0],
            FUNC_HWHH[c][1],
            color="black",
            marker="+",
            label="Function HWHH Prediction",
        )

    ax[1].legend(bbox_to_anchor=(1.0, 1), loc="upper left")
    ax[1].set_xlabel("Time to Spike (s)")
    ax[1].set_ylabel("STA Contrast")

    # plot image
    ax[0].imshow(S_DATA[cell], aspect="equal", cmap="gray_r")
    try:
        center, width, theta = filtertools.get_ellipse(S_DATA[cell])
        ells = Ellipse(
            xy=(center[1], center[0]),
            width=width[0],
            height=width[1],
            angle=theta,
            color="C1",
            fill=False,
            alpha=0.85,
            linewidth=2.0,
        )
        ax[0].add_artist(ells)
    except:
        ax[0].annotate("Ellipse fitting failed", (15, 15))
        pass

    plt.show()

In [None]:
for c in DATA[0:]:
    plot_sta(c)

## Analysis

In [None]:
FUNC_X_HWHH = [FUNC_HWHH[c][0] for c in DATA if FUNC_HWHH[c] is not None]
FUNC_Y_HWHH = [FUNC_HWHH[c][1] for c in DATA if FUNC_HWHH[c] is not None]

print("Los promedios son: ")
print("X: " + str(np.mean(FUNC_X_HWHH)))
print("Y: " + str(np.mean(FUNC_Y_HWHH)))
print("Las desviaciones estándar son: ")
print("X: " + str(np.std(FUNC_X_HWHH)))
print("Y: " + str(np.std(FUNC_Y_HWHH)))

In [None]:
HIST_BANDWIDTH = [FUNC_BANDWIDTH[c] for c in DATA if FUNC_BANDWIDTH[c] is not None]

print("El promedio para el ancho de banda es: " + str(np.mean(HIST_BANDWIDTH)))
print(
    "La desviación estándar para el ancho de banda es: " + str(np.std(HIST_BANDWIDTH))
)

In [None]:
HIST_RESIDUALS = []

for c in DATA:
    if FIT_FUNC[c] is not None:
        HIST_RESIDUALS.append(FUNC_RESIDUALS[c])

for c in DATA:
    if FIT_FUNC[c] is None:
        del FUNC_RESIDUALS[c]

print("El promedio para los residuos es: " + str(np.mean(HIST_RESIDUALS)))
print("La desviación estándar para los residuos es: " + str(np.std(HIST_RESIDUALS)), end = "\n")
print(
    "El peor ajuste fue para la célula: "
    + max(FUNC_RESIDUALS, key=FUNC_RESIDUALS.get)
    + ", y la suma de sus errores cuadráticos fue: "
    + str(FUNC_RESIDUALS[max(FUNC_RESIDUALS, key=FUNC_RESIDUALS.get)])
)
print(
    "El mejor ajuste fue para la célula: "
    + min(FUNC_RESIDUALS, key=FUNC_RESIDUALS.get)
    + ", y la suma de sus errores cuadráticos fue: "
    + str(FUNC_RESIDUALS[min(FUNC_RESIDUALS, key=FUNC_RESIDUALS.get)])
)

In [None]:
plot_hist(title="Histogram for 'x' values in HWHH points", X=FUNC_X_HWHH)
plt.show()

In [None]:
plot_hist(title="Histogram for 'y' values in HWHH points", X=FUNC_Y_HWHH)
plt.show()

In [None]:
plot_hist(title="Histogram for 'bandwidth' values", X=HIST_BANDWIDTH)
plt.show()

In [None]:
plot_hist2d(
    title="2D Histogram of x & y in HWHH points",
    X=FUNC_X_HWHH,
    Y=FUNC_Y_HWHH,
    xlab="x hwhh",
    ylab="y hwhh",
    bins=30,
)
plt.show()

In [None]:
plot_hist(title="Histogram of Function Residuals", X=HIST_RESIDUALS)
plt.show()

In [None]:
PARAMETERS = {"amp1": [], "amp2": [], "tau1": [], "tau2": [], "n": []}

In [None]:
for c in DATA:
    if FIT_FUNC[c] is not None:
        result, x_fit, t1 = FIT_FUNC[c]
        PARAMETERS["amp1"].append(result.params["amp1"].value)
        PARAMETERS["amp2"].append(result.params["amp2"].value)
        PARAMETERS["tau1"].append(result.params["tau1"].value)
        PARAMETERS["tau2"].append(result.params["tau2"].value)
        PARAMETERS["n"].append(result.params["n"].value)

In [None]:
plot_hist(title="Histogram for 'amp1' value", X=PARAMETERS["amp1"])
plt.show()

In [None]:
plot_hist(title="Histogram for 'amp2' value", X=PARAMETERS["amp2"])
plt.show()

In [None]:
plot_hist(title="Histogram for 'tau1' value", X=PARAMETERS["tau1"])
plt.show()

In [None]:
plot_hist(title="Histogram for 'tau2' value", X=PARAMETERS["tau2"])
plt.show()

In [None]:
plot_hist(title="Histogram for 'n' value", X=PARAMETERS["n"])
plt.show()

___