In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import torch
import torchio as tio
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path

dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

C:\Users\Cefir\anaconda3\envs\Neurophysiological-Data-Decoding\lib\site-packages\numpy\.libs\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll
C:\Users\Cefir\anaconda3\envs\Neurophysiological-Data-Decoding\lib\site-packages\numpy\.libs\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll


In [11]:
dataset_path = Path('X:\\Datasets\\2021 TC2See fMRI Data\\')
project_path = dataset_path / 'project'
derivatives_path = dataset_path / 'derivatives'

ssd_dataset_path = Path('C:\\Datasets\\2021 TC2See fMRI Data\\')
ssd_derivatives_path = ssd_dataset_path / 'derivatives'

In [None]:
# Create h5 file for stimulus images
from PIL import Image

stimulus_images_path = Path('G:\\Github Repositories\\bird_data\\docs\\cropped')

with h5py.File(derivatives_path / 'stimulus-images.hdf5', 'w') as f:
    for image_file_path in stimulus_images_path.iterdir():
        stimulus_name = image_file_path.stem
        
        class_id, image_id = stimulus_name.split('.')
        class_id = int(class_id)
        
        bird_name = image_id[:-2]
        bird_id = int(image_id[-1])
        
        with Image.open(image_file_path) as image:
            data = np.array(image)
        f[f'{stimulus_name}/data'] = data
        f[stimulus_name].attrs['class_id'] = class_id
        f[stimulus_name].attrs['bird_id'] = bird_id
        
        print(stimulus_name, bird_class_id, bird_name, bird_id)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
stimulus_images = h5py.File(derivatives_path / 'stimulus-images.hdf5', "r")

In [6]:
# Load a CLIP model
import clip

print(clip.available_models())
model_name = 'ViT-B/32'
model, preprocess = clip.load(model_name, device=device)
model = model.visual

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']


In [7]:
# Feature visualizer

from PIL import Image
from functools import partial

def vis_features(x):
    if not isinstance(x, torch.Tensor):
        print(type(x))
        return
    x = x.float().cpu()
    print(x.shape, x.dtype)
    if len(x.shape) != 4:
        return
    N, C, W, H = x.shape
    
    print(x.mean(), x.std())

    @interact(i=(0, N-1), c=(0, C-1))
    def plot_feature_map(i, c):
        fig = plt.figure(figsize=(8, 8))
        plt.imshow(x[i, c].cpu(), cmap="gray")
        plt.colorbar()
        plt.show()
        plt.close(fig)


modules = dict(model.named_modules())
@interact(module_name=modules.keys(), stimulus_id=list(stimulus_images.keys()))
def select_module(module_name, stimulus_id):
    image_data = stimulus_images[stimulus_id]['data'][:]
    image = Image.fromarray(image_data)
    x = preprocess(image).unsqueeze(0).to(device).to(torch.float16)
    
    features = {}
    def forward_hook(module_name, module, x_in, x_out):
        features[module_name] = x_out.clone()
    
    module = modules[module_name]
    hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
    
    with torch.no_grad():
        model(x)
    
    vis_features(features[module_name])

interactive(children=(Dropdown(description='module_name', options=('', 'conv1', 'ln_pre', 'transformer', 'tran…

In [9]:
# Define features to save

save_modules = {
    '': 'embedding'
}

In [13]:
save_modules = [f'transformer.resblocks.{i}' for i in range(12)]

In [15]:
# Feature extraction

from functools import partial
from tqdm.notebook import tqdm
from PIL import Image
from functools import partial
from typing import Sequence, Dict

modules = dict(model.named_modules())

with h5py.File(derivatives_path / f"{model_name.replace('/', '=')}-features.hdf5", "a") as f:
    for stimulus_id, stimulus_image in tqdm(stimulus_images.items()):
        image_data = stimulus_image['data'][:]
        image = Image.fromarray(image_data)
        x = preprocess(image).unsqueeze(0).to(device).to(torch.float16)

        features = {}
        def forward_hook(module_name, module, x_in, x_out):
            features[module_name] = x_out.clone()
        
        hook_handles = []
        if isinstance(save_modules, Sequence):
            for module_name in save_modules:
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
                hook_handles.append(hook_handle)
        elif isinstance(save_modules, Dict):
            for module_name, feature_name in save_modules.items():
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, feature_name))
                hook_handles.append(hook_handle)
        
        with torch.no_grad():
            model(x)
            
        for hook_handle in hook_handles:
            hook_handle.remove()

        if stimulus_id not in f:
            stimulus = f.create_group(stimulus_id)
        else:
            stimulus = f[stimulus_id]

        for node_name, feature in features.items():
            if feature.shape[0] == 1:
                feature = feature[0]
            feature = feature.cpu()
            if node_name in stimulus:
                stimulus[node_name][:] = feature
            else:
                stimulus[node_name] = feature

  0%|          | 0/300 [00:00<?, ?it/s]