## Save predicted stats to improve renormalization in decoding

In [1]:
import torch
from os.path import join as opj
import numpy as np
import pickle
import tqdm

In [2]:
base_path="/home/matteo/data/NSD"
timeseries_path=opj(base_path,"nsddata_timeseries")
betas_path=opj(base_path,"nsddata_betas")

stimuli_path=opj(base_path,"nsddata_stimuli","stimuli","nsd")
stim_file_path=opj(stimuli_path,"nsd_stimuli.hdf5")
sub="subj01"
mod="func1pt8mm"
subj_data_path=opj(timeseries_path,"ppdata",sub,mod,"timeseries")
subj_betas_path=opj(betas_path,"ppdata",sub,mod,"betas_assumehrf")

subj_betas_roi_extracted_path=opj(base_path,"processed_roi",sub,mod)

stim_order_path=opj(base_path,"nsddata","experiments","nsd","nsd_expdesign.mat")
stim_info_path=opj(base_path,"nsddata","experiments","nsd","nsd_stim_info_merged.csv")
stim_captions_train_path=opj(base_path,"nsddata_stimuli","stimuli","nsd","annotations",f"captions_train2017.json")
stim_captions_val_path=opj(base_path,"nsddata_stimuli","stimuli","nsd","annotations",f"captions_val2017.json")

processed_data=opj(base_path,"processed_roi",sub)

## 1. load train fmri data and brain to latent models

In [7]:
train_fmri=torch.load(f"models/{sub}/train_fmri.pt")
train_fmri_mean=train_fmri.mean(0)
train_fmri_std=train_fmri.std(0)

train_fmri_norm=torch.nan_to_num((train_fmri-train_fmri_mean)/train_fmri_std)

In [4]:
## load brain models
keep=31
max_len_img=257
max_len_txt=77

keys=np.arange(keep)
# filename='brain_to_latent_ridge.sav'
brain_to_latent = {}
#     pickle.load(open(opj(f"models/{sub}/decoding",filename), 'rb'))


brain_to_img_emb=[]
brain_to_txt_emb=[]

for k in keys:
    filename = f'brain_to_vdvae_latent_ridge_{k}.sav'
    p=pickle.load(open(opj(f"models/{sub}/decoding",filename), 'rb'))
    brain_to_latent[k]=p


for i in range(max_len_img):
    filename = f'brain_to_img_emb_ridge_{i}.sav'
    p=pickle.load(open(opj(f"models/{sub}/decoding",filename), 'rb'))
    brain_to_img_emb.append(p)
for i in range(max_len_txt):
    filename = f'brain_to_txt_emb_ridge_{i}.sav'
    p=pickle.load(open(opj(f"models/{sub}/decoding",filename), 'rb'))
    brain_to_txt_emb.append(p)

In [10]:
stats={}

shapes={0:(16,1,1),
        1: (16, 1, 1),
         2: (16, 4, 4),
         3: (16, 4, 4),
         4: (16, 4, 4),
         5: (16, 4, 4),
         6: (16, 8, 8),
         7: (16, 8, 8),
         8: (16, 8, 8),
         9: (16, 8, 8),
         10: (16, 8, 8),
         11: (16, 8, 8),
         12: (16, 8, 8),
         13: (16, 8, 8),
         14: (16, 16, 16),
         15: (16, 16, 16),
         16: (16, 16, 16),
         17: (16, 16, 16),
         18: (16, 16, 16),
         19: (16, 16, 16),
         20: (16, 16, 16),
         21: (16, 16, 16),
         22: (16, 16, 16),
         23: (16, 16, 16),
         24: (16, 16, 16),
         25: (16, 16, 16),
         26: (16, 16, 16),
         27: (16, 16, 16),
         28: (16, 16, 16),
         29: (16, 16, 16),
         30: (16, 32, 32)}

for k,v in brain_to_latent.items():
    print(k)
    s=shapes[k]
    z=torch.tensor(v.predict(train_fmri_norm.numpy())).reshape(-1,*s)
    
    stats[k]={"mean":z.mean(0),"std":z.std(0)}

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


In [11]:
filename="predicted_latent_stats.sav"

with open(opj(f"models/{sub}",filename),"wb") as f:
    pickle.dump(stats,f)

## 2. Same for embeddings

In [12]:
img_emb=[]
txt_emb=[]

for i in tqdm.tqdm(range(257)):
    emb=torch.tensor(brain_to_img_emb[i].predict(train_fmri_norm.numpy()))
    img_emb.append(emb)

    
for i in tqdm.tqdm(range(77)):
    emb=torch.tensor(brain_to_txt_emb[i].predict(train_fmri_norm.numpy()))
    txt_emb.append(emb)
    
img_emb=torch.stack(img_emb,1)
txt_emb=torch.stack(txt_emb,1)


100%|█████████████████████████████████████████| 257/257 [01:33<00:00,  2.75it/s]
100%|███████████████████████████████████████████| 77/77 [00:28<00:00,  2.66it/s]


In [13]:
predicted_img_emb_mean=img_emb.mean(0)
predicted_img_emb_std=img_emb.std(0)

predicted_txt_emb_mean=txt_emb.mean(0)
predicted_txt_emb_std=txt_emb.std(0)


In [14]:

# Define the file paths
img_emb_mean_path = f"models/{sub}/predicted_img_emb_mean.pt"
img_emb_std_path = f"models/{sub}/predicted_img_emb_std.pt"
txt_emb_mean_path = f"models/{sub}/predicted_txt_emb_mean.pt"
txt_emb_std_path = f"models/{sub}/predicted_txt_emb_std.pt"

# Save the tensors
torch.save(predicted_img_emb_mean, img_emb_mean_path)
torch.save(predicted_img_emb_std, img_emb_std_path)
torch.save(predicted_txt_emb_mean, txt_emb_mean_path)
torch.save(predicted_txt_emb_std, txt_emb_std_path)

In [15]:
predicted_img_emb_mean.shape

torch.Size([257, 768])

In [16]:
torch.save(train_fmri_mean,f"models/{sub}/train_fmri_mean.pt")
torch.save(train_fmri_std,f"models/{sub}/train_fmri_std.pt")
