In [None]:
import numpy as np
import matplotlib.pyplot as plt

# === USER INPUT ===
surround_light = float(input("Enter percent light in surround (0-100): ")) / 100
center_light = float(input("Enter percent light in center (0-100): ")) / 100

# === MODEL PARAMETERS ===
num_horizontal = 10
num_ganglion_on = 10
num_ganglion_off = 10
duration_ms = 5
steps = 500
time = np.linspace(0, duration_ms / 1000, steps)
basal_rate = 5  # spikes/ms baseline

# === DEFINE RETINAL LAYOUT ===
neurons = np.arange(num_horizontal)
center_idx = np.arange(4, 6)
surround_idx = np.setdiff1d(neurons, center_idx)

# === PHOTORECEPTOR INPUTS ===
# Light decreases photoreceptor activity (hyperpolarization)
photoreceptors = np.ones(num_horizontal) * (1 - surround_light)
photoreceptors[center_idx] = 1 - center_light

# === HORIZONTAL CELLS (Enhanced Retinotopic Contrast) ===
horizontal = np.zeros(num_horizontal)
for i in range(num_horizontal):
    left = max(0, i - 1)
    right = min(num_horizontal - 1, i + 1)
    horizontal[i] = np.mean(photoreceptors[left:right + 1])

# Normalize and apply nonlinear gain to exaggerate contrast
horizontal_normalized = (horizontal - np.min(horizontal)) / (np.ptp(horizontal) + 1e-9)
horizontal_firing = basal_rate + horizontal_normalized**2.5 * basal_rate * 15  # amplified contrast

# === GANGLION CELLS (ON and OFF) ===
on_center = np.zeros(num_ganglion_on)
off_center = np.zeros(num_ganglion_off)

for i in range(num_ganglion_on):
    if i in center_idx:
        on_center[i] = basal_rate * (1 + center_light * 10)
        off_center[i] = basal_rate * (1 + (1 - center_light) * 10)
    else:
        on_center[i] = basal_rate * (1 + (1 - surround_light) * 5)
        off_center[i] = basal_rate * (1 + surround_light * 5)

# === SPIKE GENERATION ===
def make_spikes(firing_rates, t):
    spikes = []
    for fr in firing_rates:
        prob = fr / 1000
        spike_times = t[np.random.rand(len(t)) < prob]
        spikes.append(spike_times)
    return spikes

h_spikes = make_spikes(horizontal_firing, time)
on_spikes = make_spikes(on_center, time)
off_spikes = make_spikes(off_center, time)

# === PLOTTING ===
fig, axs = plt.subplots(2, 1, figsize=(9, 7))

# 1️⃣ Horizontal Cells — Enhanced Retinotopy
for i, spike_times in enumerate(h_spikes):
    axs[0].vlines(spike_times, i + 0.5, i + 1.5, color='green')
axs[0].set_title("Horizontal Cells Raster (Enhanced Retinotopic Organization)")
axs[0].set_xlim(0, 0.005)
axs[0].set_ylabel("Retinal Position (0–9)")
axs[0].set_xlabel("Time (s)")

# 2️⃣ Ganglion Cells — ON (Blue) and OFF (Red)
for i, spike_times in enumerate(on_spikes):
    axs[1].vlines(spike_times, i + 0.5, i + 1, color='blue')  # ON-center
for i, spike_times in enumerate(off_spikes):
    axs[1].vlines(spike_times, i + 10.5, i + 11, color='red')  # OFF-center

axs[1].axhline(10.5, color='k', linestyle='--')
axs[1].set_title("Ganglion Cells Raster — ON (Blue) and OFF (Red)")
axs[1].set_xlim(0, 0.005)
axs[1].set_ylabel("Neuron # (1–10: ON, 11–20: OFF)")
axs[1].set_xlabel("Time (s)")

plt.tight_layout()
plt.show()

# === SUMMARY OUTPUT ===
print("\n--- INPUT SUMMARY ---")
print(f"Center light: {center_light*100:.1f}%")
print(f"Surround light: {surround_light*100:.1f}%")

print("\n--- MEAN FIRING RATES (spikes/ms) ---")
print(f"Horizontal avg: {np.mean(horizontal_firing):.2f}")
print(f"ON-center avg: {np.mean(on_center):.2f}")
print(f"OFF-center avg: {np.mean(off_center):.2f}")