# Run mechanistic models step-by-step

## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import pandas as pd
from scipy.signal import fftconvolve

sys.path.insert(1, '../mechanistic_models')
from functions import create_edge, create_noise, create_loggabors, create_isologgabors, \
    create_drift, apply_drift, add_padding, remove_padding, apply_filters, naka_rushton, \
    watson_tf, kelly_csf, zheng_tf, benardete_tf, compute_dprime

sys.path.insert(1, '../experimental_data')
from exp_params import stim_params as sparams

## Parameters


In [None]:
tempBool = True                  # True=active model. False=spatial model

In [None]:
fos = [0.5, 3., 9.]              # center SFs of log-Gabor filters
sigma_fo = 0.5945                # from Schütt & Wichmann (2017)
sigma_angleo = 0.2965            # from Schütt & Wichmann (2017)
n_trials = 1                     # average performance over n-trials
gain = None                      # gain control
Nt = 40; dt = 0.005              # number of steps; step size (s)
D = 20./(60.**2.)                # drift diffusion constant
tempType = "watson"              # choose temporal filter (watson, zheng, benardete, kelly)
collapseTime = True              # collapse temporal dimension before normalization
isoSF = False                    # isotropic spatial filter?

# Constant params
nFilters = len(fos)
ppd = sparams["ppd"]             # pixel resolution
fac = int(ppd*2)                 # padding to avoid border artefacts
sparams["n_masks"] = n_trials    # use same noise masks everytime

## Read psychophysical data

In [None]:
df = pd.read_csv("../experimental_data/expdata_pooled.txt", sep=" ")

noise_conds = np.unique(df["noise"])   # Noise conditions
edge_conds = np.unique(df["edge"])     # Edge conditions

## Create example stimulus

In [None]:
# Select example condition + contrast
n = noise_conds[0]; e = edge_conds[1]
print(n, e)

df_cond = df[(df["noise"]==n) & (df["edge"]==e)]
ncorrect = df_cond["ncorrect"].to_numpy()
ntrials = df_cond["ntrials"].to_numpy()
lamb = np.unique(df_cond["lambda"].to_numpy())[0]

# Create edge and noise stimulus
noise = create_noise(n, sparams)
edge = create_edge(df_cond["contrasts"].to_numpy()[4]*10, e, sparams)

plt.figure(figsize=(12, 3))
plt.subplot(131); plt.imshow(noise, cmap='gray'); plt.colorbar(), plt.title(n)
plt.subplot(132); plt.imshow(edge,  cmap='gray'); plt.colorbar(), plt.title(e)
plt.subplot(133); plt.imshow(edge+noise, cmap='gray'); plt.colorbar()
plt.show()

# Create spatiotemporal filters

## Spatial filters

In [None]:
# Create SF axes + log-Gabor filters
nX = int(sparams["stim_size"]*ppd)
fs = np.fft.fftshift(np.fft.fftfreq(nX, d=1./ppd))
fx, fy = np.meshgrid(fs, fs)

if isoSF:
    loggabors = create_isologgabors(fx, fy, fos, sigma_fo)
else:
    loggabors = create_loggabors(fx, fy, fos, sigma_fo, 0., sigma_angleo)

plt.figure(figsize=(12, 2))
for i in range(nFilters):
    plt.subplot(1, nFilters, i+1), plt.imshow(loggabors[i], cmap="coolwarm")

## Temporal filter

In [None]:
# Create temporal filter(s)
if tempType == "watson":
    tempFilter = watson_tf(Nt, dt)
    tempFilterP = np.fft.fftshift(np.fft.fft(tempFilter)) # fft
elif tempType == "kelly":
    tf = np.fft.fftshift(np.fft.fftfreq(Nt, d=dt))
    Tcsf05 = kelly_csf(sfs=[.5,], tfs=tf)
    Tcsf3 = kelly_csf(sfs=[3.,], tfs=tf)
    Tcsf9 = kelly_csf(sfs=[9.,], tfs=tf)
    tempFilterP = np.array([Tcsf05, Tcsf3, Tcsf9])
elif tempType == "zheng":
    tempFilterP = zheng_tf(Nt, dt)
elif tempType == "benardete":
    tempFilterP = benardete_tf(Nt, dt)

tempFilterP = tempFilterP / tempFilterP.max()  # Normalize for now

plt.figure(figsize=(8,2))
plt.plot(np.fft.fftshift(np.fft.fftfreq(Nt, d=dt)), np.abs(tempFilterP), '.-'); #plt.xlim(0,50)
plt.show()

# Create drift trace

In [None]:
# Create drift instance
driftFloat, driftInt = create_drift(Nt*dt-dt, 1./dt, ppd, D)

plt.figure(figsize=(12, 2))
plt.subplot(121); plt.plot(driftFloat[0,:], '.'); plt.plot(driftInt[0,:])  # plot x
plt.subplot(122); plt.plot(driftFloat[1,:], '.'); plt.plot(driftInt[1,:])  # plot y
plt.show()

In [None]:
# Optional: Apply drift to stimulus and animate movie (uncomment)
import matplotlib.animation as animation

#stimVid = apply_drift(edge+noise, driftInt, edge.mean())
#stimVid.shape
#vid = stimVid

#fig, ax = plt.subplots(1, 2, figsize=(10, 4))
def animate(t):
    plt.rcParams["animation.html"] = "jshtml"
    plt.rcParams['figure.dpi'] = 60  
    plt.ioff()
    
    ax[0].imshow(vid[:,:,t], cmap='gray')
    ax[1].plot(np.arange(0, t), vid[20,20,0:t], 'k.')
    plt.xlim(0, vid.shape[2]), plt.ylim(vid.min(), vid.max())

#animation.FuncAnimation(fig, animate, frames=vid.shape[2]) # show animation

# Run model

## Model parameters

In [None]:
mparams = {"n_filters": nFilters,
           "fos": fos,
           "sigma_fo": sigma_fo,
           "sigma_angleo": sigma_angleo,
           "loggabors": loggabors,
           "fac": fac,
           "nX": nX,
           "n_trials": n_trials,
           "gain": gain,
#           "outDir": outDir,
           "sameNoise": True,
           "noiseVar": 1.,
           "Nt": Nt,
           "dt": dt,
           "D": D,
           "tempType": tempType,
           #"tempFilter": tempFilter,
           "tempFilterP": tempFilterP,
           "tempBool": tempBool,
           "collapseTime": collapseTime,
           "drift": [driftInt,],
           "isoSF": isoSF,
           "trial": 0,
           }

## Apply filters

In [None]:
out = apply_filters(edge+noise, mparams)
out2 = apply_filters(edge, mparams)

print("Total mean activity:", out.mean())
print("Max:", out.max())
print("Min:", out.min())

In [None]:
# Plot activation of each channel
plt.figure(figsize=(15, 4))
for i in range(nFilters):
    if tempBool and not collapseTime:
        plt.subplot(1, 3, i+1), plt.imshow(out[:,:,i,:].mean(2), cmap="coolwarm"); plt.colorbar()
    else:
        plt.subplot(1, 3, i+1), plt.imshow(out[:,:,i], cmap="coolwarm"); plt.colorbar()
plt.show()

## Apply Naka-Rushton

In [None]:
# Naka-parameters
alphas=(3., 4., 1.); beta=1e-15; eta=1.2; kappa=6.

outNaka1 = naka_rushton(out, alphas, beta, eta, kappa, mparams["gain"])
outNaka2 = naka_rushton(out2, alphas, beta, eta, kappa, mparams["gain"])

In [None]:
# Plot activation of each channel
plt.figure(figsize=(15, 4))
for i in range(nFilters):
    if tempBool and not collapseTime:
        plt.subplot(1, 3, i+1), plt.imshow(outNaka1[:,:,i,:].mean(2), cmap="coolwarm"); plt.colorbar()
    else:
        plt.subplot(1, 3, i+1), plt.imshow(outNaka1[:,:,i], cmap="coolwarm"); plt.colorbar()
plt.show()

## Decoding

Note that performance (both for humans and the models) is dependent on the noise instance.
That's why we average performance over many noise instances to predict performance.

In [None]:
pc = compute_dprime(outNaka1, outNaka2, lamb, mparams["noiseVar"])  # dprime

print("Predicted performance:", pc)

# Extra: oriented vs unoriented filters

As we have shown in Schmittwilken & Maertens (2022), fixational eye movements obviate orientation-selectivity for edge extraction.
In the following, we illustrate this process again.

In the absence of FEMs, we need orientation-selective filters to generate high activities at the location where the edge is placed in the visual input. In the presence of FEMs, orientation-selectivity rather leads to strong activations adjacent to the edge.

We find the opposite for unoriented filters.

In [None]:
def apply_filters_(stim, mparams):
    stim = add_padding(stim, mparams["fac"], stim.mean(), axis=1) # padding
    if mparams["isoSF"]:
        # for the odd-symmetric filter, we did not need padding in filter orientation
        stim = add_padding(stim, mparams["fac"], stim.mean(), axis=0)
    
    # further remove border artefacts through masking
    fc = int(mparams["nX"] * 0.15)
    fc += fc % 2
    mask = np.pad(np.ones([mparams["nX"]-fc, mparams["nX"]-fc]), (int(fc/2), int(fc/2)))

    if mparams["tempBool"]:
        out = np.zeros([mparams["nX"], mparams["nX"], mparams["n_filters"], mparams["Nt"]])
        mask = np.expand_dims(mask, (-1,-2))
    else:
        out = np.zeros([mparams["nX"], mparams["nX"], mparams["n_filters"]])
        mask = np.expand_dims(mask, -1)
    
    # Spatial filtering
    for fil in range(mparams["n_filters"]):
        outTemp = fftconvolve(stim, mparams["loggabors"][fil], mode='same')
        outTemp = remove_padding(outTemp, mparams["fac"], axis=1)
        if mparams["isoSF"]:
            outTemp = remove_padding(outTemp, mparams["fac"], axis=0)
    
        # Temporal filtering (padding in time does not change output)
        if mparams["tempBool"]:
            # temporal filtering in freq space because it seems more robust
            outTemp = apply_drift(outTemp, mparams["drift"][mparams["trial"]], outTemp.mean())
            if mparams["tempType"] == "kelly":
                thisFilt = np.expand_dims(mparams["tempFilterP"][fil,:], (0,1))
            else:
                thisFilt = np.expand_dims(mparams["tempFilterP"], (0,1))
            outTemp = np.fft.fftshift(np.fft.fftn(outTemp)) * thisFilt
            out[:, :, fil, :] = np.real(np.fft.ifftn(np.fft.ifftshift(outTemp)))
        else:
            out[:, :, fil] = outTemp
    return out * mask

In [None]:
sfi = 0
cmap = "coolwarm"

In [None]:
# Oriented log-Gabors
loggabors = create_loggabors(fx, fy, fos, sigma_fo, 0., sigma_angleo)
mparams["loggabors"] = loggabors; mparams["isoSF"] = False

# No drift
mparams["tempBool"] = False
outOriNo = apply_filters_(edge, mparams)

# Drift
mparams["tempBool"] = True
outOriYes = apply_filters_(edge, mparams)

plt.figure(figsize=(10,4))
plt.subplot(241); plt.imshow(edge, cmap="gray"); plt.title("Edge"); plt.axis("off")
plt.subplot(242); plt.imshow(loggabors[sfi], cmap=cmap); plt.title("Spatial filter"); plt.axis("off");
plt.subplot(243); plt.plot(outOriNo[int(nX/2),:,sfi]); plt.title("No drift"); plt.axis("off");
plt.subplot(244); plt.plot(outOriNo[int(nX/2),:,sfi]); plt.title("Drift"); plt.axis("off");
plt.subplot(247); plt.imshow(np.abs(outOriNo[:,:,sfi]), cmap=cmap); plt.axis("off"); #plt.colorbar()
plt.subplot(248); plt.imshow(np.abs(outOriYes[:,:,sfi,:]).mean(2), cmap=cmap); plt.axis("off"); #plt.colorbar();
#plt.savefig('oriented.png', dpi=300)

In [None]:
# Unoriented log-Gabors
loggabors = create_isologgabors(fx, fy, fos, sigma_fo)
mparams["loggabors"] = loggabors; mparams["isoSF"] = True

# No drift
mparams["tempBool"] = False
outUnoriNo = apply_filters_(edge, mparams)

# Drift
mparams["tempBool"] = True
outUnoriYes = apply_filters_(edge, mparams)

plt.figure(figsize=(10,4))
plt.subplot(241); plt.imshow(edge, cmap="gray"); plt.title("Edge"); plt.axis("off");
plt.subplot(242); plt.imshow(loggabors[sfi], cmap=cmap); plt.title("Spatial filter"); plt.axis("off");
plt.subplot(243); plt.plot(outUnoriNo[int(nX/2),:,sfi]); plt.title("No drift"); plt.axis("off");
plt.subplot(244); plt.plot(outUnoriNo[int(nX/2),:,sfi]); plt.title("Drift"); plt.axis("off");
plt.subplot(247); plt.imshow(np.abs(outUnoriNo[:,:,sfi]), cmap=cmap); plt.axis("off"); #plt.colorbar()
plt.subplot(248); plt.imshow(np.abs(outUnoriYes[:,:,sfi,:]).mean(2), cmap=cmap); plt.axis("off"); #plt.colorbar();
#plt.savefig('unoriented.png', dpi=300)