In [76]:
# import basic modules
import sys
import os
import time
import numpy as np
from tqdm import tqdm
import gc
import torch
import argparse

# import custom modules
root_dir   = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(root_dir, 'code'))
sys.path.append(os.path.join(root_dir, 'code', 'model_fitting'))
from model_src import fwrf_fit, fwrf_predict
import initialize_fitting

fpX = np.float32


In [78]:
subject=1
roi=None

ridge=1

shuffle_images=False
random_images=False
random_voxel_data=False

zscore_features=True
nonlin_fn=False
padding_mode='circular'

n_ori=8
n_sf=4
up_to_sess=1
debug=True

fitting_type='simple_complex'
include_crosscorrs=True
include_autocorrs=True


In [79]:
device = initialize_fitting.init_cuda()
nsd_root, stim_root, beta_root, mask_root = initialize_fitting.get_paths()

if ridge==True:       
    # ridge regression, testing several positive lambda values (default)
    model_name = 'texture_ridge_%dori_%dsf'%(n_ori, n_sf)
else:        
    # fixing lambda at zero, so it turns into ordinary least squares
    model_name = 'texture_OLS_%dori_%dsf'%(n_ori, n_sf)

if include_autocorrs==False:
    print('Skipping autocorrs\n')
    model_name = model_name+'_no_autocorrelations'
else:
    print('Will compute autocorrs\n')

if include_crosscorrs==False:
    print('Skipping crosscorrs\n')
    model_name = model_name+'_no_crosscorrelations'
else:
    print('Will compute crosscorrs\n')   

output_dir, fn2save = initialize_fitting.get_save_path(root_dir, subject, model_name, shuffle_images, random_images, random_voxel_data, debug)

# decide what voxels to use  
voxel_mask, voxel_index, voxel_roi, voxel_ncsnr, brain_nii_shape = initialize_fitting.get_voxel_info(mask_root, beta_root, subject, roi)

# get all data and corresponding images, in two splits. always fixed set that gets left out
trn_stim_data, trn_voxel_data, val_stim_single_trial_data, val_voxel_single_trial_data, \
    n_voxels, n_trials_val, image_order = initialize_fitting.get_data_splits(nsd_root, beta_root, stim_root, subject, voxel_mask, up_to_sess, 
                                                                             shuffle_images=shuffle_images, random_images=random_images, random_voxel_data=random_voxel_data)

# Set up the filters
_gaborizer_complex, _gaborizer_simple, _fmaps_fn_complex, _fmaps_fn_simple = initialize_fitting.get_feature_map_simple_complex_fn(n_ori, n_sf, padding_mode=padding_mode, device=device, nonlin_fn=nonlin_fn)

# Params for the spatial aspect of the model (possible pRFs)
#     aperture_rf_range=0.8 # using smaller range here because not sure what to do with RFs at edges...
aperture_rf_range = 1.1
aperture, models = initialize_fitting.get_prf_models(aperture_rf_range=aperture_rf_range)    

# More params for fitting
holdout_size, lambdas = initialize_fitting.get_fitting_pars(trn_voxel_data, zscore_features, ridge=ridge)

#device: 1
device#: 0
device name: GeForce GTX TITAN X

torch: 1.8.1+cu111
cuda:  11.1
cudnn: 8005
dtype: torch.float32
Will compute autocorrs

Will compute crosscorrs

Time Stamp: Jul-05-2021_1425

Will save final output file to /user_data/mmhender/model_fits/S01/texture_ridge_8ori_4sf/Jul-05-2021_1425_DEBUG/

3794 voxels of overlap between kastner and prf definitions, using prf defs
unique values in retino labels:
[-1.  0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16.
 17. 18. 19. 20. 21. 22. 23. 24. 25.]
0 voxels of overlap between face and place definitions, using place defs
unique values in categ labels:
[-1.  0. 26. 27. 28. 30. 31. 32. 33.]
1535 voxels are defined (differently) in both retinotopic areas and category areas

14913 voxels are defined across all areas, and will be used for analysis

Loading numerical label/name mappings for all ROIs:
[1, 2, 3, 4, 5, 6, 7]
['V1v', 'V1d', 'V2v', 'V2d', 'V3v', 'V3d', 'hV4']
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1

In [373]:
#### DO THE ACTUAL MODEL FITTING HERE ####
gc.collect()
torch.cuda.empty_cache()
autocorr_output_pix=5
n_prf_sd_out=2
debug=True
sample_batch_size=20
voxel_batch_size=50

include_crosscorrs = False
include_autocorrs = False
include_pixels = True
include_simple = True
include_complex = True

feature_types_exclude = []
if not include_pixels:
    feature_types_exclude.append('pixel')
if not include_simple:
    feature_types_exclude.append('simple_feature_means')
if not include_complex:
    feature_types_exclude.append('complex_feature_means')
if not include_autocorrs:
    feature_types_exclude.append('autocorrs')
if not include_crosscorrs:
    feature_types_exclude.append('crosscorrs')

    
_texture_fn = texture_feature_extractor(_fmaps_fn_complex, _fmaps_fn_simple, sample_batch_size=sample_batch_size, feature_types_exclude=feature_types_exclude, autocorr_output_pix=autocorr_output_pix, n_prf_sd_out=n_prf_sd_out, aperture=aperture, device=device)

best_losses, best_lambdas, best_params, feature_info = fit_texture_model_ridge(
    trn_stim_data, trn_voxel_data, _texture_fn, models, lambdas, \
    zscore=zscore_features, voxel_batch_size=voxel_batch_size, holdout_size=holdout_size, shuffle=True, add_bias=True, debug=debug)
# note there's also a shuffle param in the above fn call, that determines the nested heldout data for lambda and param selection. always using true.
print('\nDone with training\n')

_texture_fn.update_feature_list([])

# best_losses, best_lambdas, best_params, feature_info = fit_texture_model_ridge(
#     trn_stim_data, trn_voxel_data, _texture_fn, models, lambdas, \
#     zscore=zscore_features, voxel_batch_size=voxel_batch_size, holdout_size=holdout_size, shuffle=True, add_bias=True, debug=debug)
# # note there's also a shuffle param in the above fn call, that determines the nested heldout data for lambda and param selection. always using true.
# print('\nDone with training\n')

trn_size = 619 (90.0%)
dtype = <class 'numpy.float32'>
device = cuda:0
---------------------------------------


model 0

Computing pixel-level statistics...
time elapsed = 0.04753
Computing complex cell features...
time elapsed = 0.90587
Computing simple cell features...
time elapsed = 0.52780
SKIPPING HIGHER-ORDER CORRELATIONS...
time elapsed = 0.03235
Final size of features concatenated is [688 x 100]
Feature types included are:
['wmean', 'wvar', 'wskew', 'wkurt', 'complex_feature_means', 'simple_feature_means']
fitting model    0 of 875 , voxels [ 14900:14912 ] of 14913
model 1

Computing pixel-level statistics...
time elapsed = 0.04537
Computing complex cell features...
time elapsed = 0.90793
Computing simple cell features...
time elapsed = 0.52731
SKIPPING HIGHER-ORDER CORRELATIONS...
time elapsed = 0.03192
Final size of features concatenated is [688 x 100]
Feature types included are:
['wmean', 'wvar', 'wskew', 'wkurt', 'complex_feature_means', 'simple_feature_means']
fitting mod

In [431]:
len(_texture_fn.feature_column_labels)

99

In [422]:
import numpy as np
import torch
import time
from collections import OrderedDict
from utils import numpy_utility, torch_utils
from model_src import fwrf_fit

class texture_feature_extractor(nn.Module):
    """
    Module to compute higher-order texture statistics of input images (similar to Portilla & Simoncelli style texture model), within specified area of space.
    Builds off lower-level feature maps for various orientation/spatial frequency bands, extracted using the modules specified in '_fmaps_fn_complex' and '_fmaps_fn_simple'
    Can specify different subsets of features to include (i.e. pixel-level stats, simple/complex cells, cross-correlations, auto-correlations)
    Inputs to the forward pass are images and pRF parameters of interest [x,y,sigma]
    """
    def __init__(self,_fmaps_fn_complex, _fmaps_fn_simple, sample_batch_size=100, feature_types_exclude=None, autocorr_output_pix=3, n_prf_sd_out=2, aperture=1.0, device=None):
        
        super(texture_feature_extractor, self).__init__()
        
        self.fmaps_fn_complex = _fmaps_fn_complex
        self.fmaps_fn_simple = _fmaps_fn_simple
        self.sample_batch_size = sample_batch_size
        self.autocorr_output_pix = autocorr_output_pix
        self.n_prf_sd_out = n_prf_sd_out
        self.aperture = aperture
        self.device = device       
        
        self.update_feature_list(feature_types_exclude)
        
    def update_feature_list(self, feature_types_exclude):
        
        feature_types_all = ['wmean','wvar','wskew','wkurt', 'complex_feature_means', 'simple_feature_means',\
                         'complex_feature_autocorrs','simple_feature_autocorrs',\
                         'complex_within_scale_crosscorrs','simple_within_scale_crosscorrs',\
                         'complex_across_scale_crosscorrs','simple_across_scale_crosscorrs']
        feature_type_dims = [1,1,1,1,n_ori*n_sf, n_ori*n_sf*n_phases, \
                              n_ori*n_sf*autocorr_output_pix**2, n_ori*n_sf*n_phases*autocorr_output_pix**2, \
                              n_sf*n_ori_pairs, n_sf*n_ori_pairs*n_phases, (n_sf-1)*n_ori**2, (n_sf-1)*n_ori**2*n_phases]
        
        # decide which features to ignore, or use all features
        self.feature_types_exclude = feature_types_exclude

        # a few shorthands for ignoring sets of features at a time
        if 'crosscorrs' in feature_types_exclude:
            feature_types_exclude.extend(['complex_within_scale_crosscorrs','simple_within_scale_crosscorrs','complex_across_scale_crosscorrs','simple_across_scale_crosscorrs'])
        if 'autocorrs' in feature_types_exclude:
            feature_types_exclude.extend(['complex_feature_autocorrs','simple_feature_autocorrs'])
        if 'pixel' in feature_types_exclude:
            feature_types_exclude.extend(['wmean','wvar','wskew','wkurt'])

        self.feature_types_include  = [ff for ff in feature_types_all if not ff in feature_types_exclude]
        if len(self.feature_types_include)==0:
            raise ValueError('you have specified too many features to exclude, and now you have no features left! aborting.')
            
        # how many features will be needed, in total?
        self.n_features_total = np.sum([feature_type_dims[fi] for fi in range(len(feature_type_dims)) if not feature_types_all[fi] in feature_types_exclude])
        
        # numbers that define which feature types are in which column
        self.feature_column_labels = np.squeeze(np.concatenate([fi*np.ones([1,feature_type_dims[fi]]) for fi in range(len(feature_type_dims)) if not feature_types_all[fi] in feature_types_exclude], axis=1).astype('int'))
        assert(np.size(self.feature_column_labels)==self.n_features_total)

    
    def forward(self, images, prf_params):
        
        if isinstance(prf_params, torch.Tensor):
            prf_params = torch_utils.get_value(prf_params)
        assert(np.size(prf_params)==3)
        prf_params = np.squeeze(prf_params)
        if isinstance(images, torch.Tensor):
            images = torch_utils.get_value(images)

        print('Computing pixel-level statistics...')    
        t=time.time()
        x,y,sigma = prf_params
        n_pix=np.shape(images)[2]
        g = numpy_utility.make_gaussian_mass_stack([x], [y], [sigma], n_pix=n_pix, size=self.aperture, dtype=np.float32)
        spatial_weights = g[2][0]
        wmean, wvar, wskew, wkurt = get_weighted_pixel_features(images, spatial_weights, device=device)
        elapsed =  time.time() - t
        print('time elapsed = %.5f'%elapsed)

        print('Computing complex cell features...')
        t = time.time()
        complex_feature_means = fwrf_fit.get_features_in_prf(prf_params, self.fmaps_fn_complex , images=images, sample_batch_size=self.sample_batch_size, aperture=self.aperture, device=self.device, to_numpy=False)
        elapsed =  time.time() - t
        print('time elapsed = %.5f'%elapsed)

        print('Computing simple cell features...')
        t = time.time()
        simple_feature_means = fwrf_fit.get_features_in_prf(prf_params,  self.fmaps_fn_simple, images=images, sample_batch_size=self.sample_batch_size, aperture=self.aperture,  device=self.device, to_numpy=False)
        elapsed =  time.time() - t
        print('time elapsed = %.5f'%elapsed)

        # To save time, decide now whether any autocorrelation or cross-correlation features are desired. If not, will skip a bunch of the slower computations.     
        self.include_crosscorrs = np.any(['crosscorr' in ff for ff in self.feature_types_include])
        self.include_autocorrs = np.any(['autocorr' in ff for ff in self.feature_types_include])
        
        if self.include_autocorrs and self.include_crosscorrs:
            print('Computing higher order correlations...')
        elif self.include_crosscorrs:
            print('Computing higher order correlations (SKIPPING AUTOCORRELATIONS)...')
        elif self.include_autocorrs:
            print('Computing higher order correlations (SKIPPING CROSSCORRELATIONS)...')
        else:
            print('SKIPPING HIGHER-ORDER CORRELATIONS...')    
        t = time.time()
        complex_feature_autocorrs, simple_feature_autocorrs, \
        complex_within_scale_crosscorrs, simple_within_scale_crosscorrs, \
        complex_across_scale_crosscorrs, simple_across_scale_crosscorrs = get_higher_order_features(self.fmaps_fn_complex, self.fmaps_fn_simple, images, prf_params=prf_params, 
                                                                                                    sample_batch_size=self.sample_batch_size, include_autocorrs=self.include_autocorrs, include_crosscorrs=self.include_crosscorrs, 
                                                                                                    autocorr_output_pix=self.autocorr_output_pix, n_prf_sd_out=self.n_prf_sd_out, 
                                                                                                    aperture=self.aperture,  device=self.device)
        elapsed =  time.time() - t
        print('time elapsed = %.5f'%elapsed)

        all_feat = OrderedDict({'wmean': wmean, 'wvar':wvar, 'wskew':wskew, 'wkurt':wkurt, 'complex_feature_means':complex_feature_means, 'simple_feature_means':simple_feature_means, 
                    'complex_feature_autocorrs': complex_feature_autocorrs, 'simple_feature_autocorrs': simple_feature_autocorrs, 
                    'complex_within_scale_crosscorrs': complex_within_scale_crosscorrs, 'simple_within_scale_crosscorrs':simple_within_scale_crosscorrs,
                    'complex_across_scale_crosscorrs': complex_across_scale_crosscorrs, 'simple_across_scale_crosscorrs':simple_across_scale_crosscorrs})

        feature_names_full = list(all_feat.keys())
        feature_names = [fname for fname in feature_names_full if fname in self.feature_types_include]
        assert(feature_names==self.feature_types_include) # double check here that the order is correct
        
        for ff, feature_name in enumerate(feature_names):   
            assert(all_feat[feature_name] is not None)
            if ff==0:
                all_feat_concat = all_feat[feature_name]
            else:               
                all_feat_concat = torch.cat((all_feat_concat, all_feat[feature_name]), axis=1)

        assert(all_feat_concat.shape[1]==self.n_features_total)
        print('Final size of features concatenated is [%d x %d]'%(all_feat_concat.shape[0], all_feat_concat.shape[1]))
        print('Feature types included are:')
        print(feature_names)

        if torch.any(torch.isnan(all_feat_concat)):
            print('\nWARNING THERE ARE NANS IN FEATURES MATRIX\n')

        return all_feat_concat, [self.feature_column_labels, feature_names]

In [308]:
# Validate model on held-out test set
gc.collect()
torch.cuda.empty_cache()
val_cc, val_r2 = validate_texture_model(best_params, val_voxel_single_trial_data, val_stim_single_trial_data, _texture_fn, sample_batch_size, debug=debug, dtype=fpX)



Initializing model for validation...

Creating FWRF texture model...

Getting model predictions on validation set...

samples [    0:19   ] of 62, voxels [     0:0     ] of 14913Computing pixel-level statistics...
time elapsed = 0.00463
Computing complex cell features...
time elapsed = 0.06144
Computing simple cell features...
time elapsed = 0.03896
SKIPPING HIGHER-ORDER CORRELATIONS...
time elapsed = 0.03219
Final size of features concatenated is [20 x 100]
Feature types included are:
['wmean', 'wvar', 'wskew', 'wkurt', 'complex_feature_means', 'simple_feature_means']
samples [   20:39   ] of 62, voxels [     0:0     ] of 14913Computing pixel-level statistics...
time elapsed = 0.00458
Computing complex cell features...
time elapsed = 0.06040
Computing simple cell features...
time elapsed = 0.03625
SKIPPING HIGHER-ORDER CORRELATIONS...
time elapsed = 0.03219
Final size of features concatenated is [20 x 100]
Feature types included are:
['wmean', 'wvar', 'wskew', 'wkurt', 'complex_featu

  0%|          | 0/14913 [00:00<?, ?it/s]


Evaluating correlation coefficient on validation set...



100%|██████████| 14913/14913 [00:03<00:00, 4074.13it/s]


In [446]:
val_cc, val_r2 = validate_texture_model_partial(best_params, val_voxel_single_trial_data, val_stim_single_trial_data, _texture_fn, sample_batch_size, debug=debug, dtype=fpX)



Variance partition, leaving out: wmean
Remaining features are:
['wvar', 'wskew', 'wkurt', 'complex_feature_means']
[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35]

Initializing model...


Getting model predictions on validation set...

Computing pixel-level statistics...
time elapsed = 0.00481
Computing complex cell features...
time elapsed = 0.06640
Computing simple cell features...
time elapsed = 0.04153
SKIPPING HIGHER-ORDER CORRELATIONS...
time elapsed = 0.03269
Final size of features concatenated is [20 x 35]
Feature types included are:
['wvar', 'wskew', 'wkurt', 'complex_feature_means']
Computing pixel-level statistics...
time elapsed = 0.00452
Computing complex cell features...
time elapsed = 0.06050
Computing simple cell features...
time elapsed = 0.03798
SKIPPING HIGHER-ORDER CORRELATIONS...
time elapsed = 0.03218
Final size of features concatenated is [20 x 35]
Feature types included are:
['wvar', 'wskew', 'wkurt', 

In [480]:
# out = torch.load('/user_data/mmhender/imStat/model_fits/S01/texture_ridge_8ori_4sf/Jul-06-2021_0332_DEBUG/all_fit_params')
out = torch.load('/user_data/mmhender/imStat/model_fits/S01/texture_ridge_8ori_4sfno_pixelno_simpleno_autocorrelationsno_crosscorrelations/Jul-06-2021_0336_DEBUG/all_fit_params')

In [485]:
out['feature_info'][0]

array([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4])

In [483]:
out['val_cc_partial'].shape

AttributeError: 'NoneType' object has no attribute 'shape'

In [484]:
out.keys()

dict_keys(['feature_table_simple', 'sf_tuning_masks_simple', 'ori_tuning_masks_simple', 'cyc_per_stim_simple', 'orients_deg_simple', 'orient_filters_simple', 'feature_table_complex', 'sf_tuning_masks_complex', 'ori_tuning_masks_complex', 'cyc_per_stim_complex', 'orients_deg_complex', 'orient_filters_complex', 'aperture', 'aperture_rf_range', 'models', 'include_autocorrs', 'feature_info', 'voxel_mask', 'brain_nii_shape', 'image_order', 'voxel_index', 'voxel_roi', 'voxel_ncsnr', 'best_params', 'lambdas', 'best_lambdas', 'best_losses', 'val_cc', 'val_r2', 'val_cc_partial', 'val_r2_partial', 'features_each_model_val', 'voxel_feature_correlations_val', 'zscore_features', 'nonlin_fn', 'padding_mode', 'n_prf_sd_out', 'autocorr_output_pix', 'debug', 'up_to_sess'])

In [463]:
dtype = torch_utils.get_value(next(_fmaps_fn_complex.parameters())).dtype
n_features_complex, fmaps_rez = fwrf_fit.get_fmaps_sizes(_fmaps_fn_complex, np.zeros([2,1,2,2]).astype(dtype), device)  
n_features_simple, fmaps_rez = fwrf_fit.get_fmaps_sizes(_fmaps_fn_simple, np.zeros([2,1,2,2]).astype(dtype), device)    
n_sf = len(fmaps_rez)
n_ori = int(n_features_complex/n_sf)
n_phase = int(n_features_simple/n_sf/n_ori)
[n_sf, n_ori, n_phase]

[4, 8, 2]

In [464]:
_texture_fn.feature_types_include

['wmean', 'wvar', 'wskew', 'wkurt']

In [268]:
def validate_texture_model(best_params, val_voxel_single_trial_data, val_stim_single_trial_data, _texture_fn, sample_batch_size, debug=False, dtype=np.float32):
    
    # EVALUATE PERFORMANCE ON VALIDATION SET
    
    print('\nInitializing model for validation...\n')
    param_batch = [p[0:1] if p is not None else None for p in best_params]
    # To initialize this module for prediction, need to take just first batch of voxels.
    # Will eventually pass all voxels through in batches.
    _fwd_model = texture_model(_texture_fn, param_batch, input_shape=val_stim_single_trial_data.shape)

    print('\nGetting model predictions on validation set...\n')
    val_voxel_pred = get_predictions_texture_model(val_stim_single_trial_data, _fwd_model, best_params, sample_batch_size=sample_batch_size, debug=debug)

    # val_cc is the correlation coefficient bw real and predicted responses across trials, for each voxel.
    n_voxels = np.shape(val_voxel_single_trial_data)[1]
    val_cc  = np.zeros(shape=(n_voxels), dtype=dtype)
    val_r2 = np.zeros(shape=(n_voxels), dtype=dtype)
    
    print('\nEvaluating correlation coefficient on validation set...\n')
    for v in tqdm(range(n_voxels)):    
        val_cc[v] = np.corrcoef(val_voxel_single_trial_data[:,v], val_voxel_pred[:,v])[0,1]  
        val_r2[v] = get_r2(val_voxel_single_trial_data[:,v], val_voxel_pred[:,v])
        
    val_cc = np.nan_to_num(val_cc)
    val_r2 = np.nan_to_num(val_r2)    
    
    return val_cc, val_r2

In [488]:
include_crosscorrs = True
include_autocorrs = True
include_pixels = False
include_simple = False
include_complex = False

feature_types_exclude = []
if not include_pixels:
    feature_types_exclude.append('pixel')
if not include_simple:
    feature_types_exclude.append('simple_feature_means')
if not include_complex:
    feature_types_exclude.append('complex_feature_means')
if not include_autocorrs:
    feature_types_exclude.append('autocorrs')
if not include_crosscorrs:
    feature_types_exclude.append('crosscorrs')

    
_texture_fn = texture_feature_extractor(_fmaps_fn_complex, _fmaps_fn_simple, sample_batch_size=sample_batch_size, feature_types_exclude=feature_types_exclude, autocorr_output_pix=autocorr_output_pix, n_prf_sd_out=n_prf_sd_out, aperture=aperture, device=device)

In [490]:
feature_column_labels = np.squeeze(np.concatenate([fi*np.ones([1,feature_type_dims[fi]]) for fi in range(len(feature_type_dims)) if not feature_types_all[fi] in feature_types_exclude], axis=1).astype('int'))

In [491]:
feature_column_labels

array([ 6,  6,  6, ..., 11, 11, 11])

In [492]:
out = torch.load('/user_data/mmhender/imStat/model_fits/S01/texture_ridge_8ori_4sf_no_pixel_no_simple_no_complex/Jul-06-2021_0352_DEBUG/all_fit_params')

In [493]:
out['feature_info'][0]

array([0, 0, 0, ..., 5, 5, 5])

In [489]:
# val_cc is the correlation coefficient bw real and predicted responses across trials, for each voxel.
n_voxels = np.shape(val_voxel_single_trial_data)[1]
n_feature_types = len(_texture_fn.feature_types_include)
val_cc  = np.zeros(shape=(n_voxels, n_feature_types), dtype=dtype)
val_r2 = np.zeros(shape=(n_voxels, n_feature_types), dtype=dtype)

orig_feature_column_labels = _texture_fn.feature_column_labels
orig_excluded_features = _texture_fn.feature_types_exclude

for ff, feat_name in enumerate(_texture_fn.feature_types_include):

    print('\nVariance partition, leaving out: %s'%feat_name)
    _texture_fn.update_feature_list(orig_excluded_features+[feat_name])
    print('Remaining features are:')
    print(_texture_fn.feature_types_include)

    # Choose columns of interest here, leaving out weights for one feature at a time
    params_to_use = copy.deepcopy(best_params)
    columns_to_use = np.where(orig_feature_column_labels!=ff)[0]
    print(columns_to_use)
    params_to_use[1] = params_to_use[1][:,columns_to_use]
    if best_params[3] is not None:
        params_to_use[3] = pars2use[3][:,columns_to_use]
        params_to_use[4] = pars2use[4][:,columns_to_use]

    print('\nInitializing model...\n')
    param_batch = [p[0:1] if p is not None else None for p in params_to_use]

    # To initialize this module for prediction, need to take just first batch of voxels.
    # Will eventually pass all voxels through in batches.
    _fwd_model = texture_model(_texture_fn, param_batch, input_shape=val_stim_single_trial_data.shape)

    print('\nGetting model predictions on validation set...\n')
    val_voxel_pred = get_predictions_texture_model(val_stim_single_trial_data, _fwd_model, params_to_use, sample_batch_size=sample_batch_size, debug=debug)

    print('\nEvaluating correlation coefficient on validation set...\n')
    for v in range(n_voxels):    
        val_cc[v,ff] = np.corrcoef(val_voxel_single_trial_data[:,v], val_voxel_pred[:,v])[0,1]  
        val_r2[v,ff] = get_r2(val_voxel_single_trial_data[:,v], val_voxel_pred[:,v])

val_cc = np.nan_to_num(val_cc)
val_r2 = np.nan_to_num(val_r2) 


Variance partition, leaving out: complex_feature_autocorrs
Remaining features are:
['simple_feature_autocorrs', 'complex_within_scale_crosscorrs', 'simple_within_scale_crosscorrs', 'complex_across_scale_crosscorrs', 'simple_across_scale_crosscorrs']
[   0    1    2 ... 3309 3310 3311]


IndexError: index 100 is out of bounds for axis 1 with size 100

In [445]:
import copy

def validate_texture_model_partial(best_params, val_voxel_single_trial_data, val_stim_single_trial_data, _texture_fn, sample_batch_size, debug=False, dtype=np.float32):    
    
    """ 
    Evaluate trained model, leaving out a subset of features at a time.
    """
    
    # val_cc is the correlation coefficient bw real and predicted responses across trials, for each voxel.
    n_voxels = np.shape(val_voxel_single_trial_data)[1]
    n_feature_types = len(_texture_fn.feature_types_include)
    val_cc  = np.zeros(shape=(n_voxels, n_feature_types), dtype=dtype)
    val_r2 = np.zeros(shape=(n_voxels, n_feature_types), dtype=dtype)

    orig_feature_column_labels = _texture_fn.feature_column_labels
    orig_excluded_features = _texture_fn.feature_types_exclude

    for ff, feat_name in enumerate(_texture_fn.feature_types_include):

        print('\nVariance partition, leaving out: %s'%feat_name)
        _texture_fn.update_feature_list(orig_excluded_features+[feat_name])
        print('Remaining features are:')
        print(_texture_fn.feature_types_include)

        # Choose columns of interest here, leaving out weights for one feature at a time
        params_to_use = copy.deepcopy(best_params)
        columns_to_use = np.where(orig_feature_column_labels!=ff)[0]
        print(columns_to_use)
        params_to_use[1] = params_to_use[1][:,columns_to_use]
        if best_params[3] is not None:
            params_to_use[3] = pars2use[3][:,columns_to_use]
            params_to_use[4] = pars2use[4][:,columns_to_use]

        print('\nInitializing model...\n')
        param_batch = [p[0:1] if p is not None else None for p in params_to_use]

        # To initialize this module for prediction, need to take just first batch of voxels.
        # Will eventually pass all voxels through in batches.
        _fwd_model = texture_model(_texture_fn, param_batch, input_shape=val_stim_single_trial_data.shape)

        print('\nGetting model predictions on validation set...\n')
        val_voxel_pred = get_predictions_texture_model(val_stim_single_trial_data, _fwd_model, params_to_use, sample_batch_size=sample_batch_size, debug=debug)

        print('\nEvaluating correlation coefficient on validation set...\n')
        for v in range(n_voxels):    
            val_cc[v,ff] = np.corrcoef(val_voxel_single_trial_data[:,v], val_voxel_pred[:,v])[0,1]  
            val_r2[v,ff] = get_r2(val_voxel_single_trial_data[:,v], val_voxel_pred[:,v])

    val_cc = np.nan_to_num(val_cc)
    val_r2 = np.nan_to_num(val_r2) 
    
    return val_cc, val_r2

In [409]:
class texture_model(torch.nn.Module):
    
    """
    Module that predicts voxel responses based on texture features and encoding model weights.
    Texture features are computed in the module specified by '_texture_fn'.
    Currently written to work with just 1 voxel at a time. This is because the texture features are pRF-specific, 
    and have to be computed 1 pRF at a time. Could probably batch >1 voxel if they had same pRF params, though.
    """
    
    def __init__(self, _texture_fn, params, input_shape = (1,3,227,227)):
        
        super(texture_model, self).__init__()        
#         print('Creating FWRF texture model...')
        
        self.texture_fn = _texture_fn       
        self.voxel_batch_size = 1 # because of how this model is set up, can only do for one voxel at a time! slow.
        device = next(_fmaps_fn_complex.parameters()).device
      
        models, weights, bias, features_mt, features_st, best_model_inds = params
        _x = torch.empty((1,)+input_shape[1:], device=device).uniform_(0, 1)
        _fmaps = _fmaps_fn_complex(_x)
        n_features_complex, self.fmaps_rez = fwrf_fit.get_fmaps_sizes(_fmaps_fn_complex, _x, device)    
        
        self.models = nn.Parameter(torch.from_numpy(models).to(device), requires_grad=False)
        
        self.weights = nn.Parameter(torch.from_numpy(weights).to(device), requires_grad=False)
        self.bias = None
        if bias is not None:
            self.bias = nn.Parameter(torch.from_numpy(bias).to(device), requires_grad=False)
      
        self.features_m = None
        self.features_s = None
        if features_mt is not None:
            self.features_m = nn.Parameter(torch.from_numpy(features_mt.T).to(device), requires_grad=False)
        if features_st is not None:
            self.features_s = nn.Parameter(torch.from_numpy(features_st.T).to(device), requires_grad=False)
       
    def load_voxel_block(self, *params):
        # This takes a given set of parameters for the voxel batch of interest, and puts them 
        # into the right fields of the module so we can use them in a forward pass.
        models = params[0]
        assert(models.shape[0]==self.voxel_batch_size)
        
        torch_utils.set_value(self.models, models)

        for _p,p in zip([self.weights, self.bias], params[1:3]):
            if _p is not None:
                if len(p)<_p.size()[0]:
                    pp = np.zeros(shape=_p.size(), dtype=p.dtype)
                    pp[:len(p)] = p
                    torch_utils.set_value(_p, pp)
                else:
                    torch_utils.set_value(_p, p)
                    
        for _p,p in zip([self.features_m, self.features_s], params[3:]):
            if _p is not None:
                if len(p)<_p.size()[1]:
                    pp = np.zeros(shape=(_p.size()[1], _p.size()[0]), dtype=p.dtype)
                    pp[:len(p)] = p
                    torch_utils.set_value(_p, pp.T)
                else:
                    torch_utils.set_value(_p, p.T)
                    
        
    def forward(self, image_batch):
        
        all_feat_concat, feature_info = self.texture_fn(image_batch,self.models)
        
        _features = all_feat_concat.view([all_feat_concat.shape[0],-1,1]) # trials x features x 1
#         _features = torch_utils._to_torch(all_feat_concat, device=self.weights.device).view([all_feat_concat.shape[0],-1,1]) # trials x features x 1
       
        if self.features_m is not None:    
            # features_m is [nfeatures x nvoxels]
            _features = _features - torch.tile(torch.unsqueeze(self.features_m, dim=0), [_features.shape[0], 1, 1])

        if self.features_s is not None:
            _features = _features/torch.tile(torch.unsqueeze(self.features_s, dim=0), [_features.shape[0], 1, 1])
            _features[torch.isnan(_features)] = 0.0 # this applies in the pca case when last few columns of features are missing

        # features is [#samples, #features, #voxels] - swap dims to [#voxels, #samples, features]
        _features = torch.transpose(torch.transpose(_features, 0, 2), 1, 2)
        # weights is [#voxels, #features]
        # _r will be [#voxels, #samples, 1] - then [#samples, #voxels]
        _r = torch.squeeze(torch.bmm(_features, torch.unsqueeze(self.weights, 2)), dim=2).t() 
  
        if self.bias is not None:
            _r = _r + torch.tile(torch.unsqueeze(self.bias, 0), [_r.shape[0],1])
            
        return _r
    
    
    

In [407]:
import sys
import os
import struct
import time
import numpy as np
import h5py
from tqdm import tqdm
import pickle
import math
import sklearn
from sklearn import decomposition
import gc

import torch
import torch.nn as nn
import torch.nn.init as I
import torch.nn.functional as F
import torch.optim as optim

from utils import numpy_utility, torch_utils
from model_src import fwrf_fit, texture_statistics

    
def get_r2(actual,predicted):
  
    # calculate r2 for this fit.
    ssres = np.sum(np.power((predicted - actual),2));
#     print(ssres)
    sstot = np.sum(np.power((actual - np.mean(actual)),2));
#     print(sstot)
    r2 = 1-(ssres/sstot)
    
    return r2
    
def get_predictions_texture_model(images, _fwd_model, params, sample_batch_size=100, debug=False):
   
    dtype = images.dtype.type
    device = _fwd_model.weights.device
    _params = [_p for _p in _fwd_model.parameters()]
    voxel_batch_size = _fwd_model.voxel_batch_size
    assert(voxel_batch_size==1) # this won't work with batches of >1 voxel
    n_trials, n_voxels = len(images), len(params[0])

    pred = np.full(fill_value=0, shape=(n_trials, n_voxels), dtype=dtype)
    start_time = time.time()
    
    with torch.no_grad():
        
        vv=-1
        ## Looping over voxels here in batches, will eventually go through all.
        for rv, lv in numpy_utility.iterate_range(0, n_voxels, voxel_batch_size):
            vv=vv+1
            if debug and vv>1:
                break
            # for this voxel batch, put the right parameters into the _fwrf_fn module
            # so that we can do forward pass...
            _fwd_model.load_voxel_block(*[p[rv] if p is not None else None for p in params])
            pred_block = np.full(fill_value=0, shape=(n_trials, voxel_batch_size), dtype=dtype)
            
            # Now looping over validation set trials in batches
            for rt, lt in numpy_utility.iterate_range(0, n_trials, sample_batch_size):
#                 sys.stdout.write('\rsamples [%5d:%-5d] of %d, voxels [%6d:%-6d] of %d' % (rt[0], rt[-1], n_trials, rv[0], rv[-1], n_voxels))
                # Get predictions for this set of trials.
               
                pred_block[rt] = torch_utils.get_value(_fwd_model(torch_utils._to_torch(images[rt], device))) 
                
            pred[:,rv] = pred_block[:,:lv]
            
    total_time = time.time() - start_time
    print ('\n---------------------------------------')
    print ('total time = %fs' % total_time)
    print ('sample throughput = %fs/sample' % (total_time / n_trials))
    print ('voxel throughput = %fs/voxel' % (total_time / n_voxels))
    sys.stdout.flush()
    return pred

In [275]:

def get_higher_order_features(_fmaps_fn_complex, _fmaps_fn_simple, images, prf_params, sample_batch_size=20, include_autocorrs=True, include_crosscorrs=True, autocorr_output_pix=7, n_prf_sd_out=2, aperture=1.0, device=None):

    """
    Compute all higher-order features (cross-spatial and cross-feature correlations) for a batch of images.
    Input the functions that define first level feature maps (simple and complex cells), and prf parameters.
    Returns arrays of each higher order feature.    
    """
    
    if device is None:
        device = torch.device('cpu:0')    
        
    n_trials = np.shape(images)[0]
    
    assert(np.mod(autocorr_output_pix,2)==1) # must be odd!

    n_features_simple, fmaps_rez = fwrf_fit.get_fmaps_sizes(_fmaps_fn_simple, images[0:sample_batch_size], device)
    n_features_complex, fmaps_rez = fwrf_fit.get_fmaps_sizes(_fmaps_fn_complex, images[0:sample_batch_size], device)
    
    n_sf = len(fmaps_rez)
    n_ori = int(n_features_complex/n_sf)
    n_phases = 2
    
    # all pairs of different orientation channels.
    ori_pairs = np.vstack([[[oo1, oo2] for oo2 in np.arange(oo1+1, n_ori)] for oo1 in range(n_ori) if oo1<n_ori-1])
    n_ori_pairs = np.shape(ori_pairs)[0]

    if include_autocorrs:
        complex_feature_autocorrs = torch.zeros([n_trials, n_sf, n_ori, autocorr_output_pix**2], device=device)
        simple_feature_autocorrs = torch.zeros([n_trials, n_sf, n_ori, n_phases, autocorr_output_pix**2], device=device)
    else:
        complex_feature_autocorrs = None
        simple_feature_autocorrs = None
    
    if include_crosscorrs:
        complex_within_scale_crosscorrs = torch.zeros([n_trials, n_sf, n_ori_pairs], device=device)
        simple_within_scale_crosscorrs = torch.zeros([n_trials, n_sf, n_phases, n_ori_pairs], device=device)
        complex_across_scale_crosscorrs = torch.zeros([n_trials, n_sf-1, n_ori, n_ori], device=device)
        simple_across_scale_crosscorrs = torch.zeros([n_trials, n_sf-1, n_phases, n_ori, n_ori], device=device) # only done for pairs of neighboring SF.
    else:
        complex_within_scale_crosscorrs = None
        simple_within_scale_crosscorrs = None
        complex_across_scale_crosscorrs = None
        simple_across_scale_crosscorrs = None
        
    if include_autocorrs or include_crosscorrs:
        
        x,y,sigma = prf_params

        bb=-1
        for batch_inds, batch_size_actual in numpy_utility.iterate_range(0, n_trials, sample_batch_size):
            bb=bb+1

            fmaps_complex = _fmaps_fn_complex(torch_utils._to_torch(images[batch_inds],device=device))   
            fmaps_simple =  _fmaps_fn_simple(torch_utils._to_torch(images[batch_inds],device=device))

            # First looping over frequency (scales)
            for ff in range(n_sf):

                # Scale specific things - get the prf at this resolution of interest
                n_pix = fmaps_rez[ff]
                g = numpy_utility.make_gaussian_mass_stack([x], [y], [sigma], n_pix=n_pix, size=aperture, dtype=np.float32)
                spatial_weights = g[2][0]

                patch_bbox_rect = get_bbox_from_prf(prf_params, spatial_weights.shape, n_prf_sd_out, force_square=False)
                # for autocorrelation, forcing the input region to be square
                patch_bbox_square = get_bbox_from_prf(prf_params, spatial_weights.shape, n_prf_sd_out, force_square=True)

                # Loop over orientation channels
                xx=-1
                for oo1 in range(n_ori):       


                    # Simple cell responses - loop over two phases per orient.
                    for pp in range(n_phases):
                        filter_ind = n_phases*oo1+pp  # orients and phases are both listed in the same dimension of filters matrix               
                        simple1 = fmaps_simple[ff][:,filter_ind,:,:].view([batch_size_actual,1,n_pix,n_pix])

                        # Simple cell autocorrelations.
                        if include_autocorrs:
                            auto_corr = weighted_auto_corr_2d(simple1, spatial_weights, patch_bbox=patch_bbox_square, output_pix = autocorr_output_pix, subtract_patch_mean = True, enforce_size=True, device=device)
                            simple_feature_autocorrs[batch_inds,ff,oo1,pp,:] = torch.reshape(auto_corr, [batch_size_actual, autocorr_output_pix**2])

                    # Complex cell responses
                    complex1 = fmaps_complex[ff][:,oo1,:,:].view([batch_size_actual,1,n_pix,n_pix])

                    # Complex cell autocorrelation (correlation w spatially shifted versions of itself)
                    if include_autocorrs:
                        auto_corr = weighted_auto_corr_2d(complex1, spatial_weights, patch_bbox=patch_bbox_square, output_pix = autocorr_output_pix, subtract_patch_mean = True, enforce_size=True, device=device)       
                        complex_feature_autocorrs[batch_inds,ff,oo1,:] = torch.reshape(auto_corr, [batch_size_actual, autocorr_output_pix**2])

                    if include_crosscorrs:
                        # Within-scale correlations - compare resp at orient==oo1 to responses at all other orientations, same scale.
                        for oo2 in np.arange(oo1+1, n_ori):            
                            xx = xx+1 
                            assert(oo1==ori_pairs[xx,0] and oo2==ori_pairs[xx,1])

                            complex2 = fmaps_complex[ff][:,oo2,:,:].view([batch_size_actual,1,n_pix,n_pix])      

                            # Complex cell within-scale cross correlations
                            cross_corr = weighted_cross_corr_2d(complex1, complex2, spatial_weights, patch_bbox=patch_bbox_rect, subtract_patch_mean = True, device=device)

                            complex_within_scale_crosscorrs[batch_inds,ff,xx] = torch.squeeze(cross_corr);

                            # Simple cell within-scale cross correlations
                            for pp in range(n_phases):
                                filter_ind = n_phases*oo2+pp
                                simple2 = fmaps_simple[ff][:,filter_ind,:,:].view([batch_size_actual,1,n_pix,n_pix])

                                cross_corr = weighted_cross_corr_2d(simple1, simple2, spatial_weights, patch_bbox=patch_bbox_rect, subtract_patch_mean = True, device=device)
                                simple_within_scale_crosscorrs[batch_inds,ff,pp,xx] = torch.squeeze(cross_corr);

                        # Cross-scale correlations - for these we care about same ori to same ori, so looping over all ori.
                        # Only for neighboring scales, so the first level doesn't get one
                        if ff>0:

                            for oo2 in range(n_ori):

                                # Complex cell response for neighboring scale
                                complex2_neighborscale = fmaps_complex[ff-1][:,oo2,:,:].view([batch_size_actual,1,fmaps_rez[ff-1], -1])
                                # Resize so that it can be compared w current scale
                                complex2_neighborscale = torch.nn.functional.interpolate(complex2_neighborscale, [n_pix, n_pix], mode='bilinear', align_corners=True)

                                cross_corr = weighted_cross_corr_2d(complex1, complex2_neighborscale, spatial_weights, patch_bbox=patch_bbox_rect, subtract_patch_mean = True, device=device)
                                complex_across_scale_crosscorrs[batch_inds,ff-1, oo1, oo2] = torch.squeeze(cross_corr)

                                for pp in range(n_phases):
                                    filter_ind = n_phases*oo2+pp
                                    # Simple cell response for neighboring scale
                                    simple2_neighborscale = fmaps_simple[ff-1][:,filter_ind,:,:].view([batch_size_actual,1,fmaps_rez[ff-1], -1])
                                    simple2_neighborscale = torch.nn.functional.interpolate(simple2_neighborscale, [n_pix, n_pix], mode='bilinear', align_corners=True)

                                    cross_corr = weighted_cross_corr_2d(simple1, simple2_neighborscale, spatial_weights, patch_bbox=patch_bbox_rect, subtract_patch_mean = True, device=device)
                                    simple_across_scale_crosscorrs[batch_inds,ff-1, pp, oo1, oo2] = torch.squeeze(cross_corr)

    if include_crosscorrs:
        simple_within_scale_crosscorrs = torch.reshape(simple_within_scale_crosscorrs, [n_trials, -1])
        simple_across_scale_crosscorrs = torch.reshape(simple_across_scale_crosscorrs, [n_trials, -1])
        complex_within_scale_crosscorrs = torch.reshape(complex_within_scale_crosscorrs, [n_trials, -1])
        complex_across_scale_crosscorrs = torch.reshape(complex_across_scale_crosscorrs, [n_trials, -1])
    if include_autocorrs:
        simple_feature_autocorrs = torch.reshape(simple_feature_autocorrs, [n_trials, -1])
        complex_feature_autocorrs = torch.reshape(complex_feature_autocorrs, [n_trials, -1])

    return complex_feature_autocorrs, simple_feature_autocorrs, complex_within_scale_crosscorrs, simple_within_scale_crosscorrs, complex_across_scale_crosscorrs, simple_across_scale_crosscorrs


def get_bbox_from_prf(prf_params, image_size, n_prf_sd_out=2, verbose=False, force_square=False):
    """
    For a given pRF center and size, calculate the square bounding box that captures a specified number of SDs from the center (default=2 SD)
    Returns [xmin, xmax, ymin, ymax]
    """
    x,y,sigma = prf_params
    n_pix = image_size[0]
    assert(image_size[1]==n_pix)
    assert(sigma>0 and n_prf_sd_out>0)
    
    # decide on the window to use for correlations, based on prf parameters. Patch goes # SD from the center (2 by default).
    # note this can't be < 1, even for the smallest choice of parameters (since rounding up). this way it won't be too small.
    pix_from_center = int(np.ceil(sigma*n_prf_sd_out*n_pix))

    # center goes [row ind, col ind]
    center = np.array((n_pix/2  - y*n_pix, x*n_pix + n_pix/2)) # note that the x/y dims get swapped here because of how pRF parameters are defined.

    # now defining the extent of the bbox. want to err on the side of making it too big, so taking floor/ceiling...
    xmin = int(np.floor(center[0]-pix_from_center))
    xmax = int(np.ceil(center[0]+pix_from_center))
    ymin = int(np.floor(center[1]-pix_from_center))
    ymax = int(np.ceil(center[1]+pix_from_center))

    # cropping it to within the image bounds. Can end up being a rectangle rather than square.
    [xmin, xmax, ymin, ymax] = np.maximum(np.minimum([xmin, xmax, ymin, ymax], n_pix), 0)

    # decide if we want square or are ok with a rectangle
    if force_square:
        minside = np.min([xmax-xmin, ymax-ymin])
        maxside = np.max([xmax-xmin, ymax-ymin])
        if minside!=maxside:

            if verbose:
                print('trimming bbox to make it square')
                print('original bbox was:')
                print([xmin, xmax, ymin, ymax])

            n2trim = [int(np.floor((maxside-minside)/2)), int(np.ceil((maxside-minside)/2))]

            if np.argmin([xmax-xmin, ymax-ymin])==0:
                ymin = ymin+n2trim[0]
                ymax = ymax-n2trim[1]
            else:
                xmin = xmin+n2trim[0]
                xmax = xmax-n2trim[1]

        assert((xmax-xmin)==(ymax-ymin))

    if verbose:
        print('final bbox will be:')
        print([xmin, xmax, ymin, ymax])
        
    # checking to see if the patch has become just one pixel. this can happen due to the cropping.
    # if this happens, cross-correlations will give zero.
    if ((xmax-xmin)<2 or (ymax-ymin)<2):
        print('Warning: your patch only has one pixel (for n_pix: %d and prf params: [%.2f, %.2f, %.2f])\n'%(n_pix,x,y,sigma))      
        
    return [xmin, xmax, ymin, ymax]


def weighted_auto_corr_2d(images, spatial_weights, patch_bbox=None, output_pix=None, subtract_patch_mean=False, enforce_size=False, device=None):

    """
    Compute autocorrelation of a batch of images, weighting the pixels based on the values in spatial_weights (could be for instance a pRF definition for a voxel).
    Can optionally specify a square patch of the image to compute over, based on "patch_bbox" params. Otherwise use whole image.
    Using fft method to compute, should be fast.
    Input parameters:
        patch_bbox: (optional) bounding box of the patch to use for this calculation. [xmin xmax ymin ymax], see get_bbox_from_prf
        output_pix: the size of the autocorrelation matrix output by this function. If this is an even number, the output size is this value +1. Achieved by cropping out the center of the final autocorrelation 
            matrix  (note that the full image patch is still used in computing the autocorrelation, but just the center values are returned).
            If None, then returns the full autocorrelation matrix (same size as image patch.)
        subtract_patch_mean: subtract weighted mean of image before computing autocorr?
        enforce_size: if image patch is smaller than desired output, should we pad w zeros so that it has to be same size?
    Returns:
        A matrix describing the correlation of the image and various spatially shifted versions of it.
    """
    
    
    if device is None:
        device = torch.device('cpu:0')        
    if isinstance(images, np.ndarray):
        images = torch_utils._to_torch(images, device)
    if isinstance(spatial_weights, np.ndarray):
        spatial_weights = torch_utils._to_torch(spatial_weights, device)
            
    if len(np.shape(images))==2:
        # pretend the batch and channel dims exist, for 2D input only (3D won't work)
        single_image=True
        images = images.view([1,1,images.shape[0],-1])
    else:
        single_image=False
        
    # have to be same size
    assert(images.shape[2]==spatial_weights.shape[0] and images.shape[3]==spatial_weights.shape[1])
    # images is [batch_size x n_channels x nPix x nPix]
    batch_size = images.shape[0]
    n_channels = images.shape[1]    
   
    if patch_bbox is not None:    
        [xmin, xmax, ymin, ymax] = patch_bbox
        # first crop out the region of the image that's currently of interest
        images = images[:,:,xmin:xmax, ymin:ymax]
        # crop same region from spatial weights matrix
        spatial_weights = spatial_weights[xmin:xmax, ymin:ymax]

    # make sure these sum to 1
    if not torch.sum(spatial_weights)==0.0:
        spatial_weights = spatial_weights/torch.sum(spatial_weights)   
   
    spatial_weights = spatial_weights.view([1,1,spatial_weights.shape[0],-1]).expand([batch_size,n_channels,-1,-1]) # [batch_size x n_channels x nPix x nPix]    
    
    # compute autocorrelation of this image patch
    if subtract_patch_mean:

        wmean = torch.sum(torch.sum(images * spatial_weights, dim=3), dim=2) # size is [batch_size x 1]
        wmean = wmean.view([batch_size,-1,1,1]).expand([-1,-1,images.shape[2],images.shape[3]]) # [batch_size x n_channels x nPix x nPix]
        
        weighted_images = (images - wmean) * torch.sqrt(spatial_weights) # square root of the weights here because they will get squared again in next operation
        
        auto_corr = torch.fft.fftshift(torch.real(torch.fft.ifft2(torch.abs(torch.fft.fft2(weighted_images, dim=[2,3]))**2, dim=[2,3])), dim=[2,3]);
    else:
        weighted_images = images * torch.sqrt(spatial_weights)
        auto_corr = torch.fft.fftshift(torch.real(torch.fft.ifft2(torch.abs(torch.fft.fft2(weighted_images, dim=[2,3]))**2, dim=[2,3])), dim=[2,3]);

    if output_pix is not None:

        # crop out just the center region
        new_center = int(np.floor(auto_corr.shape[2]/2))
        n_pix_out = np.min([int(np.floor(output_pix/2)), np.min([new_center, auto_corr.shape[2]-new_center])])
        auto_corr = auto_corr[:,:,new_center-n_pix_out:new_center+n_pix_out+1, new_center-n_pix_out:new_center+n_pix_out+1]        
    
    if enforce_size and not (np.shape(auto_corr)[2]==output_pix or np.shape(auto_corr)[2]==output_pix+1):
        
        # just pad w zeros if want same size.
        pix_diff = output_pix - np.shape(auto_corr)[2]   
        auto_corr = torch.nn.functional.pad(auto_corr, [int(np.floor(pix_diff/2)), int(np.ceil(pix_diff/2)), int(np.floor(pix_diff/2)), int(np.ceil(pix_diff/2))], mode='constant', value=0)
        assert(np.shape(auto_corr)[2]==output_pix and np.shape(auto_corr)[3]==output_pix)

    if single_image:
        auto_corr = torch.squeeze(auto_corr)
        
    return auto_corr

def weighted_cross_corr_2d(images1, images2, spatial_weights, patch_bbox=None, subtract_patch_mean=True, device=None):

    """
    Compute cross-correlation of two identically-sized images, weighting the pixels based on the values in spatial_weights (could be for instance a pRF definition for a voxel).
    Can optionally specify a square patch of the image to compute over, based on "patch_bbox" params. Otherwise use whole image.
    Basically a dot product of image values.
    Input parameters:
        patch_bbox: (optional) bounding box of the patch to use for this calculation. [xmin xmax ymin ymax], see get_bbox_from_prf
        subtract_patch_mean: do you want to subtract the weighted mean of image patch before computing?
    Returns:
        A single value that captures correlation between images (zero spatial shift)
            
    """
    
    if device is None:
        device = torch.device('cpu:0')  
    if isinstance(images1, np.ndarray):
        images1 = torch_utils._to_torch(images1, device)
    if isinstance(images2, np.ndarray):
        images2 = torch_utils._to_torch(images2, device)
    if isinstance(spatial_weights, np.ndarray):
        spatial_weights = torch_utils._to_torch(spatial_weights, device)      
    
    if len(np.shape(images1))==2:
        # pretend the batch and channel dims exist, for 2D input only (3D won't work)
        single_image=True
        images1 = images1.view([1,1,images1.shape[0],-1])
        images2 = images2.view([1,1,images2.shape[0],-1])
    else:
        single_image=False
        
    # have to be same size
    assert(images1.shape==images2.shape)
    assert(images1.shape[2]==spatial_weights.shape[0] and images1.shape[3]==spatial_weights.shape[1])
    assert(images2.shape[2]==spatial_weights.shape[0] and images2.shape[3]==spatial_weights.shape[1])
    # images is [batch_size x n_channels x nPix x nPix]
    batch_size = images1.shape[0]
    n_channels = images1.shape[1]
    

    if patch_bbox is not None:
        [xmin, xmax, ymin, ymax] = patch_bbox
        # first crop out the region of the image that's currently of interest
        images1 = images1[:,:,xmin:xmax, ymin:ymax]
        images2 = images2[:,:,xmin:xmax, ymin:ymax]
        # crop same region from spatial weights matrix
        spatial_weights = spatial_weights[xmin:xmax, ymin:ymax]
    
    # make sure the wts sum to 1
    if not torch.sum(spatial_weights)==0.0:
        spatial_weights = spatial_weights/torch.sum(spatial_weights)
    spatial_weights = spatial_weights.view([1,1,spatial_weights.shape[0],-1]).expand([batch_size,n_channels,-1,-1]) # [batch_size x n_channels x nPix x nPix]    
    
    # compute cross-correlation
    if subtract_patch_mean:
        # subtract mean of each weighted image patch and take their dot product.
        # this quantity is equal to weighted covariance (only true if mean-centered)
        wmean1 = torch.sum(torch.sum(images1 * spatial_weights, dim=3), dim=2) # size is [batch_size x 1]
        wmean1 = wmean1.view([batch_size,-1,1,1]).expand([-1,-1,images1.shape[2],images1.shape[3]]) # [batch_size x n_channels x nPix x nPix]
        wmean2 = torch.sum(torch.sum(images2 * spatial_weights, dim=3), dim=2) # size is [batch_size x 1]
        wmean2 = wmean2.view([batch_size,-1,1,1]).expand([-1,-1,images2.shape[2],images2.shape[3]]) # [batch_size x n_channels x nPix x nPix]
        weighted_images1 = (images1 - wmean1) * torch.sqrt(spatial_weights) # square root of the weights here because they will get squared again in dot product operation.
        weighted_images2 = (images2 - wmean2) * torch.sqrt(spatial_weights)

        cross_corr = torch.sum(torch.sum(weighted_images1 * weighted_images2, dim=3), dim=2)    

    else:
        # dot product of raw (weighted) values
        # this is closer to what scipy.signal.correlate2d will do (except this is weighted)
        weighted_images1 = images1 * torch.sqrt(spatial_weights)
        weighted_images2 = images2 * torch.sqrt(spatial_weights)
        cross_corr = torch.sum(torch.sum(weighted_images1 * weighted_images2, dim=3), dim=2)      
        
    if single_image:
        cross_corr = torch.squeeze(cross_corr)
        
    return cross_corr



def get_weighted_pixel_features(image_batch, spatial_weights, device=None):
    """
    Compute mean, variance, skewness, kurtosis of luminance values for each of a batch of images.
    Input size is [batch_size x n_channels x npix x npix]
    Spatial weights describes a weighting function, [npix x npix]
    Returns [batch_size x n_channels] size array for each property.
    """
    
    if isinstance(image_batch, np.ndarray):
        image_batch = torch_utils._to_torch(image_batch, device)
    if isinstance(spatial_weights, np.ndarray):
        spatial_weights = torch_utils._to_torch(spatial_weights, device)
     
    assert(image_batch.shape[2]==spatial_weights.shape[0] and image_batch.shape[3]==spatial_weights.shape[1])
    assert(image_batch.shape[1]==1)
    
    batch_size = image_batch.shape[0]
    n_channels = image_batch.shape[1]
    n_pix = image_batch.shape[2]

    image_batch = image_batch.view([batch_size, n_channels, n_pix**2])
    spatial_weights = spatial_weights/torch.sum(spatial_weights)
    spatial_weights = spatial_weights.view([1,1,n_pix**2]).expand([batch_size,n_channels,-1]) # [batch_size x n_channels x nPix x nPix]    
   
    ims_weighted = image_batch * spatial_weights
   
    wmean = torch.sum(ims_weighted, axis=2).view([batch_size,-1,1])

    wvar = torch.sum(spatial_weights * (image_batch - wmean.expand([-1,-1,n_pix**2]))**2, axis=2).view([batch_size,-1,1])
    
    wskew = torch.sum(spatial_weights *(image_batch - wmean.expand([-1,-1,n_pix**2]))**3 / (wvar**(3/2)), axis=2).view([batch_size,-1,1])
    
    wkurt = torch.sum(spatial_weights *(image_batch - wmean.expand([-1,-1,n_pix**2]))**4 / (wvar**(2)), axis=2).view([batch_size,-1,1])
    
    # correct for nans/inf values which happen when variance is very small (denominator)
    wskew[torch.isnan(wskew)] = 0.0
    wkurt[torch.isnan(wkurt)] = 0.0
    wskew[torch.isinf(wskew)] = 0.0
    wkurt[torch.isinf(wkurt)] = 0.0
    
    return torch.squeeze(wmean, dim=2), torch.squeeze(wvar, dim=2), torch.squeeze(wskew, dim=2), torch.squeeze(wkurt, dim=2)

In [280]:
_texture_fn.parameters()

<generator object Module.parameters at 0x7fb67dc885e8>

In [288]:

def fit_texture_model_ridge(images, voxel_data, _texture_fn, models, lambdas, zscore=False, voxel_batch_size=100, 
                            holdout_size=100, shuffle=True, add_bias=False, debug=False):
   
    """
    Solve for encoding model weights using ridge regression.
    Inputs:
        images: the training images, [n_trials x 1 x height x width]
        voxel_data: the training voxel data, [n_trials x n_voxels]
        _texture_fn: module that maps from images to texture model features
        models: the list of possible pRFs to test, columns are [x, y, sigma]
        lambdas: ridge lambda parameters to test
        zscore: want to zscore each column of feature matrix before fitting?
        voxel_batch_size: how many voxels to use at a time for model fitting
        holdout_size: how many training trials to hold out for computing loss/lambda selection?
        shuffle: do we shuffle training data order before holding trials out?
        add_bias: add a column of ones to feature matrix, for an additive bias?
        debug: want to run a shortened version of this, to test it?
    Outputs:
        best_losses: loss value for each voxel (with best pRF and best lambda), eval on held out set
        best_lambdas: best lambda for each voxel (chosen based on loss w held out set)
        best_params: 
            [0] best pRF for each voxel [x,y,sigma]
            [1] best weights for each voxel/feature
            [2] if add_bias=True, best bias value for each voxel
            [3] if zscore=True, the mean of each feature before z-score
            [4] if zscore=True, the std of each feature before z-score
            [5] index of the best pRF for each voxel (i.e. index of row in "models")
        feature_info: describes types of features in texture model, see texture_feature_extractor.py
        
    """
   
    dtype = images.dtype.type
    device = next(_texture_fn.parameters()).device
    trn_size = len(voxel_data) - holdout_size
    assert trn_size>0, 'Training size needs to be greater than zero'
    
    print ('trn_size = %d (%.1f%%)' % (trn_size, float(trn_size)*100/len(voxel_data)))
    print ('dtype = %s' % dtype)
    print ('device = %s' % device)
    print ('---------------------------------------')
    
    # First do shuffling of data and define set to hold out
    n_trials = len(images)
    n_prfs = len(models)
    n_voxels = voxel_data.shape[1]
    order = np.arange(len(voxel_data), dtype=int)
    if shuffle:
        np.random.shuffle(order)
    images = images[order]
    voxel_data = voxel_data[order]  
    trn_data = voxel_data[:trn_size]
    out_data = voxel_data[trn_size:]

    n_features_total = _texture_fn.n_features_total
        
    # Create full model value buffers    
    best_models = np.full(shape=(n_voxels,), fill_value=-1, dtype=int)   
    best_lambdas = np.full(shape=(n_voxels,), fill_value=-1, dtype=int)
    best_losses = np.full(fill_value=np.inf, shape=(n_voxels), dtype=dtype)
    best_w_params = np.zeros(shape=(n_voxels, n_features_total), dtype=dtype)

    if add_bias:
        best_w_params = np.concatenate([best_w_params, np.ones(shape=(len(best_w_params),1), dtype=dtype)], axis=1)
    features_mean = None
    features_std = None
    if zscore:
        features_mean = np.zeros(shape=(n_voxels, n_features_total), dtype=dtype)
        features_std  = np.zeros(shape=(n_voxels, n_features_total), dtype=dtype)
    
    
    start_time = time.time()
    vox_loop_time = 0
    print ('')
    
    with torch.no_grad():
        
        # Looping over models (here models are different spatial RF definitions)
        for m,(x,y,sigma) in enumerate(models):
            if debug and m>1:
                break
            print('\nmodel %d\n'%m)
            t = time.time()   
            
            # Get features for the desired pRF, across all trn set image   
        
            all_feat_concat, feature_info = _texture_fn(images, [x,y,sigma])
            
            features = torch_utils.get_value(all_feat_concat)
            
            elapsed = time.time() - t
        
            if zscore:  
                features_m = np.mean(features, axis=0, keepdims=True) #[:trn_size]
                features_s = np.std(features, axis=0, keepdims=True) + 1e-6          
                features -= features_m
                features /= features_s    
                
            if add_bias:
                features = np.concatenate([features, np.ones(shape=(len(features), 1), dtype=dtype)], axis=1)
            
            # separate design matrix into training/held out data (for lambda selection)
            trn_features = features[:trn_size]
            out_features = features[trn_size:]   

            # Send matrices to gpu
            _xtrn = torch_utils._to_torch(trn_features, device=device)
            _xout = torch_utils._to_torch(out_features, device=device)   
            
            # Do part of the matrix math involved in ridge regression optimization out of the loop, 
            # because this part will be same for all the voxels.
            _cof = _cofactor_fn_cpu(_xtrn, lambdas)
            
            # Now looping over batches of voxels (only reason is because can't store all in memory at same time)
            vox_start = time.time()
            for rv,lv in numpy_utility.iterate_range(0, n_voxels, voxel_batch_size):
                sys.stdout.write('\rfitting model %4d of %-4d, voxels [%6d:%-6d] of %d' % (m, n_prfs, rv[0], rv[-1], n_voxels))

                # Send matrices to gpu
                _vtrn = torch_utils._to_torch(trn_data[:,rv], device=device)
                _vout = torch_utils._to_torch(out_data[:,rv], device=device)

                # Here is where optimization happens - relatively simple matrix math inside loss fn.
                _betas, _loss = _loss_fn(_cof, _vtrn, _xout, _vout) #   [#lambda, #feature, #voxel, ], [#lambda, #voxel]
                # Now have a set of weights (in betas) and a loss value for every voxel and every lambda. 
                # goal is then to choose for each voxel, what is the best lambda and what weights went with that lambda.
                
                # first choose best lambda value and the loss that went with it.
                _values, _select = torch.min(_loss, dim=0)
                betas = torch_utils.get_value(_betas)
                values, select = torch_utils.get_value(_values), torch_utils.get_value(_select)

                # comparing this loss to the other models for each voxel (e.g. the other RF position/sizes)
                imp = values<best_losses[rv]
                
                if np.sum(imp)>0:                    
                    # for whichever voxels had improvement relative to previous models, save parameters now
                    # this means we won't have to save all params for all models, just best.
                    arv = np.array(rv)[imp]
                    li = select[imp]
                    best_lambdas[arv] = li
                    best_losses[arv] = values[imp]
                    best_models[arv] = m
                    if zscore:
                        features_mean[arv] = features_m # broadcast over updated voxels
                        features_std[arv]  = features_s
                    # taking the weights associated with the best lambda value
                    best_w_params[arv,:] = numpy_utility.select_along_axis(betas[:,:,imp], li, run_axis=2, choice_axis=0).T
              
            vox_loop_time += (time.time() - vox_start)
            elapsed = (time.time() - vox_start)

    # Print information about how fitting went...
    total_time = time.time() - start_time
    inv_time = total_time - vox_loop_time
    return_params = [best_w_params[:,:n_features_total],]
    if add_bias:
        return_params += [best_w_params[:,-1],]
    else: 
        return_params += [None,]
    print ('\n---------------------------------------')
    print ('total time = %fs' % total_time)
    print ('total throughput = %fs/voxel' % (total_time / n_voxels))
    print ('voxel throughput = %fs/voxel' % (vox_loop_time / n_voxels))
    print ('setup throughput = %fs/model' % (inv_time / n_prfs))
    sys.stdout.flush()
    
    best_params = [models[best_models],]+return_params+[features_mean, features_std]+[best_models]
    
    return best_losses, best_lambdas, best_params, feature_info



In [None]:
import sys
import os
import struct
import time
import numpy as np
import h5py
from tqdm import tqdm
import pickle
import math
import sklearn
from sklearn import decomposition

import torch
import torch.nn as nn
import torch.nn.init as I
import torch.nn.functional as F
import torch.optim as optim

from utils import numpy_utility, torch_utils
from model_src import texture_statistics

 

def _cofactor_fn_cpu(_x, lambdas):
    '''
    Generating a matrix needed to solve ridge regression model for each lambda value.
    Ridge regression (Tikhonov) solution is :
    w = (X^T*X + I*lambda)^-1 * X^T * Y
    This func will return (X^T*X + I*lambda)^-1 * X^T. 
    So once we have that, can just multiply by training data (Y) to get weights.
    returned size is [nLambdas x nFeatures x nTrials]
    This version makes sure that the torch inverse operation is done on the cpu, and in floating point-64 precision.
    Otherwise get bad results for small lambda values. This seems to be a torch-specific bug.
    
    '''
    device_orig = _x.device
    type_orig = _x.dtype
    # switch to this specific format which works with inverse
    _x = _x.to('cpu').to(torch.float64)
    _f = torch.stack([(torch.mm(torch.t(_x), _x) + torch.eye(_x.size()[1], device='cpu', dtype=torch.float64) * l).inverse() for l in lambdas], axis=0) 
    
    # [#lambdas, #feature, #feature] 
    cof = torch.tensordot(_f, _x, dims=[[2],[1]]) # [#lambdas, #feature, #sample]
    
    # put back to whatever way it was before, so that we can continue with other operations as usual
    return cof.to(device_orig).to(type_orig)



def _loss_fn(_cofactor, _vtrn, _xout, _vout):
    '''
    Calculate loss given "cofactor" from cofactor_fn, training data, held-out design matrix, held out data.
    returns weights (betas) based on equation
    w = (X^T*X + I*lambda)^-1 * X^T * Y
    also returns loss for these weights w the held out data. SSE is loss func here.
    '''

    _beta = torch.tensordot(_cofactor, _vtrn, dims=[[2], [0]]) # [#lambdas, #feature, #voxel]
    _pred = torch.tensordot(_xout, _beta, dims=[[1],[1]]) # [#samples, #lambdas, #voxels]
    _loss = torch.sum(torch.pow(_vout[:,None,:] - _pred, 2), dim=0) # [#lambdas, #voxels]
    return _beta, _loss


def get_fmaps_sizes(_fmaps_fn, image_batch, device):
    """ 
    Passing a batch of images through feature maps, in order to compute sizes.
    Returns number of total features across all groups of maps, and the resolution of each map group.
    """
    n_features = 0
    _x = torch.tensor(image_batch).to(device) # the input variable.
    _fmaps = _fmaps_fn(_x)
    resolutions_each_sf = []
    for k,_fm in enumerate(_fmaps):
        n_features = n_features + _fm.size()[1]
        resolutions_each_sf.append(_fm.size()[2])
    
    return n_features, resolutions_each_sf



def get_features_in_prf(prf_params, _fmaps_fn, images, sample_batch_size, aperture, device, to_numpy=True):
    """
    For a given set of images and a specified pRF position and size, compute the
    activation in each feature map channel. Returns [nImages x nFeatures]
    """
    
    dtype = images.dtype.type
    with torch.no_grad():
        
        x,y,sigma = prf_params
        n_trials = images.shape[0]

        # pass first batch of images through feature map, just to get sizes.
        n_features, fmaps_rez = get_fmaps_sizes(_fmaps_fn, images[0:sample_batch_size], device)

        features = np.zeros(shape=(n_trials, n_features), dtype=dtype)
        if to_numpy==False:
             features = torch_utils._to_torch(features, device=device)
                
        # Define the RF for this "model" version - at several resolutions.
        _prfs = [torch_utils._to_torch(numpy_utility.make_gaussian_mass(x, y, sigma, n_pix, size=aperture, \
                                  dtype=dtype)[2], device=device) for n_pix in fmaps_rez]

        # To make full design matrix for all trials, first looping over trials in batches to get the features
        # Only reason to loop is memory constraints, because all trials is big matrices.
        t = time.time()
        n_batches = np.ceil(n_trials/sample_batch_size)
        bb=-1
        for rt,rl in numpy_utility.iterate_range(0, n_trials, sample_batch_size):

            bb=bb+1
#             sys.stdout.write('\rbatch %d of %d'%(bb,n_batches))
            # multiplying feature maps by RFs here. 
            # we have one specified RF position for this version of the model. 
            # Feature maps in _fm go [nTrials x nFeatures(orientations) x nPixels x nPixels]
            # spatial RFs in _prfs go [nPixels x nPixels]
            # purpose of the for looping within this statement is to loop over map resolutions 
            # (e.g. spatial frequencies in model)
            # output _features is [nTrials x nFeatures*nResolutions], so a 2D matrix. 
            # Combining features/resolutions here finally, so we can solve for weights 
            # in that full orient x SF feature space.

            # then combine this with the other "batches" of trials to make a full "model space tensor"
            # features is [nTrialsTotal x nFeatures*nResolutions]

            # note this is concatenating SFs together from low to high - 
            # cycles through all orient channels in order for first SF, then again for next SF.
            _features = torch.cat([torch.tensordot(_fm, _prf, dims=[[2,3], [0,1]]) \
                                   for _fm,_prf in zip(_fmaps_fn(torch_utils._to_torch(images[rt], \
                                           device=device)), _prfs)], dim=1) # [#samples, #features]

            # Add features for this batch to full design matrix over all trials
            if to_numpy:
                features[rt] = torch_utils.get_value(_features)
            else:
                features[rt] = _features
                
        elapsed = time.time() - t
#         print('\nComputing features took %d sec'%elapsed)
        
    return features

