In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import torch
import torchio as tio
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import time

dir2 = os.path.abspath('..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

In [2]:
def require_dataset(group, name, data):
    group.require_dataset(name, shape=data.shape, dtype=data.dtype)
    group[name][:] = data
    
dataset_path = Path("X:\\Datasets\\Deep-Image-Reconstruction\\")
derivatives_path = dataset_path / 'derivatives'

ssd_dataset_path = Path("D:\\Datasets\\Deep-Image-Reconstruction\\")
ssd_derivatives_path = ssd_dataset_path / 'derivatives'

In [None]:
from research.data.kamitani_2019 import Kamitani2019, RawKamitani2019, Kamitani2019H5
from pathlib import Path

#root = "X:\\Datasets\\Deep-Image-Reconstruction\\"
#h5_path = Path(root) / "derivatives" / "kamitani2019.hdf5"
features_name = 'ViT-B=32'
#features_path = Path(root) / "derivatives" / "features" / f"{features_name}-features.hdf5"
dataset = Kamitani2019H5(
    ssd_derivatives_path / 'kamitani2019.hdf5', 
    subjects=['sub-02'], 
    func_sessions=['natural_training'], 
    window=(1, 9),
    #window_kernel=[.25, .25, .25, .25],
    #transform=tio.CropOrPad(target_shape=(72, 88, 74)),
    #drop_out_of_window_events=True,
    normalization='voxel_linear_trend',
    features_path=derivatives_path / "features" / f"{features_name}-features.hdf5",
    feature_keys=['embedding'],
    #folds=[0, 1, 2, 3],
    #split='train'
)
    

In [None]:
print(len(dataset))

In [None]:
@interact(i=(0, len(dataset)-1))
def show_event(i):
    event = dataset[i]
    
    if 'features' in event:
        for k, v in event['features'].items():
            print(k, v.shape, v.numel())
    
    print(event['onset'], event['stimulus_id'], event['run_id'])
    data = event['data']
    print(data.shape, data.dtype)
    
    T, H, W, D = data.shape
    @interact(d=(0, D-1), t=(0, T-1), derivative=False)
    def show_volume(d, t, derivative):
        fig = plt.figure(figsize=(12, 12))
        x = data[t, :, :, d]
        #x = np.isinf(x)
        plt.imshow(x, cmap='bwr', vmin=-3, vmax=3)
        plt.show()
        plt.close(fig)

In [None]:
from pathlib import Path

root = "X:\\Datasets\\Deep-Image-Reconstruction\\"

def require_dataset(group, name, data):
    group.require_dataset(name, shape=data.shape, dtype=data.dtype)
    group[name][:] = data

with h5py.File(Path(root) / 'derivatives' / 'feature-selection-maps.hdf5', 'a') as f:
    for subject in f.values():
        for session in subject.values():
            for cache in session.values():
                for model in cache.values():
                    for feature in model.values():
                        for selection_mode in feature.values():
                            print(selection_mode.name)
                                                         
                            data = selection_mode['scores'][:]
                            sorted_indices_flat = selection_mode['sorted_indices_flat'][:]
                            H, W, D = data.shape
                            grid = np.zeros(shape=(3, H, W, D), dtype=int)
                            grid[0] = np.arange(H)[:, None, None]
                            grid[1] = np.arange(W)[None, :, None]
                            grid[2] = np.arange(D)[None, None, :]
                            grid_flat = grid.reshape(3, H * W * D)
                            sorted_indices = grid_flat[:, sorted_indices_flat]
                            
                            require_dataset(selection_mode, 'sorted_indices', sorted_indices)

In [None]:
feature_selection_maps = h5py.File(Path(root) / 'derivatives' / 'feature-selection-maps.hdf5', 'r')

@interact(subject=feature_selection_maps.items())
def select_subject(subject):
    @interact(session=subject.items())
    def select_session(session):
        
        @interact(cache=session.items())
        def select_cache(cache):
            
            @interact(model=cache.items())
            def select_model(model):
                @interact(feature=model.items())
                def select_feature(feature):
                    @interact(mode=feature.items(), select_top=False, select_k=(0, 15000))
                    def select_mode(mode, select_top, select_k):

                        print(mode.items())
                        data = mode['scores'][:]
                        sorted_indices_flat = mode['sorted_indices_flat'][:]
                        H, W, D = data.shape
                        grid = np.zeros(shape=(3, H, W, D), dtype=int)
                        grid[0] = np.arange(H)[:, None, None]
                        grid[1] = np.arange(W)[None, :, None]
                        grid[2] = np.arange(D)[None, None, :]
                        grid_flat = grid.reshape(3, H * W * D)
                        sorted_indices = grid_flat[:, sorted_indices_flat]
                        
                        if select_top:
                            data[:] = 0.
                            top_k = sorted_indices[:, -select_k:]
                            i, j, k = list(top_k)
                            data[i, j, k] = 1.
                            
                            #original_shape = data.shape
                            #data = data.flatten()
                            #data[top_k] = 1.
                            #data = data.reshape(*original_shape)
                        
                        H, W, D = data.shape
                        @interact(d=(0, D-1))
                        def show_volume(d):
                            fig = plt.figure(figsize=(12, 12))
                            x = data[:, :, d]
                            plt.imshow(x, vmin=0.05, vmax=0.3)
                            plt.show()
                            plt.close(fig)

In [None]:
from pathlib import Path
import torchio as tio

rois_path = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\rois\\')
for subject_name in ('sub-01', 'sub-02', 'sub-03'):
    print(subject_name)
    
    kamitani_rois_path = rois_path / subject_name / 'kamitani' 
    rois = [tio.LabelMap(roi_path) for roi_path in kamitani_rois_path.iterdir()]
    roi_names = [p.name.split('.')[0].split('_') for p in kamitani_rois_path.iterdir()]
    roi_names = [f'{name[2]}_{name[3]}' for name in roi_names]
    
    for roi in rois:
        roi.load()
        
    data = torch.stack([roi.data.flatten().bool() for roi in rois])
    overlap = (data[None, :] & data[:, None]).sum(axis=2)
    
    #print(overlap)
    #plt.imshow(overlap)
    #plt.show()
    
    union = torch.cat([roi.data for roi in rois]).sum(dim=0, keepdim=True)
    union[union > 0] = 1
    union = union.to(torch.int32)
    union_image = tio.LabelMap(tensor=union, affine=rois[0].affine)
    union_image.save(rois_path / subject_name / 'derivatives' / f'{subject_name}_mask_VC.nii.gz')

In [None]:
# Cache a preprocessing option

import gc
from copy import deepcopy
from research.data.kamitani_2019 import Kamitani2019, RawKamitani2019, Kamitani2019H5
from pathlib import Path
from tqdm.notebook import tqdm


cache_name = 'window-0-16'
#cache_name = 'average-4'

subjects = ['sub-01', 'sub-02', 'sub-03']
sessions = ['natural_test']

target_shape = (72, 88, 74)
preprocessing_params = dict(
    window=(0, 9),
    #window_kernel=[1. / 4.] * 4,
    #transform=tio.CropOrPad(target_shape=target_shape),
    drop_out_of_window_events=False,
    normalization='voxel_linear_trend',
)

with h5py.File(ssd_derivatives_path / 'kamitani2019-cached.hdf5', 'a') as f:
    for subject in subjects:
        print(subject)
        
        for session in sessions:
            print(session)
            h5_path = ssd_derivatives_path / "kamitani2019.hdf5"
            dataset = Kamitani2019H5(h5_path, subjects=[subject], func_sessions=[session], **preprocessing_params)
            N = len(dataset)
            
            group = f.require_group(f'{subject}/{session}/{cache_name}')
            sample_event = dataset[0]
            group.attrs['affine'] = sample_event['affine']
            target_shape = (*sample_event['data'].shape,)
        
            data = group.require_dataset('data', shape=(N, *target_shape), dtype='f4')
            for k, v in preprocessing_params.items():
                group.attrs[k] = str(v)
            
            out_data = {
                k: [] for k in 
                ['stimulus_id', 'onset', 'run_id', 'in_window']
            }
            
            for i in tqdm(range(len(dataset))):
                event = dataset[i]
                for k, v in out_data.items():
                    v.append(deepcopy(event[k]))
                data[i] = event['data']

            for k, v in out_data.items():
                v = np.stack(v)
                if k == 'stimulus_id':
                    v = v.astype(np.dtype('S'))
                print(k, v.dtype, v.shape)
                require_dataset(group, k, v)


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
fmri_data = h5py.File(ssd_derivatives_path / 'kamitani2019-cached.hdf5', 'r')

session_name = 'natural_training'
cache_name = 'window-0-16'

def pearsonr_torch(X, Y, cast_dtype=torch.float64):
    in_dtype = X.dtype
    X = X.to(cast_dtype)
    Y = Y.to(cast_dtype)
    
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)
    
    X = X / torch.norm(X, dim=0, keepdim=True)
    Y = Y / torch.norm(Y, dim=0, keepdim=True)
    
    r = torch.einsum('b...,b...->...', X, Y).to(in_dtype)
    return r

with h5py.File(derivatives_path / 'noise-ceilings.hdf5', 'a') as f:
    for subject_name, subject in fmri_data.items():
        cache = subject[session_name][cache_name]
        affine = cache.attrs['affine']
        group = f.require_group(f'{subject_name}/{session_name}/{cache_name}')
        group.attrs['affine'] = affine
        
        # stimulus_id is already sorted
        stimulus_id = cache['stimulus_id'][:]
        stimulus_indices = torch.arange(stimulus_id.shape[0])
        unique_stimulus_id, unique_indicies, unique_inverse, unique_counts = np.unique(
            stimulus_id, 
            return_index=True, 
            return_inverse=True,
            return_counts=True
        )
        splits = torch.split(stimulus_indices, list(unique_counts))
        split_half_1_indices = torch.cat([split[:2] for split in splits])
        split_half_2_indices = torch.cat([split[2:4] for split in splits])
        
        X = cache['data']
        
        load_time = 0
        compute_time = 0
        store_time = 0

        for i in tqdm(range(X.shape[2])):
            t = time.time()
            X_slice = torch.from_numpy(X[:, :, i])
            load_time += time.time() - t

            t = time.time()
            X_slice_half_1 = X_slice[split_half_1_indices]
            X_slice_half_2 = X_slice[split_half_2_indices]
            r = torch.stack([
                pearsonr_torch(X_slice_half_1[:, j].cuda(), X_slice_half_2[:, j].cuda()).cpu()
                for j in range(X_slice.shape[1])
            ])
            r[r < 0.] = 0.
            r = np.sqrt(2 * r / (r + 1.))
            
            gc.collect()
            torch.cuda.empty_cache()

            compute_time += time.time() - t

            keys = (subject_name, session_name, cache_name, 'split_half', 'pearsonr')
            key = '/'.join(keys)
            parameters_dataset = f.require_dataset(key, X.shape[1:], dtype=np.float32)
            t = time.time()
            parameters_dataset[:, i] = r
            store_time += time.time() - t

            if i % 25 == 0:
                n = (i + 1)
                print(f'load_time={load_time / n}, compute_time={compute_time / n}, store_time={store_time / n}')

In [None]:
# Save encoder results

from pathlib import Path
import torchio as tio
from functools import partial
import nibabel as nib

out_path = derivatives_path / 'correlation_maps'
with h5py.File(derivatives_path / 'noise-ceilings.hdf5', 'a') as f:
    for subject_name, subject in f.items():
        subject_out_path = out_path / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)

        for session_name, session in subject.items():
            for cache_name, cache in session.items():
                affine = cache.attrs['affine']
                for noise_ceiling_name, noise_ceiling in cache.items():
                    for measure_name, measure in noise_ceiling.items():
                    
                        keys = (subject_name, session_name, cache_name, 
                                noise_ceiling_name, measure_name)
                        print(*keys)

                        save_file_name = f'{"__".join(keys)}.nii.gz'

                        data = measure['scores'][:]
                        sorted_indices_flat = np.argsort(data, axis=None)

                        T, H, W, D = data.shape
                        grid = np.zeros(shape=(4, T, H, W, D), dtype=int)
                        grid[0] = np.arange(T)[:, None, None, None]
                        grid[1] = np.arange(H)[None, :, None, None]
                        grid[2] = np.arange(W)[None, None, :, None]
                        grid[3] = np.arange(D)[None, None, None, :]
                        grid_flat = grid.reshape(4, T * H * W * D)
                        sorted_indices = grid_flat[:, sorted_indices_flat]

                        image = nib.Nifti1Image(torch.tensor(data).permute(1, 2, 3, 0).numpy(), affine)
                        nib.save(image, subject_out_path / save_file_name)

                        require_dataset(measure, 'sorted_indices_flat', sorted_indices_flat)
                        require_dataset(measure, 'sorted_indices', sorted_indices)

In [None]:
# Fit encoders

from pathlib import Path
from tq
.notebook import tqdm
from itertools import product
from research.data.kamitani_2019 import fix_stimulus_id
import time
from functools import partial
from einops import rearrange
import gc

def least_squares(A, B, batch_size=1,):
    A_batch_dimensions = A.shape[:-2]
    B_batch_dimensions = B.shape[:-2]
    
    A = rearrange(A, '... m n -> (...) m n')
    B = rearrange(B, '... m k -> (...) m k')
    
    solution = torch.stack([
        torch.cat([
            torch.linalg.lstsq(a, b).solution.cpu()
            for b in tqdm(torch.split(B, batch_size))
        ])
        for a in tqdm(torch.split(A, batch_size))
    ])
    solution = solution.reshape(*A_batch_dimensions, *B_batch_dimensions, A.shape[-1], B.shape[-1])
    return solution


def ridge_regression(A, B, alpha=None, batch_size=1,):
    A_batch_dimensions = A.shape[:-2]
    B_batch_dimensions = B.shape[:-2]
    
    A = rearrange(A, '... m n -> (...) m n')
    B = rearrange(B, '... m k -> (...) m k')
    
    def fit(X, Y):
        lhs = torch.einsum('... i j, ... i k -> ... j k', X, X)
        rhs = torch.einsum('... i j, ... i k -> ... j k', X, Y)
        if alpha is None:
            return torch.linalg.lstsq(lhs, rhs)
        else:
            ridge = alpha * torch.eye(lhs.shape[0], device=lhs.device)
            return torch.linalg.lstsq(lhs + ridge, rhs)
    
    solution = torch.stack([
        torch.cat([
            fit(a, b).solution.cpu()
            for b in torch.split(B, batch_size)
        ])
        for a in torch.split(A, batch_size)
    ])
    solution = solution.reshape(*A_batch_dimensions, *B_batch_dimensions, A.shape[-1], B.shape[-1])
    return solution

fmri_data = h5py.File(ssd_derivatives_path / 'kamitani2019-cached.hdf5', 'r')

fit_sessions = ['natural_train']

cache_name = 'window-0-16'
run_features = {
    #'bigbigan-resnet50': ['z_mean'],
    'ViT-B=32': ['embedding',],# *(f'transformer.resblocks.{i}' for i in range(12))],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

fit_encoders = {
    f'ridge_alpha={alpha}': partial(ridge_regression, alpha=alpha)
    for alpha in [0.1, 1., 10.]
}

seed = 0
max_features = 512

with h5py.File(derivatives_path / 'feature-encoder-parameters.hdf5', 'a') as f:
    for subject_name, subject in fmri_data.items():
        for session_name, session in subject.items():
            if session_name not in fit_sessions:
                continue
                
            if cache_name not in session:
                print(f'{cache_name} not found for session {session_name}, subject {subject_name}')
                continue
            
            cache = session[cache_name]
            affine = cache.attrs['affine']
            group = f.require_group(f'{subject_name}/{session_name}/{cache_name}')
            group.attrs['affine'] = affine
            X = cache['data']

            for model_name, feature_names in run_features.items():
                model_features = h5py.File(derivatives_path / 'features' / f'{model_name}-features.hdf5', 'r')
                stimulus_ids = session[f'{cache_name}/stimulus_id'][:]
                stimulus_ids = [s.decode('utf-8') for s in stimulus_ids]
                stimulus_ids = [fix_stimulus_id(s, model_features.keys()) for s in stimulus_ids]
                
                features = {}
                for feature_name in feature_names:
                    print(subject_name, model_name, feature_name)
                    Y = np.stack([model_features[f'{stimulus_id}/{feature_name}'][:]
                                  for stimulus_id in stimulus_ids])
                    Y = torch.from_numpy(Y)
                    Y = Y.flatten(start_dim=1)
                    if Y.shape[1] > max_features:
                        np.random.seed(seed)
                        choice = np.random.choice(max_features, size=max_features)
                        Y = Y[:, choice]
                    features[feature_name] = Y

                load_time = 0
                compute_time = 0
                store_time = 0
                
                for i in tqdm(range(X.shape[2])):
                    t = time.time()
                    X_slice = torch.from_numpy(X[:, :, i])
                    load_time += time.time() - t
                    
                    for feature_name in feature_names:
                        Y = features[feature_name]
                        
                        for encoder_name, encoder_func in fit_encoders.items():
                            t = time.time()
                            solution = encoder_func(Y.float().cuda(), rearrange(X_slice, 'n t ... -> ... n t').cuda(), batch_size=50)
                            solution = rearrange(solution, '... t -> t ...')
                            gc.collect()
                            torch.cuda.empty_cache()
                            
                            compute_time += time.time() - t

                            keys = (subject_name, session_name, cache_name, model_name, feature_name, encoder_name, 'parameters')
                            key = '/'.join(keys)
                            parameters_dataset = f.require_dataset(key, (*X.shape[1:], solution.shape[-1]), dtype=np.float32)
                            t = time.time()
                            parameters_dataset[:, i] = solution
                            store_time += time.time() - t

                    if i % 25 == 0:
                        n = (i + 1)
                        print(f'load_time={load_time / n}, compute_time={compute_time / n}, store_time={store_time / n}')

In [None]:
# Run encoders

from pathlib import Path

from itertools import product
from research.data.kamitani_2019 import fix_stimulus_id
import time
from functools import partial
from einops import rearrange
import gc

def pearsonr_torch(X, Y, cast_dtype=torch.float64):
    in_dtype = X.dtype
    X = X.to(cast_dtype)
    Y = Y.to(cast_dtype)
    
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)
    
    X = X / torch.norm(X, dim=0, keepdim=True)
    Y = Y / torch.norm(Y, dim=0, keepdim=True)
    
    r = torch.einsum('b...,b...->...', X, Y).to(in_dtype)
    return r

fmri_data = h5py.File(ssd_derivatives_path / 'kamitani2019-cached.hdf5', 'r')
encoder_parameters = h5py.File(derivatives_path / 'feature-encoder-parameters.hdf5', 'r')

cache_name = 'window-0-16'
run_features = {
    #'bigbigan-resnet50': ['z_mean'],
    'ViT-B=32': ['embedding',],# *(f'transformer.resblocks.{i}' for i in range(12))],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

subject_names = ['sub-03']
fit_session_name = 'natural_training'
run_session_names = ['natural_training', 'natural_test']

seed = 0
max_features = 512

with h5py.File(derivatives_path / 'feature-encoder-results.hdf5', 'a') as f:
    for subject_name in subject_names:

        fit_session_params = encoder_parameters[f'{subject_name}/{fit_session_name}/{cache_name}']
        for run_session_name in run_session_names:
            run_session_cache = fmri_data[f'{subject_name}/{run_session_name}/{cache_name}']

            affine = run_session_cache.attrs['affine']
            group = f.require_group(f'{subject_name}/{fit_session_name}/{run_session_name}/{cache_name}')
            group.attrs['affine'] = affine
            X = run_session_cache['data']

            for model_name, feature_names in run_features.items():
                model_features = h5py.File(derivatives_path / 'features' / f'{model_name}-features.hdf5', 'r')
                stimulus_ids = run_session_cache['stimulus_id'][:]
                stimulus_ids = [s.decode('utf-8') for s in stimulus_ids]
                stimulus_ids = [fix_stimulus_id(s, model_features.keys()) for s in stimulus_ids]

                features = {}
                parameters = {}
                for feature_name in feature_names:
                    print(subject_name, model_name, feature_name)
                    Y = np.stack([model_features[f'{stimulus_id}/{feature_name}'][:]
                                  for stimulus_id in stimulus_ids])
                    Y = torch.from_numpy(Y)
                    Y = Y.flatten(start_dim=1)
                    if Y.shape[1] > max_features:
                        np.random.seed(seed)
                        choice = np.random.choice(max_features, size=max_features)
                        Y = Y[:, choice]
                    features[feature_name] = Y
                    parameters[feature_name] = fit_session_params[model_name][feature_name]

                load_time = 0
                compute_time = 0
                store_time = 0

                for i in tqdm(range(X.shape[2])):
                    t = time.time()
                    X_slice = torch.from_numpy(X[:, :, i])
                    load_time += time.time() - t

                    for feature_name in feature_names:
                        gc.collect()
                        torch.cuda.empty_cache()

                        Y = features[feature_name]
                        for encoder_name, encoder_parameters in parameters[feature_name].items():
                            params_slice = torch.from_numpy(encoder_parameters['parameters'][:, i])

                            t = time.time()
                            X_slice_pred = torch.einsum('thde, ne -> nthd', params_slice.cuda(), Y.cuda().float()).cpu()
                            r = torch.stack([
                                pearsonr_torch(X_slice[:, j].cuda(), X_slice_pred[:, j].cuda()).cpu()
                                for j in range(X_slice.shape[1])
                            ])
                            compute_time += time.time() - t

                            keys = (model_name, feature_name, encoder_name, 'pearsonr', 'scores')
                            key = '/'.join(keys)
                            dataset = group.require_dataset(key, X.shape[1:], dtype=np.float32)
                            t = time.time()
                            dataset[:, i] = r
                            store_time += time.time() - t

                    if i % 25 == 0:
                        n = (i + 1)
                        print(f'load_time={load_time / n}, compute_time={compute_time / n}, store_time={store_time / n}')

In [None]:
# Save encoder results

from pathlib import Path
import torchio as tio
from functools import partial
import nibabel as nib


out_path = derivatives_path / 'correlation_maps'
with h5py.File(derivatives_path / 'feature-encoder-results.hdf5', 'a') as f:
    for subject_name, subject in f.items():
        subject_out_path = out_path / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)

        for fit_session_name, fit_session in subject.items():
            for run_session_name, run_session in fit_session.items():
                for cache_name, cache in run_session.items():
                    affine = cache.attrs['affine']
                    for model_name, model in cache.items():
                        for feature_name, feature in model.items():
                            for prediction_name, prediction in feature.items():
                                for encoder_name, encoder in prediction.items():
                                    keys = (subject_name, fit_session_name, run_session_name,
                                            cache_name, model_name, feature_name, prediction_name, 
                                            encoder_name, evaluation_name)
                                    print(*keys)

                                    save_file_name = f'{"__".join(keys)}.nii.gz'

                                    data = encoder['scores'][:]
                                    sorted_indices_flat = np.argsort(data, axis=None)

                                    T, H, W, D = data.shape
                                    grid = np.zeros(shape=(4, T, H, W, D), dtype=int)
                                    grid[0] = np.arange(T)[:, None, None, None]
                                    grid[1] = np.arange(H)[None, :, None, None]
                                    grid[2] = np.arange(W)[None, None, :, None]
                                    grid[3] = np.arange(D)[None, None, None, :]
                                    grid_flat = grid.reshape(4, T * H * W * D)
                                    sorted_indices = grid_flat[:, sorted_indices_flat]

                                    image = nib.Nifti1Image(torch.tensor(data).permute(1, 2, 3, 0).numpy(), affine)
                                    nib.save(image, subject_out_path / save_file_name)

                                    require_dataset(prediction, 'sorted_indices_flat', sorted_indices_flat)
                                    require_dataset(prediction, 'sorted_indices', sorted_indices)


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
fmri_data['sub-01/natural_training/window-0-16/data'].shape

In [None]:
# Create correlation maps

from pathlib import Path
from tqdm.notebook import tqdm
from itertools import product
from research.data.kamitani_2019 import fix_stimulus_id
import time
from functools import partial


def pearsonr_torch(X, Y):
    X = X.to(torch.float64)
    Y = Y.to(torch.float64)
    
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)
    
    X = X / torch.norm(X, dim=0, keepdim=True)
    Y = Y / torch.norm(Y, dim=0, keepdim=True)
    
    X_num_correlation_dims = len(X.shape) - 1
    Y_num_correlation_dims = len(Y.shape) - 1
    for i in range(Y_num_correlation_dims):
        X = X[..., None]
    for i in range(X_num_correlation_dims):
        Y = Y[:, None]
    
    return torch.einsum('b...i,b...i->...i', X, Y)


fmri_data = h5py.File(ssd_derivatives_path / 'kamitani2019-cached.hdf5', 'r')

cache_name = 'window-0-16'
run_features = {
    #'bigbigan-resnet50': ['z_mean'],
    'ViT-B=32': ['embedding', *(f'transformer.resblocks.{i}' for i in range(12))],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

def norm(data):
    data = torch.norm(data, dim=-1)
    return data

def mean_top_k(data, k):
    data = torch.abs(data)
    data = torch.sort(data)[0]
    data = data[..., -k:].mean(dim=-1)
    return data

def max_feature(data):
    data = torch.abs(data)
    data = torch.max(data, dim=-1)
    return data

selection_modes = {
    #'max': max_feature,
    'mean-top-5': partial(mean_top_k, k=5),
    'norm': norm,
    #'mean-top-10': partial(mean_top_k, k=10),
}

seed = 0
max_features = 512

with h5py.File(ssd_derivatives_path / 'feature-selection-maps.hdf5', 'a') as f:
    for subject_name, subject in fmri_data.items():
        for session_name, session in subject.items():
            if cache_name not in session:
                continue
                print(f'{cache_name} not found for session {session_name}, subject {subject_name}')
            cache = session[cache_name]
            affine = cache.attrs['affine']
            group = f.require_group(f'{subject_name}/{session_name}/{cache_name}')
            group.attrs['affine'] = affine
            X = cache['data']

            for model_name, feature_names in run_features.items():
                model_features = h5py.File(derivatives_path / 'features' / f'{model_name}-features.hdf5', 'r')
                stimulus_ids = session[f'{cache_name}/stimulus_id'][:]
                stimulus_ids = [s.decode('utf-8') for s in stimulus_ids]
                stimulus_ids = [fix_stimulus_id(s, model_features.keys()) for s in stimulus_ids]
                
                features = {}
                for feature_name in feature_names:
                    print(subject_name, model_name, feature_name)
                    Y = np.stack([model_features[f'{stimulus_id}/{feature_name}'][:]
                                  for stimulus_id in stimulus_ids])
                    Y = torch.from_numpy(Y).cuda()
                    Y = Y.flatten(start_dim=1)
                    if Y.shape[1] > max_features:
                        np.random.seed(seed)
                        choice = np.random.choice(max_features, size=max_features)
                        Y = Y[:, choice]
                    features[feature_name] = Y

                load_time = 0
                compute_time = 0
                store_time = 0
                
                for i in tqdm(range(X.shape[2])):
                    t = time.time()
                    X_slice = torch.from_numpy(X[:, :, i]).cuda()
                    load_time += time.time() - t
                    
                    for feature_name in feature_names:
                        Y = features[feature_name]
                        
                        t = time.time()
                        r = torch.stack([
                            pearsonr_torch(X_slice[:, j], Y)
                            for j in range(X_slice.shape[1])
                        ])
                        compute_time += time.time() - t
                        
                        for selection_mode, selection_func in selection_modes.items():
                            keys = (subject_name, session_name, cache_name, model_name, feature_name, selection_mode, 'scores')
                            key = '/'.join(keys)
                            selection_map = f.require_dataset(key, X.shape[1:], dtype=np.float32)
                            score = selection_func(r)
                            t = time.time()
                            selection_map[:, i] = score.cpu()
                            store_time += time.time() - t

                    if i % 25 == 0:
                        n = (i + 1)
                        print(f'load_time={load_time / n}, compute_time={compute_time / n}, store_time={store_time / n}')

In [None]:
Y.shape

In [None]:
# Save correlation maps
from pathlib import Path
import torchio as tio
from functools import partial
import nibabel as nib


def mean_top_k(data, k):
    data = np.abs(data)
    data = np.sort(data)
    data = data[:, :, :, -k:].mean(axis=-1)
    return data

def max_feature(data):
    data = np.abs(data)
    data = np.max(data, axis=3)
    return data

selection_modes = {
    'max': max_feature,
    'mean-top-5': partial(mean_top_k, k=5),
    'mean-top-10': partial(mean_top_k, k=10),
}
models = ['vqgan']

out_path = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\correlation_maps')
correlation_maps = h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\feature-correlation-maps.hdf5', 'r')

with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\feature-selection-maps.hdf5', 'a') as f:
    for subject_name, subject in correlation_maps.items():
        subject_out_path = out_path / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)
        
        for session_name, session in subject.items():
            for cache_name, cache in session.items():
                affine = cache.attrs['affine']
                for model_name, model in cache.items():
                    if model_name not in models:
                        continue

                    for feature_name, feature_correlation_map in model.items():
                        for selection_mode, selection_func in selection_modes.items():
                            keys = (subject_name, session_name, cache_name, model_name, feature_name, selection_mode)
                            print(*keys)

                            save_file_name = f'{"__".join(keys)}.nii.gz'

                            data = feature_correlation_map[:]
                            data = selection_func(data)
                            sorted_indices_flat = np.argsort(data, axis=None)
                            
                            H, W, D = data.shape
                            grid = np.zeros(shape=(4, T, H, W, D), dtype=int)
                            grid[0] = np.arange(T)[:, None, None, None]
                            grid[1] = np.arange(H)[None, :, None, None]
                            grid[2] = np.arange(W)[None, None, :, None]
                            grid[3] = np.arange(D)[None, None, None, :]
                            grid_flat = grid.reshape(4, T * H * W * D)
                            sorted_indices = grid_flat[:, sorted_indices_flat]

                            image = nib.Nifti1Image(torch.tensor(data).permute(1, 2, 3, 0).numpy(), affine)
                            nib.save(image, subject_out_path / save_file_name)
                            
                            group = f.require_group('/'.join(keys))
                            require_dataset(group, 'scores', data)
                            require_dataset(group, 'sorted_indices_flat', sorted_indices_flat)
                            require_dataset(group, 'sorted_indices', sorted_indices)

correlation_maps.close()

In [4]:
# Save correlation maps
from pathlib import Path
import torchio as tio
from functools import partial
import nibabel as nib

cache_name = 'window-0-16'
run_features = {
    #'bigbigan-resnet50': ['z_mean'],
    'ViT-B=32': [*(f'transformer.resblocks.{i}' for i in range(12)), 'embedding'],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

out_path = derivatives_path / 'correlation_maps'

with h5py.File(ssd_derivatives_path / 'feature-selection-maps.hdf5', 'a') as f:
    for subject_name, subject in f.items():
        subject_out_path = out_path / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)
        
        for session_name, session in subject.items():
            if cache_name not in session:
                continue
            cache = session[cache_name]
            affine = cache.attrs['affine']
            for model_name, model in cache.items():
                if model_name not in list(run_features.keys()):
                    continue

                for feature_name, feature in model.items():
                    for selection_mode, selection_map in feature.items():
                        keys = (subject_name, session_name, cache_name, model_name, feature_name, selection_mode)
                        print(*keys)
                        
                        save_file_name = f'{"__".join(keys)}.nii.gz'

                        data = selection_map['scores'][:]
                        sorted_indices_flat = np.argsort(data, axis=None)

                        T, H, W, D = data.shape
                        grid = np.zeros(shape=(4, T, H, W, D), dtype=int)
                        grid[0] = np.arange(T)[:, None, None, None]
                        grid[1] = np.arange(H)[None, :, None, None]
                        grid[2] = np.arange(W)[None, None, :, None]
                        grid[3] = np.arange(D)[None, None, None, :]
                        grid_flat = grid.reshape(4, T * H * W * D)
                        sorted_indices = grid_flat[:, sorted_indices_flat]
                        

                        image = nib.Nifti1Image(torch.tensor(data).permute(1, 2, 3, 0).numpy(), affine)
                        nib.save(image, subject_out_path / save_file_name)

                        group = f.require_group('/'.join(keys))
                        require_dataset(group, 'scores', data)
                        require_dataset(group, 'sorted_indices_flat', sorted_indices_flat)
                        require_dataset(group, 'sorted_indices', sorted_indices)

sub-01 natural_test window-0-16 ViT-B=32 embedding mean-top-5
sub-01 natural_test window-0-16 ViT-B=32 embedding norm
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.0 mean-top-5
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.0 norm
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.1 mean-top-5
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.1 norm
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.10 mean-top-5
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.10 norm
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.11 mean-top-5
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.11 norm
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.2 mean-top-5
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.2 norm
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.3 mean-top-5
sub-01 natural_test window-0-16 ViT-B=32 transformer.resblocks.3 norm
su

In [None]:

cache_name = 'window-4-6-8-10'
selection_mode = 'mean-top-5'
stack_name = 'depth-sequence'
times = [4, 6, 8, 10]
model_name = 'ViT-B=32'
feature_names = [*(f'transformer.resblocks.{i}' for i in range(12)), 'embedding']

out_path = derivatives_path / 'correlation_maps'

with h5py.File(derivatives_path / 'feature-selection-maps.hdf5', 'r') as f:
    for subject_name, subject in f.items():
        subject_out_path = out_path / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)
        
        for session_name, session in subject.items():
            if cache_name not in session:
                continue

            cache = session[cache_name]
            model = cache[model_name]
            affine = cache.attrs['affine']
            for i, time in enumerate(times):
                data = np.stack([
                    model[feature_name][selection_mode]['scores'][i]
                    for feature_name in feature_names
                ])
                
                keys = (subject_name, session_name, cache_name, model_name, feature_name, selection_mode,
                       stack_name, f'time-{time}')
                
                save_file_name = f'{"__".join(keys)}.nii.gz'
                
                image = nib.Nifti1Image(torch.tensor(data).permute(1, 2, 3, 0).numpy(), affine)
                nib.save(image, subject_out_path / save_file_name)


In [None]:
old_path = 'X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\feature-correlation-maps.hdf5'
new_path = 'X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\feature-correlation-maps-new.hdf5'

with h5py.File(old_path, 'r') as old_f:
    with h5py.File(new_path, 'a') as new_f:
        for subject_name, subject in old_f.items():
            for session_name, session in subject.items():
                for cache_name, cache in session.items():
                    if 'affine' not in cache.attrs:
                        continue
                    new_f[cache.name].attrs['affine'] = cache.attrs['affine']
            

In [None]:
np.sort(data, axis=None)[-10:]

In [None]:
import pandas as pd
from pathlib import Path

h5_path = Path("X:\\Datasets\\Deep-Image-Reconstruction\\") / "derivatives" / "kamitani2019.hdf5"

stimulus_id_map_path = Path("X:\\Datasets\\Deep-Image-Reconstruction\\") / "derivatives" / 'kamitani-preprocessed' / 'stimulus_NaturalImageTest.tsv'

stimulus_id_map = pd.read_csv(stimulus_id_map_path, 
                              sep='\t', 
                              names=['_1', 'stimulus_id', 'index', '_2'], 
                              dtype={'stimulus_id': str, 'index': int})

stimulus_ids = list(stimulus_id_map['stimulus_id'])
ids = list(stimulus_id_map['index'])

with h5py.File(h5_path, 'a') as f:
    del f.attrs['test_stimulus_ids']

In [None]:
cached = h5py.File(derivatives_path / 'kamitani2019-cached.hdf5', 'r')
    
with h5py.File(derivatives_path / 'kamitani2019-cached-new.hdf5', 'a') as cached_new:
    
    def copy_data(k, v):
        #print(k, v)
        if isinstance(v, h5py.Group):
            new_v = cached_new.require_group(k)
            for attr_name, attr in v.attrs.items():
                new_v.attrs[attr_name] = attr
    
        if isinstance(v, h5py.Dataset):
            print(v.name, v.shape, v.dtype)
            if v.name.endswith('data'):
                shape = v.shape
                N = v.shape[0]
                H, W, D = v.shape[-3:]
                
                new_v = cached_new.require_dataset(k, shape=(N, 1, H, W, D), dtype='float32')
                for i in tqdm(range(N)):
                    data = v[i]
                    if len(data.shape) == 4:
                        data = data[0]
                    new_v[i, 0] = data
                    
            else:
                require_dataset(cached_new, k, v[:])
    
    cached.visititems(copy_data)
    
cached.close()

In [None]:
np.array(stimulus_id, dtype=np.dtype('S'))

In [None]:
with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019-cached.hdf5', 'a') as f:
    group = f[f'{subject}/{session}/{cache_name}']
    print(group)
    group.require_dataset(

In [None]:
import h5py

with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019.hdf5', 'a') as f:
    f.visit(print)

In [None]:
import h5py

with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019.hdf5', 'a') as f:
    for subject in f.values():
        print(subject)
        for session_name, session in subject.items():
            if session_name =='anatomy':
                continue
            
            run_names = session.keys()
            session.create_group('runs')
            
            for run_name in run_names:
                session.move(run_name, f'runs/{run_name}')

In [None]:
# Compute mean and SD across volume and voxels

import h5py

break
with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019.hdf5', 'a') as f:
    for subject in f.values():
        print(subject)
        for session_name, session in subject.items():
            if session_name =='anatomy':
                continue

            for run in session.values():
                data = run['data'][:]
                
                H, W, D, T = data.shape
                
                data = torch.from_numpy(data).cuda()
                
                voxel_mean = data.mean(dim=3).cpu().numpy()
                voxel_std = data.std(dim=3).cpu().numpy()
                
                volume_mean = data.mean(dim=(0, 1, 2)).cpu().numpy()
                volume_std = data.mean(dim=(0, 1, 2)).cpu().numpy()
                
                run_mean = data.mean().cpu().item()
                run_std = data.std().cpu().item()
                
                require_dataset(run, 'voxel_mean', voxel_mean)
                require_dataset(run, 'voxel_std', voxel_std)
                require_dataset(run, 'volume_mean', volume_mean)
                require_dataset(run, 'volume_std', volume_std)
                
                run.attrs['run_mean'] = run_mean
                run.attrs['run_std'] = run_std
                
                print(run)
                
                #break
            #break
        #break

In [None]:
# Compute linear trends across voxels

with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019.hdf5', 'a') as f:
    for subject in f.values():
        print(subject)
        for session_name, session in subject.items():
            if session_name =='anatomy':
                continue

            for run in session.values():
                print(run)
                
                data = torch.from_numpy(run['data'][:]).cuda()
                voxel_mean = torch.from_numpy(run['voxel_mean'][:]).cuda()
                Y = data[..., None]

                W, H, D, T = data.shape

                X = torch.zeros_like(data)
                X[:, :, :, torch.arange(T)] = torch.arange(T).float().cuda()
                X = torch.stack([X, torch.ones_like(X)], dim=-1)

                solution, residuals, rank, singular_values = torch.linalg.lstsq(X, Y)
                solution = solution.squeeze().cpu().numpy()
                residuals = residuals.squeeze().cpu().numpy()
                
                slope = solution[..., 0]
                intercept = solution[..., 1]
                
                require_dataset(run, 'voxel_linear_trend_slope', slope)
                require_dataset(run, 'voxel_linear_trend_intercept', intercept)
                require_dataset(run, 'voxel_linear_trend_residual', residuals)
                
                #break
            #break
        #break

In [None]:
# Compute linear trends across voxels

with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019.hdf5', 'a') as f:
    for subject in f.values():
        print(subject)
        for session_name, session in subject.items():
            if session_name =='anatomy':
                continue

            for run in session.values():
                print(run)
                
                data = torch.from_numpy(run['data'][:]).cuda()
                voxel_mean = torch.from_numpy(run['voxel_mean'][:]).cuda()
                Y = data[..., None]

                W, H, D, T = data.shape

                X = torch.zeros_like(data)
                X[:, :, :, torch.arange(T)] = torch.arange(T).float().cuda()
                X = torch.stack([X, torch.ones_like(X)], dim=-1)
                
                slope = torch.from_numpy(run['voxel_linear_trend_slope'][:]).cuda()
                intercept = torch.from_numpy(run['voxel_linear_trend_intercept'][:]).cuda()
                solution = torch.stack([slope, intercept], dim=-1)[..., None]
                residual = (X @ solution - Y)
                std = residual.std(dim=3).squeeze().cpu().numpy()
                #print(std.shape)
                
                #break
                
                #require_dataset(run, 'voxel_linear_trend_slope', slope)
                #require_dataset(run, 'voxel_linear_trend_intercept', intercept)
                require_dataset(run, 'voxel_linear_trend_std', std)
                
                #break
            #break
        #break

In [None]:

with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019.hdf5', 'a') as f:
    for subject in f.values():
        print(subject)
        for session_name, session in subject.items():
            if session_name == 'anatomy':
                continue

            for run in session.values():
                print(run)
                
                volume_mean = run['volume_mean'][:]
                T = volume_mean.shape[0]
                Y = torch.from_numpy(volume_mean[:, None])
                X = torch.arange(T).float()
                X = torch.stack([X, torch.ones_like(X)], dim=-1)
                
                solution, residuals, rank, singular_values = torch.linalg.lstsq(X, Y)
                solution = solution.squeeze().cpu().numpy()
                residuals = residuals.squeeze().cpu().numpy()
                
                std = (X @ solution - Y).std()
                
                slope = solution[..., 0]
                intercept = solution[..., 1]
                
                require_dataset(run, 'volume_linear_trend_slope', np.array([slope]))
                require_dataset(run, 'volume_linear_trend_intercept', np.array([intercept]))
                require_dataset(run, 'volume_linear_trend_std', np.array([std]))
                

In [None]:
# Fix computation of std for volume linear trend removal
import h5py

with h5py.File('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\kamitani2019.hdf5', 'a') as f:
    for subject in f.values():
        print(subject)
        for session_name, session in subject.items():
            if session_name == 'anatomy':
                continue

            for run in session.values():
                print(run)
                
                data = torch.from_numpy(run['data'][:]).cuda()
                T = data.shape[3]
                intercept = torch.from_numpy(run['volume_linear_trend_intercept'][:])
                slope = torch.from_numpy(run['volume_linear_trend_slope'][:])
                old_std = torch.from_numpy(run['volume_linear_trend_std'][:])[None, None, None, :]

                X = torch.cat([slope, intercept])
                t = torch.arange(T)
                A = torch.stack([t, torch.ones_like(t)], dim=-1).float()
                mean = (A @ X)[None, None, None, :]
                std = (data - mean.cuda()).std().item()
                print(std, old_std.squeeze())
                
                del run['volume_linear_trend_std']
                require_dataset(run, 'volume_linear_trend_std', np.array([std]))