# Poly paper: Ganglion Cells in the Retina
- **Author**: Javier Cruz
- **Contact**: https://github.com/sisyphvs
- **Last Modification**: January 26, 2024
- **Description**:

## Introduction

### Importing Libraries

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import pickle
import warnings
from pyret import filtertools
from matplotlib.patches import Ellipse

In [None]:
sys.path.append("../..")

In [None]:
from scripts import load_yaml_config, truncate_float, plot_hist, plot_hist2d

### Paths and congifuration

In [None]:
plt.style.use("seaborn-v0_8-darkgrid")

In [None]:
configPath = "../../config/"
config = load_yaml_config(configPath + "general_config.yml")

DATA_RFS_PATH = "../.." + config["paths"]["data_cache"]["DATA_RFS"]
S_DATA_STA_PATH = "../.." + config["paths"]["data_cache"]["S_DATA_STA"]
T_DATA_STA_PATH = "../.." + config["paths"]["data_cache"]["T_DATA_STA"]

### Loading data

In [None]:
with open(DATA_RFS_PATH, "rb") as file:
    DATA = pickle.load(file)
with open(S_DATA_STA_PATH, "rb") as file:
    S_DATA = pickle.load(file)
with open(T_DATA_STA_PATH, "rb") as file:
    T_DATA = pickle.load(file)

### Constants

In [None]:
FPS = config["params"]["FPS"]
FRAME_TIME = 1 / FPS
TIME = T_DATA["time"]
TIME_CONT = np.linspace(0, 26, 260) * FRAME_TIME
FRAMES = len(T_DATA["contrast"]["temp_100"])

## Peaks

In [None]:
#find data peaks: 
def find_peaks(t,std_mult = 1.75):
    peaks = {}
    std = np.std(t)
    max = np.max(t)
    min = np.min(t)
    
    if max > std_mult * std:
        position = np.where(t == max)[0][0]
        peaks["max"] = (position, position*FRAME_TIME)
    else:
        peaks["max"] = None
    
    if min < - std_mult * std:
        position = np.where(t == min)[0][0]
        peaks["min"] = (position, position*FRAME_TIME)
    else:
        peaks["min"] = None
    
    return peaks

In [None]:
PEAKS_DATA = {}

for c in DATA:
    PEAKS_DATA[c] = find_peaks(T_DATA["contrast"][c][::-1])

## Polynomial Fitting

In [None]:
#get polynom
POLY_DATA = {}
POLY_RESIDUALS = {}

for c in DATA:
    t = T_DATA["contrast"][c][::-1] 

    #polynomial regression predicted
    x = T_DATA["time"]
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', np.RankWarning)
        polyfit = np.polyfit(x,t,5,full=True) #polynom of degree 5
        p = np.poly1d(polyfit[0]) 
        residual = polyfit[1][0]
    
    #save data
    POLY_DATA[c] = p
    POLY_RESIDUALS[c] = residual

## Polynomial Data

### Polynomial values

In [None]:
POLY_VALUES = {}

for c in DATA:
    POLY_VALUES[c] = POLY_DATA[c](TIME_CONT)

### Polynomial Peaks

In [None]:
POLY_PEAKS = {}
for c in DATA:
    POLY_PEAKS[c] = {"max": None,"min": None}

In [None]:
def find_poly_peaks(c): 
    peaks = PEAKS_DATA[c]
    p = POLY_DATA[c]
    t1 = p(TIME_CONT)

    if peaks["min"] is not None:
        func_min = np.min(t1)
        position = np.argmin(t1)
        func_argmin = TIME_CONT[position]
        POLY_PEAKS[c]["min"] = (func_argmin,func_min,position)

    if peaks["max"] is not None:
        func_max = np.max(t1)
        position = np.argmax(t1)
        func_argmax = TIME_CONT[position]
        POLY_PEAKS[c]["max"] = (func_argmax,func_max,position)

In [None]:
for c in DATA:
    find_poly_peaks(c)

### Polynomial HWHH

In [None]:
POLY_HWHH = {}
for c in DATA:
    POLY_HWHH[c] = None

In [None]:
def find_poly_hwhh(c,p):
    poly_peaks = POLY_PEAKS[c]
    x_p = TIME_CONT
    if poly_peaks["max"] is None and poly_peaks["min"] is None:
        return

    if poly_peaks["max"] is not None:
        high = poly_peaks["max"][1]/2
        position_array = poly_peaks["max"][2]
    
    if poly_peaks["min"] is not None:
        high = poly_peaks["min"][1]/2
        position_array = poly_peaks["min"][2]
    
    points_near = np.where(np.isclose(p(x_p),high,atol=1e-2))[0]
    try:
        max = np.max([x for x in points_near if x < position_array])
        min = np.min([x for x in points_near if x > position_array])
        POLY_HWHH[c] = ((x_p[max]+x_p[min])/2,high)
    except:
        return

In [None]:
for c in DATA:
    find_poly_hwhh(c,POLY_DATA[c])

### Polynomial Roots

In [None]:
POLY_ROOTS = {}

for c in DATA:
    POLY_ROOTS[c] = {"First": None, "ZC": None}

In [None]:
for c in DATA:
    p = POLY_DATA[c]
    zero_crossings = np.where(np.diff(np.signbit(p(TIME_CONT))))[0]

    #data peaks
    peaks = POLY_PEAKS[c]

    v1 = peaks["min"] is not None and peaks["max"] is not None and peaks["min"]<=peaks["max"]
    v2 = peaks["min"] is not None and peaks["max"] is None
    v3 = peaks["min"] is not None and peaks["max"] is not None and peaks["min"]>=peaks["max"]
    v4 = peaks["min"] is None and peaks["max"] is not None

    if v1 or v2:
        value = peaks["min"][2]    
    if v3 or v4:
        value = peaks["max"][2] 
    
    try:
        points_after = [x for x in zero_crossings if x > value]
        zc_position = np.min(points_after)
        zc = TIME_CONT[zc_position]
        POLY_ROOTS[c]["ZC"] = zc
        
        points_before = [x for x in zero_crossings if x < value]
        first_position = np.max(points_before)
        first = TIME_CONT[first_position]
        POLY_ROOTS[c]["First"] = first
    except:
        continue

### Polynomial Bandwidth

In [None]:
POLY_BANDWIDTH = {}
for c in DATA:
    if POLY_ROOTS[c]["First"] is not None and POLY_ROOTS[c]["ZC"] is not None:
        POLY_BANDWIDTH[c] = POLY_ROOTS[c]["ZC"] - POLY_ROOTS[c]["First"]
    else:
        POLY_BANDWIDTH[c] = None

## Export data

In [None]:
TO_SAVE = {
    "POLY_VALUES": POLY_VALUES,
    "POLY_PEAKS": POLY_PEAKS,
    "POLY_HWHH": POLY_HWHH,
    "POLY_ROOTS": POLY_ROOTS,
    "POLY_BANDWIDTH": POLY_BANDWIDTH,
}

for elem in TO_SAVE.keys():
    with open("../../" + config["paths"]["data_cache"][elem], "wb") as output:
        pickle.dump(TO_SAVE[elem], output)

___

## Plots

In [None]:
def plot_sta(cell, thrsh = .05):
    fig, ax = plt.subplots(1,2)
    fig.suptitle("Cell: " + cell, weight = "bold")
    fig.tight_layout()
    fig.set_size_inches(10,3)
    
    ax[1].axhline(y = 0,color = "black",linestyle = "-",lw = 1) #horizontal axis
    
    #get data
    t = T_DATA["contrast"][c][::-1]   
    min = np.min(t)
    p = POLY_DATA[c]

    #plot data
    ax[1].scatter(TIME,t,c="grey",label="Data")
    
    #find data peaks
    peaks = PEAKS_DATA[c]
    if peaks["max"] is not None:
        ax[1].axvline(x=peaks["max"][1], color = "grey", label="Data Max", linestyle = "--",lw=1)
    if peaks["min"] is not None:
        ax[1].axvline(x=peaks["min"][1], color = "grey", label="Data Min", linestyle = "--",lw=1)

    #plot ploynom
    if POLY_RESIDUALS[c] < thrsh:
        ax[1].plot(TIME_CONT,p(TIME_CONT),label = "Poly Prediction")
    else:
        ax[1].plot(TIME_CONT,p(TIME_CONT),label = "Poly Prediction",color = "red")

    ax[1].text(0.0,min,'Chi-Square: ' + str(truncate_float(POLY_RESIDUALS[c],4)), fontsize = 10,bbox = dict(facecolor = 'white', alpha = 1))
    
    #plot poly roots
    if POLY_ROOTS[c]["ZC"] is not None:
        zc = POLY_ROOTS[c]["ZC"]
        ax[1].scatter(zc,0, color = "orange",marker=".")
        ax[1].axvline(x=zc, color = "orange",label = "Poly Zero Crossing Prediction", linestyle = ":",lw=1) 

    #polynom model peaks prediction
    if peaks["max"] is not None:
        func_max = POLY_PEAKS[c]["max"][1]
        func_argmax = POLY_PEAKS[c]["max"][0]
        ax[1].scatter(func_argmax,func_max,color = "purple",marker = ".")
        ax[1].axvline(x=func_argmax, color = "purple", label="Poly Max Prediction", linestyle = "--",lw=1)
    if peaks["min"] is not None:
        func_min = POLY_PEAKS[c]["min"][1]
        func_argmin = POLY_PEAKS[c]["min"][0]
        ax[1].scatter(func_argmin,func_min,color = "purple",marker = ".")
        ax[1].axvline(x=func_argmin, color = "purple", label="Poly Min Prediction", linestyle = "--",lw=1)

    #polynom hwhh prediction
    if POLY_HWHH[c] is not None:
        ax[1].scatter(POLY_HWHH[c][0],POLY_HWHH[c][1],color = "black",marker="+",label = "Poly HWHH Prediction")
    
    ax[1].legend(bbox_to_anchor = (1.0, 1), loc = 'upper left')
    ax[1].set_xlabel("Time to Spike (s)")
    ax[1].set_ylabel("STA Contrast")
    
    #plot image
    ax[0].imshow(S_DATA[cell], aspect = 'equal', cmap = 'gray_r')
    try:
        center, width, theta = filtertools.get_ellipse(S_DATA[cell])
        ells = Ellipse(xy = (center[1], center[0]), width = width[0], height = width[1], angle = theta,
                    color = 'C1', fill = False, alpha=0.85, linewidth=2.)
        ax[0].add_artist(ells)
    except:
        ax[0].annotate('Ellipse fitting failed', (15,15))
        pass
    
    plt.show()

In [None]:
for c in DATA:
    plot_sta(c)

## Analysis

In [None]:
POLY_X_HWHH = [POLY_HWHH[c][0] for c in DATA if POLY_HWHH[c] is not None]
POLY_Y_HWHH = [POLY_HWHH[c][1] for c in DATA if POLY_HWHH[c] is not None]

print("Los promedios para las coordenadas del HWHH son: ")
print("X: " + str(np.mean(POLY_X_HWHH)))
print("Y: " + str(np.mean(POLY_Y_HWHH)))
print("Las desviaciones estándar para las coordenadas del HWHH son: ")
print("X: " + str(np.std(POLY_X_HWHH)))
print("Y: " + str(np.std(POLY_Y_HWHH)))

In [None]:
HIST_BANDWIDTH = [POLY_BANDWIDTH[c] for c in DATA if POLY_BANDWIDTH[c] is not None]

print("El promedio para el ancho de banda es: " + str(np.mean(HIST_BANDWIDTH)))
print("La desviación estándar para el ancho de banda es: " + str(np.std(HIST_BANDWIDTH)))

In [None]:
HIST_RESIDUALS = []

for c in DATA:
    residual = POLY_RESIDUALS[c]
    HIST_RESIDUALS.append(residual)

print("El promedio para los residuos es: " + str(np.mean(HIST_RESIDUALS)))
print("La desviación estándar para los residuos es: " + str(np.std(HIST_RESIDUALS)))
print()
print("El peor ajuste fue para la célula: " + max(POLY_RESIDUALS, key=POLY_RESIDUALS.get) 
      + ", y la suma de sus errores cuadráticos fue: " + str(POLY_RESIDUALS[max(POLY_RESIDUALS, key=POLY_RESIDUALS.get)]))
print("El mejor ajuste fue para la célula: " + min(POLY_RESIDUALS, key=POLY_RESIDUALS.get)
      + ", y la suma de sus errores cuadráticos fue: " + str(POLY_RESIDUALS[min(POLY_RESIDUALS, key=POLY_RESIDUALS.get)]))

In [None]:
plot_hist(title="Histogram for 'x' values in HWHH points", X=POLY_X_HWHH)
plt.show()

In [None]:
plot_hist(title="Histogram for 'y' values in HWHH points", X=POLY_Y_HWHH)
plt.show()

In [None]:
plot_hist(title="Histogram for 'bandwidth' values", X=HIST_BANDWIDTH)
plt.show()

In [None]:
plot_hist2d(
    title="2D Histogram of x & y in HWHH points",
    X=POLY_X_HWHH,
    Y=POLY_Y_HWHH,
    xlab="x hwhh",
    ylab="y hwhh",
    bins=30,
)
plt.show()

In [None]:
plot_hist(title="Histogram of Function Residuals", X=HIST_RESIDUALS)
plt.show()

___