In [1]:
# import basic modules
import sys
import os
import time
import numpy as np
from tqdm import tqdm
import gc
import torch
import argparse

# import custom modules
root_dir   = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(root_dir,'code'))
from model_src import fwrf_fit as fwrf_fit
from model_src import fwrf_predict as fwrf_predict
from model_src import texture_statistics as texture_statistics

from model_fitting import initialize_fitting

fpX = np.float32

In [8]:
subject=1
roi=None

ridge=0

shuffle_images=0
random_images=0
random_voxel_data=0

sample_batch_size=100
voxel_batch_size=100
zscore_features=1
nonlin_fn=0
padding_mode='circular'

n_ori=8
n_sf=4
up_to_sess=1
debug=1
shuff_rnd_seed=0
# shuff_rnd_seed=251709

fitting_type='texture'



do_fitting=1
do_val=1
do_partial=1
date_str=None
# date_str='Jul-09-2021_1736'

In [9]:
device = initialize_fitting.init_cuda()
nsd_root, stim_root, beta_root, mask_root = initialize_fitting.get_paths()

if do_fitting==False and date_str is None:
    raise ValueError('if you want to start midway through the process (--do_fitting=False), then specify the date when training result was saved (--date_str).')

if do_fitting==True and date_str is not None:
    raise ValueError('if you want to do fitting from scratch (--do_fitting=True), specify --date_str=None (rather than entering a date)')

# output_dir, fn2save = initialize_fitting.get_save_path(root_dir, subject, model_name, shuffle_images, random_images, random_voxel_data, debug, date_str)

def save_all():
    print('\nSaving to %s\n'%fn2save)
    torch.save({
    'feature_table_simple': _gaborizer_simple.feature_table,
    'sf_tuning_masks_simple': _gaborizer_simple.sf_tuning_masks, 
    'ori_tuning_masks_simple': _gaborizer_simple.ori_tuning_masks,
    'cyc_per_stim_simple': _gaborizer_simple.cyc_per_stim,
    'orients_deg_simple': _gaborizer_simple.orients_deg,
    'orient_filters_simple': _gaborizer_simple.orient_filters,  
    'feature_table_complex': _gaborizer_complex.feature_table,
    'sf_tuning_masks_complex': _gaborizer_complex.sf_tuning_masks, 
    'ori_tuning_masks_complex': _gaborizer_complex.ori_tuning_masks,
    'cyc_per_stim_complex': _gaborizer_complex.cyc_per_stim,
    'orients_deg_complex': _gaborizer_complex.orients_deg,
    'orient_filters_complex': _gaborizer_complex.orient_filters,  
    'aperture': aperture,
    'aperture_rf_range': aperture_rf_range,
    'models': models,
    'include_autocorrs': include_autocorrs,
    'feature_info':feature_info,
    'voxel_mask': voxel_mask,
    'brain_nii_shape': brain_nii_shape,
    'image_order': image_order,
    'voxel_index': voxel_index,
    'voxel_roi': voxel_roi,
    'voxel_ncsnr': voxel_ncsnr, 
    'best_params': best_params,
    'lambdas': lambdas, 
    'best_lambdas': best_lambdas,
    'best_losses': best_losses,
    'val_cc': val_cc,
    'val_r2': val_r2,   
    'val_cc_partial': val_cc_partial,
    'val_r2_partial': val_r2_partial,   
    'features_each_model_val': features_each_model_val,
    'voxel_feature_correlations_val': voxel_feature_correlations_val,
    'zscore_features': zscore_features,
    'nonlin_fn': nonlin_fn,
    'padding_mode': padding_mode,
    'n_prf_sd_out': n_prf_sd_out,
    'autocorr_output_pix': autocorr_output_pix,
    'debug': debug,
    'up_to_sess': up_to_sess,
    'shuff_rnd_seed': shuff_rnd_seed
    }, fn2save, pickle_protocol=4)


# decide what voxels to use  
voxel_mask, voxel_index, voxel_roi, voxel_ncsnr, brain_nii_shape = initialize_fitting.get_voxel_info(mask_root, beta_root, subject, roi)

# get all data and corresponding images, in two splits. always fixed set that gets left out
trn_stim_data, trn_voxel_data, val_stim_single_trial_data, val_voxel_single_trial_data, \
    n_voxels, n_trials_val, image_order = initialize_fitting.get_data_splits(nsd_root, beta_root, stim_root, subject, voxel_mask, up_to_sess, 
                                                                             shuffle_images=shuffle_images, random_images=random_images, random_voxel_data=random_voxel_data)

# Set up the filters
_gaborizer_complex, _gaborizer_simple, _fmaps_fn_complex, _fmaps_fn_simple = initialize_fitting.get_feature_map_simple_complex_fn(n_ori, n_sf, padding_mode=padding_mode, device=device, nonlin_fn=nonlin_fn)

# Params for the spatial aspect of the model (possible pRFs)
#     aperture_rf_range=0.8 # using smaller range here because not sure what to do with RFs at edges...
aperture_rf_range = 1.1
aperture, models = initialize_fitting.get_prf_models(aperture_rf_range=aperture_rf_range)    

In [15]:
 # Initialize the "texture" model which builds on first level feature maps
include_pixel=1
include_simple=1
include_complex=1
include_autocorrs=1
include_crosscorrs=1
model_name, feature_types_exclude = initialize_fitting.get_model_name(ridge, n_ori, n_sf, include_pixel, include_simple, include_complex, include_autocorrs, include_crosscorrs)


autocorr_output_pix=5
n_prf_sd_out=2
_texture_fn = texture_statistics.texture_feature_extractor(_fmaps_fn_complex, _fmaps_fn_simple, sample_batch_size=sample_batch_size, feature_types_exclude=feature_types_exclude, autocorr_output_pix=autocorr_output_pix, n_prf_sd_out=n_prf_sd_out, aperture=aperture, device=device)

In [35]:
gc.collect()
torch.cuda.empty_cache()
print('\nStarting training...\n')

zscore_features=True
ridge=True
# More params for fitting
holdout_size, lambdas = initialize_fitting.get_fitting_pars(trn_voxel_data, zscore_features, ridge=ridge)



if shuff_rnd_seed==0:
    shuff_rnd_seed = int(time.strftime('%M%H%d', time.localtime()))
best_losses, best_lambdas, best_params, feature_info = fit_texture_model_ridge(trn_stim_data, trn_voxel_data, _texture_fn, models, lambdas, \
    zscore=zscore_features, voxel_batch_size=voxel_batch_size, holdout_size=holdout_size, shuffle=True, add_bias=True, debug=debug, shuff_rnd_seed=shuff_rnd_seed)
# note there's also a shuffle param in the above fn call, that determines the nested heldout data for lambda and param selection. always using true.
print('\nDone with training\n')

val_cc=None
val_r2=None
val_cc_partial=None
val_r2_partial=None
features_each_model_val=None;
voxel_feature_correlations_val=None;

save_all()


Starting training...


Possible lambda values are:
[1.0000000e+00 4.2169652e+00 1.7782795e+01 7.4989418e+01 3.1622775e+02
 1.3335215e+03 5.6234131e+03 2.3713736e+04 1.0000000e+05]
trn_size = 619 (90.0%)
dtype = <class 'numpy.float32'>
device = cuda:0
---------------------------------------
Seeding random number generator: seed is 160115


model 0

Computing pixel-level statistics...
time elapsed = 0.05533
Computing complex cell features...
time elapsed = 0.48893
Computing simple cell features...
time elapsed = 0.39703
Computing higher order correlations...
time elapsed = 4.67855
Final size of features concatenated is [688 x 3412]
Feature types included are:
['wmean', 'wvar', 'wskew', 'wkurt', 'complex_feature_means', 'simple_feature_means', 'complex_feature_autocorrs', 'simple_feature_autocorrs', 'complex_within_scale_crosscorrs', 'simple_within_scale_crosscorrs', 'complex_across_scale_crosscorrs', 'simple_across_scale_crosscorrs']
n zero columns: 1224
   408 columns are complex_featu

NameError: name 'fn2save' is not defined

In [26]:
feature_info[0]

array([0, 1, 2, ..., 9, 9, 9])

In [33]:
def fit_texture_model_ridge(images, voxel_data, _texture_fn, models, lambdas, zscore=False, voxel_batch_size=100, 
                            holdout_size=100, shuffle=True, add_bias=False, debug=False, shuff_rnd_seed=0):
   
    """
    Solve for encoding model weights using ridge regression.
    Inputs:
        images: the training images, [n_trials x 1 x height x width]
        voxel_data: the training voxel data, [n_trials x n_voxels]
        _texture_fn: module that maps from images to texture model features
        models: the list of possible pRFs to test, columns are [x, y, sigma]
        lambdas: ridge lambda parameters to test
        zscore: want to zscore each column of feature matrix before fitting?
        voxel_batch_size: how many voxels to use at a time for model fitting
        holdout_size: how many training trials to hold out for computing loss/lambda selection?
        shuffle: do we shuffle training data order before holding trials out?
        add_bias: add a column of ones to feature matrix, for an additive bias?
        debug: want to run a shortened version of this, to test it?
        shuff_rnd_seed: if we do shuffle training data (shuffle=True), what random seed to use? if zero, choose a new random seed in this code.
    Outputs:
        best_losses: loss value for each voxel (with best pRF and best lambda), eval on held out set
        best_lambdas: best lambda for each voxel (chosen based on loss w held out set)
        best_params: 
            [0] best pRF for each voxel [x,y,sigma]
            [1] best weights for each voxel/feature
            [2] if add_bias=True, best bias value for each voxel
            [3] if zscore=True, the mean of each feature before z-score
            [4] if zscore=True, the std of each feature before z-score
            [5] index of the best pRF for each voxel (i.e. index of row in "models")
        feature_info: describes types of features in texture model, see texture_feature_extractor in texture_statistics.py
        
    """
   
    dtype = images.dtype.type
    device = next(_texture_fn.parameters()).device
    trn_size = len(voxel_data) - holdout_size
    assert trn_size>0, 'Training size needs to be greater than zero'
    
    print ('trn_size = %d (%.1f%%)' % (trn_size, float(trn_size)*100/len(voxel_data)))
    print ('dtype = %s' % dtype)
    print ('device = %s' % device)
    print ('---------------------------------------')
    
    # First do shuffling of data and define set to hold out
    n_trials = len(images)
    n_prfs = len(models)
    n_voxels = voxel_data.shape[1]
    order = np.arange(len(voxel_data), dtype=int)
    if shuffle:
        if shuff_rnd_seed==0:
            print('Computing a new random seed')
            shuff_rnd_seed = int(time.strftime('%M%H%d', time.localtime()))
        print('Seeding random number generator: seed is %d'%shuff_rnd_seed)
        np.random.seed(shuff_rnd_seed)
        np.random.shuffle(order)
    images = images[order]
    voxel_data = voxel_data[order]  
    trn_data = voxel_data[:trn_size]
    out_data = voxel_data[trn_size:]

    n_features_total = _texture_fn.n_features_total
        
    # Create full model value buffers    
    best_models = np.full(shape=(n_voxels,), fill_value=-1, dtype=int)   
    best_lambdas = np.full(shape=(n_voxels,), fill_value=-1, dtype=int)
    best_losses = np.full(fill_value=np.inf, shape=(n_voxels), dtype=dtype)
    best_w_params = np.zeros(shape=(n_voxels, n_features_total), dtype=dtype)

    if add_bias:
        best_w_params = np.concatenate([best_w_params, np.ones(shape=(len(best_w_params),1), dtype=dtype)], axis=1)
    features_mean = None
    features_std = None
    if zscore:
        features_mean = np.zeros(shape=(n_voxels, n_features_total), dtype=dtype)
        features_std  = np.zeros(shape=(n_voxels, n_features_total), dtype=dtype)
    
    
    start_time = time.time()
    vox_loop_time = 0
    print ('')
    
    with torch.no_grad():
        
        # Looping over models (here models are different spatial RF definitions)
        for m,(x,y,sigma) in enumerate(models):
            if debug and m>1:
                break
            print('\nmodel %d\n'%m)
            t = time.time()   
            
            # Get features for the desired pRF, across all trn set image   
        
            all_feat_concat, feature_info = _texture_fn(images, [x,y,sigma])
            
            features = torch_utils.get_value(all_feat_concat)
            
            elapsed = time.time() - t
        
            if zscore:  
                features_m = np.mean(features, axis=0, keepdims=True) #[:trn_size]
                features_s = np.std(features, axis=0, keepdims=True) + 1e-6          
                features -= features_m
                features /= features_s    
                
            if add_bias:
                features = np.concatenate([features, np.ones(shape=(len(features), 1), dtype=dtype)], axis=1)
            
            # separate design matrix into training/held out data (for lambda selection)
            trn_features = features[:trn_size]
            out_features = features[trn_size:]   

            # Send matrices to gpu
            _xtrn = torch_utils._to_torch(trn_features, device=device)
            _xout = torch_utils._to_torch(out_features, device=device)   
            
            zero_columns = np.sum(torch_utils.get_value(_xtrn[:,0:-1]), axis=0)==0
            if np.sum(zero_columns)>0:
                print('n zero columns: %d'%np.sum(zero_columns))
                for ff in range(len(feature_info[1])):
                    if np.sum(zero_columns[feature_info[0]==ff])>0:
                        print('   %d columns are %s'%(np.sum(zero_columns[feature_info[0]==ff]), feature_info[1][ff]))
                        
            # Do part of the matrix math involved in ridge regression optimization out of the loop, 
            # because this part will be same for all the voxels.
            _cof = _cofactor_fn_cpu(_xtrn, lambdas)
            
            # Now looping over batches of voxels (only reason is because can't store all in memory at same time)
            vox_start = time.time()
            for rv,lv in numpy_utility.iterate_range(0, n_voxels, voxel_batch_size):
                sys.stdout.write('\rfitting model %4d of %-4d, voxels [%6d:%-6d] of %d' % (m, n_prfs, rv[0], rv[-1], n_voxels))

                # Send matrices to gpu
                _vtrn = torch_utils._to_torch(trn_data[:,rv], device=device)
                _vout = torch_utils._to_torch(out_data[:,rv], device=device)

                # Here is where optimization happens - relatively simple matrix math inside loss fn.
                _betas, _loss = _loss_fn(_cof, _vtrn, _xout, _vout) #   [#lambda, #feature, #voxel, ], [#lambda, #voxel]
                # Now have a set of weights (in betas) and a loss value for every voxel and every lambda. 
                # goal is then to choose for each voxel, what is the best lambda and what weights went with that lambda.
                
                # first choose best lambda value and the loss that went with it.
                _values, _select = torch.min(_loss, dim=0)
                betas = torch_utils.get_value(_betas)
                values, select = torch_utils.get_value(_values), torch_utils.get_value(_select)

                # comparing this loss to the other models for each voxel (e.g. the other RF position/sizes)
                imp = values<best_losses[rv]
                
                if np.sum(imp)>0:                    
                    # for whichever voxels had improvement relative to previous models, save parameters now
                    # this means we won't have to save all params for all models, just best.
                    arv = np.array(rv)[imp]
                    li = select[imp]
                    best_lambdas[arv] = li
                    best_losses[arv] = values[imp]
                    best_models[arv] = m
                    if zscore:
                        features_mean[arv] = features_m # broadcast over updated voxels
                        features_std[arv]  = features_s
                    # taking the weights associated with the best lambda value
                    best_w_params[arv,:] = numpy_utility.select_along_axis(betas[:,:,imp], li, run_axis=2, choice_axis=0).T
              
            vox_loop_time += (time.time() - vox_start)
            elapsed = (time.time() - vox_start)
            sys.stdout.flush()
            
    # Print information about how fitting went...
    total_time = time.time() - start_time
    inv_time = total_time - vox_loop_time
    return_params = [best_w_params[:,:n_features_total],]
    if add_bias:
        return_params += [best_w_params[:,-1],]
    else: 
        return_params += [None,]
    print ('\n---------------------------------------')
    print ('total time = %fs' % total_time)
    print ('total throughput = %fs/voxel' % (total_time / n_voxels))
    print ('voxel throughput = %fs/voxel' % (vox_loop_time / n_voxels))
    print ('setup throughput = %fs/model' % (inv_time / n_prfs))
    sys.stdout.flush()
    
    best_params = [models[best_models],]+return_params+[features_mean, features_std]+[best_models]
    
    return best_losses, best_lambdas, best_params, feature_info


In [20]:

def _cofactor_fn_cpu(_x, lambdas):
    '''
    Generating a matrix needed to solve ridge regression model for each lambda value.
    Ridge regression (Tikhonov) solution is :
    w = (X^T*X + I*lambda)^-1 * X^T * Y
    This func will return (X^T*X + I*lambda)^-1 * X^T. 
    So once we have that, can just multiply by training data (Y) to get weights.
    returned size is [nLambdas x nFeatures x nTrials]
    This version makes sure that the torch inverse operation is done on the cpu, and in floating point-64 precision.
    Otherwise get bad results for small lambda values. This seems to be a torch-specific bug.
    
    '''
    device_orig = _x.device
    type_orig = _x.dtype
    # switch to this specific format which works with inverse
    _x = _x.to('cpu').to(torch.float64)
    _f = torch.stack([(torch.mm(torch.t(_x), _x) + torch.eye(_x.size()[1], device='cpu', dtype=torch.float64) * l).inverse() for l in lambdas], axis=0) 
    
    # [#lambdas, #feature, #feature] 
    cof = torch.tensordot(_f, _x, dims=[[2],[1]]) # [#lambdas, #feature, #sample]
    
    # put back to whatever way it was before, so that we can continue with other operations as usual
    return cof.to(device_orig).to(type_orig)



In [23]:
import sys
import os
import struct
import time
import numpy as np
import h5py
from tqdm import tqdm
import pickle
import math
import sklearn
from sklearn import decomposition

import torch
import torch.nn as nn
import torch.nn.init as I
import torch.nn.functional as F
import torch.optim as optim

from utils import numpy_utility, torch_utils


def _loss_fn(_cofactor, _vtrn, _xout, _vout):
    '''
    Calculate loss given "cofactor" from cofactor_fn, training data, held-out design matrix, held out data.
    returns weights (betas) based on equation
    w = (X^T*X + I*lambda)^-1 * X^T * Y
    also returns loss for these weights w the held out data. SSE is loss func here.
    '''

    _beta = torch.tensordot(_cofactor, _vtrn, dims=[[2], [0]]) # [#lambdas, #feature, #voxel]
    _pred = torch.tensordot(_xout, _beta, dims=[[1],[1]]) # [#samples, #lambdas, #voxels]
    _loss = torch.sum(torch.pow(_vout[:,None,:] - _pred, 2), dim=0) # [#lambdas, #voxels]
    return _beta, _loss


def get_fmaps_sizes(_fmaps_fn, image_batch, device):
    """ 
    Passing a batch of images through feature maps, in order to compute sizes.
    Returns number of total features across all groups of maps, and the resolution of each map group.
    """
    n_features = 0
    _x = torch.tensor(image_batch).to(device) # the input variable.
    _fmaps = _fmaps_fn(_x)
    resolutions_each_sf = []
    for k,_fm in enumerate(_fmaps):
        n_features = n_features + _fm.size()[1]
        resolutions_each_sf.append(_fm.size()[2])
    
    return n_features, resolutions_each_sf



def get_features_in_prf(prf_params, _fmaps_fn, images, sample_batch_size, aperture, device, to_numpy=True):
    """
    For a given set of images and a specified pRF position and size, compute the
    activation in each feature map channel. Returns [nImages x nFeatures]
    """
    
    dtype = images.dtype.type
    with torch.no_grad():
        
        x,y,sigma = prf_params
        n_trials = images.shape[0]

        # pass first batch of images through feature map, just to get sizes.
        n_features, fmaps_rez = get_fmaps_sizes(_fmaps_fn, images[0:sample_batch_size], device)

        features = np.zeros(shape=(n_trials, n_features), dtype=dtype)
        if to_numpy==False:
             features = torch_utils._to_torch(features, device=device)
                
        # Define the RF for this "model" version - at several resolutions.
        _prfs = [torch_utils._to_torch(numpy_utility.make_gaussian_mass(x, y, sigma, n_pix, size=aperture, \
                                  dtype=dtype)[2], device=device) for n_pix in fmaps_rez]

        # To make full design matrix for all trials, first looping over trials in batches to get the features
        # Only reason to loop is memory constraints, because all trials is big matrices.
        t = time.time()
        n_batches = np.ceil(n_trials/sample_batch_size)
        bb=-1
        for rt,rl in numpy_utility.iterate_range(0, n_trials, sample_batch_size):

            bb=bb+1
#             sys.stdout.write('\rbatch %d of %d'%(bb,n_batches))
            # multiplying feature maps by RFs here. 
            # we have one specified RF position for this version of the model. 
            # Feature maps in _fm go [nTrials x nFeatures(orientations) x nPixels x nPixels]
            # spatial RFs in _prfs go [nPixels x nPixels]
            # purpose of the for looping within this statement is to loop over map resolutions 
            # (e.g. spatial frequencies in model)
            # output _features is [nTrials x nFeatures*nResolutions], so a 2D matrix. 
            # Combining features/resolutions here finally, so we can solve for weights 
            # in that full orient x SF feature space.

            # then combine this with the other "batches" of trials to make a full "model space tensor"
            # features is [nTrialsTotal x nFeatures*nResolutions]

            # note this is concatenating SFs together from low to high - 
            # cycles through all orient channels in order for first SF, then again for next SF.
            _features = torch.cat([torch.tensordot(_fm, _prf, dims=[[2,3], [0,1]]) \
                                   for _fm,_prf in zip(_fmaps_fn(torch_utils._to_torch(images[rt], \
                                           device=device)), _prfs)], dim=1) # [#samples, #features]

            # Add features for this batch to full design matrix over all trials
            if to_numpy:
                features[rt] = torch_utils.get_value(_features)
            else:
                features[rt] = _features
                
        elapsed = time.time() - t
#         print('\nComputing features took %d sec'%elapsed)
        
    return features
