In [1]:
import numpy as np
import math
import scipy.io
import os
import matplotlib.pyplot as plt
from scipy.io import loadmat
import ipywidgets as widgets
import torch

In [None]:
def create_spatial_grid_torch(stim, stim_ecc, gridsize):
    ecc = torch.linspace(-stim_ecc,stim_ecc,gridsize)
    X, Y = torch.meshgrid(ecc, ecc)
    Y = -1*Y
    (s1, s2) = (stim.shape[0] // X.shape[0]+1, stim.shape[1] // X.shape[0]+1)
    input_stim = stim[::s1,::s2]
    return input_stim, X, Y

In [None]:
def flat_gaussian_field_torch(X,Y,x,y,sigma,gain,normalize):
    gaussian = gain*(torch.exp(-((X-x)**2 +(Y-y)**2)/(2*sigma)**2))
    if normalize: # this normalizes the Gaussian field to the unit volume before flattening it
        gaussian = gaussian/torch.sum(gaussian)
    return gaussian.flatten()

In [None]:
def simulate_prfs(n_voxels = 9, stim_ecc = 10):  
    coord = torch.sqrt((torch.tensor((stim_ecc))**2)/2)

    x = torch.linspace(-coord, coord, n_voxels)
    y = -1*torch.linspace(-coord, coord, n_voxels)
    simulated_prfs = torch.zeros((n_voxels, len(x)*len(y)))
    sigma = torch.zeros((n_voxels))

    for rf in range(0,n_voxels):
        eccen = torch.sqrt(x[rf]**2 + y[rf]**2)
        sigma[rf] = 0.05 + 0.2*eccen

    X, Y = torch.meshgrid(x, y)
    for rf in range(0,n_voxels):
        gain[rf] = torch.randint(1,4,(1,)) # assign a random gain value when simulating, not sure if the best way
        simulated_prfs[rf,:] = flat_gaussian_field_torch(X,Y,x[rf],y[rf],sigma[rf], gain[rf], True)

    prf_parameters = np.zeros((4,9))
    prf_parameters[0,:] = x
    prf_parameters[1,:] = y
    prf_parameters[2,:] = sigma
    prf_parameters[3,:] = gain
    
    return simulated_prfs, prf_parameters

In [None]:
class AttModel(torch.nn.Module):
    def __init__(self,  prf_parameters, n_voxels=9, gridsize=81, stim_ecc=10, 
                 attention_ctr=(0, 0)):
        """
        """
        super().__init__()
        self.simulated_prfs = prf_parameters
        self.attention_ctr = attention_ctr
        self.n_voxels = n_voxels
        ecc = torch.linspace(-stim_ecc, stim_ecc, gridsize)
        self.X, self.Y = torch.meshgrid(ecc, ecc)
        
        self.voxel_gain = torch.nn.Parameter(torch.rand(1, n_voxels, dtype=torch.float32))
        self.attention_sigma = torch.nn.Parameter(torch.rand(1, dtype=torch.float32))
        self.attention_gain = torch.nn.Parameter(torch.rand(1, dtype=torch.float32))
        self.suppression_sigma_scale_factor = torch.nn.Parameter(torch.rand(1, dtype=torch.float32))
        self.suppression_gain = torch.nn.Parameter(torch.rand(1, dtype=torch.float32))
        self.summation_sigma_scale_factor = torch.nn.Parameter(torch.rand(1, dtype=torch.float32))
        
    def forward(self, stim):
        """
        """

        stimdrive = torch.empty(self.n_voxels)
        numerator = torch.empty(self.n_voxels)
        surroundresponse = torch.empty(self.n_voxels)
        for rf in range(self.n_voxels):
            RF = flat_gaussian_field_torch(self.X, self.Y, self.simulated_prfs[0,rf], self.simulated_prfs[1,rf], self.simulated_prfs[2,rf], self.voxel_gain, True)
            stim = stim.flatten()
            stimdrive[rf] = torch.dot(RF, stim)
            attweight = torch.exp(-((self.attention_ctr[0]-self.simulated_prfs[0,rf])**2 
                                     + (self.attention_ctr[1]-self.simulated_prfs[1,rf])**2)/(2*self.attention_sigma)**2)
            attweight =  self.attention_gain*attweight+1
            numerator[rf] = torch.multiply(stimdrive[rf],attweight)
        
        for rfsupp in range(self.n_voxels):
            distance = torch.sqrt((self.X-self.stim_driven_prfs[1, rfsupp])**2
                                  + (self.Y-self.stim_driven_prfs[0,rfsupp])**2)
            suppfield = torch.exp(-.5*(distance/(self.stim_driven_prfs[2,rfsupp]*self.suppression_sigma_scale_factor))**2)
            suppfield = suppfield / torch.sum(suppfield)
            flatsurr = flatsurr.flatten()
            surroundresponse[rfsupp] = torch.dot(torch.transpose(flatsurr),numerator)
            
        # ADD SUPPRESSION
        return numerator

In [None]:
simulated_prfs, prf_parameters = simulate_prfs()

In [None]:
attn_model = AttModel(simulated_prfs)

In [None]:
stimpath = '/Volumes/server/Projects/attentionpRF/Simulations/python_scripts/stimfiles'
stimtemp = scipy.io.loadmat(stimpath + '/stim.mat')
stimtemp = torch.from_numpy(stimtemp['stim']).to(torch.float32)
stimorig = stimtemp[:,:,12]
stimorig = stimorig[::1080//81, ::1080//81]
stimorig = stimorig[:81, :81]

In [None]:
attn_model(stimorig)