In [1]:
import sys
import os
import struct
import time as time
import numpy as np
import pandas as pd
import h5py
import PIL.Image
from scipy import stats
from itertools import chain
from scipy.io import loadmat
from tqdm import tqdm
import pickle
import math
import matplotlib.pyplot as plt
import csv
from itertools import zip_longest

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import models

In [2]:
import neurogen.numpy_utility as pnu
from neurogen.plots import display_candidate_loss
from neurogen.file_utility import save_stuff, flatten_dict, embed_dict
from neurogen.torch_fwrf import get_value

from neurogen.torch_fwrf import learn_params_ridge_regression, get_predictions, Torch_fwRF_voxel_block
from neurogen.encoding import load_encoding
from neurogen.visualize import center_crop

In [3]:
#paths

nsd_root = "data/nsd/"
# stim_root = "/home/hhan228/memorability/Mansoure/" + "nsd_stimuli/"
# beta_root = nsd_root + "nsd_beta/"
#mask_root = nsd_root + "mask/ppdata/"
#roi_root = nsd_root + "freesurfer/"
meanROIbeta_root = nsd_root + "roiavgbeta_neurogen/"
# weight_root = "neurogen/output/"
weight_base_dir = "/home/hhan228/memorability/Willow/neurogen_output/"
# weight_root = weight_base_dir+"alexnet/"
weight_root = weight_base_dir+"resnet/"

exp_design_file = nsd_root + "nsd_expdesign.mat"
stim_file = nsd_root + "shared1000_original.npy"

In [4]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

## With memorability-controlled images

In [5]:
ROIs = [
    'OFA', 'FFA1', 'FFA2', 'mTLfaces', 'aTLfaces',
    'EBA', 'FBA1', 'FBA2', 'mTLbodies',
    'OPA', 'PPA', 'RSC',
    'V1v', 'V1d', 'V2v', 'V2d', 'V3v', 'V3d', 'hV4',
    'L-hippocampus', 'L-amygdala', 'R-hippocampus', 'R-amygdala'
]
trials = np.array([30000, 30000, 24000, 22500, 30000, 24000, 30000, 22500])

In [6]:
thr = 0.5
base_dir = f"data/synthetic_new_categories_ver3_{thr}/"
# categories = ["animals", "foods", "landscapes", "vehicles"]
categories = ["animals", "foods", "humans", "places"]
controlled = ["original", "increased", "decreased"]

### Load subject

In [44]:
# parameters
subject = 8
savearg = {'format':'png', 'dpi': 120, 'facecolor': None}
model_name = 'dnn_fwrf'

data_size = trials[subject-1]

In [45]:
roi_num = len(ROIs)
roi_data = np.zeros([data_size, roi_num])
n = 0
del_idx = []
for roi in ROIs:
    roi_data[:,n] = np.genfromtxt(meanROIbeta_root + f'subj{subject:02d}/meanbeta_' + roi + '.txt')
    if np.isnan(np.sum(roi_data[:,n])):
        print(roi)
        del_idx.append(n)
    n += 1

ROIs_bool = np.ones((23,1), dtype='bool')
ROIs_bool[del_idx] = False

roi_data = np.delete(roi_data, del_idx, axis=1)

In [46]:
# load params
model_params_set = h5py.File(weight_root+f'S{subject:02d}/model_params.h5py' , 'r')
# model_params_set = h5py.File(weight_root+f'S{subject:02d}/dnn_fwrf/model_params.h5py' , 'r')
model_params = embed_dict({k: np.copy(d) for k,d in model_params_set.items()})
model_params_set.close()

# load encoding models
fwrf, fmaps = load_encoding(model_params, fmap_name='resnet', device=device)

In [47]:
for cat in categories:
    for ctrld in controlled:
        image_data = np.load(base_dir+f"{cat}/{ctrld}_imgs_200.npy")
        image_data = image_data.astype(np.float32) / 255.
        # predict brain response
        voxel_pred = get_predictions(image_data, fmaps, fwrf, model_params['params'])
        pred_act = np.zeros((200,23))
        pred_act[:,ROIs_bool[:,0]] = voxel_pred
        save_dir = base_dir+f"{cat}/predicted_responses/"
        os.makedirs(save_dir, exist_ok=True)
        np.save(save_dir+f"S{subject:02d}_{ctrld}_responses.npy", pred_act)

samples [  100:199  ] of 200, voxels [     0:22    ] of 23
---------------------------------------
total time = 0.088387s
sample throughput = 0.000442s/sample
voxel throughput = 0.003843s/voxel
samples [  100:199  ] of 200, voxels [     0:22    ] of 23
---------------------------------------
total time = 0.091573s
sample throughput = 0.000458s/sample
voxel throughput = 0.003981s/voxel
samples [  100:199  ] of 200, voxels [     0:22    ] of 23
---------------------------------------
total time = 0.091424s
sample throughput = 0.000457s/sample
voxel throughput = 0.003975s/voxel
samples [  100:199  ] of 200, voxels [     0:22    ] of 23
---------------------------------------
total time = 0.091380s
sample throughput = 0.000457s/sample
voxel throughput = 0.003973s/voxel
samples [  100:199  ] of 200, voxels [     0:22    ] of 23
---------------------------------------
total time = 0.091217s
sample throughput = 0.000456s/sample
voxel throughput = 0.003966s/voxel
samples [  100:199  ] of 200, 

In [29]:
# for cat in categories:
#     cat_dir = base_dir+f"{cat}/"
#     original_imgs = sorted([f for f in os.listdir(cat_dir+"original/") if os.path.isfile(os.path.join(cat_dir+"original/", f))])
#     increased_imgs = sorted([f for f in os.listdir(cat_dir+"increased/") if os.path.isfile(os.path.join(cat_dir+"increased/", f))])
#     decreased_imgs = sorted([f for f in os.listdir(cat_dir+"decreased/") if os.path.isfile(os.path.join(cat_dir+"decreased/", f))])
    
#     if cat != "vehicles":
#         np.random.seed(42)
#         sampled_img_ids = sorted([f.split("_memcoef")[0] for f in np.random.choice(original_imgs, 97, replace=False)])
#         original_imgs = sorted([f for f in original_imgs if f.split("_memcoef")[0] in sampled_img_ids])
#         increased_imgs = sorted([f for f in increased_imgs if f.split("_memcoef")[0] in sampled_img_ids])
#         decreased_imgs = sorted([f for f in decreased_imgs if f.split("_memcoef")[0] in sampled_img_ids])

#     orig_imgs = []
#     inc_imgs = []
#     dec_imgs = []
#     for orig_f, inc_f, dec_f in zip(original_imgs, increased_imgs, decreased_imgs):
#         orig_imgs.append(np.array(PIL.Image.open(cat_dir+f"original/{orig_f}").convert("RGB")).transpose(2, 0, 1))
#         inc_imgs.append(np.array(PIL.Image.open(cat_dir+f"increased/{inc_f}").convert("RGB")).transpose(2, 0, 1))
#         dec_imgs.append(np.array(PIL.Image.open(cat_dir+f"decreased/{dec_f}").convert("RGB")).transpose(2, 0, 1))

#     orig_imgs = np.array(orig_imgs)
#     inc_imgs = np.array(inc_imgs)
#     dec_imgs = np.array(dec_imgs)

#     print(cat.upper())
#     print(f"orig_imgs: {orig_imgs.shape}")
#     print(f"inc_imgs: {inc_imgs.shape}")
#     print(f"dec_imgs: {dec_imgs.shape}\n")

#     np.save(cat_dir+"original_imgs_97.npy", orig_imgs)
#     np.save(cat_dir+"increased_imgs_97.npy", inc_imgs)
#     np.save(cat_dir+"decreased_imgs_97.npy", dec_imgs)

# def resize_image_tensor(x, newsize):
#     tt = x.transpose((0,2,3,1))
#     r  = np.ndarray(shape=x.shape[:1]+newsize+(x.shape[1],), dtype=tt.dtype)
#     for i,t in enumerate(tt):
#         r[i] = np.asarray(PIL.Image.fromarray(t).resize(newsize, resample=PIL.Image.BILINEAR))
#     return r.transpose((0,3,1,2))