In [None]:
import numpy as np
import math
import scipy.io
import os
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
from PIL import Image
import itertools
import matplotlib.pyplot as plt
import matplotlib.colors as clr
import ipywidgets as widgets

In [None]:
def create_spatial_grid(stim, stim_ecc, gridsize):
    ##### NOT IN USE YET ######
    ecc = np.linspace(-stim_ecc,stim_ecc,gridsize)
    Y, X = np.meshgrid(ecc, ecc, indexing='ij')
    Y = -1*Y
    (s1, s2) = (stim.shape[0] // X.shape[0]+1, stim.shape[1] // X.shape[0]+1)
    input_stim = stim[::s1,::s2]
    return input_stim, X, Y

In [None]:
def flat_gaussian_field(X,Y,x,y,sigma,normalize):
    gaussian = np.exp(-((X-x**2 + (Y-y**2)/(2*sigma**2))))
    if normalize:
        gaussian = gaussian/np.sum(gaussian)
    gaussian =  np.reshape(gaussian, (len(X)*len(X)))
    return gaussian

In [None]:
def simulate_normalization_model(stimpath, stim_ecc, attx0, atty0, attsd, attgain, sigmaNorm, gridsize):
    # load the stimulus
    os.chdir(stimpath)
    stimtemp = scipy.io.loadmat(os.getcwd() + '/stim.mat')
    stimtemp = stimtemp['stim']
    stimorig = stimtemp[:,:,0:48]

    input_stim, X, Y = create_spatial_grid(stimorig, stim_ecc, gridsize)

    # set the anonymous functions for spatial transformations:
    flatten = lambda x: np.reshape(x, (gridsize*gridsize))
    unflatten = lambda x: np.reshape(x, (gridsize,gridsize,stimorig.shape[2]))

    attfield = np.exp(-((X-attx0)**2 +(Y-atty0)**2)/(2*attsd)**2)
    attfield = attgain*attfield  + 1
    attfield = np.reshape(attfield, (gridsize*gridsize))

    nCenters = len(input_stim[0])
    coord = np.sqrt((stim_ecc**2)/2)
    x = -1*np.linspace(-coord, coord, nCenters) # double check these signs later
    y = np.linspace(-coord, coord, nCenters)
    stimdrivenRFs = np.zeros((3,len(x)*len(y)))

    iter_idx = 0
    for i in range(0,len(x)):
            for j in range(0,len(y)):
                stimdrivenRFs[0,iter_idx] = x[i] # make this better, super inefficient
                stimdrivenRFs[1,iter_idx] = y[j]
                iter_idx = iter_idx + 1 
              
    

    num_prfs = len(stimdrivenRFs[1])
    for rf in range(0,num_prfs):
        xCenter = stimdrivenRFs[0,rf] 
        yCenter = stimdrivenRFs[1,rf]
        eccen = np.sqrt(xCenter**2 + yCenter**2)
        stimdrivenRFs[2,rf] = 0.05 + 0.2*eccen

    # preallocate
    stimdrive    = np.zeros((len(stimdrivenRFs[1]),len(input_stim[1,1,:])))
    numerator = np.zeros((len(stimdrivenRFs[1]),len(input_stim[1,1,:])))
    surroundresponse = np.zeros((len(stimdrivenRFs[1]),len(input_stim[1,1,:])))

    for stimidx in range(0,len(input_stim[1,1,:])):
        for rf in range(0,num_prfs):
            RF = np.exp(-((X-(stimdrivenRFs[1,rf]))**2 + 
                          (Y-(stimdrivenRFs[0,rf]))**2)/(2*(stimdrivenRFs[2,rf]))**2)
            RF = RF/np.sum(RF)
            RF = flatten(RF)
            stim = input_stim[:,:,stimidx]
            stim = flatten(stim)
            stimdrive[rf,stimidx] = np.inner(RF,stim)

        numerator[:,stimidx] = np.multiply(stimdrive[:,stimidx],attfield)

        for rfsuppind in range(0,num_prfs):
            suppsurround = np.exp(-((X-(stimdrivenRFs[1,rfsuppind]))**2 + 
                          (Y-(stimdrivenRFs[0,rfsuppind]))**2)/(2*1.5*(stimdrivenRFs[2,rfsuppind]))**2)
            suppsurround = suppsurround/np.sum(suppsurround)
            flatsurround = flatten(suppsurround)
            surroundresponse[rfsuppind,stimidx] = np.inner(flatsurround,numerator[:,stimidx])

    stimdriveIm = unflatten(stimdrive);
    numeratorIm = unflatten(numerator);
    suppIm      = unflatten(surroundresponse);

    sptPopResp = numeratorIm/(suppIm + sigmaNorm);
    predneuralweights = np.reshape(sptPopResp, (gridsize*gridsize,stimorig.shape[2]));

    spsummedresponse = np.zeros((len(stimdrivenRFs[1]),len(input_stim[1,1,:])))
    for stimidx in range(0,len(input_stim[1,1,:])):
        for summidx in range(0,num_prfs):
            spatialsumm = np.exp(-((X-(stimdrivenRFs[1,summidx]))**2 + 
                            (Y-(stimdrivenRFs[0,summidx]))**2)/(2*2.2*(stimdrivenRFs[2,summidx]))**2)
            spatialsumm = spatialsumm/np.sum(spatialsumm)
            flatspatialsumm = flatten(spatialsumm)
            spsummedresponse[summidx,stimidx] = np.inner(flatspatialsumm,predneuralweights[:,stimidx])
            
    return predneuralweights, spsummedresponse

In [None]:
stimpath = '/Users/et2160/Desktop/pytorch_practice/pytorch_practice/stimfiles'
stim_ecc    = 10
attgain     = 4
attx0       = 0
atty0       = 5
attsd       = 1
sigmaNorm   = 0.01
gridsize    = 64

# get the baseline estimates
baseline_neural, baseline_sptsumm = simulate_normalization_model(stimpath, stim_ecc, 0, 0, 2.3, 2, sigmaNorm, gridsize)

predneuralweights, spsummedresponse = simulate_normalization_model(stimpath, stim_ecc, attx0, atty0, attsd, attgain, sigmaNorm, gridsize)

In [None]:
# visualize
clims = [0,30]
spsummedresponse_image = np.reshape(spsummedresponse, (gridsize, gridsize, spsummedresponse.shape[1]))
predneuralweights_image = np.reshape(predneuralweights, (gridsize, gridsize, predneuralweights.shape[1]))

# set the alpha level for t-stats overlay for better visualization:
@widgets.interact(stimuluslocation=widgets.IntSlider(min=0, max=predneuralweights_image.shape[2]-1, step=1, value=0))
def plot_functional(stimuluslocation):
    fig, ax1 = plt.subplots(1, 1, figsize=[4,4])
    ax1.imshow(predneuralweights_image[:,:,stimuluslocation], cmap='magma')
    fig, ax2 = plt.subplots(1, 1, figsize=[4,4])
    ax2.imshow(spsummedresponse_image[:,:,stimuluslocation], cmap='magma')
    plt.show()

In [None]:
ind = 519 # this pixel is right beneath the upper attention target location (hand picked)

# set the alpha level for t-stats overlay for better visualization:
@widgets.interact(voxel=widgets.IntSlider(min=500, max=564, step=1, value=0))
def plot_functional(voxel):
    fig, ax1 = plt.subplots(1, 1, figsize=[4,4])
    ax1.plot(spsummedresponse[voxel,:], color='purple')
    ax1.plot(baseline_sptsumm[voxel,:], color='black')
    ax1.legend(['attend up', 'neutral'])
    plt.show()

