In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import torch
import torch.nn.functional as F
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
from PIL import Image
from functools import partial
import math
from einops import rearrange

dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from research.data.natural_scenes import (
    NaturalScenesDataset,
    StimulusDataset,
    KeyDataset
)
from research.experiments.nsd.nsd_access import NSDAccess


In [31]:
device = "cuda" if torch.cuda.is_available() else "cpu"

nsd_path = Path('D:\\Datasets\\NSD')
stimulu_path = nsd_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'
stimulus_images = h5py.File(stimulu_path, 'r')['imgBrick']

In [81]:
# Load a clip model
import clip

print(clip.available_models())
model_name = 'ViT-L/14'
full_model, preprocess = clip.load(model_name, device=device)
#model = full_model.visual


['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [3]:
from transformers import CLIPTokenizer, CLIPModel, CLIPTextModel, CLIPVisionModel, CLIPFeatureExtractor, CLIPProcessor

model_name = "clip-vit-large-patch14"

model = CLIPModel.from_pretrained(f'openai/{model_name}')
vision_model = CLIPVisionModel.from_pretrained(f'openai/{model_name}')
processor = CLIPProcessor.from_pretrained(f'openai/{model_name}')
tokenizer = CLIPTokenizer.from_pretrained(f'openai/{model_name}')
transformer = CLIPTextModel.from_pretrained(f'openai/{model_name}')

def preprocess(img=None, text=None):
    return processor(images=img, text=text, return_tensors="pt")

Downloading config.json:   0%|          | 0.00/4.41k [00:00<?, ?B/s]

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.5.mlp.fc1.weight', 'vision_model.encoder.layers.15.self_attn.k_proj.weight', 'vision_model.encoder.layers.0.layer_norm1.bias', 'vision_model.encoder.layers.13.self_attn.v_proj.bias', 'vision_model.encoder.layers.4.self_attn.v_proj.bias', 'vision_model.encoder.layers.18.self_attn.v_proj.bias', 'vision_model.encoder.layers.6.layer_norm2.bias', 'vision_model.encoder.layers.0.self_attn.q_proj.weight', 'vision_model.encoder.layers.7.self_attn.v_proj.weight', 'vision_model.encoder.layers.17.mlp.fc1.weight', 'vision_model.encoder.layers.20.self_attn.q_proj.bias', 'vision_model.encoder.layers.15.layer_norm1.bias', 'vision_model.encoder.layers.16.layer_norm1.bias', 'vision_model.encoder.layers.18.layer_norm2.weight', 'vision_model.encoder.layers.13.mlp.fc2.weight', 'vision_model.encoder.layers.20.layer_norm1.weight', 'vision_model.encoder.layers.7.se

In [24]:
with torch.no_grad():
    out = model(**preprocess(Image.fromarray(stimulus_images[0]), 'asdf'))

In [29]:
print(out.keys())
out.text_embeds.shape

odict_keys(['logits_per_image', 'logits_per_text', 'text_embeds', 'image_embeds', 'text_model_output', 'vision_model_output'])


torch.Size([1, 768])

In [21]:
vision_out.pooler_output.shape

torch.Size([1, 1024])

In [70]:
text_out = outputs.last_hidden_state

In [None]:
model.modules()

In [None]:
# Laod a torchvision model
import torchvision.models as models
from torchvision import transforms as T

model_name = 'vgg19_bn'
model = models.vgg19_bn(pretrained=True)
model.to(device)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), normalize])

In [None]:
# Load depth model
model_name = "DPT_Large" # "DPT_Large" or "DPT_Hybrid"
model = torch.hub.load("intel-isl/MiDaS", model_name)
device = torch.device('cuda')
model.to(device)
model.eval()
preprocess = torch.hub.load('intel-isl/MiDaS', 'transforms').dpt_transform

In [None]:
dict(model.named_modules()).keys()

In [None]:
torch.__version__

In [21]:
# Feature visualizer


def vis_features(x):
    if not isinstance(x, torch.Tensor):
        print(type(x))
        return
    x = x.float().cpu()
    print(x.shape, x.dtype)

    if len(x.shape) == 3:
        d = int(math.sqrt(x.shape[0] - 1))
        x = rearrange(x[:-1, 0], '(h w) c -> c h w', h=d, w=d)[None]
        
    if len(x.shape) != 4:
        return
    N, C, W, H = x.shape
    
    print(x.mean(), x.std())

    @interact(i=(0, N-1), c=(0, C-1))
    def plot_feature_map(i, c):
        fig = plt.figure(figsize=(8, 8))
        plt.imshow(x[i, c].cpu(), cmap="gray")
        plt.colorbar()
        plt.show()
        plt.close(fig)


modules = dict(model.named_modules())
#print([(node, modules[node].__class__.__name__) for node in nodes if node in modules])
N = stimulus_images.shape[0]
@interact(module_name=modules.keys(), stimulus_id=range(N))
def select_module(module_name, stimulus_id):
    image = stimulus_images[stimulus_id]
    print(image.min(), image.max())
    image = Image.fromarray(image)
    x = preprocess(image).unsqueeze(0).to(device).to(torch.float16)

    #x = preprocess(image).to(device)
    print(x.shape)
    print(x.mean().item(), x.min().item(), x.max().item(), x.std().item())
    
    features = {}
    def forward_hook(module_name, module, x_in, x_out):
        features[module_name] = x_out.clone()
    
    module = modules[module_name]
    hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
    
    with torch.no_grad():
        model(x)
    
    vis_features(features[module_name])

0 255
torch.Size([1, 3, 224, 224])
0.88134765625 -1.7919921875 2.146484375 0.5576171875
torch.Size([1, 768]) torch.float32


interactive(children=(Dropdown(description='module_name', options=('', 'conv1', 'ln_pre', 'transformer', 'tran…

In [41]:
out2 = model2(preprocess2(image).unsqueeze(0).to(device).to(torch.float16)).detach().cpu()

In [73]:
out2,

torch.Size([1, 768])

In [79]:
((text_out / np.linalg.norm(text_out, axis=-1)[..., None]) * (out2 / np.linalg.norm(out2))

tensor([[-0.0260,  0.1188,  0.0790,  0.0441,  0.0510,  0.0940,  0.0188,  0.0230,
          0.0601,  0.0030,  0.0375,  0.0423,  0.0513,  0.0343,  0.0125,  0.0704,
          0.0553,  0.0814,  0.0734,  0.0708,  0.0690,  0.0682,  0.0676,  0.0670,
          0.0668,  0.0663,  0.0660,  0.0658,  0.0652,  0.0650,  0.0645,  0.0641,
          0.0636,  0.0631,  0.0629,  0.0626,  0.0623,  0.0621,  0.0619,  0.0619,
          0.0619,  0.0618,  0.0621,  0.0622,  0.0624,  0.0623,  0.0624,  0.0624,
          0.0623,  0.0623,  0.0620,  0.0623,  0.0622,  0.0624,  0.0623,  0.0622,
          0.0623,  0.0621,  0.0620,  0.0620,  0.0619,  0.0622,  0.0620,  0.0617,
          0.0618,  0.0614,  0.0616,  0.0613,  0.0612,  0.0610,  0.0604,  0.0598,
          0.0605,  0.0598,  0.0604,  0.0605,  0.0607]])

In [40]:
image = stimulus_images[1]
image = Image.fromarray(image)
x = preprocess(image)
with torch.no_grad():
    output = model.get_image_features(**x)

In [45]:
np.linalg.norm(out2)

19.47

In [47]:
((output / np.linalg.norm(output)) * (out2 / np.linalg.norm(out2))).sum()

tensor(1.0002)

In [None]:
# Label vgg nodes

layer = 1
counts = {'conv':0, 'bn': 0, 'relu': 0}
out = {}
for node in nodes:
    if not node.startswith('features'):
        continue
    num = int(node.split('.')[1])
    
    module = modules[node]
    module_name = module.__class__.__name__
    short_module_name = {'Conv2d': 'conv', 'BatchNorm2d': 'bn', 'ReLU': 'relu', 'MaxPool2d': 'pool'}[module_name]
    if short_module_name == "pool":
        layer += 1
        counts = {k: 0 for k in counts.keys()}
        continue
    counts[short_module_name] += 1
    
    out[node] = f'layer{layer}.{short_module_name}.{counts[short_module_name]}'
out

In [None]:
[node for node in nodes if node.endswith('add')]

In [None]:
save_nodes = [
    'layer1.2.add',
    'layer2.3.add',
    'layer3.5.add',
    'layer4.2.add',
    'attnpool.getitem_6',
    'attnpool.getitem_8',
]

In [None]:
save_nodes = {
    'features.10': 'layer2.conv.2',
    'features.11': 'layer2.bn.2',
    'features.12': 'layer2.relu.2',
    'features.23': 'layer3.conv.4',
    'features.24': 'layer3.bn.4',
    'features.25': 'layer3.relu.4',
    'features.36': 'layer4.conv.4',
    'features.37': 'layer4.bn.4',
    'features.38': 'layer4.relu.4',
    'features.49': 'layer5.conv.4',
    'features.50': 'layer5.bn.4',
    'features.51': 'layer5.relu.4',
    'classifier.0': 'classifier.0',
    'classifier.3': 'classifier.3',
}

In [None]:
save_modules = {
    **{f'transformer.resblocks.{i}': f'transformer.resblocks.{i}' for i in range(12)},
    '': 'embedding'
}

In [None]:
save_modules = {
    '': 'embedding'
}

In [None]:
save_modules = ['scratch.refinenet4', 'scratch.layer3_rn']

In [None]:
model_name

In [None]:
from functools import partial
#from tqdm.notebook import tqdm
from PIL import Image
from functools import partial
from typing import Sequence, Dict

derivatives_path = dataset_path / 'derivatives' / 'stimulus_embeddings'
modules = dict(model.named_modules())

with h5py.File(derivatives_path / f"{model_name.replace('/', '=').replace('@', '-')}.hdf5", "a") as f:
    N = stimulus_images.shape[0]
    for stimulus_id in tqdm(range(N)):
        image_data = stimulus_images[stimulus_id]
        if model_name.startswith('DPT'):
            x = preprocess(image_data).to(device)
        else:
            image = Image.fromarray(image_data)
            x = preprocess(image).unsqueeze(0).to(device).to(torch.float16)
        features = {}
        def forward_hook(module_name, module, x_in, x_out):
            if x_out.shape[0] == 1:
                x_out = x_out[0]
            features[module_name] = x_out.clone().cpu().float().numpy()
        hook_handles = []
        if isinstance(save_modules, Sequence):
            for module_name in save_modules:
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
                hook_handles.append(hook_handle)
        elif isinstance(save_modules, Dict):
            for module_name, feature_name in save_modules.items():
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, feature_name))
                hook_handles.append(hook_handle)
        with torch.no_grad():
            model(x)
        for hook_handle in hook_handles:
            hook_handle.remove()
        for feature_name, feature in features.items():
            f.require_dataset(feature_name, (N, *feature.shape), feature.dtype)
            f[feature_name][stimulus_id] = feature
            
            
            
            

In [None]:
from functools import partial
#from tqdm.notebook import tqdm
from PIL import Image
from functools import partial
from typing import Sequence, Dict

derivatives_path = dataset_path / 'derivatives' / 'stimulus_embeddings'
modules = dict(model.named_modules())

with h5py.File(derivatives_path / f"{model_name.replace('/', '=').replace('@', '-')}.hdf5", "a") as f:
    N = stimulus_images.shape[0]
    for stimulus_id in tqdm(range(N)):
        image_data = stimulus_images[stimulus_id]
        if model_name.startswith('DPT'):
            x = preprocess(image_data).to(device)
        else:
            image = Image.fromarray(image_data)
            x = preprocess(image)#.unsqueeze(0).to(device).to(torch.float16)
        features = {}
        def forward_hook(module_name, module, x_in, x_out):
            if x_out.shape[0] == 1:
                x_out = x_out[0]
            features[module_name] = x_out.clone().cpu().float().numpy()
        hook_handles = []
        if isinstance(save_modules, Sequence):
            for module_name in save_modules:
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
                hook_handles.append(hook_handle)
        elif isinstance(save_modules, Dict):
            for module_name, feature_name in save_modules.items():
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, feature_name))
                hook_handles.append(hook_handle)
        with torch.no_grad():
            model(x)
        for hook_handle in hook_handles:
            hook_handle.remove()
        for feature_name, feature in features.items():
            f.require_dataset(feature_name, (N, *feature.shape), feature.dtype)
            f[feature_name][stimulus_id] = feature

In [40]:
nsd_path = Path('D:\\Datasets\\NSD\\')
nsd = NaturalScenesDataset(nsd_path, coco_path='X:\\Datasets\\COCO')

In [84]:
with torch.no_grad():
    batch_encoding = tokenizer(text, truncation=True, max_length=77, return_length=True,
                               return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    tokens = batch_encoding["input_ids"]
    outputs = transformer(input_ids=tokens)

In [89]:
outputs.last_hidden_state.shape

torch.Size([1, 77, 768])

In [None]:
N = 73000

for i in tqdm(range(N)):
    text = nsd.load_coco(i)
    if len(text) != 5:
        print(i, len(text))

In [None]:

N = 73000
E = 768
derivatives_path = dataset_path / 'derivatives' / 'stimulus_embeddings'
with h5py.File(derivatives_path / f"{model_name.replace('/', '=').replace('@', '-')}-text.hdf5", "w") as f:
    
    f.require_dataset('embedding', shape=(N, 5, E), dtype='float32')
    f.require_dataset('embedding_mean', shape=(N, E), dtype='float32')
    
    for i in tqdm(range(N)):
        text = nsd.load_coco(i)[:5]
        tokens = clip.tokenize(text).cuda()
        with torch.no_grad():
            embedding = full_model.encode_text(tokens).float()
        
        embedding_mean = F.normalize(embedding.float().mean(dim=0), dim=0)
        
        f['embedding'][i] = embedding.cpu()
        f['embedding_mean'][i] = embedding_mean.cpu()
        

In [None]:
i = 108
text = nsd.load_coco(i)[:5]
for t in text:
    print(t)
tokens = clip.tokenize(text).cuda()
print(tokens.shape)
with torch.no_grad():
    embedding = full_model.encode_text(tokens)

In [41]:
def get_best_captions():
    
    model_name = 'ViT-B=32'
    stimulus_key = 'embedding'

    save_key = stimulus_key
    save_model_name = model_name

    stimulus_file = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}.hdf5', 'r')
    x = stimulus_file[stimulus_key][:]

    stimulus_file_text = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}-text.hdf5', 'r')
    x_text = stimulus_file_text[stimulus_key][:]
    x_text = x_text / np.linalg.norm(x_text, axis=-1, keepdims=True)

    ids = np.stack([np.arange(73000) for _ in range(5)], axis=-1)
    print(ids.shape)

    #random_ids = np.arange(73000)
    #np.random.shuffle(random_ids)
    #print(random_ids)
    #x_text = x_text[random_ids]

    text_dists = np.einsum('ni,nti->nt', x, x_text)
    print(text_dists)

    all_captions = np.array([nsd.load_coco(i)[:5] for i in tqdm(range(73000))])
    best_captions = all_captions[np.arange(73000), np.argmax(text_dists, axis=1)]
    return best_captions

best_captions = get_best_captions()

(73000, 5)
[[0.3166824  0.29903737 0.2705831  0.26404428 0.3248881 ]
 [0.3135147  0.29028302 0.30885544 0.29989666 0.2975764 ]
 [0.35390684 0.2882279  0.3490643  0.35402402 0.32551354]
 ...
 [0.28977567 0.2876259  0.29651064 0.30269492 0.31399542]
 [0.31555313 0.31223205 0.30614698 0.3007291  0.26928315]
 [0.29240453 0.28594497 0.32062352 0.29156315 0.2844131 ]]


  0%|          | 0/73000 [00:00<?, ?it/s]

In [104]:
transformer = transformer.to(device)

In [105]:
from tqdm.notebook import tqdm

N = 73000
derivatives_path = dataset_path / 'derivatives' / 'stimulus_embeddings'
with h5py.File(derivatives_path / f"{model_name.replace('/', '=').replace('@', '-')}-text.hdf5", "w") as f:
    
    f.require_dataset('embedding_unpooled', shape=(N, 77, 768), dtype='float32')
    
    for i in range(N):
        if i % 100 == 0:
            print(i)
        text = best_captions[i]
        with torch.no_grad():
            batch_encoding = tokenizer(text, truncation=True, max_length=77, return_length=True,
                                       return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
            tokens = batch_encoding["input_ids"].to(device)
            outputs = transformer(input_ids=tokens)

        f['embedding_unpooled'][i] = outputs.last_hidden_state.cpu()

  0%|          | 0/73000 [00:00<?, ?it/s]

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
14600
14700
14800
14900
15000
15100
15200
15300
15400
15500
15600
15700
15800
15900
16000
16100
16200
16300
16400
16500
16600
16700
16800
16900
17000
17100
17200
17300
17400
17500
17600
17700
17800
17900
18000
18100
18200
18300
18400
18

In [49]:
from tqdm.notebook import tqdm

N = 73000
derivatives_path = dataset_path / 'derivatives' / 'stimulus_embeddings'
model = model.to(device)

with h5py.File(derivatives_path / f"{model_name.replace('/', '=').replace('@', '-')}.hdf5", "w") as f:
    f.require_dataset('text_embedding', shape=(N, 768), dtype='float32')
    f.require_dataset('image_embedding', shape=(N, 768), dtype='float32')
    
    for i in range(N):
        if i % 100 == 0:
            print(i)
        text = best_captions[i]
        with torch.no_grad():
            inputs = preprocess(Image.fromarray(stimulus_images[i]), best_captions[i])
            inputs = {k: v.cuda() for k, v in inputs.items()}
            out = model(**inputs)

        f['text_embedding'][i] = out.text_embeds[0].cpu()
        f['image_embedding'][i] = out.image_embeds[0].cpu()

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
14600
14700
14800
14900
15000
15100
15200
15300
15400
15500
15600
15700
15800
15900
16000
16100
16200
16300
16400
16500
16600
16700
16800
16900
17000
17100
17200
17300
17400
17500
17600
17700
17800
17900
18000
18100
18200
18300
18400
18

In [46]:
inputs = preprocess(Image.fromarray(stimulus_images[i]), best_captions[i])
inputs = {k: v.cuda() for k, v in inputs.items()}
inputs

{'input_ids': tensor([[49406,   320,  1929, 11308,   525,   320, 10176,  1972,   593,   320,
           2533,   269, 49407]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'),
 'pixel_values': tensor([[[[ 0.2077,  0.1785,  0.0909,  ...,  0.1493,  0.1639,  0.1639],
           [ 0.0325,  0.0617, -0.0405,  ...,  0.2661,  0.2369,  0.2369],
           [ 0.1639,  0.1639,  0.1055,  ...,  0.3245,  0.2953,  0.2953],
           ...,
           [-0.0259,  0.0325, -0.0405,  ...,  0.5581,  0.5727,  0.5727],
           [ 0.3683,  0.3829,  0.3829,  ...,  0.5435,  0.4997,  0.5143],
           [ 0.5435,  0.5289,  0.5435,  ...,  0.5581,  0.5581,  0.5289]],
 
          [[ 0.3490,  0.3190,  0.2740,  ...,  0.3490,  0.3640,  0.3490],
           [ 0.1689,  0.1989,  0.1389,  ...,  0.4691,  0.4390,  0.4240],
           [ 0.3190,  0.3190,  0.2890,  ...,  0.5141,  0.4841,  0.4841],
           ...,
           [ 0.1089,  0.1389,  0.0488,  ...,  0.6642,  0.6792