In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import torch
import torchio as tio
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path

dir2 = os.path.abspath('..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset_path = Path('D:\\Datasets\\NSD')
stimulu_path = dataset_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'
stimulus_images = h5py.File(stimulu_path, 'r')['imgBrick']

In [None]:
# Load a clip model
import clip

print(clip.available_models())
model_name = 'ViT-B/32'
model, preprocess = clip.load(model_name, device=device)
model = model.visual

In [None]:
# Laod a torchvision model
import torchvision.models as models
from torchvision import transforms as T

model_name = 'vgg19_bn'
model = models.vgg19_bn(pretrained=True)
model.to(device)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), normalize])

In [None]:
dict(model.named_modules()).keys()

In [None]:
# Feature visualizer
from PIL import Image
from functools import partial

def vis_features(x):
    if not isinstance(x, torch.Tensor):
        print(type(x))
        return
    x = x.float().cpu()
    print(x.shape, x.dtype)
    if len(x.shape) != 4:
        return
    N, C, W, H = x.shape
    
    print(x.mean(), x.std())

    @interact(i=(0, N-1), c=(0, C-1))
    def plot_feature_map(i, c):
        fig = plt.figure(figsize=(8, 8))
        plt.imshow(x[i, c].cpu(), cmap="gray")
        plt.colorbar()
        plt.show()
        plt.close(fig)


modules = dict(model.named_modules())
#print([(node, modules[node].__class__.__name__) for node in nodes if node in modules])
N = stimulus_images.shape[0]
@interact(module_name=modules.keys(), stimulus_id=range(N))
def select_module(module_name, stimulus_id):
    image_data = stimulus_images[stimulus_id]
    image = Image.fromarray(image_data)
    x = preprocess(image).unsqueeze(0).to(device).to(torch.float16)
    
    features = {}
    def forward_hook(module_name, module, x_in, x_out):
        features[module_name] = x_out.clone()
    
    module = modules[module_name]
    hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
    
    with torch.no_grad():
        model(x)
    
    vis_features(features[module_name])

In [None]:
# Label vgg nodes

layer = 1
counts = {'conv':0, 'bn': 0, 'relu': 0}
out = {}
for node in nodes:
    if not node.startswith('features'):
        continue
    num = int(node.split('.')[1])
    
    module = modules[node]
    module_name = module.__class__.__name__
    short_module_name = {'Conv2d': 'conv', 'BatchNorm2d': 'bn', 'ReLU': 'relu', 'MaxPool2d': 'pool'}[module_name]
    if short_module_name == "pool":
        layer += 1
        counts = {k: 0 for k in counts.keys()}
        continue
    counts[short_module_name] += 1
    
    out[node] = f'layer{layer}.{short_module_name}.{counts[short_module_name]}'
out

In [None]:
[node for node in nodes if node.endswith('add')]

In [None]:
save_nodes = [
    'layer1.2.add',
    'layer2.3.add',
    'layer3.5.add',
    'layer4.2.add',
    'attnpool.getitem_6',
    'attnpool.getitem_8',
]

In [None]:
save_nodes = {
    'features.10': 'layer2.conv.2',
    'features.11': 'layer2.bn.2',
    'features.12': 'layer2.relu.2',
    'features.23': 'layer3.conv.4',
    'features.24': 'layer3.bn.4',
    'features.25': 'layer3.relu.4',
    'features.36': 'layer4.conv.4',
    'features.37': 'layer4.bn.4',
    'features.38': 'layer4.relu.4',
    'features.49': 'layer5.conv.4',
    'features.50': 'layer5.bn.4',
    'features.51': 'layer5.relu.4',
    'classifier.0': 'classifier.0',
    'classifier.3': 'classifier.3',
}

In [None]:
save_modules = {
    **{f'transformer.resblocks.{i}': f'transformer.resblocks.{i}' for i in range(12)},
    '': 'embedding'
}

In [None]:
save_modules = {
    '': 'embedding'
}

In [None]:
model_name

In [None]:
from functools import partial
from tqdm.notebook import tqdm
from PIL import Image
from functools import partial
from typing import Sequence, Dict

derivatives_path = dataset_path / 'derivatives' / 'stimulus_embeddings'
modules = dict(model.named_modules())

with h5py.File(derivatives_path / f"{model_name.replace('/', '=')}-embeddings.hdf5", "a") as f:
    N = stimulus_images.shape[0]
    
    for stimulus_id in tqdm(range(N)):
        image_data = stimulus_images[stimulus_id]
        image = Image.fromarray(image_data)
        x = preprocess(image).unsqueeze(0).to(device).to(torch.float16)

        features = {}
        def forward_hook(module_name, module, x_in, x_out):
            if x_out.shape[0] == 1:
                x_out = x_out[0]
            features[module_name] = x_out.clone().cpu().numpy()
        
        hook_handles = []
        if isinstance(save_modules, Sequence):
            for module_name in save_modules:
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
                hook_handles.append(hook_handle)
        elif isinstance(save_modules, Dict):
            for module_name, feature_name in save_modules.items():
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, feature_name))
                hook_handles.append(hook_handle)
        
        with torch.no_grad():
            model(x)
            
        for hook_handle in hook_handles:
            hook_handle.remove()
        
        for feature_name, feature in features.items():
            f.require_dataset(feature_name, (N, *feature.shape), feature.dtype)
            f[feature_name][stimulus_id] = feature
