In [1]:
from neuroai.datasets import NSDAllSubjectSingleRegion, download_and_extract_nsd
import matplotlib.pyplot as plt

In [None]:
## takes up 439 MB
download_and_extract_nsd(
    zip_filename = "nsd.zip",
    output_folder = "./datasets"
)

In [None]:
region = "FFA"
dataset = NSDAllSubjectSingleRegion(
    folder="./datasets/nsd",
    region=region,
)

print(len(dataset))

In [None]:
## what does it contain?
print(dataset[0].keys())

## what is the shape of the image?
print(dataset[0]["image"].shape)

In [None]:
## let's visualize a few images
## the transpose is needed because matplotlib expects (H, W, C), but we have (C, H, W)
plt.imshow(dataset[30]["image"].transpose(1,2,0))
plt.show()

In [None]:
## now lets visualize the corresponding fMRI data
## lets start by looking at the subjects we have:
print(dataset[30]["fmri_response"].keys())

## lets count the number of voxels in the subject's FFA
subject_id = "s1"
print(dataset[30]["fmri_response"][subject_id].shape)


In [None]:
subject_id = "s1"
sample_indices = [10, 20, 30]  # you can change or extend this list
n_samples = len(sample_indices)
fig, axes = plt.subplots(n_samples, 2, figsize=(7, 3 * n_samples))

for i, idx in enumerate(sample_indices):
    img = dataset[idx]["image"].transpose(1, 2, 0)
    fmri = dataset[idx]["fmri_response"][subject_id].numpy()
    # Handle axes for single row case
    ax_img = axes[i, 0] if n_samples > 1 else axes[0]
    ax_fmri = axes[i, 1] if n_samples > 1 else axes[1]
    ax_img.imshow(img)
    ax_img.set_title(f"Image (idx={idx})")
    ax_img.axis('off')
    ax_fmri.plot(fmri)
    ax_fmri.set_title(f"fMRI response in {subject_id}'s {region} (idx={idx})")
    ax_fmri.set_ylim(-3, 3)
    ax_fmri.grid()

plt.tight_layout()
plt.show()