# 1. Prepare data

For each voxel, the **BraInCoRL** model takes image - beta value pairs as in-context examples, and predicts the beta values of query images.

First, load the in-context image embeddings, in-context beta values and query images from numpy files.

In [1]:
from pathlib import Path
import numpy as np

data_dir = Path('./data')

# load in-context image embeddings
ic_img_path = data_dir / 'sample_in_ctx_imgs.npz'
ic_img = np.load(ic_img_path)['arr_0']
print('ic_img.shape', ic_img.shape)     # (10 voxels, 100 in-context images, 512 embed dim)

# load in-context beta values
ic_beta_path = data_dir / 'sample_in_ctx_betas.npz'
ic_beta = np.load(ic_beta_path)['arr_0']
print('ic_beta.shape', ic_beta.shape)     # (10 voxels, 100 in-context images, )

# load query image embeddings
query_img_path = data_dir / 'sample_query_imgs.npz'
query_img = np.load(query_img_path)['arr_0']
print('query_img.shape', query_img.shape)     # (10 voxels, 20 query images, 512 embed dim)

ic_img.shape (10, 100, 512)
ic_beta.shape (10, 100)
query_img.shape (10, 20, 512)


# 2. Load pretrained model

The model script is in `./scripts/model.py` and the checkpoints should be placed in `./scripts/` folder.

In [2]:
import sys
sys.path.append('./scripts')
from model import HyperweightsPredictorModel    # the BraInCoRL model
import torch

backbone_type = 'CLIP'  # or change to other image embedding backbones
predict_subj = 1
device = f'cuda' if torch.cuda.is_available() else 'cpu'

model = HyperweightsPredictorModel(backbone_type=backbone_type).to(device)

# load weights
model_ckpt_path = f'./checkpoints/{backbone_type}_subj{predict_subj}.pth'
checkpoint = torch.load(model_ckpt_path, weights_only=True)
model.load_state_dict(checkpoint)


HyperweightsPredictorModel INITIALIZATION PARAMETERS
embed_dim: 512
internal_emb_dim: 560
num_tsfm_layers: 20
tsfm_hidden_dim: 2048
num_reg_tok: 4
num_heads: 10
num_early_lyr: 1
num_w_pred_layers: 1
early_hidden_dim: 1120
w_pred_hidden_dim: 1120
dropout: 0



<All keys matched successfully>

# 3. Inference

Feed the loaded data to the model to get the inference results.

The model output result consists of two parts:

1. The predicted beta values of the query images

2. the predicted voxelwise mapping weights, so that  
`weights(image_embed) == predicted_beta`

In [3]:
ic_img = torch.from_numpy(ic_img).to(device).float()
ic_beta = torch.from_numpy(ic_beta).to(device).float()
query_img = torch.from_numpy(query_img).to(device).float()

predicted_betas, predicted_weights = model(ic_img, ic_beta, query_img)

print('predicted_betas.shape', predicted_betas.shape)       # (10 voxels, 20 query images)
print('predicted_weights.shape', predicted_weights.shape)   # (10 voxels, 513 projection weights dimension)

predicted_betas.shape torch.Size([10, 20])
predicted_weights.shape torch.Size([10, 513])
