# How does drift affect stimulus power on the retina?

# Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
from matplotlib import animation
from ipywidgets import interact
import copy
import scipy.io as sio
import pickle

sys.path.insert(1, '../mechanistic_models')
from functions import create_edge, create_noise, create_drift

sys.path.insert(1, '../experimental_data')
from exp_params import stim_params as sparams
from exp_params import exp_params as eparams

# Preparations

In [None]:
# Parameters
stimSize = sparams["stim_size"]          # visual extent in deg
ppd = sparams["ppd"]                     # spatial resolution (pixels per deg)
meanLum = sparams["mean_lum"]            # mean luminance in cd/m**2
edgeWidths = sparams["edge_widths"]      # of Cornsweet edges in deg
edgeExponent = sparams["edge_exponent"]  # of Cornsweet edges
noiseTypes = sparams["noise_types"]
rmsNoise = sparams["noise_contrast"]     # rms contrast of noise masks
rmsEdge = 0.05                           # edge contrast for visuals

In [None]:
# Calculate spatial axis in pixels:
vextent = [-stimSize/2., stimSize/2., stimSize/2., -stimSize/2.]
x = np.arange(vextent[0]*ppd, vextent[1]*ppd, 1.)
nX = len(x)

# Calculate spatial frequency axis in cpd:
fs = np.fft.fftshift(np.fft.fftfreq(int(nX), d=1./ppd))
fx, fy = np.meshgrid(fs, fs)
fs1d = fs[int(nX/2)::]
fextent = (fs[0], fs[-1], fs[0], fs[-1])

# Min and max intensities for plotting:
vmin, vmax = 0, meanLum*2.

# Helper functions

In [None]:
def compute_power(stimulus):
    return np.abs(np.fft.fftshift(np.fft.fftn(stimulus - stimulus.mean()))) ** 2. / np.size(stimulus)

def csf_kelly(sfs, tf):
    sfs = np.abs(sfs)                   # for negative SFs
    idx=np.where(sfs==0.); sfs[idx]=1.  # fudge for sf=0
    v = tf / sfs                        # calculate "velocity" from TFs

    k = 6.1 + 7.3 * np.abs(np.log10(v/3.))**3.
    amax = 45.9 / (v + 2.)
    csf = k * v * (2.*np.pi*sfs)**2. * np.exp((-4.*np.pi*sfs) / amax)  # kelly1979
    if len(idx): csf[idx]=0                                            # undo fudge
    return csf

def normRange(inpt, vmin, vmax):
    inpt = (inpt - inpt.min()) / (inpt.max() - inpt.min())
    return inpt * (vmax-vmin) + vmin

def create_grating(sf, alpha=0., phi=0, C=1., ppd=100., sSize=1.):
    x = np.arange(-sSize/2, sSize/2, 1./ppd)
    xx, yy = np.meshgrid(x, x)
    alpha = [np.cos(alpha), np.sin(alpha)]
    return C * np.sin((alpha[0]*xx + alpha[1]*yy) * 2. * np.pi * sf + phi)

# Create stimuli

## Noise masks

In [None]:
noises = []; nNoises = len(noiseTypes); noiseNames = ["none", "white", "pink", "brown", "NB0.5", "NB3", "NB9"]
for n in noiseTypes: noises.append(create_noise(n, sparams) + meanLum)

plt.figure(figsize=(16, 2))
for ni, n in enumerate(noises):
    plt.subplot(1,7,ni+1); plt.imshow(n,cmap="gray",vmin=vmin,vmax=vmax); plt.title(noiseNames[ni]); plt.axis("off")

## Cornsweet edges

In [None]:
edges = []; nEdges = len(edgeWidths); edgeNames = ["9 edge", "3 edge", "0.5 edge"]
for e in edgeWidths: edges.append(create_edge(rmsEdge, e, sparams))

plt.figure(figsize=(8, 2))
for ei, e in enumerate(edges):
    plt.subplot(1,3,ei+1); plt.imshow(e,cmap="gray",vmin=vmin,vmax=vmax); plt.title(edgeNames[ei]); plt.axis("off")

In [None]:
# Plot all noise-masked stimuli
fig, axes = plt.subplots(nNoises, nEdges, figsize=(14, 32), sharex=True, sharey=True)
fig.subplots_adjust(hspace=0.001, wspace=0.001)

for ni, n in enumerate(noises):
    for ei, e in enumerate(edges):
        lumProf = (e[:, int(nX/2)] - meanLum) / 30
        axes[ni,ei].imshow(e+n-meanLum, cmap='gray', extent=vextent, vmin=vmin, vmax=vmax)
        axes[ni,ei].set_axis_off(); #axes[i, j].plot(profile, color="white", linewidth=1);
#plt.savefig('stimuli.png', dpi=300)
plt.show()

# Static power spectra

## 2d spectra

In [None]:
# Calculate power spectra of noise masks
noisePows = []
for n in noises: noisePows.append(compute_power(n))

plt.figure(figsize=(16, 2))
for ni, n in enumerate(noisePows):
    plt.subplot(1,7,ni+1); plt.imshow(n); plt.title(noiseNames[ni]); plt.axis("off")

In [None]:
# Calculate power spectra of edges
edgePows = []
for e in edges: edgePows.append(compute_power(e))

plt.figure(figsize=(8, 2))
for ei, e in enumerate(edgePows):
    plt.subplot(1, 3, ei+1), plt.imshow(e), plt.title(edgeNames[ei]); plt.axis("off")

## 1d spectra (visual-only)

We do this just for visualization purposes because it is more correct to do all computations with the 2d spectra.

Reason: If we just take the 1d spectra (either radially averaged or cut-through), the resulting spectra are not equal in power anymore, as the arrays have a square shape and not circular one.

In [None]:
# Get horizontal cuts
edgePows1d = []; noisePows1d = []
for e in edgePows: edgePows1d.append(e[int(nX/2), int(nX/2)::])
for n in noisePows: noisePows1d.append(n[int(nX/2), int(nX/2)::])
    
noisePows1d=noisePows1d[1::]; noiseNames_=noiseNames[1::] # skip no-noise

In [None]:
eCol = ["C2", "C1", "C0"]
noiseCols = ["C3", "C4", "C5", "C2", "C1", "C0"]
axParams = {
    "xlabel": "SF (cpd)", "xscale": "log", "xticks": (1, 10), "xticklabels": (1, 10), "xlim": (0.25,20),
    "yscale": "log", "ylim":(1, 1e8), "yticks": [],
}

fig, axes = plt.subplots(1, 2, figsize=(6.2, 3), sharex=True, sharey=True)
for ei, e in enumerate(edgePows1d):
    axes[0].plot(fs1d, e, '.-', label=edgeNames[ei]); axes[0].set(**axParams, ylabel="Power (dB)")
for ni, n in enumerate(noisePows1d):
    axes[1].plot(fs1d, n, '.-', label=noiseNames_[ni], color=noiseCols[ni]); axes[1].set(**axParams)
axes[0].legend(); axes[1].legend(ncol=2); plt.tight_layout()
#plt.savefig('1d_powers.png', dpi=300)

# Empirical results

In [None]:
# Load empirical data
with open('../experimental_data/empThresholds.pickle', 'rb') as handle:
    empData = pickle.load(handle)
empThresh = np.transpose(empData["thresholds"][1::, ::-1])
empCIs = empData["credible68"][1::, ::-1, :]

# Plot thresholds
fig, axes = plt.subplots(1, 3, figsize=(6, 1), sharex=True, sharey=True); fig.subplots_adjust(wspace=0)
for ei in range(nEdges):
    axes[ei].plot(noiseNames[1::], empThresh[-ei-1, :], '.-', color=eCol[ei])
    axes[ei].fill_between(np.arange(6), empCIs[:, -ei-1, 0], empCIs[:, -ei-1, 1], color=eCol[ei], alpha=0.2)
    axes[0].set(ylabel="Thresholds", xticklabels=[]); axes[ei].grid(True, color=[0.9,]*3)
#plt.savefig('thresholds.png', dpi=300)

# Simulate drift-gain

Drift redistributes power in a fashion that whitens natural scene statistics. Let's look at this in more detail.

## Gratings
As the basis for our simulations, we use sinusoidal gratings to show the effect of drift on individual SFs.

In [None]:
def plot_grating(g, gPow, ppd):
    x = np.arange(0, g.shape[0]/ppd, 1/ppd); extent = [0, g.shape[0]/ppd]
    sff = np.fft.fftshift(np.fft.fftfreq(len(x), d=1./ppd)); sff_ext = (sff[0], sff[-1], sff[0], sff[-1])

    plt.figure(figsize=(16, 4))
    plt.subplot(131); plt.imshow(g, cmap='gray', extent=extent+extent)
    plt.colorbar(); plt.title('2d grating'); plt.xlabel('deg'); plt.ylabel('deg')

    plt.subplot(132); plt.imshow(np.log10(gPow + 0.001), extent=sff_ext, vmax=-2)
    plt.colorbar(); plt.title('Power spectrum (log)'); plt.xlabel('cpd'); plt.ylabel('cpd')

    plt.subplot(133); plt.plot(sff, np.log10(gPow[int(len(x)/2), :]+0.001))
    plt.xlabel('cpd'); plt.ylabel('Power');

def explore_grating(sf):
    g = create_grating(sf, ppd=ppd, sSize=stimSize)
    plot_grating(g, compute_power(g), ppd)

interact(explore_grating, sf=(1,10,1))

## Drift creates a transient luminance signal

Let's simulate the luminance input of a single neuron which moves over time following ocular drift behavior, and also calculate the (temporal) power of this signal.
We simulate drift as a Brownian motion process.

Crucially, we can observe that the extent of these luminance modulations depends on the SF of the grating.
Temporal luminance modulations for low SF gratings are smaller than for high SF gratings.

In [None]:
def plot_drift(grating_2d, ppd, fps, T, path, lumTime, lumPower):
    nT = int(T*fps)+1; t = np.linspace(0, T, nT)
    tff = np.fft.fftshift(np.fft.fftfreq(nT, d=1./fps)); extent = [0, grating_2d.shape[0]/ppd]

    plt.figure(figsize=(15, 4))
    plt.subplot(131); plt.imshow(grating_2d, cmap='gray', extent=extent+extent)
    plt.plot(path[0, :] / ppd, path[1, :] / ppd, 'r')
    plt.title('Grating + drift'); plt.xlabel('deg'); plt.ylabel('deg')

    plt.subplot(132); plt.plot(t, lumTime, 'c.'); plt.plot(t, lumTime, alpha=0.2)
    plt.ylim(-grating_2d.max()*1.2, grating_2d.max()*1.2)
    plt.title('Luminance over time'); plt.xlabel('Time in s'); plt.ylabel('Luminance')

    plt.subplot(133); plt.plot(tff, np.log10(lumPower), 'c.'); plt.plot(tff, np.log10(lumPower), alpha=0.2)
    plt.ylim(-3., 6.); plt.xlim(0.1, 50.); plt.xscale('log'); plt.xticks([0.1, 1, 10, 50], [0.1, 1, 10, 50])
    plt.title('Temporal power'); plt.xlabel('tf in Hz'); plt.ylabel('Power in dB'); 

def explore_drift(sf):
    T = 2.; fps=100; D = 20/(60.**2.)                     # time; frames/second; diffusion constant
    gPow = np.zeros(int(T*fps+1))
    for i in range(100):                                  # average over 100 repetitions
        g = create_grating(sf, sSize=stimSize, ppd=ppd)   # create grating
        _, path = create_drift(T, fps, ppd, D)            # create drift path (in px)
        path += int(stimSize/2 * ppd)                     # shift to grating center

        lumTime = g[path[0,:], path[1,:]]                 # luminance over time for 1px
        gPow += compute_power(lumTime)**2.                # calculate temporal power
    plot_drift(g, ppd, fps, T, path, lumTime, gPow/100)   # plot

interact(explore_drift, sf=(1,7,3))

## Drift whitens the visual input

After these observations, let's simulate how these temporal modulations due to drift interact with the spatial characteristics of the visual input.

For this, we simulate the previous effect of drift for a larger number of SFs (gratings with different SFs), and compute the spatiotemporal power.
Finally, we plot the resulting power averaged across all non-zero TFs.
We can see that drift shifts power towards higher SFs.

Incidentally, this shift is the inverse of the typical $\frac{1}{f}$ spectrum we observe for natural scenes at SFs up to approximately 5 cpd.

In [None]:
# Parameters
ppd_ = 100          # spatial resolution; needs to be large to avoid Nyquist
T = 3.              # time in s; needs to be large for low TFs
fps = 100.          # temporal resolution; frames/second = Hz
D = 20.             # diffusion coefficient in arcmin**2/s
nT = int(T*fps)+1
sfs = fs1d[1::]     # select SFs

# Compute drift gain
gPow = np.zeros([nT, len(sfs)])
for r in range(50):                                                 # average over 50 repetitions
    for fi, f in enumerate(sfs):
        phiRd = np.random.random(1) * 2 * np.pi                     # randomize phase
        g = create_grating(f, phi=phiRd, ppd=ppd_, sSize=stimSize)  # create grating
        
        _, path = create_drift(T, fps, ppd_, D/(60.**2.))           # create drift path
        path += int(stimSize/2 * ppd)                               # shift to grating center
        lum = g[path[0,:], path[1,:]]                               # luminance over time for 1px
        
        gPow[:,fi] += compute_power(lum) / 50                       # compute spatiotemporal power
gPow = gPow[int(nT/2)+1::,:]                                        # remove negative, and non-zero TFs
gPow1d = 10. ** np.log10(gPow).mean(0)                              # average over TFs

In [None]:
ft = np.fft.fftshift(np.fft.fftfreq(nT, d=1./fps))[int(nT/2)+1::]   # TF axis for plotting
gPowDB = 10. * np.log10(gPow / gPow.max())                          # compute in db

plt.figure(figsize=(7, 2))
plt.subplot(121, aspect=0.8); plt.pcolormesh(sfs, ft, gPowDB, cmap='hot', vmin=-60, vmax=0., shading='auto')
plt.yscale('log'); plt.ylim(0.25); plt.yticks([1, 10], [1, 10]); plt.ylabel("TF (Hz)")
plt.xscale('log'); plt.xlim(0.25); plt.xticks([1, 10], [1, 10]); plt.xlabel("SF (cpd)")
plt.title("Spatiotemporal power");

plt.subplot(122); plt.plot(sfs, gPowDB.mean(0), '.-')
plt.xscale("log"); plt.xticks([1, 10], [1, 10]); plt.xlabel("SF (cpd)"); plt.ylabel("Power (db)");
plt.title("Drift gain");

## Simulate drift gain

To simulate the effect of drift for our purposes, we want to fit a function that mimics this drift gain across all SFs.

As mentioned before, we can emulate drift gain as $f^2$ for SFs up to about 5cpd.
However, for larger SFs, the curves flattens out.
To capture this behavior, we fit a sigmoid function to this part of the curve.

Sigmoid function:
$ sig(f) = A + \frac{K - A}{1 + Q exp(-bf)}$

In [None]:
from scipy.optimize import minimize

gPow1d = normRange(gPow1d, 0, 1)

def sigmoid(x, p=[1,]*4):
    A = p[0]; Q = p[1]; b = p[2]; K = p[3]
    return A + (K - A) / (1 + Q * np.exp(-b * x))

def optimizeSig(p):
    return np.abs(gPow1d - sigmoid(fs1d[1::], p)).sum()

# Optimize sigmoid to drift gain
res = minimize(optimizeSig, [1,]*4)
sig = sigmoid(fs1d, res.x)

# Replace lower part (below 5cpd) by 1/f
fswitch = 5.
idx = np.argmin(np.abs(fs1d-fswitch))
f2 = normRange(fs1d[0:idx]**2, 0, sig[idx-1])
driftGain1d = np.concatenate( (f2, sig[idx::]) )

print("Best fit of sig (loss):\t", res.fun)
print("Best params: \t\t", res.x)
print("Best fit total (loss): \t", np.abs(gPow1d - driftGain1d[1::]).sum())
print("Divide f2-part by:\t", fs1d[idx]**2 / sig[idx])

# Plot
plt.figure(figsize=(3,3))
plt.plot(sfs, gPow1d, '.', label="simulated data")
plt.plot(fs1d, driftGain1d, '-', linewidth=2.5, label="fitted function")
plt.legend(); plt.ylim(0); plt.show()

## Drift gain in 2d

Because of the aforementioned issues, we will perform our simulations in 2d rather than 1d.
Hence, we need the previous 1d-drift-gain-vector in 2d.

In [None]:
# Reproduce in 2d
f2 = np.sqrt(fx**2. + fy**2.)                   # create 2d array with SFs
sig = sigmoid(f2, res.x)                        # create 2d-sigmoid
f2[f2 > fswitch] = fswitch                      # cap SFs at 5cpd
mask = np.where(f2==fswitch, 1, 0)              # we will use this mask to join the sigmoid and f**2 parts
f2 = normRange(f2**2., 0, sig[mask==1].min())   # create f**2 part and normalize between [0,1]
driftGain = np.where(mask, sig, f2)             # join sigmoid and f**2 parts

plt.figure(figsize=(8,3))
plt.subplot(121); plt.imshow(driftGain); plt.colorbar()
plt.subplot(122); plt.plot(fs1d, driftGain[int(nX/2), int(nX/2)::], label="2d cut")
plt.plot(fs1d, driftGain1d, '--', label="1d"); plt.legend();

## Visual for manuscript

In the following, we will repeat the previous steps to create a nice visualization for the manuscript.
For this, we increase / decrease resolution where necessary to generate smooth spectra despite plotting in log-scale.
Besides for visualization, we will not use any of variables that we create along the way.

In [None]:
createVisual = False
if createVisual:
    plt.figure(figsize=(8,6))
    
    # Part I: Gratings + drift
    sfs_ = [.5, 3, 9]; fps_ = 1000; ppd_ = 200; sSize_ = 1.
    _, path = create_drift(T=0.2, pps=1000, ppd=ppd_, D=D/(60.**2.))
    path += int(sSize_/2 * ppd_)
    gAll = np.zeros([int(sSize_*ppd_), 3*int(sSize_*ppd_)])
    
    for fi, f in enumerate(sfs_):
        g = create_grating(f, sSize=sSize_, ppd=ppd_); gAll[:, int(sSize_*ppd_*fi):int(sSize_*ppd_*(fi+1))] = g
        lt = g[path[0,:], path[1,:]]
        plt.subplot(3,2,1); plt.plot(path[0, :]/ppd_+fi*sSize_, path[1, :]/ppd_, linewidth=.5, color=eCol[fi])
        plt.subplot(3,2,3); plt.plot(lt - lt[0], color=eCol[fi]);
        plt.xticks([]); plt.yticks([]); plt.xlabel("Time (s)"); plt.ylabel("Luminance")
    plt.subplot(3,2,1); plt.axis("off"); plt.imshow(gAll, cmap="gray", extent=[0, sSize_*3, 0, sSize_])

    # Part II: Spatiotemporal power
    sfs_=np.logspace(-2.5,0,30)*30; ppd_=64; sSize_=5.; T_=5; nR=500
    gPow_ = np.zeros([int(T_*fps+1), len(sfs_)])
    for r in range(nR):
        for fi, f in enumerate(sfs_):
            g = create_grating(f, phi=np.random.random(1)*2*np.pi, ppd=ppd_, sSize=sSize_)
            _, path = create_drift(T_, fps, ppd_, D/(60.**2.))
            path += int(sSize_/2 * ppd)
            lt = g[path[0,:], path[1,:]]
            gPow_[:,fi] += compute_power(lt) / nR
    gPow_ = gPow_[int(T_*fps/2)+1::,:]
    gPow1d_ = 10. ** np.log10(gPow_).mean(0)
    gPowDB_ = 10. * np.log10(gPow_ / gPow_.max())
    
    idx_ = np.argmin(np.abs(sfs_ - 5))
    sig_ = sigmoid(sfs_, res.x)
    f2_ = normRange(sfs_[0:idx_]**2, 0, sig_[idx_-1])
    gainPlot = normRange(np.concatenate((f2_, sig_[idx_::])), gPow_.min(), gPow_.max())
    gainPlot = normRange(10.*np.log10(gainPlot), -80, -19) # not ideal but good enough to have same range in plot
    
    ft = np.fft.fftshift(np.fft.fftfreq(int(T_*fps+1), d=1./fps))[int(T_*fps/2)+1::]
    plt.subplot(3,2,(2,4)); plt.pcolormesh(sfs_, ft, gPowDB_, cmap='hot', vmin=-55, vmax=0., shading='auto')
    plt.yscale('log'); plt.ylim(0.5,50); plt.yticks([1, 10], [1, 10]); #plt.ylabel("TF (Hz)")
    plt.xscale('log'); plt.xlim(0.5,30); plt.xticks([1, 10], [1, 10]); #plt.xlabel("SF (cpd)")
    
    plt.subplot(3,2,5); plt.xlabel("TF (Hz)"); plt.xlim(0.5,50); plt.xscale("log")
    plt.plot(ft, gPowDB_[:, np.argmin(np.abs(sfs_-9))], '-')
    plt.plot(ft, gPowDB_[:, np.argmin(np.abs(sfs_-3))], '-')
    plt.plot(ft, gPowDB_[:, np.argmin(np.abs(sfs_-.5))], '-'); plt.xticks([1, 10], [1, 10]);

    plt.subplot(3,2,6); plt.plot(sfs_, gPowDB_.mean(0), 'k.');
    plt.plot(sfs_, gainPlot, 'k-'); plt.xlim(0.5,30); plt.ylim(-50);
    plt.xscale("log"); plt.xticks([1, 10], [1, 10]); plt.xlabel("SF (cpd)");
    plt.savefig('driftgain.png', dpi=300)

# Comparison of static + dynamic power

In the following, we want to investigate whether we can account better for our empirical edge sensitivity data, if we take into account how drift redistributes stimulus power during fixations as we have seen in the previous steps.

## Generate dynamic stimulus spectra

In the next step, we apply the drift gain to the (static) power spectra for all our edges and noises individually.

In [None]:
# Calculate power dynamic spectra of noise masks
noisePowsDyn = []
for n in noisePows: noisePowsDyn.append(n*driftGain)

plt.figure(figsize=(16, 2))
for ni, n in enumerate(noisePowsDyn):
    plt.subplot(1,7,ni+1); plt.imshow(n); plt.title(noiseNames[ni]); plt.axis("off")

In [None]:
# Calculate dynamic power spectra of edges
edgePowsDyn = []
for e in edgePows: edgePowsDyn.append(e*driftGain)

plt.figure(figsize=(8, 2))
for ei, e in enumerate(edgePowsDyn):
    plt.subplot(1, 3, ei+1), plt.imshow(e), plt.title(edgeNames[ei]); plt.axis("off")

In [None]:
# Get horizontal cuts
edgePowsDyn1d = []; noisePowsDyn1d = []
for e in edgePowsDyn: edgePowsDyn1d.append(e[int(nX/2), int(nX/2)::])
for n in noisePowsDyn[1::]: noisePowsDyn1d.append(n[int(nX/2), int(nX/2)::]) # skip no-noise

# Plot with same mean (because arbitrary)
edgePowsDyn1dm = []; noisePowsDyn1dm = []
for ei,e in enumerate(edgePowsDyn1d): edgePowsDyn1dm.append(e/e.mean() * edgePows1d[ei].mean())
for ni,n in enumerate(noisePowsDyn1d): noisePowsDyn1dm.append(n/n.mean() * noisePows1d[ni].mean())

In [None]:
# Plot edge power
fig, axes = plt.subplots(1, 3, figsize=(3.3, 1), sharex=True, sharey=True)
for ei, e in enumerate(edgePowsDyn1dm[::-1]):
    axes[ei].plot(fs1d, edgePows1d[-ei-1], ':', color=eCol[ei], label=edgeNames[-ei-1]);
    axes[ei].plot(fs1d, e, '--', color=eCol[ei])
axes[0].set(**axParams, ylabel="Power (dB)");
#plt.savefig('edgespectra.png', dpi=300)

# Plot noise power
fig, axes = plt.subplots(1, 6, figsize=(6.8, 1), sharex=True, sharey=True)
for ni, n in enumerate(noisePowsDyn1dm):
    axes[ni].plot(fs1d, noisePows1d[ni], ':', color=noiseCols[ni], label=noiseNames_[ni]);
    axes[ni].plot(fs1d, n, '--', color=noiseCols[ni])
axes[0].set(**axParams, ylabel="Power (dB)");
#plt.savefig('noisespectra.png', dpi=300)

## Weight responses with CSF

We will weight the "responses" to the power spectra by the CSF.

In [None]:
# CSF based on Kelly (1979)
csfK = csf_kelly(np.sqrt((fy**2. + fx**2.)), tf=2.5)  # tf=2.5

# CSF based on castleCSF (Ashraf et al., 2024) imported from Matlab
mat_contents = sio.loadmat("castleCSF_schmittwilken2024.mat")
csfC = mat_contents["csf"]
matSF = mat_contents["fd"]
print(np.unique(matSF - np.sqrt(fx**2 + fy**2))) # are we using the same SFs?

In [None]:
plt.figure(figsize=(8,2));
plt.subplot(131); plt.imshow(csfK, extent=fextent, cmap="coolwarm"); plt.title("Kelly-CSF")
plt.subplot(132); plt.imshow(csfC, extent=fextent, cmap="coolwarm"); plt.title("castle-CSF")
plt.subplot(133)
plt.plot(fs1d, csfK[int(nX/2),int(nX/2)::]/csfK.max(), "-", label="Kelly-CSF");
plt.plot(fs1d, csfC[int(nX/2),int(nX/2)::]/csfC.max(), "-", label="castle-CSF"); plt.legend();
#plt.savefig('csf_profiles.png', dpi=300)

In [None]:
csfType = "kelly"     # using "kelly" or "castle" CSF models

weight = csfK if csfType == "kelly" else csfC

## Compute expected interference (via correlation)

Based on the idea of channel-specific interference, we assume that if edge spectra and noise spectra are more similar (ie. occupying the same channels), the stronger the masking effect.

We quantify this by simply computing a correlaton between each edge spectrum and each noise spectrum, and compare the resulting similarity pattern with the empirical data.

In [None]:
interference = "mult"

corrSpt = np.zeros([nEdges, nNoises-1]); corrAct = np.zeros([nEdges, nNoises-1])
if interference == "corr":
    # Correlate each edge and noise spectra
    for ei, e in enumerate(edgePows):
        for ni, n in enumerate(noisePows[1::]):
            corrSpt[ei, ni] = np.corrcoef((e*weight).flatten(), (n*weight).flatten())[0, 1]
            corrAct[ei, ni] = np.corrcoef((e*weight*driftGain).flatten(), (n*weight*driftGain).flatten())[0, 1]
else:
    # Multiply edge and noise spectra
    for ei, e in enumerate(edgePows):
        for ni, n in enumerate(noisePows[1::]):
            corrSpt[ei, ni] = ((e*weight) * (n*weight)).sum()
            corrAct[ei, ni] = ((e*weight*driftGain) * (n*weight*driftGain)).sum()

In [None]:
# Normalize correlations to threshold range per edge for comparability
empThreshN = np.zeros(empThresh.shape)
for ei in range(nEdges):
    corrSpt[ei, :] = normRange(corrSpt[ei, :], empThresh[ei, :].min(), empThresh[ei, :].max())
    corrAct[ei, :] = normRange(corrAct[ei, :], empThresh[ei, :].min(), empThresh[ei, :].max())

In [None]:
# Compute log-thresholds
corrSptLog = np.log(corrSpt)
corrActLog = np.log(corrAct)
empThreshLog = np.log(empThresh)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(6, 1), sharex=True, sharey=True)
fig.subplots_adjust(wspace=0.05)
for ei in range(nEdges):
    axes[0].plot(noiseNames[1::], empThreshLog[-ei-1, :], '.-', color=eCol[ei])
    axes[1].plot(noiseNames[1::], corrSptLog[-ei-1, :], '.:', color=eCol[ei])
    axes[2].plot(noiseNames[1::], corrActLog[-ei-1, :], '.--', color=eCol[ei])
    axes[ei].grid(True, color=[0.9,]*3); axes[ei].set(xticklabels=[], ylim=(-7,-3))
axes[0].set(ylabel="Thresholds (log)");
#plt.savefig('interference.png', dpi=300)
    
# Compute deviation from empirical data
fig, axes = plt.subplots(1, 3, figsize=(6, 1), sharex=True, sharey=True)
fig.subplots_adjust(wspace=0.05)
for ei in range(nEdges):
    axes[ei].axhline(0, color="k", linestyle="-", linewidth=0.5)
    axes[1].plot(noiseNames[1::], corrSptLog[-ei-1, :]-empThreshLog[-ei-1, :], '.:', color=eCol[ei])
    axes[2].plot(noiseNames[1::], corrActLog[-ei-1, :]-empThreshLog[-ei-1, :], '.--', color=eCol[ei])
    axes[ei].grid(True, color=[0.9,]*3); axes[ei].set(xticklabels=[], yticks=(-2,0,2), ylim=(-2.5,2.5))
axes[0].set(ylabel="${\Delta}$ threshold");
#plt.savefig('deviation.png', dpi=300)

print("Mean deviation static:\t", np.abs(corrSptLog - empThreshLog).mean())
print("Mean deviation dynamic:\t", np.abs(corrActLog - empThreshLog).mean())

# Results of mechanistic modeling

## Multi-scale models

In [None]:
def getModelThresh(rf):
    with open(rf, 'rb') as handle:
        moData = pickle.load(handle)
    moPCs = moData["pcs"][::-1,1::,:]
    moContrasts = moData["contrasts"][::-1,1::,:]
    return np.take_along_axis(moContrasts, np.argmin(np.abs(moPCs-0.75), axis=2, keepdims=True), axis=2)

try:
    # Load psicurve data from pickles
    moThreshSp = np.log(getModelThresh('../mechanistic_models/results/spatial_multi_psicurve.pickle'))   
    moThreshAc = np.log(getModelThresh('../mechanistic_models/results/active_multi_psicurve.pickle')) 
    
    fig, axes = plt.subplots(1, 3, figsize=(6, 1), sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.05)
    for ei in range(nEdges):
        axes[0].plot(noiseNames[1::], empThreshLog[-ei-1, :], '.-', color=eCol[ei])
        axes[1].plot(noiseNames[1::], moThreshSp[-ei-1, :], '.:', color=eCol[ei])
        axes[2].plot(noiseNames[1::], moThreshAc[-ei-1, :], '.--', color=eCol[ei])
        axes[ei].grid(True, color=[0.9,]*3); axes[ei].set(xticklabels=[], ylim=(-7,-3))
    axes[0].set(ylabel="Thresholds (log)");
    #plt.savefig('model_thresholds.png', dpi=300)

    # Compute deviation from empirical data
    fig, axes = plt.subplots(1, 3, figsize=(6, 1), sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.05)
    for ei in range(nEdges):
        axes[ei].axhline(0, color="k", linestyle="-", linewidth=0.5)
        axes[1].plot(noiseNames[1::], moThreshSp[-ei-1, :,0]-empThreshLog[-ei-1, :], '.:', color=eCol[ei])
        axes[2].plot(noiseNames[1::], moThreshAc[-ei-1, :,0]-empThreshLog[-ei-1, :], '.--', color=eCol[ei])
        axes[ei].grid(True, color=[0.9,]*3); axes[ei].set(xticklabels=[], yticks=(-1,0,1), ylim=(-1.1,1.1))
    axes[0].set(ylabel="${\Delta}$ threshold");
    #plt.savefig('model_deviation.png', dpi=300)
    
    print("Mean deviation static:\t", np.abs(moThreshSp[:,:,0] - empThreshLog).mean())
    print("Mean deviation dynamic:\t", np.abs(moThreshAc[:,:,0] - empThreshLog).mean())
    
except:
    print("No results file found")

## Single-scale models

In [None]:
try:
    # Load psicurve data from pickles
    moThreshSp = np.log(getModelThresh('../mechanistic_models/results/spatial_single_psicurve.pickle'))   
    moThreshAc = np.log(getModelThresh('../mechanistic_models/results/active_single_psicurve.pickle')) 
    
    fig, axes = plt.subplots(1, 3, figsize=(6, 1), sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.05)
    for ei in range(nEdges):
        axes[0].plot(noiseNames[1::], empThreshLog[-ei-1, :], '.-', color=eCol[ei])
        axes[1].plot(noiseNames[1::], moThreshSp[-ei-1, :], '.:', color=eCol[ei])
        axes[2].plot(noiseNames[1::], moThreshAc[-ei-1, :], '.--', color=eCol[ei])
        axes[ei].grid(True, color=[0.9,]*3); axes[ei].set(xticklabels=[], ylim=(-7,-3))
    axes[0].set(ylabel="Thresholds (log)");
    #plt.savefig('single_thresholds.png', dpi=300)

    # Compute deviation from empirical data
    fig, axes = plt.subplots(1, 3, figsize=(6, 1), sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.05)
    for ei in range(nEdges):
        axes[ei].axhline(0, color="k", linestyle="-", linewidth=0.5)
        axes[1].plot(noiseNames[1::], moThreshSp[-ei-1, :,0]-empThreshLog[-ei-1, :], '.:', color=eCol[ei])
        axes[2].plot(noiseNames[1::], moThreshAc[-ei-1, :,0]-empThreshLog[-ei-1, :], '.--', color=eCol[ei])
        axes[ei].grid(True, color=[0.9,]*3); axes[ei].set(xticklabels=[], yticks=(-1,0,1), ylim=(-1.1,1.1))
    axes[0].set(ylabel="${\Delta}$ threshold");
    #plt.savefig('single_deviation.png', dpi=300)
    
    print("Mean deviation static:\t", np.abs(moThreshSp[:,:,0] - empThreshLog).mean())
    print("Mean deviation dynamic:\t", np.abs(moThreshAc[:,:,0] - empThreshLog).mean())
    
except:
    print("No results file found")