In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import random
import json
import gc
from typing import Tuple, Optional, Dict
from functools import partial

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import nibabel as nib
from einops import rearrange
from scipy import ndimage
import wandb

dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from research.data.natural_scenes import (
    NaturalScenesDataset,
)
from research.experiments.nsd.nsd_experiment import NSDExperiment

In [2]:
nsd_path = Path('D:\\Datasets\\NSD')
nsd = NaturalScenesDataset(nsd_path)

In [None]:
subject_name = 'subj01'
subject = nsd.subjects[subject_name]
train_mask = nsd.get_split(subject_name, 'split-01')[0]

betas_h5 = subject['betas']
responses = subject['responses']
stimulus_ids = np.array(responses['73KID']) - 1

In [None]:
n = 3
unique_ids, unique_counts = np.unique(stimulus_ids, return_counts=True)
atleast_n_ids = unique_ids[unique_counts >= n]
repetition_ids = np.stack([
    np.where(stimulus_ids == i)[0][:n]
    for i in atleast_n_ids
])
repetition_ids.shape

In [3]:
from tqdm.notebook import tqdm
from research.metrics.metrics import compute_ncsnr_fast, compute_nc


def require_dataset(group, name, data):
    group.require_dataset(name, shape=data.shape, dtype=data.dtype)
    group[name][:] = data

with h5py.File(nsd.dataset_path / 'derivatives/noise-ceiling.hdf5', 'a') as f:
    for subject_id in range(1, 8):
        subject_name = f'subj0{subject_id + 1}'
        print(subject_name)
        
        subject = nsd.subjects[subject_name]
        train_mask = nsd.get_split(subject_name, 'split-01')[0]

        betas_h5 = subject['betas']
        responses = subject['responses']
        stimulus_ids = np.array(responses['73KID']) - 1
        stimulus_ids = stimulus_ids[train_mask]

        n = 3
        unique_ids, unique_counts = np.unique(stimulus_ids, return_counts=True)
        atleast_n_ids = unique_ids[unique_counts >= n]
        repetition_ids = np.stack([
            np.where(stimulus_ids == i)[0][:n]
            for i in atleast_n_ids
        ])

        num_betas, num_voxels = betas_h5['betas'].shape
        voxel_batch_size = 10000
        indices_batches = np.array_split(np.arange(num_voxels), num_voxels // voxel_batch_size)
        ncsnr = []

        for betas_indices in indices_batches:
            print(f'{betas_indices[-1]}/{num_voxels}, {betas_indices[-1] / num_voxels * 100:.1f}%')
            betas = nsd.load_betas(subject_name, betas_indices=betas_indices, return_tensor_dataset=False)[0]
            betas = betas[train_mask]
            ncsnr.append(compute_ncsnr_fast(betas, repetition_ids))
        ncsnr = np.concatenate(ncsnr)

        nc = compute_nc(ncsnr, num_averages=1)

        voxel_selection_path = 'derivatives/voxel-selection.hdf5'
        voxel_selection_key = 'nc/value'

        voxel_selection_file = h5py.File(nsd.dataset_path / voxel_selection_path, 'r')
        key = f'{subject_name}/{voxel_selection_key}'
        nc_original = voxel_selection_file[key][:]
        
        nc = nc.reshape(nc_original.shape)
        nc[np.isnan(nc)] = 0.
        grid = np.argwhere(np.ones_like(nc, dtype=bool))
        nc_sorted_indices_flat = nc.argsort(axis=None)[::-1].astype(int)
        nc_sorted_indices = grid[nc_sorted_indices_flat].astype(int)
        
        require_dataset(f, f'{subject_name}/split-01/value', nc)
        require_dataset(f, f'{subject_name}/split-01/sorted_indices_flat', nc_sorted_indices_flat)
        require_dataset(f, f'{subject_name}/split-01/sorted_indices', nc_sorted_indices)
        
        

subj02
10001/730128, 1.4%


  ncsnr = std_signal / std_noise


20003/730128, 2.7%
30005/730128, 4.1%
40007/730128, 5.5%
50009/730128, 6.8%
60011/730128, 8.2%
70013/730128, 9.6%
80015/730128, 11.0%
90017/730128, 12.3%
100019/730128, 13.7%
110021/730128, 15.1%
120023/730128, 16.4%
130025/730128, 17.8%
140027/730128, 19.2%
150029/730128, 20.5%
160031/730128, 21.9%
170033/730128, 23.3%
180035/730128, 24.7%
190037/730128, 26.0%
200039/730128, 27.4%
210041/730128, 28.8%
220043/730128, 30.1%
230045/730128, 31.5%
240047/730128, 32.9%
250049/730128, 34.2%
260051/730128, 35.6%
270053/730128, 37.0%
280055/730128, 38.4%
290057/730128, 39.7%
300059/730128, 41.1%
310061/730128, 42.5%
320063/730128, 43.8%
330065/730128, 45.2%
340067/730128, 46.6%
350069/730128, 47.9%
360071/730128, 49.3%
370073/730128, 50.7%
380075/730128, 52.1%
390077/730128, 53.4%
400079/730128, 54.8%
410081/730128, 56.2%
420083/730128, 57.5%
430085/730128, 58.9%
440087/730128, 60.3%
450089/730128, 61.6%
460091/730128, 63.0%
470093/730128, 64.4%
480095/730128, 65.8%
490097/730128, 67.1%
500099

  nc = 100. * ncsnr_squared / (ncsnr_squared + (1. / num_averages))


subj03
10057/704052, 1.4%
20115/704052, 2.9%
30173/704052, 4.3%
40231/704052, 5.7%
50289/704052, 7.1%
60347/704052, 8.6%
70405/704052, 10.0%
80463/704052, 11.4%
90521/704052, 12.9%
100579/704052, 14.3%
110637/704052, 15.7%
120695/704052, 17.1%
130753/704052, 18.6%
140811/704052, 20.0%
150869/704052, 21.4%
160927/704052, 22.9%
170985/704052, 24.3%
181043/704052, 25.7%
191101/704052, 27.1%
201159/704052, 28.6%
211217/704052, 30.0%
221275/704052, 31.4%
231333/704052, 32.9%
241391/704052, 34.3%
251449/704052, 35.7%
261507/704052, 37.1%
271565/704052, 38.6%
281623/704052, 40.0%
291681/704052, 41.4%
301739/704052, 42.9%
311797/704052, 44.3%
321855/704052, 45.7%
331913/704052, 47.1%
341971/704052, 48.6%
352029/704052, 50.0%
362087/704052, 51.4%
372145/704052, 52.9%
382203/704052, 54.3%
392261/704052, 55.7%
402319/704052, 57.1%
412377/704052, 58.6%
422435/704052, 60.0%
432493/704052, 61.4%
442551/704052, 62.9%
452609/704052, 64.3%
462667/704052, 65.7%
472725/704052, 67.1%
482783/704052, 68.6%


In [None]:
print(nc.shape)

In [None]:
grid = np.argwhere(np.ones_like(nc, dtype=bool))
nc_sorted_indices_flat = nc.argsort(axis=None)[::-1].astype(int)
nc_sorted_indices = grid[nc_sorted_indices_flat].astype(int)

In [None]:
nc_sorted_indices_flat.shape
nc_sorted_indices.shape

In [None]:
nc_original.shape

In [None]:
from tqdm.notebook import tqdm
from research.metrics.metrics import compute_ncsnr, compute_nc

voxel_selection_path = 'derivatives/voxel-selection.hdf5'
voxel_selection_key = 'nc/value'

voxel_selection_file = h5py.File(nsd.dataset_path / voxel_selection_path, 'r')
key = f'{subject_name}/{voxel_selection_key}'
nc_original = voxel_selection_file[key][:]
threshold = 10.
nc_original[nc_original < threshold] = 0.

betas, volume_indices = nsd.load_betas(
    subject_name, 
    voxel_selection_path=voxel_selection_path,
    voxel_selection_key=voxel_selection_key,
    threshold=10.,
    return_tensor_dataset=False,
    return_volume_indices=True,
    session_normalize=True,
    scale_betas=True,
)


In [None]:
from research.metrics.metrics import compute_ncsnr, compute_nc
ncsnr = compute_ncsnr(betas, repetition_ids)
nc = compute_nc(ncsnr, num_averages=1)

nc_volume = nsd.reconstruct_volume(
    subject_name, 
    torch.from_numpy(nc), 
    volume_indices
)

In [None]:
D = nc_original.shape[2]
volume = nc - nc_original

@interact(d=(0, D-1), original=True)
def show(d, original):
    volume = nc_original if original else nc
    
    plt.figure(figsize=(12, 12))
    plt.imshow(volume.T[:, :, d], cmap='jet', vmin=0, vmax=75)