In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import time
import random
import json
import gc

import PIL
from PIL import Image
import pylustrator
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import nibabel as nib
from einops import rearrange
from scipy import ndimage, stats
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torchvision import transforms


dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from research.data.natural_scenes import NaturalScenesDataset
from research.experiments.nsd.nsd_access import NSDAccess
from research.metrics.metrics import cosine_distance, top_knn_test, r2_score, pearsonr
from pipeline.utils import get_data_iterator, DisablePrints

In [None]:
nsd_path = Path('D:\\Datasets\\NSD\\')
nsd = NaturalScenesDataset(nsd_path, coco_path='X:\\Datasets\\COCO')
stimuli_path = nsd_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'
stimulus_images = h5py.File(stimuli_path, 'r')['imgBrick']

In [None]:
model_name = 'ViT-B=32' #'clip-vit-large-patch14'
group_name = 'group-10'

subjects = [f'subj0{i}' for i in range(1, 9)]
embedding_name = 'embedding'
fold_name = 'val'

embeddings = h5py.File(nsd_path / f'derivatives/decoded_features/{model_name}/{group_name}.hdf5', 'r')

results_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold_name}/{embedding_name}/'
results_path.mkdir(exist_ok=True, parents=True)

Y_full = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}.hdf5', 'r')[embedding_name][:]

In [None]:
# Load a clip model
import clip

device = torch.device('cuda')
print(clip.available_models())
clip_model_name = 'ViT-B/32'
full_model, preprocess = clip.load(clip_model_name, device=device)
perceptor = full_model.visual

In [None]:
# Combine evaluation summaries into one folder

import shutil

groups = ['group-4', 'group-10', 'group-11', 'group-12', 'group-13', 'group-14']

for group_name in groups:
    results_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold_name}/{embedding_name}/'
    shutil.copyfile(
        results_path / 'decoding_evaluation/evaluation_summary.png', 
        nsd_path / f'derivatives/figures/decoding/{model_name}/decoding_evaluation_summary/{group_name}.png'
    )
    

In [None]:
# Load encoder data

from research.models.fmri_decoders import Decoder

num_voxels = None

results = {}
folds = {}
for fold in ('val', 'test'):
    fold_data = {
        'X_all': [],
        'Y_all': [],
        'Y_pred_all': [],
        'stimulus_ids_all': [],
    }
    folds[fold] = fold_data
models_all = []
state_dicts_all = []
indices_all = []

load_X = True

for i, subject in enumerate(subjects):
    print(subject)
    
    train_mask, val_mask, test_mask = nsd.get_split(subject, 'split-01')
    
    if load_X:
        subject_embeddings = embeddings[f'{subject}/{embedding_name}']
        config = dict(subject_embeddings.attrs)

        model_params = {k: config[k] for k in ('layer_sizes', 'dropout_p')}
        model = Decoder(**model_params)
        model = model.eval()
        state_dict = {k: torch.from_numpy(v[:]) for k, v in subject_embeddings['model'].items()}
        state_dicts_all.append(state_dict)
        model.load_state_dict({k: v.clone() for k, v in state_dict.items()})
        models_all.append(model)
        
        betas_params = {
            k: config[k] 
            for k in (
                'subject_name', 'voxel_selection_path', 
                'voxel_selection_key', 'num_voxels', 'return_volume_indices', 'threshold'
            )
        }
        if betas_params['threshold'] is not None:
            betas_params['num_voxels'] = None
            betas_params['return_tensor_dataset'] = False
        betas, betas_indices = nsd.load_betas(**betas_params)
        folds['val']['X_all'].append(betas[val_mask])
        folds['test']['X_all'].append(betas[test_mask])
        indices_all.append(betas_indices)

    stimulus_params = dict(
        subject_name=subject,
        stimulus_path=f'derivatives/stimulus_embeddings/{model_name}.hdf5',
        stimulus_key=embedding_name,
        delay_loading=False,
        return_tensor_dataset=False,
        return_stimulus_ids=True,
    )
    stimulus, stimulus_ids = nsd.load_stimulus(**stimulus_params)
    for fold, mask in [('val', val_mask), ('test', test_mask)]:
        
        folds[fold]['stimulus_ids_all'].append(stimulus_ids[mask])
    
        Y = stimulus[mask].astype(np.float32)
        Y = Y.reshape(Y.shape[0], -1)
        folds[fold]['Y_all'].append(Y)
    
        Y_pred = subject_embeddings[f'{fold}/Y_pred'][:]
        Y_pred = Y_pred / np.linalg.norm(Y_pred, axis=1)[:, None]
        folds[fold]["Y_pred_all"].append(Y_pred)
        
locals().update(folds[fold_name])

In [None]:
list(subject_embeddings.items())

In [None]:
top_knn_accuracy

# Basic Evaluation

In [None]:
# Top k accuracy figure

N = 1000
top_k_values = [1, 5, 10, 50, 100, 500]
chance_accuracy = [k / N for k in top_k_values]
fold = 'test'
metric = 'cosine'

for fold in ('val', 'test'):
    top_knn_accuracy = {}
    for subject_id, subject in enumerate(subjects):
        
        stimulus_ids = folds[fold]['stimulus_ids_all'][subject_id]
        Y = folds[fold]['Y_all'][subject_id]
        Y_pred = folds[fold]['Y_pred_all'][subject_id]
        
        unique_stimulus_ids, unique_index, unique_inverse = np.unique(
            stimulus_ids, return_index=True, return_inverse=True)
        
        top_knn_accuracy[subject] = top_knn_test(
            Y[unique_index], Y_pred, unique_inverse, k=top_k_values, metric=metric)

    plt.figure(figsize=(12, 8))
    plt.xticks(ticks=range(len(top_k_values)), labels=top_k_values)
    plt.title(f'Top knn accuracy (n={N})\n{model_name=}, {embedding_name=}, {group_name=}')
    plt.xlabel('k')
    plt.ylabel('accuracy')
    plt.plot(range(len(top_k_values)), chance_accuracy, label='chance (k/n)', color='gray')
    for subject, subject_results in top_knn_accuracy.items():
        plt.plot(range(len(top_k_values)), subject_results, label=subject)

    plt.grid()
    plt.legend()
    file_name = 'top_k_accuracy.png'
    out_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold}/{embedding_name}/decoding_evaluation'
    out_path.mkdir(exist_ok=True, parents=True)
    plt.savefig(out_path / file_name, pad_inches=0)
    plt.show()

In [None]:
# Print standard deviation of each dimension
Y_std = Y_full.std(axis=0)
Y_std_argsort_ids = Y_std.argsort()
print('dim, std')
for std, dim in zip(Y_std[Y_std.argsort()], Y_std.argsort()):
    print(f'{dim}, {std:.5f}')

In [None]:
# R^2 histograms

fig, ax = plt.subplots(nrows=1, ncols=8, figsize=(2 * len(subjects), 3), 
                      sharex=True, sharey=True,)
fig.tight_layout()
fig.subplots_adjust(bottom=0.2, top=0.9, left=0.07)

out_path = results_path / 'decoding_evaluation'
out_path.mkdir(exist_ok=True, parents=True)

r2_all = []
r_all = []
cosine_dist_all = []
distance_rank_all = []
top_knn_all = []

for i, subject in enumerate(subjects):
    print(subject)
    print('dim, r2')
    r2 = r2_score(torch.from_numpy(Y_all[i]), torch.from_numpy(Y_pred_all[i]), reduction=None)
    r2_argsort_ids = np.argsort(r2)
    for dim, dim_r2 in zip(r2_argsort_ids[:10], r2[r2_argsort_ids[:10]]):
        print(f'{round(dim.item(), 3)}, {round(dim_r2.item(), 3)}')
    r2_all.append(r2)
    r2 = torch.clone(r2)
    r2[r2 < -0.25] = 0
    ax[i].hist(r2)
    #ax[i].set_ylim(0, 225)
    ax[i].set_xlim(-0.1, 0.7) 
    ax[i].set_xticks([0.0, 0.2, 0.4, 0.6])
    #if i > 0:
        #ax[i].set_yticks([])
    #ax[i].set_xlabel(subject)

fig.suptitle('Histogram of R^2 for Brain-Decoded Embeddings, 8 Participants')
fig.supxlabel('R^2')
fig.supylabel('Number of Dimensions')
file_name = 'variance_explained_histogram.png'
plt.savefig(out_path / file_name, pad_inches=0)
plt.show()

# r histograms

fig, ax = plt.subplots(nrows=1, ncols=8, figsize=(2 * len(subjects), 3), 
                      sharex=True, sharey=True,)
fig.tight_layout()
fig.subplots_adjust(bottom=0.2, top=0.9, left=0.07)
for i, subject in enumerate(subjects):
    r = pearsonr(torch.from_numpy(Y_all[i]), torch.from_numpy(Y_pred_all[i]), reduction=None)
    r_all.append(r)
    ax[i].hist(r)

fig.suptitle('Histogram of pearsonr for Brain-Decoded Embeddings, 8 Participants')
fig.supxlabel('pearsonr')
fig.supylabel('# Dimensions')
file_name = 'pearsonr_histogram.png'
plt.savefig(out_path / file_name, pad_inches=0)
plt.show()

# cosine histogram

fig, ax = plt.subplots(nrows=1, ncols=8, figsize=(2 * len(subjects), 3), 
                      sharex=True, sharey=True,)
fig.tight_layout()
fig.subplots_adjust(bottom=0.2, top=0.9, left=0.07)
for i, subject in enumerate(subjects):
    cosine_dist = 1. - (Y_all[i] * Y_pred_all[i]).sum(axis=1)
    cosine_dist_all.append(cosine_dist)
    ax[i].hist(cosine_dist)

fig.suptitle('Histogram of Cosine Distance for Brain-Decoded Embeddings, 8 Participants')
fig.supxlabel('Cosine Distance')
fig.supylabel('# Stimuli')
file_name = 'cosine_distance_histogram.png'
plt.savefig(out_path / file_name, pad_inches=0)
plt.show()

# distance classification histogram
fig, ax = plt.subplots(nrows=1, ncols=8, figsize=(2 * len(subjects), 3), 
                      sharex=True, sharey=True,)
fig.tight_layout()
fig.subplots_adjust(bottom=0.2, top=0.9, left=0.07)
for i, subject in enumerate(subjects):
    Y = Y_all[i]
    Y_pred = Y_pred_all[i]
    stimulus_ids = stimulus_ids_all[i]
    unique_stimulus_ids, unique_index, unique_inverse = np.unique(
        stimulus_ids, return_index=True, return_inverse=True
    )
    Y_dists = 1. - Y[unique_index] @ Y_pred.T
    Y_dist_ids = Y_dists.argsort(axis=0)
    Y_ks = np.argwhere((Y_dist_ids == unique_inverse[None, :]).T)
    
    distance_rank_all.append(Y_ks)
    top_knn_all.append(top_knn_test(
            Y[unique_index], Y_pred, unique_inverse, k=top_k_values, metric=metric))
    ax[i].hist(Y_ks[:, 1])

fig.suptitle('Histogram of Distance Classifcation Rankings for Brain-Decoded Embeddings, 8 Participants')
fig.supxlabel('Classification Rank')
fig.supylabel('# Stimuli')
file_name = 'distance_ranking_histogram.png'
plt.savefig(out_path / file_name, pad_inches=0)
plt.show()

In [None]:
# Plot lines
for subject_id, subject in enumerate(subjects):
    subject_path = out_path / f'dimension_plots/{subject}'
    subject_path.mkdir(exist_ok=True, parents=True)
    
    r = r_all[subject_id]
    r2 = r2_all[subject_id]
    r2_argsort = np.argsort(r2)
    num_dims = r.shape[0]
    dim_ids = np.arange(num_dims)
    
    dim_ids = np.concatenate([dim_ids[:5], dim_ids[num_dims // 2 - 2: num_dims // 2 + 2], dim_ids[-5:]])
    dims = r2_argsort[dim_ids]
    
    dim_start = 0
    dim_end = 100
    
    for i, dim in enumerate(dims):
        plt.figure()
        plt.plot(np.arange(dim_start, dim_end), Y_all[subject_id][dim_start:dim_end, dim], label='target')
        plt.plot(np.arange(dim_start, dim_end), Y_pred_all[subject_id][dim_start:dim_end, dim], label='prediction')
        plt.title(f'{subject=}, r2_rank={dim_ids[i]}, dim={int(dim)}, r2={r2[dim]:.2f}, r={r[dim]:.2f}')
        plt.legend()
        file_name = f'rank-{dim_ids[i]}_dim-{dim}.png'
        plt.savefig(subject_path / file_name, pad_inches=0)
        plt.close()
        
    

In [None]:
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 

num_checks = 50

font = ImageFont.truetype('arial.ttf', 64)
for subject_id, subject in enumerate(subjects):
    Y = Y_all[subject_id]
    Y_pred = Y_pred_all[subject_id]
    stimulus_ids = stimulus_ids_all[subject_id]
    unique_stimulus_ids, unique_index, unique_inverse = np.unique(
        stimulus_ids, return_index=True, return_inverse=True
    )
    Y_dists = 1. - Y[unique_index] @ Y_pred.T
    Y_dist_ids = Y_dists.argsort(axis=0)
    Y_ks = np.argwhere((Y_dist_ids == unique_inverse[None, :]).T)
    Y_dist_stimulus_ids = unique_stimulus_ids[Y_dist_ids]
    
    for i, (y_ks, y_dist_stimulus_ids) in enumerate(zip(Y_ks, Y_dist_stimulus_ids.T)):
        y_k = y_ks[1]
        y_dists = Y_dists[:, i]
        y_dists.sort()
        
        show_ids = np.array([0, 1, 2, *[i for i in range(np.clip(y_k, 6, 993) - 3, np.clip(y_k, 6, 993) + 4)]])

        show_stimulus_ids = y_dist_stimulus_ids[show_ids]

        images = np.concatenate([stimulus_images[stimulus_id] for stimulus_id in show_stimulus_ids], axis=1)

        text_images = []
        for j in show_ids:
            img = Image.new('RGB', (425, (64 + 25) * 2))
            draw = ImageDraw.Draw(img)
            fill = 'green'
            if j < y_k:
                fill = 'red'
            elif j > y_k:
                fill = 'blue'
            draw.multiline_text((212, 50), str(j), fill=fill, anchor='mm', font=font, align='center')
            draw.multiline_text((212, 64+50), f'd={y_dists[j]:.2f}', anchor='mm', font=font, align='center')
            text_images.append(img)
        text_images = np.concatenate(text_images, axis=1)
        images = np.concatenate([images, text_images], axis=0)
    
        out_path = results_path / 'distance_ranking' / subject
        out_path.mkdir(exist_ok=True, parents=True)
        file_name = f'stim-{y_dist_stimulus_ids[y_k]}_image-{i}.png'
        Image.fromarray(images).save(out_path / file_name)
        if i == num_checks:
            break

In [None]:
Y_dist_stimulus_ids = unique_stimulus_ids[Y_dist_ids]
Y_dist_stimulus_ids.shape

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(8, 8), sharex=True)
plt.tight_layout()
fig.subplots_adjust(top=0.95, wspace=0.25)
plt.suptitle(f'Brain-Decoder Evaluation, {group_name=}')


ax[0, 0].bar(np.arange(8), [r2.mean() for r2 in r2_all])
ax[0, 0].set_ylabel('R^2')
ax[0, 1].bar(np.arange(8), [r.mean() for r in r_all])
ax[0, 1].set_ylabel('Pearson Correlation')
ax[1, 0].bar(np.arange(8), [cosine_dist.mean() for cosine_dist in cosine_dist_all])
ax[1, 0].set_ylabel('Cosine Distance')
ax[1, 0].set_xlabel('Subject')
ax[1, 1].bar(np.arange(8), [top_knn[0] for top_knn in top_knn_all])
ax[1, 1].set_ylabel('Top 1 Accuracy')
ax[1, 1].set_xlabel('Subject')

for x in ax.flat:
    x.set_xticks(np.arange(8), np.arange(1, 9))
    
file_name = 'evaluation_summary.png'
plt.savefig(out_path / file_name, bbox_inches='tight')
plt.show()

# PCA

In [None]:
# Fit individual and group PCAs

Y_pca = PCA()
Y_pca.fit(Y_full)

Y_pred_group_pca = PCA()
Y_pred_group_pca.fit(np.concatenate(Y_pred_all))

Y_pred_pcas = []
for i, subject in enumerate(subjects):
    print(subject)
    
    Y_pred_pca = PCA()
    Y_pred_pcas.append(Y_pred_pca)
    Y_pred_pca.fit(Y_pred_all[i])

In [None]:
# Compute variance explained of different PCAs

subject_id = 0

r2_results = {}

for subject_id, subject in enumerate(subjects):
    Y_pred = Y_pred_all[subject_id]
    Y_pred_pca = Y_pred_pcas[subject_id]
    r2_results[subject] = subject_results = {}
    
    pcas = [('subject', Y_pred_pca), ('group', Y_pred_group_pca), ('full', Y_pca)]
    for pca_name, pca in pcas:
        print(subject, pca_name)
        subject_results[pca_name] = []
        Y_pred_transformed = pca.transform(Y_pred)
        for j in range(Y_pred.shape[1]):
            Y_pred_inv = Y_pred_transformed[:, :j+1] @ pca.components_[:j+1] + pca.mean_
            r2 = r2_score(torch.from_numpy(Y_pred), torch.from_numpy(Y_pred_inv))
            subject_results[pca_name].append(r2.item())



In [None]:
# Plot variance explained of PCAs

fig, ax = plt.subplots(nrows=1, ncols=8, figsize=(2 * len(subjects), 3), 
                      sharex=True, sharey=True,)
fig.tight_layout()
fig.subplots_adjust(bottom=0.2, top=0.9, left=0.05)

for i, subject in enumerate(subjects):
    subject_results = r2_results[subject]
    ax[i].set_xscale('log', base=2)
    ticks = [1, 2, 8, 32, 128, 512]
    ax[i].set_xticks(ticks=ticks, labels=ticks)
    ax[i].grid(visible=True)
    #ax[i].set_xlim(1, 512)
    #ax[i].set_ylim(0., 1.)
    #ax[i].axhline(0.95, color='black', linestyle='dashed')
    #for t in ticks[1:-1]:
        #ax[i].axvline(t, color='gray', linestyle='dashed')
    for pca_name, r2 in subject_results.items():
        #r2_sum = [sum(r2[:i+1]) for i in range(Y_full.shape[1])]
        x = np.arange(Y_full.shape[1]) + 1
        ax[i].plot(x, r2, label=pca_name)

ax[0].legend()
fig.suptitle('Variance Explained in Brain-Decoded Embeddings by PCA Components, 8 Participants')
fig.supxlabel('Number of PCA Components')
fig.supylabel('Variance Explained')
file_name = 'pca_variance_explained.png'
plt.savefig(results_path / file_name, pad_inches=0)
plt.show()

In [None]:
# Similarity of PCA components (group vs individual)

num_components = 32

group_components = Y_pred_group_pca.components_[:num_components]
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(12, 6), )
ax = ax.flatten()
fig.tight_layout()
fig.subplots_adjust(top=0.95)

for i, subject in enumerate(subjects):
    subject_components = Y_pred_pcas[i].components_[:num_components]
    ticks = [2,  8, 16, 24, 32]
    ax[i].set_xticks(np.array(ticks)-1, ticks)
    ax[i].imshow(subject_components @ group_components.T, vmin=-1, vmax=1, cmap='bwr')
#ax[0].legend()
fig.suptitle('RSM Between Subject and Group PCA Components')
#fig.supxlabel('Number of PCA Components')
#fig.supylabel('Variance Explained')
file_name = 'pca_component_rsm.png'
plt.savefig(results_path / file_name, pad_inches=0)
plt.show()

In [None]:
# Similarity of PCA components (group vs full)

num_components = 32
group_components = Y_pred_group_pca.components_
full_components = Y_pca.components_

fig = plt.figure(figsize=(32, 32))
fig.tight_layout()
plt.imshow(group_components @ full_components.T, vmin=-1, vmax=1, cmap='bwr')

file_name = 'pca_component_rsm_full_vs_group.png'
plt.savefig(results_path / file_name, pad_inches=0)
plt.show()

In [None]:
# Variance explained of the decoding model in the PCA space

# Plot variance explained of PCAs

fig, ax = plt.subplots(nrows=3, ncols=8, figsize=(2 * len(subjects), 8), 
                      sharex=True, sharey=True,)
fig.tight_layout()
fig.subplots_adjust(bottom=0.075, top=0.9,)

num_components = 32
for subject_id, subject in enumerate(subjects):
    print(subject)
    Y = Y_all[subject_id]
    Y_pred = Y_pred_all[subject_id]
    Y_pca_tranformed = Y_pred_group_pca.transform(Y)[:, :num_components]
    Y_pred_pca_transformed = Y_pred_group_pca.transform(Y_pred)[:, :num_components]
    
    #Y_pca_tranformed = np.maximum(-Y_pca_tranformed, 0)
    #Y_pred_pca_transformed = np.maximum(-Y_pred_pca_transformed, 0)
    
    r2_whole = r2_score(
        torch.from_numpy(Y_pca_tranformed), 
        torch.from_numpy(Y_pred_pca_transformed),
    reduction=None)
    
    r2_positive = r2_score(
        torch.from_numpy(np.maximum(Y_pca_tranformed, 0)), 
        torch.from_numpy(np.maximum(Y_pred_pca_transformed, 0)),
    reduction=None)
    
    r2_negative = r2_score(
        torch.from_numpy(np.maximum(-Y_pca_tranformed, 0)), 
        torch.from_numpy(np.maximum(-Y_pred_pca_transformed, 0)),
    reduction=None)
    
    #print(r2[:num_components].tolist())
    for i in range(3):
        ax[i, subject_id].set_ylim(-0.5, 0.8)
        ax[i, subject_id].set_xticks([1, 8, 16, 24, 32])
    ax[0, subject_id].bar(x=np.arange(num_components), height=r2_whole[:num_components], bottom=0, width=1.)
    ax[1, subject_id].bar(x=np.arange(num_components), height=r2_positive[:num_components], bottom=0, width=1.)
    ax[2, subject_id].bar(x=np.arange(num_components), height=r2_negative[:num_components], bottom=0, width=1.)
    
    ax[0, subject_id].set_title(subject)

#ax[0].legend()
ax[0, 0].set_ylabel('Whole PCA')
ax[1, 0].set_ylabel('Positive PCA')
ax[2, 0].set_ylabel('Negative PCA')
fig.suptitle('Variance Explained by Brain-Decoding Model in Group PCA Space')
fig.supxlabel('PCA Dimension')
#fig.supylabel('Variance Explained')
file_name = 'decoding_variance_explained_group_pca.png'
plt.savefig(results_path / file_name, pad_inches=0)
plt.show()

In [None]:
# Top K accuracy in PCA space

top_k_values = [1]
chance_accuracy = [k / N for k in top_k_values]
metric = 'cosine'

for fold in ('val', 'test'):
    results = {}
    for subject_id, subject in enumerate(subjects):
        print(subject)
        
        Y = folds[fold]['Y_all'][subject_id]
        Y_pred = folds[fold]['Y_pred_all'][subject_id]
        Y_pca_tranformed = Y_pred_group_pca.transform(Y)
        Y_pred_pca_transformed = Y_pred_group_pca.transform(Y_pred)

        stimulus_ids = folds[fold]['stimulus_ids_all'][subject_id]
        
        unique_stimulus_ids, unique_index, unique_inverse = np.unique(
            stimulus_ids, return_index=True, return_inverse=True)
        
        results[subject] = top_1_accuracy = []
        for num_components in range(1, Y.shape[1]):
            top_1_accuracy.append(top_knn_test(
                Y_pca_tranformed[unique_index][:, :num_components], 
                Y_pred_pca_transformed[:, :num_components], 
                unique_inverse, k=top_k_values, metric=metric
            )[0])

    plt.figure(figsize=(12, 6))

    for i, (subject_name, top_1_accuracy) in enumerate(results.items()):
        x = np.arange(Y_full.shape[1] - 1) + 1
        plt.plot(x, top_1_accuracy, label=subject_name)

    plt.legend(loc='lower right')
    plt.xticks(ticks=[i * 32 for i in range(17)])
    plt.grid(visible=True)
    plt.xlabel('num PCA components')
    plt.ylabel('top 1 accuracy')
    out_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold}/{embedding_name}'
    file_name = 'pca_vs_top1.png'
    plt.savefig(out_path / file_name, pad_inches=0)
    plt.show()



In [None]:
results_path

In [None]:
out_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold}/{embedding_name}'
out_path


In [None]:
#fig, ax = plt.subplots(nrows=1, ncols=8, figsize=(2 * len(subjects), 4), 
#                      sharex=True, sharey=True,)
#fig.tight_layout()
#fig.subplots_adjust(bottom=0.075, top=0.9,)



In [None]:
# Find PCA component images with CLIP retrieval

from sklearn.cluster import AgglomerativeClustering, OPTICS, KMeans
from hdbscan import HDBSCAN
from clip_retrieval.clip_client import ClipClient, Modality
from torchvision import transforms
from urllib.request import urlopen, HTTPError, URLError
from PIL import UnidentifiedImageError

indice_name = 'laion_400m'
client = ClipClient(
    url="https://knn5.laion.ai/knn-service", 
    indice_name=indice_name,
    aesthetic_score=9,
    aesthetic_weight=0.0,
    use_safety_model=False,
    use_violence_detector=True,
    num_images=500,
)
dims = 512

num_components = 32
group_components = Y_pred_group_pca.components_

trans = torch.nn.Sequential(
    transforms.Resize(256),
    transforms.CenterCrop(256),
)

for sign, sign_name in [(1, 'positive'), (-1, 'negative')]:
    num_images = 10
    images = []
    for c in range(num_components):
        print(c)
        response = client.query(embedding_input=(group_components[c] * sign).tolist())
        component_images = []
        for r in response:
            try:
                component_images.append(trans(Image.open(urlopen(r['url'], timeout=10))).convert('RGB'))
            except UnidentifiedImageError:
                print("UnidentifiedImageError")
            except HTTPError:
                print("HTTPError")
            except URLError:
                print("URLError")
            except ConnectionResetError:
                print("ConnectionResetError")
            except:
                print("timeout?")

            if len(component_images) >= num_images:
                break
        component_images = np.concatenate(component_images, axis=1)
        images.append(component_images)
    images = np.concatenate(images, axis=0)
    file_name = f'pca_component_images_{sign_name}.png'
    Image.fromarray(images).save(results_path / file_name)

In [None]:
# PCA brain decoded image retreival

from clip_retrieval.clip_client import ClipClient, Modality
from torchvision import transforms
from urllib.request import urlopen, HTTPError, URLError
from PIL import UnidentifiedImageError

run_name = 'run-001'
average_repetitions = False
num_images = 3

indice_name = 'laion_400m'
clip_retreival_params = dict(
    url="https://knn5.laion.ai/knn-service", 
    indice_name=indice_name,
    aesthetic_score=9,
    aesthetic_weight=0.0,
    use_safety_model=False,
    use_violence_detector=False,
    num_images=500,
)

client = ClipClient(
    **clip_retreival_params
)

trans = torch.nn.Sequential(
    transforms.Resize(256),
    transforms.CenterCrop(256),
)

out_path = results_path / run_name
out_path.mkdir(exist_ok=True, parents=True)
with open(out_path / 'params.json', 'w') as f:
    f.write(json.dumps({
        'average_repetitions': average_repetitions, 
        'run_name': run_name, 
        **clip_retreival_params
    }))

for subject_id, subject in enumerate(subjects):
    subject_path = out_path / subject / 'images'
    subject_path.mkdir(exist_ok=True, parents=True)
    print(subject_path)
    
    Y_pred = Y_pred_all[subject_id]
    stimulus_ids = stimulus_ids_all[subject_id]
    unique_stimulus_ids = np.unique(stimulus_ids)
    
    if average_repetitions:
        Y_pred = np.stack([Y_pred[i == stimulus_ids].mean(axis=0) for i in unique_stimulus_ids])
        stimulus_ids = unique_stimulus_ids
    
    for stimulus_id in unique_stimulus_ids:
        stimulus_image = trans(Image.fromarray(stimulus_images[stimulus_id]))
        
        for image_version, image_id in enumerate(np.where(stimulus_ids == stimulus_id)[0]):
            y_pred = Y_pred[image_id]
            images = [stimulus_image]
            response = client.query(embedding_input=y_pred.tolist())
            for r in response:
                try:
                    images.append(trans(Image.open(urlopen(r['url'], timeout=10))).convert('RGB'))
                except UnidentifiedImageError:
                    print("UnidentifiedImageError")
                except HTTPError:
                    print("HTTPError")
                except URLError:
                    print("URLError")
                except ConnectionResetError:
                    print("ConnectionResetError")
                except:
                    print("timeout?")

                if len(images) > num_images:
                    break

            images = np.concatenate(images, axis=1)
            file_name = f'stim-{stimulus_id}_image-{image_id}_v-{image_version}.png'
            Image.fromarray(images).save(subject_path / file_name)
            

In [None]:
# PCA component distributions

Y_pred_all_transformed = Y_pred_group_pca.transform(np.concatenate(Y_pred_all))

size = 2
fig, ax = plt.subplots(nrows=4, ncols=8, figsize=(8 * size, 4 * size), 
                      sharex=False, sharey=True,)
ax = ax.flatten()
fig.tight_layout()
fig.subplots_adjust(bottom=0.1, top=0.9, left=0.07)

for i in range(ax.shape[0]):
    ax[i].hist(Y_pred_all_transformed[:, i])
    #ax[i].set_ylim(0, 225)
    #ax[i].set_xlim(-0.1, 0.7) 
    #ax[i].set_xticks([0.0, 0.2, 0.4, 0.6])
    #if i > 0:
        #ax[i].set_yticks([])
    #ax[i].set_xlabel(subject)

fig.suptitle('Histogram of Group PCA Dimensions')
fig.supxlabel('Dimension Values')
fig.supylabel('Number of Stimuli')
file_name = 'group_pca_histogram.png'
plt.savefig(results_path / file_name, pad_inches=0)
plt.show()

In [None]:
Y_pred_all_transformed = Y_pred_group_pca.transform(np.concatenate(Y_pred_all))

for component_id in range(32):
    y = Y_pred_all_transformed[:, component_id]
    sample_points = np.linspace(y.min(), y.max(), 90)
    sample_ids = [np.argmin(np.abs(y - s)) for s in sample_points]
    sample_stimulus_ids = np.concatenate(stimulus_ids_all)[sample_ids]
    sample_images = np.array([stimulus_images[i] for i in sample_stimulus_ids])
    sample_images = rearrange(sample_images, '(n1 n2) h w c -> (n2 h) (n1 w) c', n2=3)
    out_path = results_path / 'pca_component_images'
    out_path.mkdir(exist_ok=True, parents=True)
    file_name = f'component-{component_id}.png'
    Image.fromarray(sample_images).save(out_path / file_name)

In [None]:
# decoded tsne
Y_pred_tsne = []
tsne_params = dict(
    n_components=2, 
    metric='cosine',
    init="pca", 
    #learning_rate="auto", 
    random_state=2,
    verbose=1, 
)
for subject_id in range(8):
    print(subject_id)
    top_n_voxels = None
    tsne = TSNE(**tsne_params)
    Y_pred_tsne.append(tsne.fit_transform(Y_pred_all[subject_id]))
Y_pred_group_tsne = TSNE(**tsne_params).fit_transform(np.concatenate(Y_pred_all))

In [None]:
def tsne_image_plot(y, stimulus_ids, stimulus_images, image_size, num_images, extent):
    S = image_size * num_images
    full_image = np.zeros(shape=(S, S, 3), dtype=np.ubyte)

    coords = np.linspace(-extent, extent, num_images)
    grid = np.stack(np.meshgrid(coords, coords))
    grid = rearrange(grid, 'd h w -> (h w) d')
    grid.shape

    from sklearn.neighbors import NearestNeighbors
    neighbors = NearestNeighbors(metric='chebyshev')
    neighbors.fit(y)

    distances, ids = neighbors.kneighbors(grid, n_neighbors=1,)
    distances = rearrange(distances, '(h w) d -> h w d', h=num_images)
    ids = rearrange(ids, '(h w) d -> h w d', h=num_images)

    distance_threshold = extent / num_images
    for i in range(num_images):
        for j in range(num_images):
            if distances[i, j] > distance_threshold:
                continue
            stimulus_id = stimulus_ids[ids[i, j, 0]]
            stim_image = stimulus_images[stimulus_id]
            stim_image = Image.fromarray(stim_image)
            stim_image = stim_image.resize(size=(image_size, image_size), resample=PIL.Image.LANCZOS)
            stim_image = np.array(stim_image)
            full_image[i * image_size:(i + 1) * image_size, j * image_size:(j + 1) * image_size] = stim_image
    return full_image


out_path = results_path / 'tsne'
out_path.mkdir(exist_ok=True, parents=True)
for subject_id, subject in enumerate(subjects):
    full_image = tsne_image_plot(
        y=Y_pred_tsne[subject_id], 
        stimulus_ids=stimulus_ids_all[subject_id], 
        stimulus_images=stimulus_images,
        image_size=64, num_images=120, extent=60
    )

    file_name = f'tsne_{subject}.png'
    Image.fromarray(full_image).save(out_path / file_name)
    
full_image = tsne_image_plot(
    y=Y_pred_group_tsne, 
    stimulus_ids=np.concatenate(stimulus_ids_all), 
    stimulus_images=stimulus_images,
    image_size=64, num_images=200, extent=60
)

file_name = f'tsne_group.png'
Image.fromarray(full_image).save(out_path / file_name)

# Retreival

In [None]:
# One-hot clip retreival

from sklearn.cluster import AgglomerativeClustering, OPTICS, KMeans
from hdbscan import HDBSCAN
from clip_retrieval.clip_client import ClipClient, Modality
from torchvision import transforms
from urllib.request import urlopen, HTTPError, URLError
from PIL import UnidentifiedImageError

indice_name = 'laion_400m'
client = ClipClient(
    url="https://knn5.laion.ai/knn-service", 
    indice_name=indice_name,
    aesthetic_score=9,
    aesthetic_weight=0.0,
    use_safety_model=False,
    use_violence_detector=False,
    num_images=500,
)
dims = 512

trans = torch.nn.Sequential(
    transforms.Resize(256),
    transforms.CenterCrop(256),
)

num_images = 9
for dim in range(dims):
    print(dim)
    embedding_input = np.zeros(dims)
    embedding_input[dim] = 1.
    response = client.query(embedding_input=embedding_input.tolist())
    images = []
    for r in response:
        try:
            images.append(trans(Image.open(urlopen(r['url'], timeout=10))).convert('RGB'))
        except UnidentifiedImageError:
            print("UnidentifiedImageError")
        except HTTPError:
            print("HTTPError")
        except URLError:
            print("URLError")
        except ConnectionResetError:
            print("ConnectionResetError")
        except:
            print("timeout?")

        if len(images) >= num_images:
            break
    
    image = np.stack(images)
    image = rearrange(image, '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=3)
    out_path = nsd_path / f'derivatives/clip_retreival/{indice_name}/'
    out_path.mkdir(exist_ok=True, parents=True)
    file_name = f'dim_{dim:04}.png'
    Image.fromarray(image).save(out_path / file_name)

In [None]:
# One-hot stimulus set retreival

trans = torch.nn.Sequential(
    transforms.Resize(256),
    transforms.CenterCrop(256),
)

num_images = 9
for dim in range(Y_full.shape[1]):
    y = Y_full[:, dim]
    argsort_ids = np.argsort(y)
    
    positive_images = []
    negative_images = []
    
    for stimulus_id in argsort_ids[:num_images]:
        positive_images.append(trans(Image.fromarray(stimulus_images[stimulus_id])))
    for stimulus_id in argsort_ids[::-1][:num_images]:
        negative_images.append(trans(Image.fromarray(stimulus_images[stimulus_id])))
        
    positive_image = rearrange(np.stack(positive_images), '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=3)
    negative_image = rearrange(np.stack(negative_images), '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=3)
    
    out_path = nsd_path / f'derivatives/clip_retreival/mscoco_73k/'
    out_path.mkdir(exist_ok=True, parents=True)
    Image.fromarray(positive_image).save(out_path /  f'dim_{dim:04}_positive.png')
    Image.fromarray(negative_image).save(out_path /  f'dim_{dim:04}_negative.png')

In [None]:
def top_images_73k(Y_dims, num_images=10, labels=None, font_size=32, concatenate=True):
    trans = torch.nn.Sequential(
        transforms.Resize(256),
        transforms.CenterCrop(256),
    )
    Y_full_activations = Y_full @ Y_dims.T
    
    font = ImageFont.truetype('arial.ttf', font_size)
    images_all = []
    for i in range(Y_dims.shape[0]):
        y = Y_full_activations[:, i]
        stimulus_ids = np.argsort(y)[::-1]
        images = [
            trans(Image.fromarray(stimulus_images[stimulus_id]))
            for stimulus_id in stimulus_ids[:num_images]
        ]
        if labels is not None:
            img = Image.new('RGB', (256, 256))
            draw = ImageDraw.Draw(img)
            draw.multiline_text((128, 128),labels[i], anchor='mm', font=font, align='center')
            images = [img] + images
        if concatenate:
            images = np.concatenate(images, axis=1)
        images_all.append(images)
    if concatenate:
        images_all = np.concatenate(images_all)
    
    return images_all

In [None]:
from research.experiments.nsd.nsd_clip_reconstruction import (
    reconstruct,
    load_vqgan_model
)

from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 

model_name = "vqgan_imagenet_f16_16384"
vqgan_checkpoint = f"{model_name}.ckpt"
vqgan_config = f"{model_name}.yaml"
vqgan_model = load_vqgan_model(vqgan_config, vqgan_checkpoint).cuda()

def top_images_clip_recon(Y_dims, num_images=10, labels=None):
    font = ImageFont.truetype('arial.ttf', 32)
    
    images_all = []
    for i in range(Y_dims.shape[0]):
        images = []
        for j in range(num_images):
            recon_image = reconstruct(
                stimulus_embeddings={'embedding': Y_dims[i:i+1].cuda()},
                hook_modules={'': 'embedding'},
                model=vqgan_model,
                vqgan_checkpoint=vqgan_checkpoint,
                perceptor=full_model.visual.eval().requires_grad_(False).to(device),
                device=torch.device('cuda'),
                max_iterations=500,
                embedding_iterations=500,
            )
            images.append(recon_image[-1][0])
            
        if labels is not None:
            img = Image.new('RGB', (224, 224))
            draw = ImageDraw.Draw(img)
            draw.multiline_text((112, 112),labels[i], anchor='mm', font=font, align='center')
            images = [img] + images
        images = np.concatenate(images, axis=1)
        images_all.append(images)
    images_all = np.concatenate(images_all)
    
    return images_all
    


In [None]:
from clip_retrieval.clip_client import ClipClient, Modality
from sklearn.cluster import AgglomerativeClustering, OPTICS, KMeans
from hdbscan import HDBSCAN

from torchvision import transforms
from urllib.request import urlopen, HTTPError, URLError
from PIL import UnidentifiedImageError


def top_images_clip_retreival(Y_dims, num_images=10, use_safety_model=False, use_violence_detector=False,):
    indice_name = 'laion_400m'
    client = ClipClient(
        url="https://knn5.laion.ai/knn-service", 
        indice_name=indice_name,
        aesthetic_score=9,
        aesthetic_weight=0.0,
        use_safety_model=use_safety_model,
        use_violence_detector=use_violence_detector,
        num_images=500,
    )

    trans = torch.nn.Sequential(
        transforms.Resize(256),
        transforms.CenterCrop(256),
    )
    images_all = []
    for i in range(Y_dims.shape[0]):
        response = client.query(embedding_input=Y_dims[i].tolist())
        images = []
        for r in response:
            try:
                images.append(trans(Image.open(urlopen(r['url'], timeout=10))).convert('RGB'))
            except UnidentifiedImageError:
                print("UnidentifiedImageError")
            except HTTPError:
                print("HTTPError")
            except URLError:
                print("URLError")
            except ConnectionResetError:
                print("ConnectionResetError")
            except:
                print("timeout?")

            if len(images) >= num_images:
                break

        image = np.concatenate(images, axis=1)
        images_all.append(image)
    images_all = np.concatenate(images_all)
    return images_all

#image = rearrange(image, '(n1 n2) h w c -> (n1 h) (n2 w) c', n1=3)
#out_path = nsd_path / f'derivatives/clip_retreival/{indice_name}/'
#out_path.mkdir(exist_ok=True, parents=True)
#file_name = f'dim_{dim:04}.png'
#Image.fromarray(image).save(out_path / file_name)

# Hidden Layer

In [None]:
# Extract hidden layer activations

Y_hidden_all = []
for subject_id in range(len(subjects)):
    model = models_all[subject_id]
    #for k, v in model.state_dict().items():
    #    print(k, v.shape)

    modules = dict(model.named_modules())
    #print(modules)
    hook_layer = modules['layers.1']

    def forward_hook(module, input, output):
        Y_hidden_all.append(output)

    hook_handle = hook_layer.register_forward_hook(forward_hook)

    X = X_all[subject_id][0]
    with torch.no_grad():
        model(X)

    hook_handle.remove()

In [None]:
# Histogram of hidden layer activations

for subject_id, subject in enumerate(subjects):
    size = 2
    fig, ax = plt.subplots(nrows=8, ncols=8, figsize=(8 * size, 8 * size), 
                          sharex=False, sharey=True,)
    ax = ax.flatten()
    fig.tight_layout()
    fig.subplots_adjust(bottom=0.1, top=0.9, left=0.07)

    for i in range(ax.shape[0]):
        ax[i].hist(Y_hidden_all[subject_id][:, i])
        #ax[i].set_ylim(0, 225)
        #ax[i].set_xlim(-0.1, 0.7) 
        #ax[i].set_xticks([0.0, 0.2, 0.4, 0.6])
        #if i > 0:
            #ax[i].set_yticks([])
        #ax[i].set_xlabel(subject)

    fig.suptitle('Histogram of Hidden Layer Activations')
    fig.supxlabel('Dimension Values')
    fig.supylabel('Number of Stimuli')
    file_name = f'hidden_activation_histogram_{subject}.png'
    out_path = results_path / 'hidden_layer_activations'
    out_path.mkdir(exist_ok=True, parents=True)
    plt.savefig(out_path / file_name, pad_inches=0)


In [None]:
for subject_id, subject in enumerate(subjects):
    print(subject)
    for component_id in range(32):
        y = Y_hidden_all[subject_id][:, component_id]
        sample_points = np.linspace(y.min(), y.max(), 90)
        sample_ids = [np.argmin(np.abs(y - s)) for s in sample_points]
        sample_stimulus_ids = stimulus_ids_all[subject_id][sample_ids]
        sample_images = np.array([stimulus_images[i] for i in sample_stimulus_ids])
        sample_images = rearrange(sample_images, '(n1 n2) h w c -> (n2 h) (n1 w) c', n2=3)
        out_path = results_path / 'hidden_layer_activation_images'
        out_path.mkdir(exist_ok=True, parents=True)
        file_name = f'subject-{subject}_component-{component_id}.png'
        Image.fromarray(sample_images).save(out_path / file_name)

In [None]:
# 

Y_pred_all_transformed = Y_pred_group_pca.transform()

for subject_id, subject in enumerate(subjects):
    components = models_all[0].state_dict()['layers.3.weight']
    for component_id in range(32):
        y = Y_hidden_all[subject_id][:, component_id]
        sample_points = np.linspace(y.min(), y.max(), 90)
        sample_ids = [np.argmin(np.abs(y - s)) for s in sample_points]
        sample_stimulus_ids = np.concatenate(stimulus_ids_all)[sample_ids]
        sample_images = np.array([stimulus_images[i] for i in sample_stimulus_ids])
        sample_images = rearrange(sample_images, '(n1 n2) h w c -> (n2 h) (n1 w) c', n2=3)
        out_path = results_path / 'hidden_layer_activation_images'
        out_path.mkdir(exist_ok=True, parents=True)
        file_name = f'subject-{subject}_component-{component_id}.png'
        Image.fromarray(sample_images).save(out_path / file_name)



In [None]:
trans = torch.nn.Sequential(
    transforms.Resize(256),
    transforms.CenterCrop(256),
)

num_components = 32
for subject_id, subject in enumerate(subjects):
    
    components = models_all[subject_id].state_dict()['layers.3.weight'].numpy()
    Y_full_activations = Y_full @ components
    
    images_all = []
    for i in range(num_components):
        y = Y_full_activations[:, i]
        stimulus_ids = np.argsort(y)[::-1]
        images = np.concatenate([
            trans(Image.fromarray(stimulus_images[stimulus_id]))
            for stimulus_id in stimulus_ids[:num_images]
        ], axis=1)
        images_all.append(images)
    images_all = np.concatenate(images_all)
    
    out_path = results_path / 'hidden_layer_activation_images_73k'
    out_path.mkdir(exist_ok=True, parents=True)
    file_name = f'subject-{subject}.png'
    Image.fromarray(images_all).save(out_path / file_name)

# Misc

In [None]:
# Check for correlations in stimulus set embeddings
from research.metrics.metrics import pearsonr

r = pearsonr(torch.from_numpy(Y_full[:, None]), torch.from_numpy(Y_full[:, :, None]), reduction=None).numpy()

r_ids = np.stack(np.meshgrid(*[np.arange(s) for s in r.shape]))
triu_ids = np.triu_indices_from(r, k=1)

r_ids_triu = r_ids[:, triu_ids[0], triu_ids[1]]
r_triu = r[triu_ids[0], triu_ids[1]]

r_argsort_ids = np.argsort(np.abs(r_triu))[::-1]
print('dimensions, r')
for dim_r, dims in list(zip(r_triu[r_argsort_ids], r_ids_triu[:, r_argsort_ids].T))[:50]:
    print(tuple(dims), round(dim_r, 7))
    
fig = plt.figure(figsize=(32, 32))
fig.tight_layout()
plt.imshow(r, vmin=-1, vmax=1, cmap='bwr')

#file_name = 'pca_component_rsm_full_vs_group.png'
#plt.savefig(results_path / file_name, pad_inches=0)
plt.show()

In [None]:
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 

img = Image.new('RGB', (256, 256))
draw = ImageDraw.Draw(img)
# font = ImageFont.truetype(<font-file>, <font-size>)
font = ImageFont.truetype('arial.ttf', 32)
# draw.text((x, y),"Sample Text",(r,g,b))
draw.multiline_text((128, 128),"Sample Text\nLine2", anchor='mm', font=font, align='center')
img

In [None]:
dimensions = [
    'trees', 'dog', 'cat', 'crowded', 'foggy', 'sunny', 'nighttime', 'sad', 'happy', 'angry', 'pets',
    'animals', 'children', 'event', 'activity'
]

with torch.no_grad():
    Y_dims = full_model.encode_text(clip.tokenize(dimensions).to(device))
Y_dims = Y_dims.cpu().numpy()
#Y_dims[:, [133, 312, 92]] = 0.
Y_dims = Y_dims / np.linalg.norm(Y_dims, axis=1)[:, None]

images_all = top_images_73k(Y_dims, num_images=10, labels=dimensions)

out_path = results_path / 'handpicked_clip_dimensions'
out_path.mkdir(exist_ok=True, parents=True)
file_name = f'top_images_73k.png'
Image.fromarray(images_all).save(out_path / file_name)

In [None]:
W = torch.randn(5000, 512).numpy()
r_random_all = []
for subject_id, subject in enumerate(subjects):
    r_random = pearsonr(
        torch.from_numpy(Y_all[subject_id] @ W.T), 
        torch.from_numpy(Y_pred_all[subject_id] @ W.T), 
    reduction=None)
    print(r_random.mean(), r_random.std())
    r_random_all.append(r_random)

In [None]:
r_random_all[0].sort().values

In [None]:
plt.hist(r_random_all[0])

In [None]:
for subject_id, subject in enumerate(subjects):
    Y, Y_pred = Y_all[subject_id], Y_pred_all[subject_id]
    Y_dims_proj = Y @ Y_dims.T
    Y_pred_dims_proj = Y_pred @ Y_dims.T
    Y_dims_proj_r = pearsonr(
        torch.from_numpy(Y_dims_proj), torch.from_numpy(Y_pred_dims_proj), 
    reduction=None)
    print(Y_dims_proj_r)

In [None]:
list((Path(os.environ['WINDIR']) / 'fonts').iterdir())

In [None]:
(Path(os.environ['WINDIR']) / 'fonts/DejaVuSans.ttf').exists()

In [None]:
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 

img = Image.new('RGB', (128, 128))
draw = ImageDraw.Draw(img)
# font = ImageFont.truetype(<font-file>, <font-size>)
#font = ImageFont.truetype('DejaVuSans.ttf', 16)
# draw.text((x, y),"Sample Text",(r,g,b))
draw.text((0, 0),"Sample Text",(255,255,255),)
img

# Dimension Optimization

In [None]:
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw


def brain_decodability_evaluation(
    Y, 
    Y_pred, 
    Y_full,
    stimulus_ids,
    W, 
    out_path,
    labels=None,
    image_labels=None
):
    out_path.mkdir(exist_ok=True, parents=True)
    num_components = W.shape[0]
    if labels is None:
        labels = [str(i) for i in range(num_components)]
    if image_labels is None:
        image_labels = labels
    
    W_triu_indices = torch.triu_indices(num_components, num_components, offset=1)
    W_rsm = (W @ W.T).cpu().numpy()
    W_rsm_triu = W_rsm[W_triu_indices[0], W_triu_indices[1]]
    
    r2_scores = r2_score(F.relu(Y @ W.T), F.relu(Y_pred @ W.T), reduction=None).cpu().numpy()
    
    plt.figure(figsize=(num_components * 0.2, 3))
    plt.title('Brain-Decoding Variance Explained of CLIP Dimensions')
    plt.xlabel('Dimension')
    plt.ylabel('R^2')
    plt.bar(np.arange(num_components), r2_scores)
    if labels is not None:
        plt.xticks(rotation=45, ha='right', labels=labels, ticks=np.arange(num_components))
    file_name = 'variance_explained.png'
    plt.savefig(out_path / file_name, bbox_inches='tight')
    # plt.show()
    
    r_scores = pearsonr(F.relu(Y @ W.T), F.relu(Y_pred @ W.T), reduction=None).cpu().numpy()
    
    plt.figure(figsize=(num_components * 0.2, 3))
    plt.title('Brain-Decoding Pearson Correlation of CLIP Dimensions')
    plt.xlabel('Dimension')
    plt.ylabel('pearsonr')
    plt.bar(np.arange(num_components), r_scores)
    if labels is not None:
        plt.xticks(rotation=45, ha='right', labels=labels, ticks=np.arange(num_components))
    file_name = 'pearsonr.png'
    plt.savefig(out_path / file_name, bbox_inches='tight')
    # plt.show()

    plt.figure(figsize=(num_components * 0.2,) * 2) 
    plt.imshow(W_rsm, cmap='bwr', vmin=-1, vmax=1)
    if labels is not None:
        plt.xticks(rotation=45, ha='right', labels=labels, ticks=np.arange(num_components))
        plt.yticks(labels=labels, ticks=np.arange(num_components))
    file_name = 'cosine_similarity_rsm.png'
    plt.savefig(out_path / file_name, bbox_inches='tight')
    # plt.show()
    
    plt.figure() 
    plt.hist(W_rsm[W_triu_indices[0], W_triu_indices[1]])
    file_name = 'cosine_similairty_histogram.png'
    plt.savefig(out_path / file_name, bbox_inches='tight', ticks=np.arange(num_components))
    # plt.show()

    Y_full_transformed = Y_full @ W.cpu().numpy().T

    size = 2
    num_cols = 8
    num_rows = int(num_components / num_cols)
    fig, ax = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(num_cols * size, num_rows * size), 
                          sharex=False, sharey=True,)
    ax = ax.flatten()
    fig.tight_layout()
    fig.subplots_adjust(bottom=0.1, top=0.95, left=0.07)

    for i in range(ax.shape[0]):
        y = Y_full_transformed[:, i].copy()
        ax[i].hist(y[y > 0])

    fig.suptitle('Histogram of 73k Stimulus Set Projected Onto Optimized Components')
    fig.supxlabel('Dimension Values')
    fig.supylabel('Number of Stimuli')
    file_name = 'component_histogram.png'
    plt.savefig(out_path / file_name, pad_inches=0)
    # plt.show()
    
    N = 1000
    top_k_values = [1, 5, 10, 50, 100, 500]
    chance_accuracy = [k / N for k in top_k_values]

    metric = 'euclidean'
    
    Y_W, Y_pred_W = F.relu(Y @ W.T), F.relu(Y_pred @ W.T)
    
    top_knn_accuracy = {}
    unique_stimulus_ids, unique_index, unique_inverse = np.unique(
        stimulus_ids, return_index=True, return_inverse=True)
    top_knn_accuracy = top_knn_test(
        Y_W[unique_index].cpu().numpy(), Y_pred_W.cpu().numpy(), unique_inverse, k=top_k_values, metric=metric)

    plt.figure(figsize=(12, 8))
    plt.xticks(ticks=range(len(top_k_values)), labels=top_k_values)
    plt.title(f'Top knn accuracy')
    plt.xlabel('k')
    plt.ylabel('accuracy')
    plt.plot(range(len(top_k_values)), chance_accuracy, label='chance (k/n)', color='gray')
    plt.plot(range(len(top_k_values)), top_knn_accuracy, label='actual')

    plt.grid()
    plt.legend()
    file_name = 'top_k_accuracy.png'
    plt.savefig(out_path / file_name, pad_inches=0)

    images_all = top_images_73k(W.cpu().numpy(), num_images=10, labels=image_labels)
    file_name = f'top_images_73k.png'
    Image.fromarray(images_all).save(out_path / file_name)

In [None]:
contrastive_criterion = ContrastiveDistanceLoss(squared_euclidean_distance)

In [None]:
stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'][subject_id])

In [None]:
Y.shape, Y_pred.shape

In [None]:
dataset = TensorDataset(Y, Y_pred)
dataloader = DataLoader(dataset, batch_size=3000)
data_iterator = get_data_iterator(dataloader)

In [None]:
Y_batch, Y_pred_batch = next(data_iterator)
batch

In [None]:
import torch.nn.functional as F
from scipy.cluster import hierarchy
from research.metrics.loss_functions import ContrastiveDistanceLoss
from research.metrics.metrics import squared_euclidean_distance

optimized_dims_path = nsd_path / f'derivatives/decoded_features/{model_name}/optimzied_dims/'
optimized_dims_path.mkdir(exist_ok=True, parents=True)

run_name = 'run36'

metric = 'euclidean'
top_k_values = [1]

contrastive_criterion = ContrastiveDistanceLoss(squared_euclidean_distance)

with h5py.File(optimized_dims_path / f'{group_name}.hdf5', 'w') as f:
    for subject_id in ['all']:
        params = {
            'subject_id': subject_id,
            'lr': 0.0002, 
            'r2_weight': 0.,
            'r_weight': 0.,
            'contrastive_weight': 1.,
            'cossim_weight': 0.,
            'correlation_weight': 0.0,
            'l1_weight': 0.,
            'num_components': 128,
            'num_iterations': 15000,
            'negative_slope': 0.05
        }
        locals().update(params)
        
        if subject_id == 'all':
            subject_name = subject_id
            Y_val_all = torch.cat([torch.from_numpy(Y_all[subject_id]).cuda() for subject_id in range(8)])
            Y_val_all_pred = torch.cat([torch.from_numpy(Y_pred_all[subject_id]).cuda() for subject_id in range(8)])

            Y_val = torch.from_numpy(Y_all[0]).cuda()
            Y_val_pred = torch.from_numpy(Y_pred_all[0]).cuda()
            stimulus_ids_val = torch.from_numpy(folds['val']['stimulus_ids_all'][0])

            Y_test = torch.from_numpy(folds['test']['Y_all'][0]).cuda()
            Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][0]).cuda()
            stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'][0])
            
            dataset = TensorDataset(Y_val_all, Y_val_all_pred)
            dataloader = DataLoader(dataset, batch_size=3000)
            data_iterator = get_data_iterator(dataloader)
            
        else:
            subject_name = f'subj0{subject_id + 1}'
            Y_val = torch.from_numpy(Y_all[subject_id]).cuda()
            Y_val_pred = torch.from_numpy(Y_pred_all[subject_id]).cuda()
            stimulus_ids_val = torch.from_numpy(folds['val']['stimulus_ids_all'][subject_id])

            Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
            Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
            stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'][subject_id])
            
            dataset = TensorDataset(Y_val, Y_val_pred)
            dataloader = DataLoader(dataset, batch_size=3000)
            data_iterator = get_data_iterator(dataloader)
            
            '''
            train_mask, val_mask, test_mask = nsd.get_split(subject_name, 'split-01')
            
            subject_embeddings = embeddings[f'{subject_name}/{embedding_name}']
            config = dict(subject_embeddings.attrs)
            
            betas_params = {
                k: config[k] 
                for k in (
                    'subject_name', 'voxel_selection_path', 
                    'voxel_selection_key', 'num_voxels', 'return_volume_indices', 'threshold'
                )
            }
            if betas_params['threshold'] is not None:
                betas_params['num_voxels'] = None
                betas_params['return_tensor_dataset'] = False
            print('1')
            betas, betas_indices = nsd.load_betas(**betas_params)
            X_train = torch.from_numpy(betas[train_mask])
            model = models_all[subject_id]

            print('2')
            with torch.no_grad():
                Y_train_pred = model(X_train).cuda()
                Y_train_pred = Y_train_pred / torch.linalg.norm(Y_train_pred, axis=1)[:, None]
            print('3')
            
            stimulus_params = dict(
                subject_name=subject_name,
                stimulus_path=f'derivatives/stimulus_embeddings/{model_name}.hdf5',
                stimulus_key=embedding_name,
                delay_loading=False,
                return_tensor_dataset=False,
                return_stimulus_ids=True,
            )
            stimulus, stimulus_ids = nsd.load_stimulus(**stimulus_params)
            stimulus_ids_train = stimulus_ids[train_mask]
            Y_train = stimulus[train_mask].astype(np.float32)
            Y_train = Y_train.reshape(Y_train.shape[0], -1)
            Y_train = torch.from_numpy(Y_train).cuda()
            print('4')'''
            
        unique_stimulus_ids_val, unique_index_val, unique_inverse_val = np.unique(
            stimulus_ids_val, return_index=True, return_inverse=True)
        unique_stimulus_ids_test, unique_index_test, unique_inverse_test = np.unique(
            stimulus_ids_test, return_index=True, return_inverse=True)

        dataset = TensorDataset(Y_val, Y_val_pred)
        dataloader = DataLoader(dataset, batch_size=3000)
        data_iterator = get_data_iterator(dataloader)

        W = torch.randn(num_components, Y_train.shape[1]).cuda()
        W.requires_grad = True

        W_triu_indices = torch.triu_indices(num_components, num_components, offset=1)

        optim = torch.optim.Adam([W], lr=lr)
        
        for i in range(num_iterations):
            Y_batch, Y_pred_batch = next(data_iterator)
            Y_W = F.leaky_relu(Y_batch @ W.T, negative_slope=negative_slope)
            Y_pred_W = F.leaky_relu(Y_pred_batch @ W.T, negative_slope=negative_slope)
            r2_loss = -r2_score(Y_W, Y_pred_W)
            r_loss = -pearsonr(Y_W, Y_pred_W)
            contrastive_loss = contrastive_criterion(Y_W, Y_pred_W)
            
            if i % 100 == 0:
                with torch.no_grad():
                    Y_val_W, Y_val_pred_W = F.relu(Y_val @ W.T), F.relu(Y_val_pred @ W.T)
                    Y_test_W, Y_test_pred_W = F.relu(Y_test @ W.T), F.relu(Y_test_pred @ W.T)
                    
                    top_knn_accuracy_val = top_knn_test(
                        Y_val_W[unique_index_val].cpu().numpy(), Y_val_pred_W.cpu().numpy(), unique_inverse_val, k=top_k_values, metric=metric)
                    top_knn_accuracy_test = top_knn_test(
                        Y_test_W[unique_index_test].cpu().numpy(), Y_test_pred_W.cpu().numpy(), unique_inverse_test, k=top_k_values, metric=metric)
            
                print(f'i={i}, r_loss={-r_loss.detach().cpu().item()}', top_knn_accuracy_val, top_knn_accuracy_test)
            
            l1_loss = Y_pred_W.mean()
            
            W_norm = W / torch.linalg.norm(W, dim=1)[:, None]
            W_cossim = W_norm @ W_norm.T
            cossim_loss = F.relu(W_cossim[W_triu_indices[0], W_triu_indices[1]]).mean()
            
            Y_pred_W_centered = (Y_pred_W - Y_pred_W.mean(axis=0))
            Y_pred_W_norm = Y_pred_W_centered / torch.norm(Y_pred_W_centered, dim=0)
            Y_pred_W_r = Y_pred_W_norm.T @ Y_pred_W_norm
            correlation_loss = F.relu(Y_pred_W_r[W_triu_indices[0], W_triu_indices[1]]).mean()

            loss = r2_loss * r2_weight + r_loss * r_weight + contrastive_loss * contrastive_weight + l1_loss * l1_weight + cossim_loss * cossim_weight + correlation_loss * correlation_weight
            loss.backward()
            optim.step()

        W.requires_grad = False
        W = W / torch.linalg.norm(W, dim=1)[:, None]
        
        W_triu_indices = torch.triu_indices(num_components, num_components, offset=1)
        W_rsm = (W @ W.T).cpu().numpy()
        W_rsm_triu = W_rsm[W_triu_indices[0], W_triu_indices[1]]
        y = 1 - W_rsm_triu
        Z = hierarchy.linkage(y, optimal_ordering=True)
        leaves = hierarchy.leaves_list(Z)
        W = W[leaves]

        out_path = results_path / 'optimized_clip_dimensions' / run_name / subject_name
        out_path.mkdir(exist_ok=True, parents=True)

        with open(out_path / 'params.json', 'w') as f:
            f.write(json.dumps(params))
            
        np.save(out_path / 'W.npy', W.cpu().numpy())
        
        brain_decodability_evaluation(Y_test, Y_test_pred, Y_full, stimulus_ids_test, W, out_path)
        break

In [None]:
run_name = 'run36'
subject_name = 'all'

W = np.load(results_path / 'optimized_clip_dimensions' / run_name / subject_name / "W.npy")

Y_full.shape

Y_full_W = Y_full @ W.T

In [None]:
for subject_id in range(8):
    subject_name = f'subj0{subject_id + 1}'
    out_path = results_path / 'optimized_clip_dimensions' / run_name / subject_name
    out_path.mkdir(exist_ok=True, parents=True)

    with open(out_path / 'params.json', 'w') as f:
        f.write(json.dumps(params))

    np.save(out_path / 'W.npy', W.cpu().numpy())
    Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
    Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
    stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'][subject_id])
    
    brain_decodability_evaluation(Y_test, Y_test_pred, Y_full, stimulus_ids_test, W, out_path)

In [None]:
# Top K accuracy in PCA space

top_k_values = [1]
chance_accuracy = [k / N for k in top_k_values]
metric = 'cosine'
fold = 'test'
num_pca_components = 64

full_top_1 = []
W_top_1 = []
pca64_top_1 = []
pca128_top_1 = []
for subject_id, subject in enumerate(subjects):
    print(subject)

    Y = folds[fold]['Y_all'][subject_id]
    Y_pred = folds[fold]['Y_pred_all'][subject_id]
    Y_pca_tranformed = torch.from_numpy(Y_pred_group_pca.transform(Y))
    Y_pred_pca_transformed = torch.from_numpy(Y_pred_group_pca.transform(Y_pred))

    stimulus_ids = folds[fold]['stimulus_ids_all'][subject_id]

    unique_stimulus_ids, unique_index, unique_inverse = np.unique(
        stimulus_ids, return_index=True, return_inverse=True)

    full_top_1.append(top_knn_test(
        torch.from_numpy(Y)[unique_index],
        torch.from_numpy(Y_pred), 
        unique_inverse, k=top_k_values, metric=metric
    )[0])
    W_top_1.append(top_knn_test(
        F.relu(torch.from_numpy(Y) @ W.T)[unique_index],
        F.relu(torch.from_numpy(Y_pred) @ W.T), 
        unique_inverse, k=top_k_values, metric=metric
    )[0])
    pca64_top_1.append(top_knn_test(
        Y_pca_tranformed[unique_index][:, :64], 
        Y_pred_pca_transformed[:, :64], 
        unique_inverse, k=top_k_values, metric=metric
    )[0])
    pca128_top_1.append(top_knn_test(
        Y_pca_tranformed[unique_index][:, :128], 
        Y_pred_pca_transformed[:, :128], 
        unique_inverse, k=top_k_values, metric=metric
    )[0])



In [None]:
plt.figure(figsize=(12, 6))

bar_width = 0.2
plt.bar(np.arange(8) - bar_width * 1.5, full_top_1, width=bar_width, label='full CLIP space')
plt.bar(np.arange(8) - bar_width * 0.5, W_top_1, width=bar_width, label='W-space 128 components')
plt.bar(np.arange(8) + bar_width * 0.5, pca64_top_1, width=bar_width, label='pca 64 components')
plt.bar(np.arange(8) + bar_width * 1.5, pca128_top_1, width=bar_width, label='pca 128 components')

plt.ylabel('top 1 test accuracy')
plt.xlabel('subject')
plt.xticks(np.arange(8), [f'subj0{i+1}' for i in range(8)])
plt.yticks([0.01 * i for i in range(25)])
plt.grid(axis='y')

out_path = results_path / 'optimized_clip_dimensions' / run_name / 'all'
out_path.mkdir(exist_ok=True, parents=True)

plt.legend()
file_name = 'wspace_vs_pca.png'
plt.savefig(out_path / file_name, pad_inches=0)
plt.show()

In [None]:
W_top_1, pca_top_1

In [None]:
plt.figure(figsize=(12, 6))

for i, (subject_name, top_1_accuracy) in enumerate(results.items()):
    x = np.arange(Y_full.shape[1] - 1) + 1
    plt.plot(x, top_1_accuracy, label=subject_name)

plt.legend(loc='lower right')
plt.xticks(ticks=[i * 32 for i in range(17)])
plt.grid(visible=True)
plt.xlabel('num PCA components')
plt.ylabel('top 1 accuracy')
out_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold}/{embedding_name}'
file_name = 'pca_vs_top1.png'
plt.savefig(out_path / file_name, pad_inches=0)
plt.show()

In [None]:
for subject_id in range(8):

In [None]:
run_name = 'run32'
subject_name = 'subj01'

W = np.load(results_path / 'optimized_clip_dimensions' / run_name / subject_name / "W.npy")

Y_full.shape

Y_full_W = Y_full @ W.T

In [None]:
top_k_ids.shape

In [None]:
Y_full[top_k_ids[4]].shape

In [None]:
Y_full_W.shape

In [None]:
np.sort(Y_full_W[:, 25])[-k:]

In [None]:

def top_images_tsne(Y, Y_W, stimulus_ids, out_path, k=250, labels=None):

    for dim in range(Y_W.shape[1]):
        y = Y_W[:, dim]
        top_stimulus_ids = np.argsort(y)[::-1][:k]
        Y_top = Y[top_stimulus_ids]
        with DisablePrints():
            tsne = TSNE(
                n_components=2, 
                metric='cosine', 
                init="pca", 
                #learning_rate="auto", 
                random_state=0,
                verbose=0,
            )
            tsne.fit_transform(Y_top)
        y = tsne.embedding_

        image_size = 128
        num_images = 40
        S = image_size * num_images
        full_image = np.zeros(shape=(S, S, 3), dtype=np.ubyte)

        extent = 20
        coords = np.linspace(-extent, extent, num_images)
        grid = np.stack(np.meshgrid(coords, coords))
        grid = rearrange(grid, 'd h w -> (h w) d')
        grid.shape

        from sklearn.neighbors import NearestNeighbors
        neighbors = NearestNeighbors(metric='chebyshev')
        neighbors.fit(y)

        distances, ids = neighbors.kneighbors(grid, n_neighbors=1,)
        distances = rearrange(distances, '(h w) d -> h w d', h=num_images)
        ids = rearrange(ids, '(h w) d -> h w d', h=num_images)

        neighbors = NearestNeighbors(metric=metric)

        distance_threshold = extent / num_images
        for i in tqdm(range(num_images)):
            for j in range(num_images):
                if distances[i, j] > distance_threshold:
                    continue
                stimulus_id = stimulus_ids[top_stimulus_ids[ids[i, j, 0]]]
                stim_image = stimulus_images[stimulus_id]
                stim_image = Image.fromarray(stim_image)
                stim_image = stim_image.resize(size=(image_size, image_size), resample=PIL.Image.LANCZOS)
                stim_image = np.array(stim_image)
                full_image[i * image_size:(i + 1) * image_size, j * image_size:(j + 1) * image_size] = stim_image

        out_path.mkdir(exist_ok=True, parents=True)
        if labels is None:
            Image.fromarray(full_image).save(out_path / f'dim-{dim}.png')
        else:
            Image.fromarray(full_image).save(out_path / f'dim-{dim}__label-{labels[dim]}.png')
    


In [None]:
Y_W = Y_full @ W.cpu().numpy().T
Y_W = torch.from_numpy(Y_W)
top_images_tsne(Y_full,
                Y_W.numpy(), 
                np.arange(73000), 
                out_path=results_path / 'optimized_clip_dimensions' / run_name / 'all' / 'tsne_73k')

In [None]:
Y_W.shape, Y_full.shape

In [None]:

Y_W.sum(dim=1).shape


In [None]:
ids.flatten()

In [None]:
out_path = results_path / 'optimized_clip_dimensions' / run_name / subject_id
out_path.mkdir(exist_ok=True, parents=True)

with open(out_path / 'params.json', 'w') as f:
    f.write(json.dumps(params))

np.save(out_path / 'W.npy', W.cpu().numpy())

brain_decodability_evaluation(Y_test, Y_test_pred, Y_full, stimulus_ids_test, W, out_path)

In [None]:
#
num_components = 128
size = 2
num_cols = 8
num_rows = int(num_components / num_cols)
fig, ax = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(num_cols * size, num_rows * size), 
                      sharex=False, sharey=True,)
ax = ax.flatten()
fig.tight_layout()
fig.subplots_adjust(top=0.925,)

for i in range(ax.shape[0]):
    #if labels is not None:
    #    ax[i].set_title(labels[i])
    ax[i].hist(Y_full_transformed[:, i])

fig.suptitle('Histogram of 73k Stimulus Set Projected Onto Components')
fig.supxlabel('Dimension Values')
fig.supylabel('Number of Stimuli')
file_name = 'component_histogram.png'
plt.savefig(out_path / file_name, bbox_inches='tight')

In [None]:
run_name = 'run26'
subject_id = 0

out_path = results_path / 'optimized_clip_dimensions' / run_name / f'subj0{subject_id+1}'
out_path.mkdir(exist_ok=True, parents=True)

W = torch.from_numpy(np.load(out_path / 'W.npy')).cuda()

Y_val = torch.from_numpy(folds['val']['Y_all'][subject_id]).cuda()
Y_val_pred = torch.from_numpy(folds['val']['Y_pred_all'][subject_id]).cuda()
stimulus_ids_val = torch.from_numpy(folds['val']['stimulus_ids_all'][subject_id])

Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'][subject_id])

#brain_decodability_evaluation(Y_test, Y_test_pred, Y_full, stimulus_ids_test, W, out_path)

In [None]:
torch.arange(10)[:0], torch.arange(10)[0:]

In [None]:
N = 1000
top_k_values = [1, 5, 10, 50, 100, 500]
chance_accuracy = [k / N for k in top_k_values]

metric = 'cosine'

unique_stimulus_ids_val, unique_index_val, unique_inverse_val = np.unique(
            stimulus_ids_val, return_index=True, return_inverse=True)
unique_stimulus_ids_test, unique_index_test, unique_inverse_test = np.unique(
            stimulus_ids_test, return_index=True, return_inverse=True)

Y_W, Y_pred_W = F.relu(Y_val @ W.T), F.relu(Y_val_pred @ W.T)
top_knn_accuracy = top_knn_test(
    Y_W[unique_index_val].cpu().numpy(), Y_pred_W.cpu().numpy(), unique_inverse_val, k=top_k_values, metric=metric)
print('baseline', top_knn_accuracy)

W_pruned = W

top_knn_accuracies_val = []
top_knn_accuracies_test = []
pruned_dims = []

for i in range(W.shape[0] - 1):
    
    pruned_j = -1
    top_knn_accuracy_best = [0]
    W_best = None
    for j in range(W_pruned.shape[0]):
        ids = torch.arange(W_pruned.shape[0])
        ids = torch.cat([ids[:j], ids[(j + 1):]])
        
        Y_W, Y_pred_W = F.relu(Y_val @ W_pruned[ids].T), F.relu(Y_val_pred @ W_pruned[ids].T)

        top_knn_accuracy = top_knn_test(
            Y_W[unique_index_val].cpu().numpy(), Y_pred_W.cpu().numpy(), unique_inverse_val, k=top_k_values, metric=metric)
        if top_knn_accuracy[0] > top_knn_accuracy_best[0]:
            top_knn_accuracy_best = top_knn_accuracy
            pruned_j = j
            W_best = W_pruned[ids]
    
    top_knn_accuracies_val.append(top_knn_accuracy_best)
    W_pruned = W_best
    pruned_dims.append(pruned_j)
    
    Y_W, Y_pred_W = F.relu(Y_test @ W_pruned.T), F.relu(Y_test_pred @ W_pruned.T)
    top_knn_accuracy_test = top_knn_test(
        Y_W[unique_index_test].cpu().numpy(), Y_pred_W.cpu().numpy(), unique_inverse_test, k=top_k_values, metric=metric)
    top_knn_accuracies_test.append(top_knn_accuracy_test)
    
    print(f'{pruned_j=}, {top_knn_accuracy_best=}, {top_knn_accuracy_test=}')

top_knn_accuracies_test = np.array(top_knn_accuracies_test)
top_knn_accuracies_val = np.array(top_knn_accuracies_val)
    

In [None]:
plt.figure(figsize=(12, 8))
#plt.xticks(ticks=range(len(top_k_values)), labels=top_k_values)
plt.title(f'Top 1 Accuracy After Pruning Dimensions')
plt.xlabel('Number of Pruned Dimensions')
plt.ylabel('accuracy')
plt.plot(range(top_knn_accuracies_test[:, 0].shape[0]), top_knn_accuracies_test[:, 0] * 100, label='test data')
plt.plot(range(top_knn_accuracies_val[:, 0].shape[0]), top_knn_accuracies_val[:, 0] * 100, label='validation data')
plt.grid()
plt.legend()
file_name = 'top_1_accuracy_pruned.png'
plt.savefig(out_path / file_name, pad_inches=0)

In [None]:
dims = np.arange(128)
corrected_pruned_dims = []
for dim in pruned_dims:
    corrected_pruned_dims.append(dims[dim])
    dims = np.concatenate([dims[:dim], dims[(dim + 1):]])
    
np.array(corrected_pruned_dims)


In [None]:
brain_decodability_evaluation(Y_test, Y_test_pred, Y_full, stimulus_ids_test, W[np.array(corrected_pruned_dims)], out_path / 'prune_order')

In [None]:
top_knn_accuracies_test[:, 0]

In [None]:
top_knn_accuracies_test[, 0]

In [None]:
N = 1000
top_k_values = [1, 5, 10, 50, 100, 500]
chance_accuracy = [k / N for k in top_k_values]

metric = 'cosine'
    
Y_W, Y_pred_W = F.relu(Y_test @ W.T), F.relu(Y_test_pred @ W.T)

unique_stimulus_ids, unique_index, unique_inverse = np.unique(
    stimulus_ids_test, return_index=True, return_inverse=True)
D = W.shape[0]
top_knn_accuracies = []
for i in range(D):
    y_W, y_pred_W = Y_W[:, i:i+1], Y_pred_W[:, i:i+1]
    top_knn_accuracy = top_knn_test(
        y_W[unique_index].cpu().numpy(), y_pred_W.cpu().numpy(), unique_inverse, k=top_k_values, metric=metric)
    top_knn_accuracies.append(top_knn_accuracy)
top_knn_accuracies = np.array(top_knn_accuracies)

top_knn_accuracy = top_knn_test(
    Y_W[unique_index].cpu().numpy(), Y_pred_W.cpu().numpy(), unique_inverse, k=top_k_values, metric=metric)

In [None]:
top_knn_accuracy

In [None]:
top_knn_accuracies[:, 0]

In [None]:


Y_full_transformed = Y_full @ W.cpu().numpy().T

size = 2
num_cols = 8
num_rows = int(num_components / num_cols)
fig, ax = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(num_cols * size, num_rows * size), 
                      sharex=False, sharey=True,)
ax = ax.flatten()
fig.tight_layout()
fig.subplots_adjust(bottom=0.1, top=0.95, left=0.07)

for i in range(ax.shape[0]):
    y = Y_full_transformed[:, i].copy()
    ax[i].hist(y[y > 0])

fig.suptitle('Histogram of 73k Stimulus Set Projected Onto Optimized Components')
fig.supxlabel('Dimension Values')
fig.supylabel('Number of Stimuli')
file_name = 'component_histogram.png'
plt.savefig(out_path / file_name, pad_inches=0)

In [None]:
with torch.no_grad():
    Y_pred_W_centered = (Y_pred_W - Y_pred_W.mean(axis=0))
    Y_pred_W_norm = Y_pred_W_centered / torch.norm(Y_pred_W_centered, dim=0)
    
    Y_correlation = Y_pred_W_norm.T @ Y_pred_W_norm

In [None]:
t = time.time()
recon_image = reconstruct(
    stimulus_embeddings={'embedding': W[0:1].cuda().repeat(5, 1)},
    hook_modules={'': 'embedding'},
    model=vqgan_model,
    vqgan_checkpoint=vqgan_checkpoint,
    perceptor=full_model.visual.eval().requires_grad_(False).to(device),
    device=torch.device('cuda'),
    max_iterations=50,
    embedding_iterations=50,
)
print(time.time() - t)

In [None]:
@interact(i=(0, len(recon_image)-1))
def show(i):
    plt.imshow(recon_image[i][0])

In [None]:
W.shape

In [None]:
images_all = top_images_73k(W.cpu().numpy(), num_images=10)
file_name = f'top_images_73k.png'
Image.fromarray(images_all).save(out_path / file_name)

In [None]:
images_all = top_images_clip_retreival(W.cpu().numpy(), num_images=10)
file_name = f'top_images_clip_retreival.png'
Image.fromarray(images_all).save(out_path / file_name)

# Things

In [None]:
things_concepts = [
    'made of metal, artificial, hard', 
    'food-related, eating-related, kitchen-related',
    'animal-related, organic',
    'clothing-related, fabric, covering',
    'furniture-related, household-related, artifact',
    'plant-related, green',
    'outdoors-related',
    'transportation, motorized, dynamic',
    'wood-related, brownish',
    'body part-related',
    'colorful',
    'valuable, special occasion-related',
    'electronic, technology',
    'sport-related, recreational activity-related',
    'disc-shaped, round',
    'tool-related',
    'many small things, course pattern',
    'paper-related, thin, flat, text-related',
    'fluid-related, drink-related',
    'long, thin',
    'water-related, blue',
    'powdery, fine-scale pattern',
    'red',
    'feminine (stereotypically), decorative',
    'bathroom-related, sanitary',
    'black, noble',
    'weapon, danger-related, violence',
    'musical instrument-related, noise-related',
    'sky-related, flying-related, floating-related',
    'spherical, ellipsoid, rounded, voluminous',
    'repetitive',
    'flat, patterned',
    'white',
    'thin, flat',
    'disgusting, bugs',
    'string-related',
    'arms/legs/skin-related',
    'shiny, transparent',
    'construction-related, physical work-related',
    'fire-related, heat-related',
    'head-related, face-related',
    'beams-related',
    'seating-related, put things on top',
    'container-related, hollow',
    'child-related, toy-related',
    'medicine-related',
    'has grating',
    'handicraft-related',
    'cylindrical, conical'
]

things_concepts2 = [c.replace('-related', '') for c in things_concepts]
things_concepts3 = [c.split(',')[0] for c in things_concepts2]

In [None]:
clip.tokenize(things_concepts2)

In [None]:
with torch.no_grad():
    W_things_clip = full_model.encode_text(clip.tokenize(things_concepts2).to(device))
W_things_clip = W_things_clip.cpu().numpy()
#Y_dims[:, [133, 312, 92]] = 0.
W_things_clip = W_things_clip / np.linalg.norm(W_things_clip, axis=1)[:, None]
W_things_clip = torch.from_numpy(W_things_clip).cuda().float()

In [None]:
W = W_things_clip

for subject_id in range(8):
    subject_name = f'subj0{subject_id + 1}'
    
    Y_val = torch.from_numpy(folds['val']['Y_all'][subject_id]).cuda()
    Y_val_pred = torch.from_numpy(folds['val']['Y_pred_all'][subject_id]).cuda()
    stimulus_ids_val = torch.from_numpy(folds['val']['stimulus_ids_all'][subject_id])

    Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
    Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
    stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'][subject_id])
    
    Y_val_W = F.relu(Y_val @ W.T)
    Y_val_pred_W = F.relu(Y_val_pred @ W.T)
    
    r = pearsonr(Y_val_W, Y_val_pred_W, reduction=None)

In [None]:
r

In [None]:
r

In [None]:
for corr, concept in zip(r, things_concepts):
    print(f'{concept} - {corr.cpu().item():.3f}')

In [None]:
with torch.no_grad():
    W_things = full_model.encode_text(clip.tokenize(things_concepts).to(device))
W_things = W_things.cpu().numpy()
#Y_dims[:, [133, 312, 92]] = 0.
W_things = W_things / np.linalg.norm(W_things, axis=1)[:, None]

images_all = top_images_73k(W_things, num_images=10)

out_path = results_path / 'things_dimensions'
out_path.mkdir(exist_ok=True, parents=True)
file_name = f'top_images_73k_v1.png'
Image.fromarray(images_all).save(out_path / file_name)

In [None]:
W_things_torch = np.load(nsd_path / 'derivatives/things/torch_weights.npy').astype(np.float32)

In [None]:
import torch.nn.functional as F

W_things_ridge = np.load(nsd_path / 'derivatives/things/ridge_weights.npy').astype(np.float32)
W_things_ridge = W_things_ridge / np.linalg.norm(W_things_ridge, axis=1)[:, None]
W_things_torch = np.load(nsd_path / 'derivatives/things/torch_weights.npy').astype(np.float32)
W_things_torch = W_things_torch / np.linalg.norm(W_things_torch, axis=1)[:, None]

image_labels = [c.replace(',', '\n') for c in things_concepts]
fold = 'val'

for fold in ('val', 'test'):
    for subject_id in range(8):
        out_path = results_path / 'things'

        Y = torch.from_numpy(folds[fold]['Y_all'][subject_id]).cuda()
        Y_pred = torch.from_numpy(folds[fold]['Y_pred_all'][subject_id]).cuda()
        stimulus_ids = torch.from_numpy(folds[fold]['stimulus_ids_all'][subject_id])
        
        top_images_tsne_73k(W_things_torch, out_path / f'torch_weights/subj0{subject_id + 1}/{fold}/tsne', k=250)
        break

        brain_decodability_evaluation(Y, Y_pred, Y_full, stimulus_ids,  W_things_clip, 
                                      out_path / f'clip_weights/subj0{subject_id + 1}/{fold}',
                                      things_concepts3, image_labels)
        brain_decodability_evaluation(Y, Y_pred, Y_full, stimulus_ids,  torch.from_numpy(W_things_ridge).cuda(), 
                                      out_path / f'ridge_weights/subj0{subject_id + 1}/{fold}',
                                      things_concepts3, image_labels)
        brain_decodability_evaluation(Y, Y_pred, Y_full, stimulus_ids,  torch.from_numpy(W_things_torch).cuda(), 
                                      out_path / f'torch_weights/subj0{subject_id + 1}/{fold}',
                                      things_concepts3, image_labels)
        

# Mask Optimization

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:


torch.cuda.empty_cache()
gc.collect()

In [None]:
for model in models_all:
    model.cpu()

In [None]:
original_state_dicts = [model.state_dict() for model in models_all]

In [None]:
def plot_flatmap(subject_id, m, component_id, out_path, mask_value=-1, scatter_params=None):
    if scatter_params is None:
        scatter_params = {'vmax': 1, 'vmin': 0, 'cmap': 'gray'}
    volume = nsd.reconstruct_volume(
            subject_id, 
            m, 
            indices_all[subject_id],
            mask_value
        ).T.numpy()
    D = volume.shape[2]
    
    fsavg_data = nsd.to_fs_average_space(subject_id, volume, interp_type='nearest')
    nsd.flat_scatter_plot(fsavg_data['lh'], fsavg_data['rh'], mask_value=mask_value, **scatter_params)
    plt.savefig(out_path, bbox_inches='tight')
    
def compare_pearsonr(r_original, r_masked, labels, out_file_path):
    num_components = r_original.shape[0]
    width = 0.45
    plt.figure(figsize=(0.25 * num_components, 3))
    plt.title("Pearson Correlation Comparison After Mask Optimization")
    plt.xlabel("Things Dimension")
    plt.ylabel("Correlation")
    plt.bar(np.arange(num_components), r_original, 
            tick_label=labels, 
            width=width,
            label='Original',)
    plt.bar(np.arange(num_components) + width, r_masked, width=width, label='Masked')
    plt.legend()
    plt.xticks(rotation=45, ha='right')
    plt.savefig(out_file_path, bbox_inches='tight')


def pearsonr_retained(r_original, r_masked, labels, out_file_path):
    num_components = r_original.shape[0]
    plt.figure(figsize=(num_components*0.25,) * 2)
    plt.imshow(r_masked / r_original)
    plt.yticks(ticks=np.arange(num_components), labels=labels)
    plt.xticks(rotation=45, ha='right', ticks=np.arange(num_components), labels=labels)
    plt.colorbar()
    plt.ylabel('concept')
    plt.xlabel('mask')
    plt.title("Fraction of pearson correlation retained after masking")
    plt.savefig(out_file_path, bbox_inches='tight')
    plt.show()

In [None]:
from sklearn.linear_model import RidgeCV, LassoCV


def concept_linear_model(W, run_name, subject_id, labels, short_labels=None):
    if short_labels is None:
        short_labels = labels
    run_path = results_path / f'concept_linear_models/torch_weights/{run_name}/subj0{subject_id + 1}'
    run_path.mkdir(exist_ok=True, parents=True)
    
    flatmap_path = results_path / f'concept_linear_models/torch_weights/{run_name}/flatmaps'
    flatmap_path.mkdir(exist_ok=True, parents=True)
    
    X = X_all[subject_id]
    Y = Y_all[subject_id]
    Y_pred = Y_pred_all[subject_id]

    X_test = folds['test']['X_all'][subject_id]
    Y_test = folds['test']['Y_all'][subject_id]
    Y_pred_test = folds['test']['Y_pred_all'][subject_id]
    
    Y_test_W = torch.from_numpy(Y_test @ W.T)
    r_original = pearsonr(F.relu(Y_test_W).cpu(), F.relu(torch.from_numpy(Y_pred_test @ W.T)).cpu(), reduction=None).numpy()

    M = []
    logs_all = []
    r_all = []
    for component_id, w in enumerate(W):
        Y_w = Y @ w
        Y_test_w = Y_test @ w
        Y_w[Y_w < 0] = 0.
        Y_test_w[Y_test_w < 0] = 0.
        
        model = LassoCV()
        model.fit(X, Y_w)
        
        Y_test_w_pred = model.predict(X_test)
        M.append(model.coef_)
        
        r = pearsonr(torch.from_numpy(Y_test_w), torch.from_numpy(Y_test_w_pred))
        r_all.append(r)
        
        file_name = f'subj0{subject_id + 1}__component-{component_id}'
        scatter_params = dict(
            vmax=0.003, 
            vmin=-0.003, 
            cmap='RdBu', 
            mask_color='gray', 
            bottomleft_text=f'{component_id=}\nconcept={labels[component_id]}'
        )
        plot_flatmap(subject_id, 
                     torch.from_numpy(model.coef_), 
                     component_id, 
                     flatmap_path / file_name, scatter_params=scatter_params)
                        
    r_all = np.stack(r_all)
    compare_pearsonr(r_original, r_all, short_labels, run_path / 'pearsonr_change.png')
    M = np.stack(M)
    np.save(run_path / 'mask.npy', M)
    
        
W_things = np.load(nsd_path / 'derivatives/things/torch_weights.npy').astype(np.float32)
for subject_id in range(8):
    num_components = 5
    concept_linear_model(
        W=W_things[:num_components],
        run_name='run-3',
        subject_id=subject_id,
        labels=things_concepts[:num_components],
        short_labels=things_concepts3[:num_components]
    )

In [None]:
import torch.nn.functional as F


def mask_optimization(W, run_name, subject_id, lr, num_iterations, mask_sparse_alpha, mask_loss_weight, freeze_model=True, mask_init=None):
    if mask_init is None:
        mask_init = 1 - 0.5 ** (1. / mask_sparse_alpha)
    params = locals().copy()
    del params['W']
    print(params)

    run_path = results_path / f'mask_optimization/torch_weights/{run_name}/subj0{subject_id + 1}'
    run_path.mkdir(exist_ok=True, parents=True)
    flatmap_path = results_path / f'mask_optimization/torch_weights/{run_name}/flatmaps'
    flatmap_path.mkdir(exist_ok=True, parents=True)
    with open(run_path / 'params.json', 'w') as f:
        f.write(json.dumps(params))

    model = models_all[subject_id]
    model.normalize = True
    model.eval()

    X = torch.from_numpy(X_all[subject_id]).cuda()
    Y = torch.from_numpy(Y_all[subject_id]).cuda()
    Y_pred = torch.from_numpy(Y_pred_all[subject_id])

    X_test = torch.from_numpy(folds['test']['X_all'][subject_id]).cuda()
    Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
    Y_pred_test = torch.from_numpy(folds['test']['Y_pred_all'][subject_id])

    M = []
    logs_all = []
    for component_id, w in enumerate(W):
        logs = []
        print(f'{component_id=}')
        
        model.load_state_dict({
            k: v.clone() 
            for k, v in state_dicts_all[subject_id].items()
        })
        model.cuda()
        
        w = w.cuda()
        Y_w = Y @ w
        Y_test_w = Y_test @ w

        m = torch.full((X.shape[1],), mask_init).cuda()
        m.requires_grad = True
        
        if freeze_model:
            optim = torch.optim.Adam([m], lr=lr)
        else:
            optim = torch.optim.Adam([
                {'params': [m], 'lr': lr}, 
                {'params': model.parameters(), 'lr': 1e-4},
            ])
            
        
        m_best = None
        r_best = 0
        i_best = 0
        for i in range(num_iterations):
            mask_exp = (1 - (1 - m) ** mask_sparse_alpha)
            Y_pred_w = model(X * m) @ w

            #r2_loss = -r2_score(F.relu(Y_w), F.relu(Y_pred_w))
            r_loss = -pearsonr(F.relu(Y_w), F.relu(Y_pred_w))
            mask_loss = (mask_exp).mean()
            mask_mean = m.mean()
            loss = r_loss + mask_loss * mask_loss_weight

            loss.backward()
            optim.step()

            def fmt_loss(l):
                return l.cpu().detach().item()
            
            with torch.no_grad():
                m = torch.clamp(m, 0., 1., out=m)
                
            if i > 350:
                if -r_loss > r_best:
                    i_best = i
                    r_best = -r_loss
                    m_best = m
            
            if i % 5 == 0:
                with torch.no_grad():
                    Y_pred_test_w = model(X_test * m) @ w
                    m_masked = m.clone()
                    m_masked[m >= 0.5] = 1.
                    m_masked[m < 0.5] = 0.
                    Y_pred_test_w = model(X_test * m) @ w
                    Y_pred_test_w_masked = model(X_test * m_masked) @ w
                    test_r_loss = pearsonr(F.relu(Y_test_w), F.relu(Y_pred_test_w))
                    test_r_loss_masked = pearsonr(F.relu(Y_test_w), F.relu(Y_pred_test_w_masked))
                    iter_logs = {
                        'iteration': i, 
                        'val_r': fmt_loss(-r_loss), 
                        'test_r': fmt_loss(test_r_loss), 
                        'test_r_masked': fmt_loss(test_r_loss_masked), 
                        'mask_mean': fmt_loss(mask_mean.detach().cpu())
                    }
                    print(', '.join([f'{k}={v:.3f}' for k, v in iter_logs.items()]))

                    bins = (0.000, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 1.0)
                    m_hist, _ = np.histogram(m.detach().cpu(), bins=bins)
                    print('mask hist', [f'{b} {v}' for b, v in zip(bins, m_hist)])
                    iter_logs['mask'] = m.detach().cpu()
                    logs.append(iter_logs)

        logs_all.append(logs)
        print(f'{i_best=}, {r_best=}')
        m = m_best
        m.requires_grad = False
        M.append(m)
        if not freeze_model:
            torch.save(model.state_dict(), run_path / f'mask__component-{component_id}.pkl')
        np.save(run_path / f'mask__component-{component_id}.npy', m.cpu().numpy())
        model.cpu()
    
        volume = nsd.reconstruct_volume(
            subject_id, 
            m.cpu(), 
            indices_all[subject_id],
            -1
        ).T.numpy()
        D = volume.shape[2]

        things_concept = things_concepts[component_id]
        text = f'{component_id=}\n{things_concept=}'

        fsavg_data = nsd.to_fs_average_space(subject_id, volume, interp_type='nearest')
        nsd.flat_scatter_plot(fsavg_data['lh'], fsavg_data['rh'], vmax=1, vmin=0, mask_value=-1, cmap='gray', bottomleft_text=text)
        file_name = f'subj0{subject_id + 1}__component-{component_id}'
        plt.savefig(flatmap_path / file_name, bbox_inches='tight')
                   
    M = torch.stack(M).cpu()
    np.save(run_path / 'mask.npy', M.numpy())

W_things = torch.from_numpy(np.load(nsd_path / 'derivatives/things/torch_weights.npy').astype(np.float32)).cuda()

for subject_id in range(8):
    torch.cuda.empty_cache()
    gc.collect()
    models_all[subject_id].load_state_dict(state_dicts_all[subject_id])
    mask_optimization(
        W=W_things[:20],
        run_name='run-44',
        subject_id=subject_id,
        lr=0.002, 
        num_iterations=1000,
        mask_sparse_alpha=1.,
        mask_loss_weight=5,
        freeze_model=True,
        mask_init=.5,
    )
    models_all[subject_id].cpu()

In [None]:
x = np.linspace(0, 1, 100)
np.histogram(x, bins=(0.001, 0.01, 0.1, 0.2, 0.5, 0.9, 0.99))

In [None]:
M = torch.stack(M).cpu()
np.save(run_path / 'mask.npy', M.numpy())

In [None]:
from research.models.components_3d import GaussianSmoothing

In [None]:
kernel = torch.tensor([1, 5, 10, 10, 5, 1])
kernel = kernel[None, None, :] * kernel[None, :, None] * kernel[:, None, None]
kernel = kernel / kernel.sum() * 2

In [None]:
from scipy.ndimage import binary_dilation

def plot_flatmap(subject_id, m, ):
    volume = nsd.reconstruct_volume(
            subject_id, 
            torch.from_numpy(m), 
            indices_all[subject_id],
            -1
        ).T.numpy()
    D = volume.shape[2]

    things_concept = things_concepts[component_id]
    text = f'{component_id=}\n{things_concept=}'

    fsavg_data = nsd.to_fs_average_space(subject_id, volume, interp_type='nearest')
    nsd.flat_scatter_plot(fsavg_data['lh'], fsavg_data['rh'], vmax=1, vmin=0, mask_value=-1, cmap='gray', bottomleft_text=text)
    file_name = f'subj0{subject_id + 1}__component-{component_id}'
    plt.savefig(out_path / file_name, bbox_inches='tight')


for subject_id in range(2):
    run_name = 'run-44'

    run_path = results_path / f'mask_optimization/torch_weights/{run_name}/subj0{subject_id + 1}'
    run_path.mkdir(exist_ok=True, parents=True)

    M = np.load(run_path / 'mask.npy')

    W = np.load(nsd_path / 'derivatives/things/torch_weights.npy').astype(np.float32)
    W = torch.from_numpy(W).cuda()

    model = models_all[subject_id].cuda()
    model.normalize = True
    model.eval()

    X = torch.from_numpy(X_all[subject_id]).cuda()
    Y = torch.from_numpy(Y_all[subject_id]).cuda()
    Y_pred = torch.from_numpy(Y_pred_all[subject_id])

    X_test = torch.from_numpy(folds['test']['X_all'][subject_id]).cuda()
    Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
    Y_pred_test = torch.from_numpy(folds['test']['Y_pred_all'][subject_id])

    out_path = results_path / f'mask_optimization/torch_weights/{run_name}/smooth_flatmaps'
    out_path.mkdir(exist_ok=True, parents=True)
    
    kernel = torch.tensor([1, 5, 10, 10, 5, 1])
    kernel = kernel[None, None, :] * kernel[None, :, None] * kernel[:, None, None]
    kernel = kernel / kernel.sum() * 5

    for component_id, m in enumerate(M):
        volume = nsd.reconstruct_volume(
            subject_id, 
            torch.from_numpy(m), 
            indices_all[subject_id],
            -1
        ).T.numpy()
        volume_smooth = volume.copy()
        volume_smooth[volume_smooth == -1] = 0
        D = volume.shape[2]
        
        volume_smooth = F.conv3d(torch.from_numpy(volume_smooth[None, None]), kernel[None, None], padding='same')[0, 0].numpy()
        
        volume_smooth[volume_smooth >= 0.5] = 1
        volume_smooth[volume_smooth < 0.5] = 0
        #volume_smooth = binary_dilation(volume_smooth, iterations=1).astype(float)
        volume_smooth[volume == -1] = -1              

        things_concept = things_concepts[component_id]
        text = f'{component_id=}\n{things_concept=}'

        fsavg_data = nsd.to_fs_average_space(subject_id, volume_smooth, interp_type='nearest')
        nsd.flat_scatter_plot(fsavg_data['lh'], fsavg_data['rh'], vmax=1, vmin=0, mask_value=-1, cmap='gray', bottomleft_text=text)
        file_name = f'subj0{subject_id + 1}__component-{component_id}'
        plt.savefig(out_path / file_name, bbox_inches='tight')

In [None]:
M_fsavg_all = []
for subject_id in range(2):
    run_name = 'run-44'

    run_path = results_path / f'mask_optimization/torch_weights/{run_name}/subj0{subject_id + 1}'
    run_path.mkdir(exist_ok=True, parents=True)

    M = np.load(run_path / 'mask.npy')
    
    kernel = torch.tensor([1, 5, 10, 10, 5, 1])
    kernel = kernel[None, None, :] * kernel[None, :, None] * kernel[:, None, None]
    kernel = kernel / kernel.sum() * 5
    
    M_fsavg = []
    for component_id, m in enumerate(M):
        volume = nsd.reconstruct_volume(
            subject_id, 
            torch.from_numpy(m), 
            indices_all[subject_id],
            -1
        ).T.numpy()
        D = volume.shape[2]
        
        volume_smooth = volume.copy()
        volume_smooth[volume_smooth == -1] = 0
        D = volume.shape[2]
        
        volume_smooth = F.conv3d(torch.from_numpy(volume_smooth[None, None]), kernel[None, None], padding='same')[0, 0].numpy()
        
        volume_smooth[volume_smooth >= 0.5] = 1
        volume_smooth[volume_smooth < 0.5] = 0
        #volume_smooth = binary_dilation(volume_smooth, iterations=1).astype(float)
        volume_smooth[volume == -1] = -1              

        things_concept = things_concepts[component_id]
        text = f'{component_id=}\n{things_concept=}'

        fsavg_data = nsd.to_fs_average_space(subject_id, volume_smooth, interp_type='nearest')
        M_fsavg.append(np.concatenate([fsavg_data['lh'], fsavg_data['rh']]))
    M_fsavg_all.append(np.stack(M_fsavg))
    
M_fsavg_all = np.stack(M_fsavg_all)

In [None]:
def jaccard(A, B):
    intersection = torch.einsum('...i,...i -> ...', A, B)
    union = A.sum(axis=-1) + B.sum(axis=-1) - intersection
    return intersection / union
    

In [None]:
M_fsavg_all[None, :, None, :].shape

In [None]:
M_fsavg_all[M_fsavg_all < 0.5] = 0.
M_fsavg_all[M_fsavg_all >= 0.5] = 1.
M_fsavg_jaccard = jaccard(torch.from_numpy(M_fsavg_all[:, None, :, None]), torch.from_numpy(M_fsavg_all[None, :, None, :]))

In [None]:
num_components, M_fsavg_jaccard.shape

In [None]:
M_fsavg_jaccard

# self similarity

def jaccard_rsm_plot(subject_A, subject_B, out_path, vmax=1.):
    jaccard_matrix = M_fsavg_jaccard[subject_A, subject_B]
    plt.figure(figsize=(6, 6))
    plt.imshow(jaccard_matrix, vmin=0, vmax=vmax)

    plt.title(f"Jaccard Index of Concept Brain Masks\nsubj0{subject_A+1} versus subj0{subject_B+1}")
    plt.colorbar()
    plt.xticks(rotation=45, ha='right', labels=things_concepts3[:num_components], ticks=np.arange(num_components))
    plt.yticks(labels=things_concepts3[:num_components], ticks=np.arange(num_components))

    out_path.mkdir(exist_ok=True, parents=True)
    file_name = f'subj0{subject_A+1}__vs__subj0{subject_B+1}'
    plt.savefig(out_path / file_name, bbox_inches='tight')

    plt.close()

num_components = M_fsavg_jaccard.shape[2]
diag = torch.eye(num_components).bool()
for subject_A in range(2):
    jaccard_rsm_plot(subject_A, subject_A, out_path=results_path / f'mask_optimization/torch_weights/{run_name}/smoothed_rsms/jaccard/same', vmax=None)

diff = []
same = []
for subject_A in range(2):
    for subject_B in range(subject_A + 1, 2):
        jaccard_matrix = M_fsavg_jaccard[subject_A, subject_B]
        diff = np.concatenate([diff, jaccard_matrix[~diag]])
        same = np.concatenate([same, jaccard_matrix[diag]])
        jaccard_rsm_plot(subject_A, subject_B, out_path=results_path / f'mask_optimization/torch_weights/{run_name}/smoothed_rsms/jaccard/different', vmax=None)
        
diff.mean(), same.mean(), diff.std(), same.std()

In [None]:
from statsmodels.stats.weightstats import ztest

ztest(diff, same)

In [None]:
for subject_id in range(2):
    run_name = 'run-44'

    run_path = results_path / f'mask_optimization/torch_weights/{run_name}'
    run_path.mkdir(exist_ok=True, parents=True)

    M = torch.from_numpy(np.load(run_path / f'subj0{subject_id+1}/mask.npy')).cuda()
    num_components = M.shape[0]

    W = np.load(nsd_path / 'derivatives/things/torch_weights.npy').astype(np.float32)[:num_components]
    W = torch.from_numpy(W).cuda()

    model = models_all[subject_id]
    model.normalize = True
    model.eval()
    model.load_state_dict(state_dicts_all[subject_id])
    model.cuda()

    X_test = torch.from_numpy(folds['test']['X_all'][subject_id]).cuda()
    Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
    Y_pred_test = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()

    Y_test_W = Y_test @ W.T
    r_original = pearsonr(F.relu(Y_test_W), F.relu(Y_pred_test @ W.T), reduction=None)
    r_original = r_original[:num_components].cpu().numpy()
    
    r_soft_masked = []
    r_hard_masked = []
    for i in range(num_components):
        m = M[i]
        m_hard = m.clone()
        m_hard[m >= 0.5] = 1.
        m_hard[m < 0.5] = 0.
        Y_test_w = Y_test_W.T[i]
        with torch.no_grad():
            Y_test_pred_W_soft = model(X_test * m) @ W.T
            Y_test_pred_W_hard = model(X_test * m_hard) @ W.T
        r_soft_masked.append(pearsonr(F.relu(Y_test_W), F.relu(Y_test_pred_W_soft), reduction=None))
        r_hard_masked.append(pearsonr(F.relu(Y_test_W), F.relu(Y_test_pred_W_hard), reduction=None))
    r_soft_masked = torch.stack(r_soft_masked).cpu().numpy()
    r_hard_masked = torch.stack(r_hard_masked).cpu().numpy()
    model.cpu()
    
    (run_path / f'masked_r').mkdir(exist_ok=True)
    (run_path / f'masked_r_binary').mkdir(exist_ok=True)
    (run_path / f'pearsonr_retained').mkdir(exist_ok=True)
    (run_path / f'pearsonr_retained_binary').mkdir(exist_ok=True)
    compare_pearsonr(r_original, np.diagonal(r_soft_masked), things_concepts3[:num_components],
                     run_path / f'masked_r/subj0{subject_id+1}.png')
    compare_pearsonr(r_original, np.diagonal(r_hard_masked), things_concepts3[:num_components],
                     run_path / f'masked_r_binary/subj0{subject_id+1}.png')

    pearsonr_retained(r_original, r_soft_masked, things_concepts3[:num_components], run_path / f'pearsonr_retained/subj0{subject_id+1}.png')
    pearsonr_retained(r_original, r_hard_masked, things_concepts3[:num_components], run_path / f'pearsonr_retained_binary/subj0{subject_id+1}.png')

In [None]:
M_means = M.mean(axis=1)

width = 0.45
plt.figure(figsize=(0.25 * num_components, 3))
plt.title("Mean of All Mask Values")
plt.xlabel("Things Dimension")
plt.ylabel("Mean")
plt.bar(np.arange(num_components), M_means, 
        tick_label=[c.split(',')[0].replace('-related', '') for c in things_concepts[:num_components]])
plt.xticks(rotation=45, ha='right')
plt.savefig(run_path / 'mask_mean.png', bbox_inches='tight')
plt.show()

In [None]:
M_norm = M / torch.linalg.norm(M, axis=1, keepdim=True)
M_rsm = M_norm @ M_norm.T

plt.figure(figsize=(num_components * 0.2,) * 2) 
plt.imshow(M_rsm, cmap='gray', vmin=0, vmax=1)
plt.xticks(rotation=45, ha='right', labels=things_concepts2[:num_components], ticks=np.arange(num_components))
plt.yticks(labels=things_concepts2[:num_components], ticks=np.arange(num_components))
plt.colorbar()
file_name = 'cosine_similarity_rsm.png'
plt.savefig(run_path / file_name, bbox_inches='tight')

# Misc

In [None]:
model_name = 'ViT-B=32'
stimulus_key = 'embedding'

save_key = stimulus_key
save_model_name = model_name

stimulus_file = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}.hdf5', 'r')
x = stimulus_file[stimulus_key][:]

stimulus_file_text = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}-text.hdf5', 'r')
x_text = stimulus_file_text[stimulus_key][:]
x_text = x_text / np.linalg.norm(x_text, axis=-1, keepdims=True)

ids = np.stack([np.arange(73000) for _ in range(5)], axis=-1)
print(ids.shape)

#random_ids = np.arange(73000)
#np.random.shuffle(random_ids)
#print(random_ids)
#x_text = x_text[random_ids]

text_dists = np.einsum('ni,nti->nt', x, x_text)
print(text_dists)

neighbors = NearestNeighbors(metric='cosine')
neighbors.fit(x)

all_captions = np.array([nsd.load_coco(i)[:5] for i in range(73000)])
best_captions = all_captions[np.arange(73000), np.argmax(text_dists, axis=1)]

#top_knn_test(x, x_text.reshape(-1, 512), ids.flatten(), k=[1, 5, 10], metric='cosine')

In [None]:
x_sample = x_text[:1000, 0]
x_sample_rsm = (x_sample[:, None] * x_sample[None, :]).sum(axis=-1)
plt.hist(x_sample_rsm.flatten())

In [None]:
for i, sentence in enumerate(best_captions[:100]):
    print(f"{i+1}. {sentence}")

In [None]:
import os
import openai
openai.api_key = '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'

prompt_header = 'Decompose the following sentence into multiple independent descriptive sentences. Write one sentence per line.'

ids = np.arange(73000)
random_ids = np.random.choice(ids, size=1000)
random_captions = best_captions[random_ids]

raw_responses = []
responses = []
for i, sentence in enumerate(random_captions):
    if i % 10 == 0:
        print(i)
    sentence = sentence.strip().capitalize()
    if not sentence.endswith('.'):
        sentence += '.'
    response = openai.Completion.create(
      model="text-davinci-003",
      prompt=f"{prompt_header}\n\nInput:\n{sentence}\n\nOutput:",
      temperature=0.7,
      max_tokens=256,
      top_p=1,
      frequency_penalty=0,
      presence_penalty=0
    )
    raw_response = response.choices[0].text
    raw_responses.append(raw_response)
    
    sentences = raw_response.strip().split('\n')
    sentences = [s for s in sentences if s != '']
    responses.append(response.choices[0].text.strip().split('\n'))

In [None]:
out = ''
for caption, response in zip(random_captions, raw_responses):
    out += f'{caption}'
    out += '\n'.join([f'\t{r}' for r in response.split('\n')])
    out += '\n\n'
print(out)

In [None]:
with open(nsd_path / 'derivatives/responses.txt', 'r') as f:
    random_captions = []
    responses = []
    lines = [
        [l.strip() for l in s.split('\n')] 
        for s in f.read().split('\n\n')
    ]
    random_captions = [l[0] for l in lines]
    responses = [l[1:] for l in lines]

In [None]:
responses
caption_concepts = sum(responses, start=[])
tokens = clip.tokenize(caption_concepts).to(device)

with torch.no_grad():
    W_captions = full_model.encode_text(tokens).float()
    
W_captions = W_captions / torch.linalg.norm(W_captions, dim=1)[:, None]

In [None]:
torch.linalg.norm(W_captions, dim=1)

In [None]:
for subject_id in range(8):
    subject_name = f'subj0{subject_id + 1}'
    
    Y_val = torch.from_numpy(folds['val']['Y_all'][subject_id]).cuda()
    Y_val_pred = torch.from_numpy(folds['val']['Y_pred_all'][subject_id]).cuda()
    stimulus_ids_val = torch.from_numpy(folds['val']['stimulus_ids_all'])

    Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
    Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
    stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'])
    
    Y_val_W = F.relu(Y_val @ W_captions.T)
    Y_val_pred_W = F.relu(Y_val_pred @ W_captions.T)
    
    r = pearsonr(Y_val_W, Y_val_pred_W, reduction=None)
    break

In [None]:
folds['test']['stimulus_ids_all']

In [None]:
import torch.nn.functional as F
from scipy.cluster import hierarchy
from research.metrics.loss_functions import ContrastiveDistanceLoss
from research.metrics.metrics import squared_euclidean_distance

run_name = 'run1'

contrastive_criterion = ContrastiveDistanceLoss(squared_euclidean_distance)

for subject_id in range(8):
    params = {
        'subject_id': subject_id,
        'lr': 0.0001, 
        'num_iterations': 10000,
        'mask_init': 0.5,
        'mask_loss_weight': 0.000
    }
    locals().update(params)

    if subject_id == 'all':
        subject_name = subject_id
        Y = torch.cat([torch.from_numpy(Y_all[subject_id]).cuda() for subject_id in range(8)])
        Y_pred = torch.cat([torch.from_numpy(Y_pred_all[subject_id]).cuda() for subject_id in range(8)])

        Y_test = torch.cat([torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda() for subject_id in range(8)])
        Y_test_pred = torch.cat([torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda() for subject_id in range(8)])
        stimulus_ids_test = torch.cat([torch.from_numpy(folds['test']['stimulus_ids_all']) for subject_id in range(8)])
    else:
        subject_name = f'subj0{subject_id + 1}'
        Y = torch.from_numpy(Y_all[subject_id]).cuda()
        Y_pred = torch.from_numpy(Y_pred_all[subject_id]).cuda()

        Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
        Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
        stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'])

    dataset = TensorDataset(Y, Y_pred)
    dataloader = DataLoader(dataset, batch_size=3000)
    data_iterator = get_data_iterator(dataloader)
    
    num_components = W_captions.shape[0]
    m = torch.full((num_components,), mask_init).cuda()
    m.requires_grad = True

    optim = torch.optim.Adam([m], lr=lr)

    for i in range(num_iterations):
        W_m = W_captions * m[:, None]
        Y_batch, Y_pred_batch = next(data_iterator)
        Y_W = F.relu(Y_batch @ W_m.T)
        Y_pred_W = F.relu(Y_pred_batch @ W_m.T)

        contrastive_loss = contrastive_criterion(Y_W, Y_pred_W)
        
        with torch.no_grad():
            m = torch.clamp(m, 0., 1., out=m)

        if i % 25 == 0:
            with torch.no_grad():
                
                W_m = W_captions[m >= 0.5]
                Y_W, Y_pred_W = F.relu(Y_test @ W_m.T), F.relu(Y_pred_test @ W_m.T)
                
                top_k_values = [1, 5, 10, 50, 100, 500]
                
                unique_stimulus_ids, unique_index, unique_inverse = np.unique(
                    stimulus_ids_test, return_index=True, return_inverse=True)
                top_knn_accuracy = top_knn_test(Y_W[unique_index].cpu().numpy(), Y_pred_W.cpu().numpy(), 
                                                unique_inverse, k=top_k_values, metric='euclidean')
                print(i, top_knn_accuracy, (m >= 0.5).float().sum().cpu().item(), m.shape[0])
                bins = (0.000, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 1.0)
                m_hist, _ = np.histogram(m.detach().cpu(), bins=bins)
                print('mask hist', [f'{b} {v}' for b, v in zip(bins, m_hist)])
                print(f'{contrastive_loss.item()=}')
        
        loss = contrastive_loss + m.mean() * mask_loss_weight
        loss.backward()
        optim.step()

    out_path = results_path / 'masked_text_concepts' / run_name / subject_name
    out_path.mkdir(exist_ok=True, parents=True)

    with open(out_path / 'params.json', 'w') as f:
        f.write(json.dumps(params))

    np.save(out_path / 'W.npy', W_captions.cpu().numpy())
    concept_mask = m > 0.5
    brain_decodability_evaluation(Y_test, Y_test_pred, Y_full, stimulus_ids_test, W[m > 0.5], out_path,
                                 labels=np.array(caption_concepts)[concept_mask])
    break

In [None]:
m[:, None].shape

In [None]:
list(zip(caption_concepts, r.cpu().numpy()))

In [None]:
nsd_path

In [None]:
_, r_sorted_ids = r.sort()
n = 128
top_ids =  r_sorted_ids.cpu().numpy()[::-1].copy()
top_n_ids = r_sorted_ids.cpu().numpy()[-n:][::-1].copy()

with open(nsd_path / 'derivatives/response_correlation.txt', 'w') as f:
    f.write('\n'.join([f'{b:.3f} {a}' for a, b in zip(np.array(caption_concepts)[top_ids], r[top_ids].cpu().numpy())]))

In [None]:
N = 1000
top_k_values = [1]
chance_accuracy = [k / N for k in top_k_values]

metric = 'euclidean'

unique_stimulus_ids_val, unique_index_val, unique_inverse_val = np.unique(
            stimulus_ids_val, return_index=True, return_inverse=True)
unique_stimulus_ids_test, unique_index_test, unique_inverse_test = np.unique(
            stimulus_ids_test, return_index=True, return_inverse=True)

top_knn_accuracies_val = []
top_knn_accuracies_test = []
pruned_dims = []

num_dims = 128

chosen_ids = []
for i in range(num_dims):
    print(f'choosing dim {i}')
    top_knn_accuracy_best = [0]
    for j in range(W_captions.shape[0]):
        if j in chosen_ids:
            continue
            
        Y_W, Y_pred_W = F.relu(Y_val @ W_captions[chosen_ids + [j]].T), F.relu(Y_val_pred @ W_captions[chosen_ids + [j]].T)

        top_knn_accuracy = top_knn_test(
            Y_W[unique_index_val].cpu().numpy(), Y_pred_W.cpu().numpy(), unique_inverse_val, k=top_k_values, metric=metric)
        
        if top_knn_accuracy[0] > top_knn_accuracy_best[0]:
            top_knn_accuracy_best = top_knn_accuracy
            best_j = j
            print(j, top_knn_accuracy_best)
    chosen_ids.append(best_j)

# GPT

In [None]:
# Combine categories into dataframe

df = pd.DataFrame(columns=['category_id', 'category_name', 'concept', 'concept_id'])

version = 'version1'
categories_path = nsd_path / f'derivatives/gpt_categories/{version}'
for category_file in categories_path.iterdir():
    file_name = category_file.name.split('.')[0]
    category_id = int(file_name[:2])
    category_name = file_name[3:]
    print(category_id, category_name)
    with open(category_file, 'r') as f:
        concepts = f.read().split('\n')
        for concept_id, concept in enumerate(concepts):
            df.loc[len(df)] = [category_id, category_name, concept.lower(), concept_id]
            


In [None]:
for subject_id in range(8):
    subject_name = f'subj0{subject_id + 1}'
    
    Y_val = torch.from_numpy(folds['val']['Y_all'][subject_id]).cuda()
    Y_val_pred = torch.from_numpy(folds['val']['Y_pred_all'][subject_id]).cuda()
    stimulus_ids_val = torch.from_numpy(folds['val']['stimulus_ids_all'])

    Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
    Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
    stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'])
    
    Y_val_W = F.relu(Y_val @ W_captions.T)
    Y_val_pred_W = F.relu(Y_val_pred @ W_captions.T)
    
    r = pearsonr(Y_val_W, Y_val_pred_W, reduction=None)

In [None]:
embedding_name

In [None]:
results_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/val/{embedding_name}/gpt_concepts'
results_path.mkdir(exist_ok=True, parents=True)

In [None]:
r.max()

In [None]:
#responses
#caption_concepts = sum(responses, start=[])
tokens = clip.tokenize(list(df['concept'])).to(device)

with torch.no_grad():
    W_captions = full_model.encode_text(tokens).float()
    
W_captions = W_captions / torch.linalg.norm(W_captions, dim=1)[:, None]

In [None]:
W_captions.shape

In [None]:
list(df['concept'])

In [None]:
Y_W = Y_full @ W_captions.cpu().numpy().T
Y_W = torch.from_numpy(Y_W)
top_images_tsne(Y_full,
                Y_W.numpy(), 
                np.arange(73000),
                out_path=results_path / 'gpt_concepts' / 'tsne',
                labels=list(df['concept']))

In [None]:
_, r_sorted_ids = r.sort()
n = 128
top_ids =  r_sorted_ids.cpu().numpy()[::-1].copy()
top_n_ids = r_sorted_ids.cpu().numpy()[-n:][::-1].copy()

with open(results_path / 'decodability_pearsonr.txt', 'w') as f:
    f.write('\n'.join(
        [f'{b:.3f} {a}' for a, b in zip(np.array(list(df['concept']))[top_ids], r[top_ids].cpu().numpy())]
    ))

# Lemmas

In [None]:
cleaned_captions = []
remove_chars = '.,"'
word_occurences = {}
for stimulus_id, caption in enumerate(best_captions):
    caption = caption.lower().strip()
    
    for char in remove_chars:
        caption = caption.replace(char, '')
    
    for word in caption.split(' '):
        if word not in word_occurences:
            word_occurences[word] = [stimulus_id]
        else:
            word_occurences[word].append(stimulus_id)
    
    cleaned_captions.append(caption)
    
word_occurences = dict(sorted(list(word_occurences.items()), key=lambda pair: -len(pair[1])))

In [None]:
print(len(word_occurences))
word_occurences

In [None]:
from nltk.stem.wordnet import WordNetLemmatizer

lemmatizer = WordNetLemmatizer()

lemma_occurences = {}
for word, occurences in word_occurences.items():
    
    lemma = lemmatizer.lemmatize(word)
    if lemma not in lemma_occurences:
        lemma_occurences[lemma] = occurences
    else:
        lemma_occurences[lemma] += occurences
        
lemma_occurences = dict(sorted(list(lemma_occurences.items()), key=lambda pair: -len(pair[1])))

In [None]:
from nltk.corpus import wordnet as wn

stimulus_lemmas = [[] for _ in range(73000)]
for lemma, lemma_stimulus_ids in lemma_occurences.items():
    for stimulus_id in lemma_stimulus_ids:
        stimulus_lemmas[stimulus_id].append(lemma)
    

In [None]:
stopwords = [
    'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours', 
    'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 
    'it', 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 
    'whom', 'this', 'that', 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 
    'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 
    'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 
    'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 
    'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 
    'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 
    'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 
    'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', 'should', 'now', '', "it's", 'ha'
]

def top_lemmas_73k(Y_dims, num_images=10, min_occurences=3,):
    Y_full_activations = Y_full @ Y_dims.T
    
    lemmas_all = []
    for i in range(Y_dims.shape[0]):
        y = Y_full_activations[:, i]
        stimulus_ids = np.argsort(y)[::-1]
        lemmas = []
        for stimulus_id in stimulus_ids[:num_images]:
            lemmas += set(list(stimulus_lemmas[stimulus_id]))
        lemmas = list(zip(*np.unique(lemmas, return_counts=True)))
        lemmas = [lemma for lemma in lemmas if lemma[1] > min_occurences and lemma[0] not in stopwords]
        lemmas = sorted(lemmas, key=lambda pair: -pair[1])
        lemmas_all.append(lemmas)
    return lemmas_all

In [None]:
run_name = 'run26'
subject_id = 0

out_path = results_path / 'optimized_clip_dimensions' / run_name / f'subj0{subject_id+1}'
W = np.load(out_path / 'W.npy')

In [None]:
W_images = top_images_73k(W, num_images=250, concatenate=False)

In [None]:


@interact(w_id=(0, 127), randomize=False)
def show(w_id, randomize):
    ids = np.arange(250)
    np.random.shuffle(ids)
    w_images = np.stack([W_images[w_id][i] for i in ids[:25]])
    w_images = rearrange(w_images, '(i1 i2) w h c -> (i1 w) (i2 h) c', i1=5)
    plt.figure(figsize=(10, 10))
    plt.imshow(w_images)


In [None]:
run_name = 'run26'
subject_id = 0

out_path = results_path / 'optimized_clip_dimensions' / run_name / f'subj0{subject_id+1}'
W = np.load(out_path / 'W.npy')
num_images = 250
min_occurences = 10
W_lemmas = top_lemmas_73k(W, num_images=num_images, min_occurences=min_occurences)

In [None]:
characters_per_line = 30
labels = []
for lemmas in W_lemmas:
    lemmas = [lemma[0] for lemma in lemmas]
    lines = []
    line = lemmas[0]
    for lemma in lemmas[1:]:
        if len(f'{line}, {lemma}') > characters_per_line:
            lines.append(line)
            line = lemma
        else:
            line = f'{line}, {lemma}'
    label = '\n'.join(lines)
    labels.append(label)
    

In [None]:
images_all = top_images_73k(W, num_images=10, labels=labels, font_size=20)
file_name = f'top_images_73k_lemmas_images-{num_images}_occurences-{min_occurences}.png'
Image.fromarray(images_all).save(out_path / file_name)

In [None]:
W_lemmas_union = []
for lemmas in W_lemmas:
    W_lemmas_union += [lemma[0] for lemma in lemmas]
W_lemmas_union = list(set(W_lemmas_union))
len(W_lemmas_union)

In [None]:
w_lemmas

In [None]:
C = np.zeros((len(W_lemmas_union), W.shape[0]))
for w_id, w_lemmas in enumerate(W_lemmas):
    w_lemmas = [l[0] for l in w_lemmas]
    print(w_id, w_lemmas)
    for lemma_id, lemma in enumerate(W_lemmas_union):
        if lemma in w_lemmas:
            print(lemma, lemma_id, w_id)
            C[w_id, lemma_id] = 1.

In [None]:
W_lemmas[15]

In [None]:
[len(lemmas) for lemmas in W_lemmas]

In [None]:
W_lemmas_union[15]

In [None]:
C.sum(axis=1)

In [None]:
plt.figure(figsize=(24, 24))
plt.yticks(ticks=[i*5 for i in range(len(W_lemmas_union) // 5)])
plt.imshow(C)
plt.show()

In [None]:
from sklearn.linear_model import RidgeCV

model = RidgeCV(fit_intercept=False)

model.fit(C.T, W)
W.shape, C.shape

In [None]:
E = model.coef_
E.shape

In [None]:
W_lemmas_union

In [None]:
import torch.nn.functional as F

out_path = results_path / 'optimized_clip_dimensions' / run_name / f'subj0{subject_id+1}' / 'lemmas'
subject_id = 0

Y = torch.from_numpy(Y_all[subject_id]).cuda()
Y_pred = torch.from_numpy(Y_pred_all[subject_id]).cuda()

Y_test = torch.from_numpy(folds['test']['Y_all'][subject_id]).cuda()
Y_test_pred = torch.from_numpy(folds['test']['Y_pred_all'][subject_id]).cuda()
stimulus_ids_test = torch.from_numpy(folds['test']['stimulus_ids_all'][subject_id])
        
brain_decodability_evaluation(Y_test, Y_test_pred, Y_full, stimulus_ids_test, torch.from_numpy(E.T).float().cuda(), out_path, labels=W_lemmas_union)