# GLMS -- SingleRFs 

In [1]:
import sys
import os
import h5py 

# setup paths
iteration = 1 # which version of this tutorial to run (in case want results in different dirs)
NBname = 'color_cloud_initial{}'.format(iteration)

myhost = os.uname()[1] # get name of machine
print("Running on Computer: [%s]" %myhost)

if myhost=='mt': # this is sigur
    
#    sys.path.insert(0, '/home/jake/Repos/')
#    dirname = os.path.join('.', 'checkpoints')
#    datadir = '/home/dbutts/V1/Conway/'
else:
    sys.path.insert(0, '/home/felixbartsch/Code/') 
    datadir = '/Data/FelixData/Conway/'  
    dirname = '/home/felixbartsch/Data/Colorworkspace/' # Working directory 

import numpy as np
import scipy.io as sio
from copy import deepcopy

# plotting
import matplotlib.pyplot as plt

# Import torch
import torch
from torch import nn

# NDN tools
import NDNT.utils as utils # some other utilities
from NDNT.utils import imagesc   # because I'm lazy
from NDNT.utils import ss        # because I'm real lazy
import NDNT.NDNT as NDN
from NDNT.modules.layers import *
from NDNT.networks import *
from time import time
import dill
import ColorDataUtils.ConwayUtils as CU
import ColorDataUtils.RFutils as RU

from NTdatasets.generic import GenericDataset
import NTdatasets.conway.cloud_datasets as datasets
import NTdatasets.conway.bar1d_datasets as bardatasets

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device0 = torch.device("cpu")
dtype = torch.float32

# Where saved models and checkpoints go -- this is to be automated
print( 'Save_dir =', dirname)

%load_ext autoreload
%autoreload 2

IndentationError: expected an indented block (2676782858.py, line 17)

In [None]:
# Load BINOCULAR only (single RFs)
# Load data
fn = 'Jocamo_220801_full_CC_ETCC_nofix_v08'
fndate = '0801'
dirname2 = dirname+fndate+'/'
dirname_mod = dirname+fndate+'/models/'

UNX = 52
num_lags = 10

inclMUs = True
t0 = time()
data = datasets.ColorClouds(
    filenames=[fn], eye_config=3, drift_interval=16,
    datadir=datadir, folded_lags=False, luminance_only=False,
    trial_sample=True, num_lags=num_lags, 
    include_MUs=inclMUs)
t1 = time()
print(t1-t0, 'sec elapsed')

lam_units = np.where(data.channel_map < 32)[0]
ETunits = np.where(data.channel_map > 32)[0]
UTunits = np.where(data.channel_map >= 32+127)[0]
NFunits = np.where((data.channel_map > 32) & (data.channel_map < (32+128)))[0]
print( "%d laminar units, %d ET units"%(len(lam_units), len(ETunits)))

# Pull correct saccades
matdat = sio.loadmat( datadir+fn+'_ETupdate.mat')
data.process_fixations(matdat['sac_binsB'][0, :])

sac_ts_all = matdat['ALLsac_bins'][0, :]
#data.process_fixations( sac_ts_all )

#    'tsacs_msB': Bsac_ts,  ## microsaccade times in seconds (4 sec trials)
#    'sac_bins_allB': Bsacbins_all, # bin numbers of sac sampled at 60 Hz
#    'sac_amps_allB': CampsB, ## squared magnitude
#    'sac_ampsB': Bsac_amps, # reduced sac times & amplitude given threshold
#    'sac_binsB': Bsac_bins,
#    'sac_tsB': Bsac_ts,
#    'et60HzB': ETprocB, # downsampled processed eye trace
#    'et1kHzB': et1khzB,

NT = len(data.fix_n)
NA = data.Xdrift.shape[1]
print("%d (%d valid) time points"%(NT, len(data)))

# # Replace DFs
matdat = sio.loadmat(datadir+fn+'_DFextra.mat')
data.dfs = torch.tensor( matdat['XDF'][:NT, :], dtype=torch.float32 )
data.dfs.shape

#shift robs to test weird NLRF alignment
#data.robs=torch.from_numpy(np.roll(data.robs,shift=8,axis=0))

In [None]:
SHfile = sio.loadmat(dirname2 + 'J22'+fndate+'_bestshifts.mat')
#SHfile = sio.loadmat(dirname2 + 'BDshifts0628i3.mat')
fix_n  = SHfile['fix_n']
fixshifts = SHfile['shifts']
metricsLL = SHfile['metricsLL']
metricsTH = SHfile['metricsTH']
ETshifts  = SHfile['ETshifts']
ETmetrics = SHfile['ETmetrics']


In [None]:
## only use most confident fixations
goodfix = np.where(ETmetrics[:,1] < 0.99)
valfix = torch.zeros([ETmetrics.shape[0], 1], dtype=torch.float32)
valfix[goodfix] = 1.0
# Test base-level performance (full DFs and then modify DFs)
data.dfs=np.multiply(data.dfs,valfix)

In [None]:
# Segregate 'valid' Utah units
Reff = torch.mul(data.robs[:, UTunits], data.dfs[:, UTunits]).numpy()
nspks = np.sum(Reff, axis=0)
a = np.where(nspks > 10)[0]
valUT = UTunits[a]
NCUT = len(valUT)

Reff = torch.mul(data.robs[:, NFunits], data.dfs[:, NFunits]).numpy()
nspks = np.sum(Reff, axis=0)
a = np.where(nspks > 10)[0]
valNF = NFunits[a]
NCNF = len(valNF)

Reff = torch.mul(data.robs[:, ETunits], data.dfs[:, ETunits]).numpy()
nspks = np.sum(Reff, axis=0)
a = np.where(nspks > 100)[0]
valET = ETunits[a]
NCv_ET = len(valET)

Reff = torch.mul(data.robs[:, lam_units], data.dfs[:, lam_units]).numpy()
nspks = np.sum(Reff, axis=0)
a = np.where(nspks > 10)[0]
valLP = lam_units[a]
NCL = len(valLP)
NCv = len(valLP)

print( "%d out of %d Laminar units kept"%(NCL, len(lam_units)) )
print( "%d out of %d NF units kept"%(NCNF, len(NFunits)) )
print( "%d out of %d Utah units kept"%(NCUT, len(UTunits)) )

In [None]:
# New GQM fits:
top_corner = np.array([970, 540], dtype=np.int64) #for 0808 - LP meas done
#top_corner = np.array([980, 530], dtype=np.int64) #for 0722 - LP meas done; ET done; GQMs done

# OLD fits/ TODO:
#TODO with new 1Dbar-compatible loading:
#for 0207
#for 0217
#for 0304
#top_corner = np.array([966, 535], dtype=np.int64) #for 0314
#for 0316
#for 0318
#top_corner = np.array([966, 530], dtype=np.int64) #for 0320 - LP meas done
#for 0321

# 0616
#top_corner = np.array([955, 525], dtype=np.int64) #for 0621 -LP meas done (1D ET)

#have ET clouds:
#top_corner = np.array([950, 530], dtype=np.int64) #for 0628 - LP meas done; ET done (previously [965, 535] - may ned to fit second set of STAs with new location)
#top_corner = np.array([950, 530], dtype=np.int64) # 0701 - LP meas done
#top_corner = np.array([965, 530], dtype=np.int64) # 0705 - LP meas done
#top_corner = np.array([960, 535], dtype=np.int64) #for 0707 - LP meas done
# 0711
#top_corner = np.array([955, 535], dtype=np.int64) #for 0713 - LP meas done
#top_corner = np.array([938, 512], dtype=np.int64) #for 0715 - All meas done
#top_corner = np.array([950, 530], dtype=np.int64) #for 0718 - LP meas done; ET done
#top_corner = np.array([995, 545], dtype=np.int64) #for 0720 - LP meas done

#for 0725 - Ethan sorted
#top_corner = np.array([1000, 555], dtype=np.int64) #for 0727 - LP meas done
#top_corner = np.array([980, 540], dtype=np.int64) #for  0801 - LP meas done; ET needs to be redone; previously [990, 550]
#top_corner = np.array([985, 560], dtype=np.int64) #for 0803 -LP meas done
#
# 0909
# top_corner = np.array([965, 530], dtype=np.int64) #for 0921 - LP meas done
# top_corner = np.array([1020, 545], dtype=np.int64) #for 0926 - LP meas done
#top_corner = np.array([1005, 530], dtype=np.int64) #for 1003 - LP meas done
#top_corner = np.array([1015, 540], dtype=np.int64) #for 1007 - LP meas done

data.draw_stim_locations(top_corner = top_corner, L=60)

In [None]:
#data.assemble_stimulus(which_stim='lam', stim_wrap = [30, -25], fixdot=0 ) #, stim_crop=[0,49,5,54])

data.assemble_stimulus(top_corner=top_corner, fixdot=0, L=60, shifts=-fixshifts)

In [None]:
# LP STAs
Reff = torch.mul(data.robs[:, valLP], data.dfs[:, valLP])
nspks = torch.sum(Reff, axis=0)
lag = 3
#stas0 = ((data.stim[:-lag, ...].T @ Reff[lag:,:]).squeeze() / nspks).reshape([60,60,-1]).numpy()
stas0 = ((data.stim[:-lag, ...].T @ Reff[lag:,:]).squeeze() / nspks).reshape([3,60,60,-1]).numpy()

In [None]:
Nrows=int(np.ceil(NCv/4))
ss(Nrows,4)
for cc in range(NCv):
    plt.subplot(Nrows,4, cc+1)
    imagesc(stas0[0,:,:, cc])
    plt.title(str(cc))
plt.show()

In [None]:
ctext = ['Lum', 'L-M', 'S']
Nrows=int(np.ceil(NCv/2))
ss(Nrows, 6, rh=2)
for cc in range(NCv):
    for clr in range(3):
        plt.subplot(Nrows, 6, 3*cc+clr+1)
        imagesc(stas0[clr, :, :, cc])
        if clr == 0:
            plt.ylabel( "Cell %d"%cc)
        plt.title(ctext[clr])

## GLMs

In [None]:
NCv = len(valLP)
shifts = np.array([[0,0], [0, 0], [0, 0]]) # if not caring about non-LP things
NXglm = UNX

LLs0, LLsGLM, XTopt, GLopt = np.zeros([NCv,2]), np.zeros(NCv), np.zeros(NCv), np.zeros(NCv)
glms = [None]*NCv
driftmods = [None]*NCv
Dreg = 0.1

lbfgs_pars = utils.create_optimizer_params(
    optimizer_type='lbfgs',
    tolerance_change=1e-10,
    tolerance_grad=1e-10,
    batch_size=16000,
    max_epochs=3,
    max_iter = 200)
#lbfgs_pars['num_workers'] = 0

In [None]:
#set up fits
Xreg = 10 # [20]
Treg = 1 # [20]
L1reg = 1 # [0.5]
GLreg = 10.0 # [4.0]

# drift network
drift_pars1 = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=1, bias=False, norm_type=0, NLtype='lin')
drift_pars1['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 
# for stand-alone drift model
drift_pars1N = deepcopy(drift_pars1)
drift_pars1N['NLtype'] = 'softplus'
drift_net =  FFnetwork.ffnet_dict( xstim_n = 'Xdrift', layer_list = [drift_pars1] )

# glm net
glm_layer = Tlayer.layer_dict( 
    input_dims=[3,NXglm,NXglm,1], num_filters=1, bias=False, num_lags=num_lags,
    NLtype='lin', initialize_center = True)
glm_layer['reg_vals'] = {'d2x': Xreg, 'd2t': Treg, 'l1': L1reg, 'glocalx': GLreg,'edge_t':10} 
stim_net =  FFnetwork.ffnet_dict( xstim_n = 'stim', layer_list = [glm_layer] )

# abs layer
glm_layer2 = Tlayer.layer_dict( 
    input_dims = [2,NXglm*3,NXglm,1], num_filters=1, bias=False, num_lags=num_lags, 
    norm_type=0, NLtype='lin', initialize_center=True) 
stim_net2 =  FFnetwork.ffnet_dict( xstim_n='stim2', layer_list=[glm_layer2] )

# gqm net
num_subs = 2
gqm_layer = Tlayer.layer_dict( 
    input_dims=[3,NXglm,NXglm,1], num_filters=num_subs, num_inh=0, bias=False, num_lags=num_lags,
    NLtype='square', initialize_center = True)
gqm_layer['reg_vals'] = {'d2x': Xreg, 'd2t': Treg, 'l1': L1reg, 'glocalx': GLreg,'edge_t':10} 
stim_qnet =  FFnetwork.ffnet_dict( xstim_n = 'stim', layer_list = [gqm_layer] )

#combine glm
comb_layer = NDNLayer.layer_dict(
    num_filters = 1, NLtype='softplus', bias=False)
comb_layer['weights_initializer'] = 'ones'

net_comb = FFnetwork.ffnet_dict( 
    xstim_n = None, ffnet_n=[0,1],
    layer_list = [comb_layer], ffnet_type='add')

#combine gqm
comb2_layer = ChannelLayer.layer_dict( 
    num_filters = 1, NLtype='softplus', bias=False)
comb2_layer['weights_initializer'] = 'ones'

net2_comb = FFnetwork.ffnet_dict( 
    xstim_n = None, ffnet_n=[0,1,2],
    layer_list = [comb2_layer], ffnet_type='normal')
net2_comb['layer_list'][0]['bias'] = True

glms = [None]*NCv
glms_abs = [None]*NCv
gqms = [None]*NCv
driftmods = [None]*NCv
XTopt = np.zeros(NCv) 
GLopt = np.zeros(NCv) 
LLsNULL = np.zeros(NCv)
LLsR = np.zeros([NCv,4])+1000
LL_abs = np.zeros(NCv)
LLsQR = np.zeros([NCv,5])+1000

rvals = [0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000]
rvalsG = [0, 0.001, 0.01, 1, 10, 100, 1000] # glocal

# set up stimulus
data.assemble_stimulus(top_corner=top_corner, fixdot=0, L=NXglm, time_embed=0, num_lags=num_lags, shifts=-fixshifts )
    
# for fit model with abs values as well
new_stim = torch.zeros([NT, 2, 3*UNX*UNX])
new_stim[:,0,:] = deepcopy(data.stim)
new_stim[:,1,:] = deepcopy(abs(data.stim))
data.add_covariate('stim2', new_stim.reshape([NT, -1]))

In [None]:
# now fit across all models
for cc in range(NCv):
    #for cc in range(12,14):
    #cc=9
    data.set_cells([valLP[cc]])

    # fit drift network
    drift_iter = NDN.NDN( 
        layer_list = [drift_pars1N], loss_type='poisson')
    drift_iter.block_sample=True
    drift_iter.networks[0].xstim_n = 'Xdrift'
    drift_iter.fit( data, force_dict_training=True, train_inds=None, **lbfgs_pars, verbose=0, version=1)
    LLsNULL[cc] = drift_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
    driftmods[cc] = deepcopy(drift_iter)

    # first fit GLM
    LLs = np.zeros(len(rvals))+1000
    for rr in range(len(rvals)):
        stim_net['layer_list'][0]['reg_vals']['d2x'] = rvals[rr]
        glm_iter = NDN.NDN(ffnet_list = [stim_net, drift_net, net_comb], loss_type='poisson')
        glm_iter.block_sample=True
        glm_iter.networks[1].layers[0].weight.data[:,0] = deepcopy(
            driftmods[cc].networks[0].layers[0].weight.data[:,0])
        glm_iter.networks[1].layers[0].set_parameters(val=False)
        glm_iter.networks[2].layers[0].set_parameters(val=False,name='weight')

        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, version=9, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (rr == 0) or (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm1 = np.argmin(LLs)
    LLsR[cc,0] = LLs[bm1]
    XTopt[cc] = rvals[bm1]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  d2x-R%d   LLs ="%(cc, bm1), LLsNULL[cc]-LLsR[cc,0] )

    LLs = np.zeros(len(rvals))+LLsR[cc,0]
    for rr in range(len(rvals)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['d2t'] = rvals[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm1b = np.argmin(LLs)
    LLsR[cc,1] = LLs[bm1b]
    GLopt[cc] = rvals[bm1b]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  L1  -R%d   LL =%8.5f ->%8.5f"%(cc, bm1b, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,1]) )

    LLs = np.zeros(len(rvalsG))+np.min(LLsR[cc,:2])
    for rr in range(len(rvalsG)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['glocalx'] = rvalsG[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm2 = np.argmin(LLs)
    LLsR[cc,2] = LLs[bm2]
    GLopt[cc] = rvals[bm2]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  L1  -R%d   LL =%8.5f ->%8.5f"%(cc, bm2, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,2]) )

    LLs = np.zeros(len(rvals))+np.min(LLsR[cc,:3])
    for rr in range(1,len(rvals)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['l1'] = rvals[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm3 = np.argmin(LLs)
    LLsR[cc,3] = LLs[bm3]
    GLopt[cc] = rvals[bm3]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  Gloc-R%d   LL =%8.5f ->%8.5f"%(cc, bm3, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,3]) )
    best_glm = deepcopy(best_model)

    # plot GLM filters
    w = best_model.get_weights()
    utils.subplot_setup(1,num_lags, row_height=4)
    for ll in range(num_lags):
        plt.subplot(3,num_lags,ll+1)
        utils.imagesc(w[0,:,:,ll,0], aspect=1, max=np.max(w))
        plt.subplot(3,num_lags,ll+1+num_lags)
        utils.imagesc(w[1,:,:,ll,0], aspect=1, max=np.max(w))
        plt.subplot(3,num_lags,ll+1+num_lags*2)
        utils.imagesc(w[2,:,:,ll,0], aspect=1, max=np.max(w))
    plt.show()

    ## also fit GLM-abs?
    glm_abs = NDN.NDN(ffnet_list = [stim_net2, drift_net, net_comb], loss_type='poisson')
    glm_abs.block_sample=True
    glm_abs.networks[1].layers[0].weight.data[:,0] = deepcopy(
        driftmods[cc].networks[0].layers[0].weight.data[:,0])
    # also initialize with previous GLM filter?
    glm_abs.networks[2].layers[0].set_parameters(val=False,name='weight')
    glm_abs.networks[1].layers[0].set_parameters(val=False)
    glm_abs.networks[0].layers[0].reg.vals['d2x'] = rvals[bm1]
    glm_abs.networks[0].layers[0].reg.vals['d2t'] = rvals[bm1b]
    glm_abs.networks[0].layers[0].reg.vals['glocalx'] = rvals[bm2]
    glm_abs.networks[0].layers[0].reg.vals['l1'] = rvalsG[bm3]

    glm_abs.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, version=9, verbose=0)
    LL_abs[cc] = glm_abs.eval_models(data[data.val_blks], null_adjusted=False)[0]
    glms_abs[cc] = deepcopy(glm_abs)
    print( "Cell %3d:  glm+abs   LL =%8.5f ->%8.5f"%(cc, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LL_abs[cc]) )

    # plot glm_abs

    w2 = glms_abs[cc].get_weights()
    w2=w2.reshape([2,3,NXglm,NXglm,num_lags,1])

    maxlin = np.max(w2[0,:,:,:,:,:])
    maxabs = np.max(w2[1,:,:,:,:,:])
    maxall = np.max(abs(w2))
    utils.ss(2,10, rh=4)
    for ll in range(10):
        plt.subplot(6,10,ll+1)
        utils.imagesc(w2[0,0,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+11)
        utils.imagesc(w2[0,1,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+21)
        utils.imagesc(w2[0,2,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+31)
        utils.imagesc(w2[1,0,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+41)
        utils.imagesc(w2[1,1,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+51)
        utils.imagesc(w2[1,2,:,:,ll,0], aspect=1, max=maxall)

    plt.show()

    # now fit GQM
    LLs = np.zeros(len(rvals))+1000
    for rr in range(len(rvals)):
        gqm_iter = NDN.NDN(ffnet_list = [stim_net, drift_net, stim_qnet, net2_comb], loss_type='poisson')
        gqm_iter.block_sample=True
        gqm_iter.networks[3].layers[0].set_parameters(val=False,name='weight')
        gqm_iter.networks[1].layers[0].weight.data[:,0] = deepcopy(
            driftmods[cc].networks[0].layers[0].weight.data[:,0])
        gqm_iter.networks[1].layers[0].set_parameters(val=False)
        gqm_iter.networks[0].layers[0].weight.data[:,0] = deepcopy(
            best_glm.networks[0].layers[0].weight.data[:,0])
        gqm_iter.networks[0].layers[0].reg.vals['d2x'] = rvals[bm1]
        gqm_iter.networks[0].layers[0].reg.vals['d2t'] = rvals[bm1b]
        gqm_iter.networks[0].layers[0].reg.vals['l1'] = rvals[bm2]
        gqm_iter.networks[0].layers[0].reg.vals['glocalx'] = rvalsG[bm3]

        gqm_iter.networks[2].layers[0].reg.vals['d2x'] = rvals[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, version=9, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (rr == 0) or (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi

    bmq1 = np.argmin(LLs)
    LLsQR[cc,1] = LLs[bmq1]
    XTopt[cc] = rvals[bmq1]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  d2x -R%d   LL =%8.5f ->%8.5f"%(cc, bmq1, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,1]) )

    LLs = np.zeros(len(rvals))+np.min(LLsQR[cc,:2])
    for rr in range(len(rvals)):
        gqm_iter = deepcopy(best_gqm)
        gqm_iter.networks[2].layers[0].reg.vals['d2t'] = rvals[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi
    bmq2 = np.argmin(LLs)
    LLsQR[cc,2] = LLs[bmq2]
    GLopt[cc] = rvals[bmq2]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  d2t -R%d   LL =%8.5f ->%8.5f"%(cc, bmq2, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,2]) )

    LLs = np.zeros(len(rvalsG))+np.min(LLsQR[cc,:3])
    for rr in range(len(rvalsG)):
        gqm_iter = deepcopy(best_gqm)
        gqm_iter.networks[2].layers[0].reg.vals['glocalx'] = rvalsG[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi
    bmq3 = np.argmin(LLs)
    LLsQR[cc,3] = LLs[bmq3]
    GLopt[cc] = rvals[bmq3]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  L1  -R%d   LL =%8.5f ->%8.5f"%(cc, bmq3, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,3]) )

    LLs = np.zeros(len(rvals))+np.min(LLsQR[cc,:4])
    for rr in range(len(rvals)):
        gqm_iter = deepcopy(best_gqm)
        gqm_iter.networks[2].layers[0].reg.vals['l1'] = rvals[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi
    bmq4 = np.argmin(LLs)
    LLsQR[cc,4] = LLs[bmq4]
    GLopt[cc] = rvals[bmq4]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  Gloc-R%d   LL =%8.5f ->%8.5f"%(cc, bmq4, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,4]) )

    best_gqm = deepcopy(best_gqm)

    w2 = best_gqm.get_weights(ffnet_target=0)
    w2=w2.reshape([3,NXglm,NXglm,num_lags])

    w3 = best_gqm.get_weights(ffnet_target=2)
    w3=w3.reshape([3,NXglm,NXglm,num_lags,num_subs])

    maxall = np.max(abs(w3))
    utils.ss(3,num_lags, rh=4)
    for ll in range( num_lags):
        plt.subplot(9,num_lags,ll+1)
        utils.imagesc(w2[0,:,:,ll], aspect=1, max=maxall)
        plt.subplot(9,num_lags,ll+1+num_lags)
        utils.imagesc(w2[1,:,:,ll], aspect=1, max=maxall)
        plt.subplot(9,num_lags,ll+1+num_lags*2)
        utils.imagesc(w2[2,:,:,ll], aspect=1, max=maxall)
        for subs in range(num_subs):
            plt.subplot(9,num_lags,ll+1+num_lags*3+(subs*num_lags*3))
            utils.imagesc(w3[0,:,:,ll,subs], aspect=1, max=maxall)

            plt.subplot(9,num_lags,ll+1+num_lags*4+(subs*num_lags*3))
            utils.imagesc(w3[1,:,:,ll,subs], aspect=1, max=maxall)

            plt.subplot(9,num_lags,ll+1+num_lags*5+(subs*num_lags*3))
            utils.imagesc(w3[2,:,:,ll,subs], aspect=1, max=maxall)

    plt.show()

In [None]:
# extract measures across cells
complex_scores = np.zeros(NCv)
lin_corrs = np.zeros(NCv)
quad_corrs = np.zeros(NCv)
filtwsmat = np.zeros([NCv,3])

for cc in range(NCv):
    filtwsmat[cc,:]=RU.GQM_filtws(gqms[cc], data, valLP[cc], 0, 2 )
    complex_scores[cc], lin_corrs[cc], quad_corrs[cc] = RU.GQM_complexity(gqms[cc], data, valLP[cc], 0, 2 )


In [None]:
glm_ks = np.zeros([NCv, 3, NXglm, NXglm, num_lags]) 
glm_abs_ks = np.zeros([NCv, 2, 3*NXglm, NXglm, num_lags]) 
gqm_ksl = np.zeros([NCv,3, NXglm, NXglm, num_lags])
gqm_ksq = np.zeros([NCv,3, NXglm, NXglm, num_lags,num_subs])

LLsGLM = LLsNULL-np.min(LLsR,axis=1)
LLsGLMabs = LLsNULL-LL_abs
LLsGQM = LLsNULL-np.min(LLsQR,axis=1)

for cc in range(NCv):
    glm_ks[cc,:,:,:,:] = glms[cc].get_weights().squeeze()
    glm_abs_ks[cc,:,:,:,:] = glms_abs[cc].get_weights().squeeze()
    gqm_ksl[cc,:,:,:,:] = gqms[cc].get_weights(ffnet_target=0).squeeze()
    gqm_ksq[cc,:,:,:,:,:] = gqms[cc].get_weights(ffnet_target=2).squeeze()
    
sio.savemat(dirname2+'LPmodfilts.mat', {
    'glm_ks':glm_ks, 'glm_abs_ks':glm_abs_ks, 'gqm_ksl':gqm_ksl, 'gqm_ksq':gqm_ksq,
    'valLP':valLP, 'LLsGLM':LLsGLM[:, None],'LLsGLMabs':LLsGLMabs[:, None], 'LLsGQM':LLsGQM[:, None],
    'lin_corrs':lin_corrs, 'quad_corrs':quad_corrs, 'complex_scores':complex_scores, 'filtws_gqm':filtwsmat})

In [None]:
# Assemble population drift terms
w0 = driftmods[0].get_weights()
NA = w0.shape[0]
drift_terms = np.zeros([NA, NCv])
for cc in range(NCv):
    drift_terms[:, cc] = deepcopy(driftmods[cc].get_weights())[:,0]

# Assemble drift population model
drift_pars_pop = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=NCv, bias=False, norm_type=0, NLtype='softplus')
drift_pars_pop['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 
drift_pop = NDN.NDN( layer_list = [drift_pars_pop], loss_type='poisson')
drift_pop.networks[0].xstim_n = 'Xdrift'
drift_pop.networks[0].layers[0].weight.data = torch.tensor(drift_terms, dtype=torch.float32)
data.cells_out = list(valLP)
LLsCHECK = drift_pop.eval_models(data[data.val_blks], null_adjusted=False) 
print(np.mean(LLsNULL), np.mean(LLsCHECK))

In [None]:
# Get array centers
RFcenters = np.zeros([NCv,2], dtype=np.int64) - 1
RFcols = np.zeros([NCv,1], dtype=np.int64) - 1
for cc in range(NCv):
    k = glms[cc].get_weights()
    pfilt = np.sum(np.std(k, axis=3),axis=0).squeeze() # for colorRFs
    x,y = utils.max_multiD( pfilt )
    RFcenters[cc, :] = [x,y]

plt.plot(RFcenters[:, 0], RFcenters[:, 1],'o')
plt.show()

### now save

In [None]:
LLsGLM = LLsNULL-np.min(LLsR,axis=1)
LLsGLMabs = LLsNULL-np.min(LLsR,axis=1)
LLsGQM = LLsNULL-np.min(LLsQR,axis=1)
sio.savemat(dirname2+'ModLLsET.mat', {
    'LLsNULL':LLsNULL[:,None], 'LLsGLM':LLsGLM[:, None], 'LLsGLMabs':LLsGLMabs[:, None], 'LLsGQM':LLsGQM[:, None],
    'drift_terms': drift_terms, 'Dreg': Dreg,
    'RFcenters': RFcenters, 'top_corner': top_corner[:, None]})
drift_pop.save_model(alt_dirname=dirname_mod, filename='LPdriftmods_pop.pkl')

for cc in range(NCv):
    glms[cc].save_model(alt_dirname=dirname_mod, filename='LP' + utils.filename_num2str(cc, num_digits=3) + 'glmwET.pkl')
    glms_abs[cc].save_model(alt_dirname=dirname_mod, filename='LP' + utils.filename_num2str(cc, num_digits=3) + 'glmabswET.pkl')
    gqms[cc].save_model(alt_dirname=dirname_mod, filename='LP' + utils.filename_num2str(cc, num_digits=3) + 'gqmwET.pkl')


### now get LN correlation model measures


In [None]:
# measure still to translate for gqms

### now get filter-based activation measures 

In [None]:
# gqm
#net_target = 2
#cur_mod = deepcopy(gqms[9])
#nf = cur_mod.networks[net_target].layers[0].shape[1]
#acts = cur_mod.networks[2].layers[0](data.stim)

In [None]:
#sio.savemat(dirname2+'GLMT_measLP_ET.mat', {
#    'glm_ks':glm_ks, 'valLP':valLP, 'LLsGLM':LLsGLM[:, None],
#    'RFcenters': RFcenters, 'Colws': Gconv_cws, 'RFareas': RF_areas, 'RFareas_col':RF_areas_col,
#    'RFmaps': Gconv_space_all, 'RFmaps_col': Gconv_space_col_all, 
#    'RFbs': Gconv_space_shuff_all, 'RFbs_col': Gconv_space_col_shuff_all, 
#    'Contours':contours_all, 'Contours_col':contours_col_all})

In [None]:
crash

## Now fit Utah array models again to check ET

In [None]:
UTloc = np.array([930, 515], dtype=np.int64)

data.assemble_stimulus(top_corner=UTloc, fixdot=0, L=UNX, time_embed=0, shifts=-fixshifts)
Reff = torch.mul(data.robs[:, valUT], data.dfs[:, valUT])
nspks = torch.sum(Reff, axis=0)

data.draw_stim_locations( top_corner=UTloc, L=UNX)

In [None]:
## Calculate STAS
lag = 4
stasC = ((data.stim[:-lag, ...].T @ Reff[lag:,:]).squeeze() / nspks).reshape([3, UNX,UNX,-1]).numpy()
stasC.shape

In [None]:
ctext = ['Lum', 'L-M', 'S']
ss(86, 6, rh=2)
for cc in range(NCUT):
    for clr in range(3):
        plt.subplot(86, 6, 3*cc+clr+1)
        imagesc(stasC[clr, :, :, cc])
        if clr == 0:
            plt.ylabel( "Cell %d"%cc)
        plt.title(ctext[clr])

In [None]:
NCv=NCUT

NXglm = UNX
#num_lags = 8

LLs0, LLsGLM, XTopt, GLopt = np.zeros([NCv,2]), np.zeros(NCv), np.zeros(NCv), np.zeros(NCv)
glms = [None]*NCv
driftmods = [None]*NCv
Dreg = 0.1

lbfgs_pars = utils.create_optimizer_params(
    optimizer_type='lbfgs',
    tolerance_change=1e-10,
    tolerance_grad=1e-10,
    batch_size=16000,
    max_epochs=3,
    max_iter = 200)
#lbfgs_pars['num_workers'] = 0

In [None]:
#set up fits
Xreg = 10 # [20]
Treg = 1 # [20]
L1reg = 1 # [0.5]
GLreg = 10.0 # [4.0]

# drift network
drift_pars1 = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=1, bias=False, norm_type=0, NLtype='lin')
drift_pars1['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 
# for stand-alone drift model
drift_pars1N = deepcopy(drift_pars1)
drift_pars1N['NLtype'] = 'softplus'
drift_net =  FFnetwork.ffnet_dict( xstim_n = 'Xdrift', layer_list = [drift_pars1] )

# glm net
glm_layer = Tlayer.layer_dict( 
    input_dims=[3,NXglm,NXglm,1], num_filters=1, bias=False, num_lags=num_lags,
    NLtype='lin', initialize_center = True)
glm_layer['reg_vals'] = {'d2x': Xreg, 'd2t': Treg, 'l1': L1reg, 'glocalx': GLreg,'edge_t':10} 
stim_net =  FFnetwork.ffnet_dict( xstim_n = 'stim', layer_list = [glm_layer] )

# abs layer
glm_layer2 = Tlayer.layer_dict( 
    input_dims = [2,NXglm*3,NXglm,1], num_filters=1, bias=False, num_lags=num_lags, 
    norm_type=0, NLtype='lin', initialize_center=True) 
stim_net2 =  FFnetwork.ffnet_dict( xstim_n='stim2', layer_list=[glm_layer2] )

# gqm net
num_subs = 2
gqm_layer = Tlayer.layer_dict( 
    input_dims=[3,NXglm,NXglm,1], num_filters=num_subs, num_inh=0, bias=False, num_lags=num_lags,
    NLtype='square', initialize_center = True)
gqm_layer['reg_vals'] = {'d2x': Xreg, 'd2t': Treg, 'l1': L1reg, 'glocalx': GLreg,'edge_t':10} 
stim_qnet =  FFnetwork.ffnet_dict( xstim_n = 'stim', layer_list = [gqm_layer] )

#combine glm
comb_layer = NDNLayer.layer_dict(
    num_filters = 1, NLtype='softplus', bias=False)
comb_layer['weights_initializer'] = 'ones'

net_comb = FFnetwork.ffnet_dict( 
    xstim_n = None, ffnet_n=[0,1],
    layer_list = [comb_layer], ffnet_type='add')

#combine gqm
comb2_layer = ChannelLayer.layer_dict( 
    num_filters = 1, NLtype='softplus', bias=False)
comb2_layer['weights_initializer'] = 'ones'

net2_comb = FFnetwork.ffnet_dict( 
    xstim_n = None, ffnet_n=[0,1,2],
    layer_list = [comb2_layer], ffnet_type='normal')
net2_comb['layer_list'][0]['bias'] = True

glms = [None]*NCv
glms_abs = [None]*NCv
gqms = [None]*NCv
driftmods = [None]*NCv
XTopt = np.zeros(NCv) 
GLopt = np.zeros(NCv) 
LLsNULL = np.zeros(NCv)
LLsR = np.zeros([NCv,4])+1000
LL_abs = np.zeros(NCv)
LLsQR = np.zeros([NCv,5])+1000

rvals = [0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000]
rvalsG = [0, 0.001, 0.01, 1, 10, 100, 1000] # glocal

# set up stimulus
data.assemble_stimulus(top_corner=UTloc, fixdot=0, L=NXglm, time_embed=0, num_lags=num_lags, shifts=-fixshifts )
    
# for fit model with abs values as well
new_stim = torch.zeros([NT, 2, 3*UNX*UNX])
new_stim[:,0,:] = deepcopy(data.stim)
new_stim[:,1,:] = deepcopy(abs(data.stim))
data.add_covariate('stim2', new_stim.reshape([NT, -1]))

In [None]:
# now fit across all models
for cc in range(NCv):
    #for cc in range(12,14):
    #cc=9
    data.set_cells([valUT[cc]])

    # fit drift network
    drift_iter = NDN.NDN( 
        layer_list = [drift_pars1N], loss_type='poisson')
    drift_iter.block_sample=True
    drift_iter.networks[0].xstim_n = 'Xdrift'
    drift_iter.fit( data, force_dict_training=True, train_inds=None, **lbfgs_pars, verbose=0, version=1)
    LLsNULL[cc] = drift_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
    driftmods[cc] = deepcopy(drift_iter)

    # first fit GLM
    LLs = np.zeros(len(rvals))+1000
    for rr in range(len(rvals)):
        stim_net['layer_list'][0]['reg_vals']['d2x'] = rvals[rr]
        glm_iter = NDN.NDN(ffnet_list = [stim_net, drift_net, net_comb], loss_type='poisson')
        glm_iter.block_sample=True
        glm_iter.networks[1].layers[0].weight.data[:,0] = deepcopy(
            driftmods[cc].networks[0].layers[0].weight.data[:,0])
        glm_iter.networks[1].layers[0].set_parameters(val=False)
        glm_iter.networks[2].layers[0].set_parameters(val=False,name='weight')

        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, version=9, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (rr == 0) or (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm1 = np.argmin(LLs)
    LLsR[cc,0] = LLs[bm1]
    XTopt[cc] = rvals[bm1]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  d2x-R%d   LLs ="%(cc, bm1), LLsNULL[cc]-LLsR[cc,0] )

    LLs = np.zeros(len(rvals))+LLsR[cc,0]
    for rr in range(len(rvals)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['d2t'] = rvals[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm1b = np.argmin(LLs)
    LLsR[cc,1] = LLs[bm1b]
    GLopt[cc] = rvals[bm1b]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  L1  -R%d   LL =%8.5f ->%8.5f"%(cc, bm1b, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,1]) )

    LLs = np.zeros(len(rvalsG))+np.min(LLsR[cc,:2])
    for rr in range(len(rvalsG)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['glocalx'] = rvalsG[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm2 = np.argmin(LLs)
    LLsR[cc,2] = LLs[bm2]
    GLopt[cc] = rvals[bm2]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  Gloc-R%d   LL =%8.5f ->%8.5f"%(cc, bm2, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,2]) )

    LLs = np.zeros(len(rvals))+np.min(LLsR[cc,:3])
    for rr in range(1,len(rvals)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['l1'] = rvals[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi

    bm3 = np.argmin(LLs)
    LLsR[cc,3] = LLs[bm3]
    GLopt[cc] = rvals[bm3]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  L1-R%d   LL =%8.5f ->%8.5f"%(cc, bm3, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,3]) )
    best_glm = deepcopy(best_model)

    # plot GLM filters
    w = best_model.get_weights()
    utils.subplot_setup(1,num_lags, row_height=4)
    for ll in range(num_lags):
        plt.subplot(3,num_lags,ll+1)
        utils.imagesc(w[0,:,:,ll,0], aspect=1, max=np.max(w))
        plt.subplot(3,num_lags,ll+1+num_lags)
        utils.imagesc(w[1,:,:,ll,0], aspect=1, max=np.max(w))
        plt.subplot(3,num_lags,ll+1+num_lags*2)
        utils.imagesc(w[2,:,:,ll,0], aspect=1, max=np.max(w))
    plt.show()

    ## also fit GLM-abs?
    glm_abs = NDN.NDN(ffnet_list = [stim_net2, drift_net, net_comb], loss_type='poisson')
    glm_abs.block_sample=True
    glm_abs.networks[1].layers[0].weight.data[:,0] = deepcopy(
        driftmods[cc].networks[0].layers[0].weight.data[:,0])
    # also initialize with previous GLM filter?
    glm_abs.networks[2].layers[0].set_parameters(val=False,name='weight')
    glm_abs.networks[1].layers[0].set_parameters(val=False)
    glm_abs.networks[0].layers[0].reg.vals['d2x'] = rvals[bm1]
    glm_abs.networks[0].layers[0].reg.vals['d2t'] = rvals[bm1b]
    glm_abs.networks[0].layers[0].reg.vals['glocalx'] = rvalsG[bm2]
    glm_abs.networks[0].layers[0].reg.vals['l1'] = rvals[bm3]

    glm_abs.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, version=9, verbose=0)
    LL_abs[cc] = glm_abs.eval_models(data[data.val_blks], null_adjusted=False)[0]
    glms_abs[cc] = deepcopy(glm_abs)
    print( "Cell %3d:  glm+abs   LL =%8.5f ->%8.5f"%(cc, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LL_abs[cc]) )

    # plot glm_abs

    w2 = glms_abs[cc].get_weights()
    w2=w2.reshape([2,3,NXglm,NXglm,num_lags,1])

    maxlin = np.max(w2[0,:,:,:,:,:])
    maxabs = np.max(w2[1,:,:,:,:,:])
    maxall = np.max(abs(w2))
    utils.ss(2,10, rh=4)
    for ll in range(10):
        plt.subplot(6,10,ll+1)
        utils.imagesc(w2[0,0,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+11)
        utils.imagesc(w2[0,1,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+21)
        utils.imagesc(w2[0,2,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+31)
        utils.imagesc(w2[1,0,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+41)
        utils.imagesc(w2[1,1,:,:,ll,0], aspect=1, max=maxall)

        plt.subplot(6,10,ll+51)
        utils.imagesc(w2[1,2,:,:,ll,0], aspect=1, max=maxall)

    plt.show()

    # now fit GQM
    LLs = np.zeros(len(rvals))+1000
#    LLsQR[cc,0]=np.min(LLsR[cc,:])
    for rr in range(len(rvals)):
        gqm_iter = NDN.NDN(ffnet_list = [stim_net, drift_net, stim_qnet, net2_comb], loss_type='poisson')
        gqm_iter.block_sample=True
        gqm_iter.networks[3].layers[0].set_parameters(val=False,name='weight')
        gqm_iter.networks[1].layers[0].weight.data[:,0] = deepcopy(
            driftmods[cc].networks[0].layers[0].weight.data[:,0])
        gqm_iter.networks[1].layers[0].set_parameters(val=False)
        gqm_iter.networks[0].layers[0].weight.data[:,0] = deepcopy(
            best_glm.networks[0].layers[0].weight.data[:,0])
        gqm_iter.networks[0].layers[0].reg.vals['d2x'] = rvals[bm1]
        gqm_iter.networks[0].layers[0].reg.vals['d2t'] = rvals[bm1b]
        gqm_iter.networks[0].layers[0].reg.vals['glocalx'] = rvalsG[bm2]
        gqm_iter.networks[0].layers[0].reg.vals['l1'] = rvals[bm3]

        gqm_iter.networks[2].layers[0].reg.vals['d2x'] = rvals[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, version=9, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (rr == 0) or (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi

    bmq1 = np.argmin(LLs)
    LLsQR[cc,1] = LLs[bmq1]
    XTopt[cc] = rvals[bmq1]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  d2x -R%d   LL =%8.5f ->%8.5f"%(cc, bmq1, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,1]) )

    LLs = np.zeros(len(rvals))+np.min(LLsQR[cc,:2])
    for rr in range(len(rvals)):
        gqm_iter = deepcopy(best_gqm)
        gqm_iter.networks[2].layers[0].reg.vals['d2t'] = rvals[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi
    bmq2 = np.argmin(LLs)
    LLsQR[cc,2] = LLs[bmq2]
    GLopt[cc] = rvals[bmq2]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  d2t -R%d   LL =%8.5f ->%8.5f"%(cc, bmq2, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,2]) )

    LLs = np.zeros(len(rvalsG))+np.min(LLsQR[cc,:3])
    for rr in range(len(rvalsG)):
        gqm_iter = deepcopy(best_gqm)
        gqm_iter.networks[2].layers[0].reg.vals['glocalx'] = rvalsG[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi
    bmq3 = np.argmin(LLs)
    LLsQR[cc,3] = LLs[bmq3]
    GLopt[cc] = rvals[bmq3]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  L1  -R%d   LL =%8.5f ->%8.5f"%(cc, bmq3, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,3]) )

    LLs = np.zeros(len(rvals))+np.min(LLsQR[cc,:4])
    for rr in range(len(rvals)):
        gqm_iter = deepcopy(best_gqm)
        gqm_iter.networks[2].layers[0].reg.vals['l1'] = rvals[rr]
        gqm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = gqm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_gqm = deepcopy(gqm_iter)
        LLs[rr] = LLi
    bmq4 = np.argmin(LLs)
    LLsQR[cc,4] = LLs[bmq4]
    GLopt[cc] = rvals[bmq4]
    gqms[cc] = deepcopy(best_gqm)
    print(LLs)
    print( "Cell %3d:  Gloc-R%d   LL =%8.5f ->%8.5f"%(cc, bmq4, LLsNULL[cc]-LLsR[cc,2], LLsNULL[cc]-LLsQR[cc,4]) )

    best_gqm = deepcopy(best_gqm)

    w2 = best_gqm.get_weights(ffnet_target=0)
    w2=w2.reshape([3,NXglm,NXglm,num_lags])

    w3 = best_gqm.get_weights(ffnet_target=2)
    w3=w3.reshape([3,NXglm,NXglm,num_lags,num_subs])

    maxall = np.max(abs(w3))
    utils.ss(3,num_lags, rh=4)
    for ll in range( num_lags):
        plt.subplot(9,num_lags,ll+1)
        utils.imagesc(w2[0,:,:,ll], aspect=1, max=maxall)
        plt.subplot(9,num_lags,ll+1+num_lags)
        utils.imagesc(w2[1,:,:,ll], aspect=1, max=maxall)
        plt.subplot(9,num_lags,ll+1+num_lags*2)
        utils.imagesc(w2[2,:,:,ll], aspect=1, max=maxall)
        for subs in range(num_subs):
            plt.subplot(9,num_lags,ll+1+num_lags*3+(subs*num_lags*3))
            utils.imagesc(w3[0,:,:,ll,subs], aspect=1, max=maxall)

            plt.subplot(9,num_lags,ll+1+num_lags*4+(subs*num_lags*3))
            utils.imagesc(w3[1,:,:,ll,subs], aspect=1, max=maxall)

            plt.subplot(9,num_lags,ll+1+num_lags*5+(subs*num_lags*3))
            utils.imagesc(w3[2,:,:,ll,subs], aspect=1, max=maxall)

    plt.show()

In [None]:
glm_ks = np.zeros([NCv, 3, NXglm, NXglm, num_lags]) 
glm_abs_ks = np.zeros([NCv, 2, 3*NXglm, NXglm, num_lags]) 
gqm_ksl = np.zeros([NCv,3, NXglm, NXglm, num_lags])
gqm_ksq = np.zeros([NCv,3, NXglm, NXglm, num_lags,num_subs])

LLsGLM = LLsNULL-np.min(LLsR,axis=1)
LLsGLMabs = LLsNULL-LL_abs
LLsGQM = LLsNULL-np.min(LLsQR,axis=1)

for cc in range(NCv):
    glm_ks[cc,:,:,:,:] = glms[cc].get_weights().squeeze()
    glm_abs_ks[cc,:,:,:,:] = glms_abs[cc].get_weights().squeeze()
    gqm_ksl[cc,:,:,:,:] = gqms[cc].get_weights(ffnet_target=0).squeeze()
    gqm_ksq[cc,:,:,:,:,:] = gqms[cc].get_weights(ffnet_target=2).squeeze()
    
sio.savemat(dirname2+'UTmodfilts.mat', {
    'glm_ks':glm_ks, 'glm_abs_ks':glm_abs_ks, 'gqm_ksl':gqm_ksl, 'gqm_ksq':gqm_ksq,
    'valUT':valUT, 'LLsGLM':LLsGLM[:, None],'LLsGLMabs':LLsGLMabs[:, None], 'LLsGQM':LLsGQM[:, None]})

In [None]:
# Assemble population drift terms
w0 = driftmods[0].get_weights()
NA = w0.shape[0]
drift_terms = np.zeros([NA, NCv])
for cc in range(NCv):
    drift_terms[:, cc] = deepcopy(driftmods[cc].get_weights())[:,0]

# Assemble drift population model
drift_pars_pop = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=NCv, bias=False, norm_type=0, NLtype='softplus')
drift_pars_pop['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 
drift_pop = NDN.NDN( layer_list = [drift_pars_pop], loss_type='poisson')
drift_pop.networks[0].xstim_n = 'Xdrift'
drift_pop.networks[0].layers[0].weight.data = torch.tensor(drift_terms, dtype=torch.float32)
data.cells_out = list(valLP)
LLsCHECK = drift_pop.eval_models(data[data.val_blks], null_adjusted=False) 
print(np.mean(LLsNULL), np.mean(LLsCHECK))

In [None]:
# Get array centers
RFcenters = np.zeros([NCv,2], dtype=np.int64) - 1
RFcols = np.zeros([NCv,1], dtype=np.int64) - 1
for cc in range(NCv):
    k = glms[cc].get_weights()
    pfilt = np.sum(np.std(k, axis=3),axis=0).squeeze() # for colorRFs
    x,y = utils.max_multiD( pfilt )
    RFcenters[cc, :] = [x,y]

plt.plot(RFcenters[:, 0], RFcenters[:, 1],'o')
plt.show()

### now save

In [None]:
LLsGLM = LLsNULL-np.min(LLsR,axis=1)
LLsGLMabs = LLsNULL-np.min(LLsR,axis=1)
LLsGQM = LLsNULL-np.min(LLsQR,axis=1)
sio.savemat(dirname2+'ModLLsET_UT.mat', {
    'LLsNULL':LLsNULL[:,None], 'LLsGLM':LLsGLM[:, None], 'LLsGLMabs':LLsGLMabs[:, None], 'LLsGQM':LLsGQM[:, None],
    'drift_terms': drift_terms, 'Dreg': Dreg,
    'RFcenters': RFcenters, 'top_corner': top_corner[:, None]})
drift_pop.save_model(alt_dirname=dirname_mod, filename='UTdriftmods_pop.pkl')

for cc in range(NCv):
    glms[cc].save_model(alt_dirname=dirname_mod, filename='UT' + utils.filename_num2str(cc, num_digits=3) + 'glmwET.pkl')
    glms_abs[cc].save_model(alt_dirname=dirname_mod, filename='UT' + utils.filename_num2str(cc, num_digits=3) + 'glmabswET.pkl')
    gqms[cc].save_model(alt_dirname=dirname_mod, filename='UT' + utils.filename_num2str(cc, num_digits=3) + 'gqmwET.pkl')


### TODO

In [None]:
crash

In [None]:
# Assemble population drift terms
w0 = driftmods[0].get_weights()
NA = w0.shape[0]
drift_terms = np.zeros([NA, NCv])
for cc in range(NCv):
    drift_terms[:, cc] = deepcopy(driftmods[cc].get_weights())[:,0]

# Assemble drift population model
drift_pars_pop = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=NCv, bias=False, norm_type=0, NLtype='softplus')
drift_pars_pop['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 
drift_pop = NDN.NDN( layer_list=[drift_pars_pop], loss_type='poisson')
drift_pop.networks[0].xstim_n = 'Xdrift'
drift_pop.networks[0].layers[0].weight.data = torch.tensor(drift_terms, dtype=torch.float32)
data.set_cells(valUT)
LLsCHECK = drift_pop.eval_models(data[data.val_blks], null_adjusted=False) 
print(np.mean(LLsNULL), np.mean(LLsCHECK))

In [None]:
# Get array centers
RFcenters = np.zeros([NCv,2], dtype=np.int64) - 1
for cc in range(NCv):
    k = glms[cc].get_weights()
    pfilt = np.sum(np.std(k, axis=3),axis=0).squeeze() # for colorRFs
    x,y = utils.max_multiD( pfilt )
    RFcenters[cc, :] = [x,y]   
#    pfilt = np.std(k, axis=2).squeeze()
#    x,y, snrs[cc] = CU.RFstd_evaluate( pfilt )
#    RFcenters[cc, :] = [x,y]

plt.plot(RFcenters[:, 0], RFcenters[:, 1],'o')
plt.show()

In [None]:
# Generate full GLM model like full drift model
glm_layer = NDNLayer.layer_dict( 
    input_dims=[3, NXglm, NXglm, num_lags], num_filters=NCv, bias=False, NLtype='lin', initialize_center=True)
glm_layer['reg_vals'] = {'d2xt': XTreg, 'l1': L1reg, 'glocalx': GLreg,'edge_t':100} 

stim_net =  FFnetwork.ffnet_dict( xstim_n='stim', layer_list=[glm_layer] )

drift_pars = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=NCv, bias=False, norm_type=0, NLtype='lin')
drift_pars['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 

drift_net =  FFnetwork.ffnet_dict( xstim_n = 'Xdrift', layer_list=[drift_pars] )

comb_layer = ChannelLayer.layer_dict( num_filters=NCv, NLtype='softplus', bias=True )
comb_layer['weights_initializer'] = 'ones'

net_comb = FFnetwork.ffnet_dict( 
    xstim_n = None, ffnet_n=[0,1],
    layer_list = [comb_layer], ffnet_type='add')

glm_all = NDN.NDN(  
            ffnet_list = [stim_net, drift_net, net_comb], loss_type='poisson')

for cc in range(NCv):
    glm_all.networks[0].layers[0].weight.data[:, cc] = deepcopy(glms[cc].networks[0].layers[0].weight.data[:,0])
    glm_all.networks[1].layers[0].weight.data[:, cc] = deepcopy(glms[cc].networks[1].layers[0].weight.data[:,0])
    glm_all.networks[-1].layers[0].bias.data[cc] = deepcopy(glms[cc].networks[-1].layers[0].bias.data)

data.set_cells(valUT)
LLsCHECK = glm_all.eval_models(data[data.val_blks], null_adjusted=False) 
print(np.mean(LLsNULL-LLsR[:,1]), np.mean(LLsNULL-LLsCHECK))

In [None]:
LLsGLM = LLsNULL-LLsR[:,1]
sio.savemat(dirname2+'LLsGLMTwET_UT.mat', {
    'LLsNULL':LLsNULL[:,None], 'LLsGLM':LLsGLM[:, None], 'drift_terms': drift_terms, 'Dreg': Dreg,
    'RFcenters': RFcenters, 'top_corner': UTloc})
drift_pop.save_model(alt_dirname=dirname_mod, filename='ETdriftmods_pop.pkl')
glm_all.save_model(alt_dirname=dirname_mod, filename='ETglms_pop.pkl')

for cc in range(NCv):
    glms[cc].save_model(
        alt_dirname=dirname_mod, filename='UTglmT' + utils.filename_num2str(cc, num_digits=3) + 'wET_iter0.pkl')

In [None]:
# pure color weights
Gconv_cws = np.zeros([NCv,3])
Gconv_space_all = np.zeros([NCv,NXglm,NXglm])
Gconv_space_col_all = np.zeros([NCv,3,NXglm,NXglm])
Gconv_space_shuff_all = np.zeros([NCv,NXglm,NXglm])
Gconv_space_col_shuff_all = np.zeros([NCv,3,NXglm,NXglm])

RF_areas = np.zeros([NCv])
RF_areas_col = np.zeros([NCv,3])
contours_all = []
contours_col_all = [] # dtype=object

for cc in range(NCv):
    data.cells_out = [valUT[cc]]
    Gconv_cws[cc,0], Gconv_cws[cc,1], Gconv_cws[cc,2] = RU.get_Gconv_colws(glms[cc], data, valUT[cc])
    [Gconv_space_all[cc,:,:], Gconv_space_col_all[cc,:,:,:]] = RU.get_Gconv_RFmap(glms[cc], data, valUT[cc], drift_mod=driftmods[cc])
    [Gconv_space_shuff_all[cc,:,:], Gconv_space_col_shuff_all[cc,:,:,:]] = RU.get_Gconv_RFmap(glms[cc], data, valUT[cc], drift_mod=driftmods[cc], bootstrap=1)
    
    try:
        con, RF_areas[cc], ctr = RU.get_contour(Gconv_space_all[cc,:,:].squeeze(), 
                                                                      thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,:,:,:]))
        contours_all.append(con)
    except:
        contours_all.append([0])
    
    try:
        con_col1, RF_areas_col[cc,0], ctr = RU.get_contour(Gconv_space_col_all[cc,0,:,:].squeeze(),
                                                           thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,0,:,:]))
    except: 
        con_col1=[0];
        
    try: 
        con_col2, RF_areas_col[cc,1], ctr = RU.get_contour(Gconv_space_col_all[cc,1,:,:].squeeze(), 
                                                                       thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,1,:,:]))
    except: 
        con_col2=[0];
    
    try:
        con_col3, RF_areas_col[cc,2], ctr = RU.get_contour(Gconv_space_col_all[cc,2,:,:].squeeze(), 
                                                                       thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,2,:,:]))
    except: 
        con_col3=[0];
        
    contours_col_all.append([con_col1,con_col2,con_col3])  
    print( "Cell %3d of %d"%(cc, NCv) )
print("Done!")

In [None]:
glm_ks = np.zeros([NCv, 3, NXglm, NXglm, num_lags]) 
for cc in range(NCv):
    glm_ks[cc,:,:,:,:] = glms[cc].get_weights().squeeze()
glm_ks.shape

In [None]:
sio.savemat(dirname2+'GLMT_measUT_ET.mat', {
    'glm_ks':glm_ks, 'valUT':valUT, 'LLsGLM':LLsGLM[:, None],
    'RFcenters': RFcenters, 'Colws': Gconv_cws, 'RFareas': RF_areas, 'RFareas_col':RF_areas_col,
    'RFmaps': Gconv_space_all, 'RFmaps_col': Gconv_space_col_all, 
    'RFbs': Gconv_space_shuff_all, 'RFbs_col': Gconv_space_col_shuff_all, 
    'Contours':contours_all, 'Contours_col':contours_col_all})

## Now do NForm for completion

In [None]:
NFloc = np.array([918, 515], dtype=np.int64)

data.assemble_stimulus(top_corner=NFloc, fixdot=0, L=UNX, time_embed=0, shifts=-fixshifts)
Reff = torch.mul(data.robs[:, valNF], data.dfs[:, valNF])
nspks = torch.sum(Reff, axis=0)

data.draw_stim_locations( top_corner=NFloc, L=UNX)

In [None]:
## Calculate STAS
lag = 4
stasC = ((data.stim[:-lag, ...].T @ Reff[lag:,:]).squeeze() / nspks).reshape([3, UNX,UNX,-1]).numpy()
stasC.shape

In [None]:
ctext = ['Lum', 'L-M', 'S']
ss(6, 6, rh=4)
for cc in range(NCNF):
    for clr in range(3):
        plt.subplot(6, 6, 3*cc+clr+1)
        imagesc(stasC[clr, :, :, cc])
        if clr == 0:
            plt.ylabel( "Cell %d"%cc)
        plt.title(ctext[clr])

In [None]:
NCv=NCNF

NXglm = 60
num_lags = 8

LLs0, LLsGLM, XTopt, GLopt = np.zeros([NCv,2]), np.zeros(NCv), np.zeros(NCv), np.zeros(NCv)
glms = [None]*NCv
driftmods = [None]*NCv
Dreg = 0.1

lbfgs_pars = utils.create_optimizer_params(
    optimizer_type='lbfgs',
    tolerance_change=1e-10,
    tolerance_grad=1e-10,
    batch_size=16000,
    max_epochs=3,
    max_iter = 200)
#lbfgs_pars['num_workers'] = 0

In [None]:
valNF

In [None]:
XTreg = 20 # [20]
L1reg = 0.1 # [0.5]
GLreg = 4.0 # [4.0]

glm_layer = NDNLayer.layer_dict( 
    input_dims=data.stim_dims, num_filters=1, bias=False,
    NLtype='lin', initialize_center = False)
glm_layer['reg_vals'] = {'d2xt': XTreg, 'l1': L1reg, 'glocalx': GLreg,'edge_t':100} 

stim_net =  FFnetwork.ffnet_dict( xstim_n = 'stim', layer_list = [glm_layer] )

drift_pars1 = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=1, bias=False, norm_type=0, NLtype='lin')
drift_pars1['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 

# for stand-alone drift model
drift_pars1N = deepcopy(drift_pars1)
drift_pars1N['NLtype'] = 'softplus'

drift_net =  FFnetwork.ffnet_dict( xstim_n = 'Xdrift', layer_list = [drift_pars1] )

comb_layer = NDNLayer.layer_dict(
    num_filters = 1, NLtype='softplus', bias=False)
comb_layer['weights_initializer'] = 'ones'

net_comb = FFnetwork.ffnet_dict( 
    xstim_n = None, ffnet_n=[0,1],
    layer_list = [comb_layer], ffnet_type='add')

#stim_net['layer_list'][0]['reg_vals']['d2xt'] = 10
net_comb['layer_list'][0]['bias'] = True

glms = [None]*NCv
driftmods = [None]*NCv

XTopt = np.zeros(NCv) 
GLopt = np.zeros(NCv) 
LLsNULL = np.zeros(NCv)
LLsR = np.zeros([NCv,4])

rvals = [0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000]
rvalsG = [0.001, 0.01, 0.01, 1, 10, 100, 1000] # glocal

data.assemble_stimulus(top_corner=NFloc, fixdot=0, L=NXglm, time_embed=0, num_lags=num_lags, shifts=-fixshifts)

for cc in range(NCv):
    data.cells_out = [valNF[cc]]

    drift_iter = NDN.NDN( 
        layer_list = [drift_pars1N], loss_type='poisson')
    drift_iter.block_sample=True
    drift_iter.networks[0].xstim_n = 'Xdrift'
    drift_iter.fit( data, force_dict_training=True, train_inds=None, **lbfgs_pars, verbose=0, version=1)
    LLsNULL[cc] = drift_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
    driftmods[cc] = deepcopy(drift_iter)

    LLs = np.zeros(len(rvals))+LLsR[cc,0]
    LLs = np.zeros(len(rvals))+LLsNULL[cc]
    for rr in range(len(rvals)):
        stim_net['layer_list'][0]['reg_vals']['d2xt'] = rvals[rr]
        glm_iter = NDN.NDN(ffnet_list = [stim_net, drift_net, net_comb], loss_type='poisson')
        glm_iter.block_sample=True
        glm_iter.networks[1].layers[0].weight.data[:,0] = deepcopy(
            driftmods[cc].networks[0].layers[0].weight.data[:,0])
        glm_iter.networks[2].layers[0].set_parameters(val=False,name='weight')
        glm_iter.networks[1].layers[0].set_parameters(val=False)

        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, version=9, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (rr == 0) or (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi
    bm = np.argmin(LLs)
    LLsR[cc,0] = LLs[bm]
    XTopt[cc] = rvals[bm]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  d2xt-R%d   LLs ="%(cc, bm), LLsNULL[cc]-LLsR[cc,0] )

    LLs = np.zeros(len(rvalsG))+LLsR[cc,0]
    for rr in range(len(rvalsG)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['glocalx'] = rvalsG[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi
    bm = np.argmin(LLs)
    LLsR[cc,1] = LLs[bm]
    GLopt[cc] = rvals[bm]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  Gloc-R%d   LL =%8.5f ->%8.5f"%(cc, bm, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,1]) )
    
    LLs = np.zeros(len(rvals))+LLsR[cc,0]
    for rr in range(len(rvals)):
        glm_iter = deepcopy(best_model)
        glm_iter.networks[0].layers[0].reg.vals['l1'] = rvals[rr]
        glm_iter.fit( data, force_dict_training=True, **lbfgs_pars, seed=5, verbose=0)
        LLi = glm_iter.eval_models(data[data.val_blks], null_adjusted=False)[0]
        if (LLi < np.min(LLs)):
            best_model = deepcopy(glm_iter)
        LLs[rr] = LLi
    bm = np.argmin(LLs)
    LLsR[cc,2] = LLs[bm]
    GLopt[cc] = rvals[bm]
    glms[cc] = deepcopy(best_model)
    print( "Cell %3d:  L1  -R%d   LL =%8.5f ->%8.5f"%(cc, bm, LLsNULL[cc]-LLsR[cc,0], LLsNULL[cc]-LLsR[cc,2]) )

    utils.ss(3,7, row_height=4)
    for ll in range(7):
        plt.subplot(3,7,ll+1)
        utils.imagesc(w[0,:,:,1+ll,0], aspect=1, max=np.max(w))

        plt.subplot(3,7,ll+8)
        utils.imagesc(w[1,:,:,1+ll,0], aspect=1, max=np.max(w))

        plt.subplot(3,7,ll+15)
        utils.imagesc(w[2,:,:,1+ll,0], aspect=1, max=np.max(w))
    plt.show()

In [None]:
# Assemble population drift terms
w0 = driftmods[0].get_weights()
NA = w0.shape[0]
drift_terms = np.zeros([NA, NCv])
for cc in range(NCv):
    drift_terms[:, cc] = deepcopy(driftmods[cc].get_weights())[:,0]

# Assemble drift population model
drift_pars_pop = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=NCv, bias=False, norm_type=0, NLtype='softplus')
drift_pars_pop['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 
drift_pop = NDN.NDN( layer_list=[drift_pars_pop], loss_type='poisson')
drift_pop.networks[0].xstim_n = 'Xdrift'
drift_pop.networks[0].layers[0].weight.data = torch.tensor(drift_terms, dtype=torch.float32)
data.set_cells(valNF)
LLsCHECK = drift_pop.eval_models(data[data.val_blks], null_adjusted=False) 
print(np.mean(LLsNULL), np.mean(LLsCHECK))

In [None]:
# Get array centers
RFcenters = np.zeros([NCv,2], dtype=np.int64) - 1
for cc in range(NCv):
    k = glms[cc].get_weights()
    pfilt = np.sum(np.std(k, axis=3),axis=0).squeeze() # for colorRFs
    x,y = utils.max_multiD( pfilt )
    RFcenters[cc, :] = [x,y]   
#    pfilt = np.std(k, axis=2).squeeze()
#    x,y, snrs[cc] = CU.RFstd_evaluate( pfilt )
#    RFcenters[cc, :] = [x,y]

plt.plot(RFcenters[:, 0], RFcenters[:, 1],'o')
plt.show()

In [None]:
# Generate full GLM model like full drift model
glm_layer = NDNLayer.layer_dict( 
    input_dims=[3, NXglm, NXglm, num_lags], num_filters=NCv, bias=False, num_lags=num_lags, NLtype='lin', initialize_center=True)
glm_layer['reg_vals'] = {'d2xt': XTreg, 'l1': L1reg, 'glocalx': GLreg,'edge_t':100} 

stim_net =  FFnetwork.ffnet_dict( xstim_n='stim', layer_list=[glm_layer] )

drift_pars = NDNLayer.layer_dict( 
    input_dims=[1,1,1,NA], num_filters=NCv, bias=False, norm_type=0, NLtype='lin')
drift_pars['reg_vals'] = {'d2t': Dreg, 'bcs':{'d2t':0} } 

drift_net =  FFnetwork.ffnet_dict( xstim_n = 'Xdrift', layer_list=[drift_pars] )

comb_layer = ChannelLayer.layer_dict( num_filters=NCv, NLtype='softplus', bias=True )
comb_layer['weights_initializer'] = 'ones'

net_comb = FFnetwork.ffnet_dict( 
    xstim_n = None, ffnet_n=[0,1],
    layer_list = [comb_layer], ffnet_type='add')

glm_all = NDN.NDN(  
            ffnet_list = [stim_net, drift_net, net_comb], loss_type='poisson')
glm_all.block_sample=True

for cc in range(NCv):
    glm_all.networks[0].layers[0].weight.data[:, cc] = deepcopy(glms[cc].networks[0].layers[0].weight.data[:,0])
    glm_all.networks[1].layers[0].weight.data[:, cc] = deepcopy(glms[cc].networks[1].layers[0].weight.data[:,0])
    glm_all.networks[-1].layers[0].bias.data[cc] = deepcopy(glms[cc].networks[-1].layers[0].bias.data)

data.set_cells(valNF)
LLsCHECK = glm_all.eval_models(data[data.val_blks], null_adjusted=False) 
print(np.mean(LLsNULL-LLsR[:,1]), np.mean(LLsNULL-LLsCHECK))

In [None]:
LLsGLM = LLsNULL-LLsR[:,1]
sio.savemat(dirname2+'LLsGLMTwET_NF.mat', {
    'LLsNULL':LLsNULL[:,None], 'LLsGLM':LLsGLM[:, None], 'Dreg': Dreg,
    'RFcenters': RFcenters, 'top_corner': NFloc})
drift_pop.save_model(alt_dirname=dirname_mod, filename='ETdriftmods_popNF.pkl')
glm_all.save_model(alt_dirname=dirname_mod, filename='ETglms_popNF.pkl')

for cc in range(NCv):
    glms[cc].save_model(
        alt_dirname=dirname_mod, filename='NFglmT' + utils.filename_num2str(cc, num_digits=3) + 'wET_iter0.pkl')

In [None]:
# pure color weights
Gconv_cws = np.zeros([NCv,3])
Gconv_space_all = np.zeros([NCv,NXglm,NXglm])
Gconv_space_col_all = np.zeros([NCv,3,NXglm,NXglm])
Gconv_space_shuff_all = np.zeros([NCv,NXglm,NXglm])
Gconv_space_col_shuff_all = np.zeros([NCv,3,NXglm,NXglm])

RF_areas = np.zeros([NCv])
RF_areas_col = np.zeros([NCv,3])
contours_all = []
contours_col_all = [] # dtype=object

for cc in range(NCv):
    data.cells_out = [valNF[cc]]
    Gconv_cws[cc,0], Gconv_cws[cc,1], Gconv_cws[cc,2] = RU.get_Gconv_colws(glms[cc], data, valNF[cc])
    [Gconv_space_all[cc,:,:], Gconv_space_col_all[cc,:,:,:]] = RU.get_Gconv_RFmap(glms[cc], data, valUT[cc], drift_mod=driftmods[cc])
    [Gconv_space_shuff_all[cc,:,:], Gconv_space_col_shuff_all[cc,:,:,:]] = RU.get_Gconv_RFmap(glms[cc], data, valUT[cc], drift_mod=driftmods[cc], bootstrap=1)
    
    try:
        con, RF_areas[cc], ctr = RU.get_contour(Gconv_space_all[cc,:,:].squeeze(), 
                                                                      thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,:,:,:]))
        contours_all.append(con)
    except:
        contours_all.append([0])
    
    try:
        con_col1, RF_areas_col[cc,0], ctr = RU.get_contour(Gconv_space_col_all[cc,0,:,:].squeeze(),
                                                           thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,0,:,:]))
    except: 
        con_col1=[0];
        
    try: 
        con_col2, RF_areas_col[cc,1], ctr = RU.get_contour(Gconv_space_col_all[cc,1,:,:].squeeze(), 
                                                                       thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,1,:,:]))
    except: 
        con_col2=[0];
    
    try:
        con_col3, RF_areas_col[cc,2], ctr = RU.get_contour(Gconv_space_col_all[cc,2,:,:].squeeze(), 
                                                                       thresh = 0.8*np.max(Gconv_space_col_shuff_all[cc,2,:,:]))
    except: 
        con_col3=[0];
        
    contours_col_all.append([con_col1,con_col2,con_col3])  
    print( "Cell %3d of %d"%(cc, NCv) )
print("Done!")

In [None]:
glm_ks = np.zeros([NCv, 3, NXglm, NXglm, num_lags]) 
for cc in range(NCv):
    glm_ks[cc,:,:,:,:] = glms[cc].get_weights().squeeze()
glm_ks.shape

In [2]:
sio.savemat(dirname2+'GLMT_measNF_ET.mat', {
    'glm_ks':glm_ks, 'valNF':valNF, 'LLsGLM':LLsGLM[:, None],
    'RFcenters': RFcenters, 'Colws': Gconv_cws, 'RFareas': RF_areas, 'RFareas_col':RF_areas_col,
    'RFmaps': Gconv_space_all, 'RFmaps_col': Gconv_space_col_all, 
    'RFbs': Gconv_space_shuff_all, 'RFbs_col': Gconv_space_col_shuff_all, 
    'Contours':contours_all, 'Contours_col':contours_col_all})

NameError: name 'sio' is not defined