In [None]:
%load_ext autoreload
%autoreload 2

## Load the model and set up a hook

In [64]:
import torch
from neuroai.models import ResNet18

backbone_model = ResNet18(pretrained=True, download_root="./pretrained_checkpoints")

In [66]:
from neuroai.utils import ForwardHook

hook = ForwardHook(model=backbone_model, hook_layer_name = "model.layer4.0.conv1")

In [None]:
x = torch.randn(1,3,224,224)
y = backbone_model(x)

## it's of shape (batch, channels, height, width)
hook.output.shape

## Extracting features

In [131]:
from neuroai.datasets import NSDAllSubjectSingleRegion

device = "mps:0"

## lets move the model to our device
model = backbone_model.to(device)

region = "PPA"
subject_id = "s1"

dataset = NSDAllSubjectSingleRegion(
    folder="./datasets/nsd",
    region=region,
    transforms=backbone_model.transforms
)

all_features = []
all_fmri_voxels = []


with torch.no_grad():
    for i in range(len(dataset)):
        image = dataset[i]["image"]

        ## (voxels) -> (1, voxels)
        fmri_data = dataset[i]["fmri_response"][subject_id].unsqueeze(0)
        
        ## (channels, height, width) -> (1, channels, height, width)
        image= image.unsqueeze(0)
        image = image.to(device)
        y = backbone_model(image)

        ## making sure that we're moving stuff back to the RAM with .cpu()
        all_features.append(hook.output.cpu())
        all_fmri_voxels.append(fmri_data.cpu())

In [None]:
len(all_features), len(all_fmri_voxels)

In [None]:
all_features[0].shape, all_fmri_voxels[0].shape

In [135]:
all_features = torch.cat(all_features, dim = 0)
all_fmri_voxels = torch.cat(all_fmri_voxels, dim = 0)

In [None]:
## num samples, channels, height, width
print(f"Shape of features", all_features.shape)

## num samples, voxels
print(f"Shape of voxels", all_fmri_voxels.shape)

In [None]:
from einops import rearrange

all_features_flattened = rearrange(
    all_features,
    "batch channels height width -> batch (channels height width)"
)

print(f"Shape of features after flattening", all_features_flattened.shape)

In [138]:
from einops import reduce
all_fmri_voxels_region_mean = reduce(
    all_fmri_voxels,
    pattern = "batch voxels -> batch",
    reduction="mean"
)

In [None]:
from neuroai.utils.regression import ridge_regression

num_train_samples = 900

X = {
    "train": all_features_flattened[:num_train_samples],
    "test": all_features_flattened[num_train_samples:],
}
Y = {
    "train": all_fmri_voxels_region_mean[:num_train_samples],
    "test": all_fmri_voxels_region_mean[num_train_samples:],
}
ridge_result = ridge_regression(
    X_train=X["train"],
    Y_train=Y["train"],
    device="cpu"
)
print(ridge_result)

In [140]:
from neuroai.utils.regression import RidgeModel

model = RidgeModel(
    backbone_model=backbone_model,
    transforms=backbone_model.transforms,
    hook_layer_name="model.layer4.0.conv1",
    ridge_result=ridge_result,
    device="mps:0"
)


In [None]:
correlation = model.evaluate(
    x_test=X["test"],
    y_test=Y["test"],
)
print(f"Correlation on test set: {correlation}")

In [142]:
# !wget -O face.jpg "https://img.freepik.com/free-photo/portrait-white-man-isolated_53876-40306.jpg"
# !wget -O body.jpg "https://images.unsplash.com/photo-1586710743237-1eb35c3c881c?fm=jpg"