In [27]:
import requests
from tqdm import tqdm
from os.path import join as oj
import tables, numpy
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage as ndi
from skimage import data
import pickle as pkl
from skimage.util import img_as_float
from sklearn import metrics
import h5py
from copy import deepcopy
from skimage.filters import gabor_kernel
import gabor_feats
from sklearn.linear_model import RidgeCV
import seaborn as sns
import numpy.linalg as npl
out_dir = '/scratch/users/vision/data/gallant/vim_2_crcns'

def save_h5(data, fname):
    if os.path.exists(fname):
        os.remove(fname)
    f = h5py.File(fname, 'w')
    f['data'] = data
    f.close()    

def load_h5(fname):
    f = h5py.File(fname, 'r')
    return np.array(f['data'])

def save_pkl(d, fname):
    if os.path.exists(fname):
        os.remove(fname)
    pkl.dump(d, open(fname, 'wb'))

# download data

In [48]:
def download(datafile, username, password, out_dir):
    '''
    Params
    ------
    datafile
    '''
    
    URL = 'https://portal.nersc.gov/project/crcns/download/index.php'
    login_data = dict(
        username=username,
        password=password,
        fn=datafile,
        submit='Login' 
    )

    with requests.Session() as s:
        local_filename = oj(out_dir, login_data['fn'].split('/')[-1])
        print(local_filename)
        r = s.post(URL, data=login_data, stream=True)
        with open(local_filename, 'wb') as f:
            for chunk in tqdm(r.iter_content(chunk_size=1024)):
                if chunk:
                    f.write(chunk)
                    
uname = 'csinva'
pwd = 'password'
dset = 'vim-2'
fnames = ['Stimuli.tar.gz', 'VoxelResponses_subject1.tar.gz', 'anatomy.zip', 'checksums.md5', 'filelist.txt', 'docs']
for fname in fnames:
    fname = oj(dset, fname)
#     download(fname, uname, pwd, out_dir)

In [49]:
ls /scratch/users/vision/data/gallant/vim_2_crcns

anatomy.zip    docs          Stimuli.mat     VoxelResponses_subject1.mat
checksums.md5  filelist.txt  Stimuli.tar.gz  VoxelResponses_subject1.tar.gz


In [23]:
!du -sh /scratch/users/vision/data/gallant/vim_2_crcns
# next extract the tars
# next unzip the zips

6.9G	/scratch/users/vision/data/gallant/vim_2_crcns


In [27]:
!ls /scratch/users/vision/data/gallant/vim_2_crcns/*.gz |xargs -n1 tar -xzf # extract the tar files

# view responses

In [None]:
f = tables.open_file(oj(out_dir, 'VoxelResponses_subject1.mat'))
# f.listNodes # Show all variables available
data = f.get_node('/rt')[:] # training responses: 7200 (timepoints) x 73728
roi = f.get_node('/roi/v1lh')[:].flatten() # structure containing volume matrices (64x64x18) with indices corresponding to each roi in each hemisphere
v1lh_idx = numpy.nonzero(roi==1)[0]
v1lh_resp = data[v1lh_idx]

In [None]:
f2 = tables.open_file(oj(out_dir, 'Stimuli.mat'))
im = f2.get_node('/st')[100].transpose()
plt.imshow(im)

In [134]:
str(xs[0]).split(' ')[0]

['/roi/FFAlh', '(EArray(18,', '64,', '64),', 'zlib(3))', "''"]

In [168]:
f = tables.open_file(oj(out_dir, 'VoxelResponses_subject1.mat'))
xs = []
nums = []
for x in f.get_node('/roi'):
    xs.append(x)
    nums.append(np.array(f.get_node(x)).nonzero()[0].sum())
# sns.barplot(x=x, y=nums)
print([str(x) for x in xs])

["/roi/FFAlh (EArray(18, 64, 64), zlib(3)) ''", "/roi/FFArh (EArray(18, 64, 64), zlib(3)) ''", "/roi/IPlh (EArray(18, 64, 64), zlib(3)) ''", "/roi/IPrh (EArray(18, 64, 64), zlib(3)) ''", "/roi/MTlh (EArray(18, 64, 64), zlib(3)) ''", "/roi/MTplh (EArray(18, 64, 64), zlib(3)) ''", "/roi/MTprh (EArray(18, 64, 64), zlib(3)) ''", "/roi/MTrh (EArray(18, 64, 64), zlib(3)) ''", "/roi/OBJlh (EArray(18, 64, 64), zlib(3)) ''", "/roi/OBJrh (EArray(18, 64, 64), zlib(3)) ''", "/roi/PPAlh (EArray(18, 64, 64), zlib(3)) ''", "/roi/PPArh (EArray(18, 64, 64), zlib(3)) ''", "/roi/RSCrh (EArray(18, 64, 64), zlib(3)) ''", "/roi/STSrh (EArray(18, 64, 64), zlib(3)) ''", "/roi/VOlh (EArray(18, 64, 64), zlib(3)) ''", "/roi/VOrh (EArray(18, 64, 64), zlib(3)) ''", "/roi/latocclh (EArray(18, 64, 64), zlib(3)) ''", "/roi/latoccrh (EArray(18, 64, 64), zlib(3)) ''", "/roi/v1lh (EArray(18, 64, 64), zlib(3)) ''", "/roi/v1rh (EArray(18, 64, 64), zlib(3)) ''", "/roi/v2lh (EArray(18, 64, 64), zlib(3)) ''", "/roi/v2rh (EAr

In [7]:
# calculate sds
f = tables.open_file(oj(out_dir, 'VoxelResponses_subject1.mat'))
rva = np.array(f.get_node('/rva')[:]) # 73728 (voxels) x 10 (trials) x 540 (timepoints)
sigmas = rva.std(axis=1).mean(axis=-1)
out_name = oj(out_dir, f'out_rva_sigmas.h5')
save_h5(sigmas, out_name)

# view images

In [160]:
SAMPLING_FREQ = 15
DOWNSAMPLE = 2
N_TRAIN = 7200
N_TEST = 540
OFFSET = SAMPLING_FREQ // 2
NUM_FEATS = 1280

In [None]:
# find the relevant stimuli
for dset, N in zip(['sv'], [N_TEST]): # 'st', 'sv'
    f2 = tables.open_file(oj(out_dir, 'Stimuli.mat'))
    ims = np.zeros((N, 128 // DOWNSAMPLE, 128 // DOWNSAMPLE)).astype(np.int)
    for i in tqdm(range(N)):
        ims[i] = deepcopy(f2.get_node(f'/{dset}')[OFFSET + i * SAMPLING_FREQ].transpose())[::DOWNSAMPLE, ::DOWNSAMPLE].mean(axis=-1)

    out_name = oj(out_dir, f'out_{dset}.h5')
    save_h5(ims, out_name)

In [None]:
# convert stimuli to feature vectors
for dset, N in zip(['sv'], [N_TEST]): # 'st', 'sv'
    f = h5py.File(oj(out_dir, f'out_{dset}.h5'), 'r')
    feats = np.zeros((N, NUM_FEATS))
    for i in tqdm(range(N)):
        feats[i] = gabor_feats.all_feats(deepcopy(f['data'][i]))

    out_name = oj(out_dir, f'out_{dset}_feats.h5')
    save_h5(feats, out_name)

In [None]:
# decompose the training data
X = np.array(h5py.File(oj(out_dir, 'out_st_feats.h5'), 'r')['data'])
U, s, Vh = npl.svd(X)
decomp = {
    'U': U,
    's': s,
    'Vh': Vh
}
save_pkl((U, s, Vh), oj(out_dir, 'decomp.pkl'))

In [34]:
U.shape

(7200, 7200)

In [35]:
Vh.shape

(1280, 1280)

In [None]:
# fit linear models
feats_name = oj(out_dir, 'out_st_feats.h5')
feats_test_name = oj(out_dir, 'out_sv_feats.h5')
resps_name = oj(out_dir, 'VoxelResponses_subject1.mat')
X = np.array(h5py.File(feats_name, 'r')['data'])
Y = np.array(tables.open_file(resps_name).get_node('/rt')[:]) # training responses: 73728 (voxels) x 7200 (timepoints)
X_test = np.array(h5py.File(feats_test_name, 'r')['data'])
Y_test = np.array(tables.open_file(resps_name).get_node('/rv')[:]) # training responses: 73728 (voxels) x 7200 (timepoints)
sigmas = load_h5(oj(out_dir, f'out_rva_sigmas.h5'))
(U, s, Vh) = pkl.load(open(oj(out_dir, 'decomp.pkl'), 'rb'))
# plt.imshow(np.isnan(Y))

In [None]:
rois = ['v1lh', 'v1rh', 'v2lh', 'v2rh', 'v4lh', 'v4rh']
NUM = 100
save_dir = '/scratch/users/vision/data/gallant/vim_2_crcns/visual2'
os.makedirs(save_dir, exist_ok=True)
f = tables.open_file(oj(out_dir, 'VoxelResponses_subject1.mat'), 'r')
for roi in rois:
    roi_idxs = f.get_node(f'/roi/{roi}')[:].flatten().nonzero()[0] # structure containing volume matrices (64x64x18) with indices corresponding to each roi in each hemisphere
    print(roi, roi_idxs.size)
    roi_idxs = roi_idxs[:NUM]
    results = {}
    
    for i in tqdm(roi_idxs):
        y = Y[i]
        sigma = sigmas[i]
        var = sigma**2
        w = U.T @ y
        y_test = Y_test[i]
        idxs_cv = ~np.isnan(y)
        idxs_test = ~np.isnan(y_test)
        n = np.sum(idxs_cv)
        num_test = np.sum(idxs_test)
        d = X.shape[1]
        d_n_min = min(n, d)
        if n == y.size and num_test == y_test.size: # ignore voxels w/ missing vals
#             y = y[idxs_cv]
#             x = X[idxs_cv]
#             x_test = X_test[idxs_test]
#             y_test = y_test[idxs_test]

            m = RidgeCV()
            m.fit(X, y)
            preds = m.predict(X_test)
            mse = metrics.mean_squared_error(y_test, preds)
            r2 = metrics.r2_score(y_test, preds)
            
            
            
            term1 = 0.5 * (npl.norm(y_test) ** 2 - npl.norm(w) ** 2) / var
            term2 = 0.5 * np.sum([np.log(1 + w[i]**2 / var) for i in range(d_n_min)])
            complexity1 = term1 + term2
            
            idxs = np.abs(w) > sigma
            term3 = 0.5 * np.sum([np.log(1 + w[i]**2 / var) for i in np.arange(n)[idxs]])
            term4 = 0.5 * np.sum([w[i]**2 / var for i in np.arange(n)[~idxs]])
            complexity2 = term1 + term3 + term4
            
            results = {
                'roi': roi,
                'mse': mse,
                'model': m,
                'complexity1': complexity1,
                'complexity2': complexity2,
                'num_train': n,
                'num_test': num_test,
                'd': d,
                'r2': r2
            }
            pkl.dump(results, open(oj(save_dir, f'ridge_{i}.pkl'), 'wb'))


  0%|          | 0/100 [00:00<?, ?it/s][A

v1lh 494



  1%|          | 1/100 [00:00<00:19,  5.20it/s][A
  2%|▏         | 2/100 [00:00<00:16,  6.05it/s][A

In [None]:
brick = h5py.File(out_name, 'r')['data'][0]
grass = h5py.File(out_name, 'r')['data'][100]
gravel = h5py.File(out_name, 'r')['data'][300]
images = [brick, grass, gravel]


# Plot a selection of the filter bank kernels and their responses.
results = []
kernel_params = []
for theta in [0, 1]:
    theta = theta / 4. * np.pi
    for frequency in [0.1, 0.4]:
        kernel = gabor_kernel(frequency, theta=theta)
        params = 'theta=%d,\nfrequency=%.2f' % (theta * 180 / np.pi, frequency)
        kernel_params.append(params)
        results.append((kernel, [gabor_feats.calc_feats(img, kernel) for img in images])) # Save kernel and the power image for each image

fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 6), dpi=200)
plt.gray()
axes[0][0].axis('off')


# Plot original images
for img, ax in zip(images, axes[0][1:]):
    ax.imshow(img)
    ax.axis('off')

# Plot Gabor kernel    
for label, (kernel, _), ax_row in zip(kernel_params, results, axes[1:]):
    ax = ax_row[0]
    ax.imshow(np.real(kernel))
    ax.set_ylabel(label, fontsize=7)
    ax.set_xticks([])
    ax.set_yticks([])

# Plot Gabor responses with the contrast normalized for each filter    
for label, (kernel, features), ax_row in zip(kernel_params, results, axes[1:]):
    ax = ax_row[2]
    ax.set_xticks([])
    ax.set_yticks([])
    vmin = np.min(features)
    vmax = np.max(features)
    for patch, ax in zip(features, ax_row[1:]):
        ax.imshow(patch, vmin=vmin, vmax=vmax)
        ax.axis('off')  

plt.tight_layout()
plt.show()