## This starter notebook uses the mosaic-dataset python package to:
 1. download MOSAIC hdf5 files containing the fMRI beta responses from the AWS bucket (https://mosaicfmri.s3.amazonaws.com/index.html)
 2. visualize single trial beta values on an inflated brain
 3. download brain optimized model weights and load the model
 4. run inference on a model
 5. visualize model predictions on an inflated brain

In [None]:
!pip install mosaic-dataset --upgrade

### 1. Donwload MOSAIC hdf5 files containing the fMRI beta responses

In [None]:
import mosaic

#this method locally downloads the specified hdf5 file(s) if not yet already downloaded. Additionally, the returned dataset variable formats the responses by ROI (MMP1.0 parcellation) and concatenates multiple subjects together, if applicable
dataset = mosaic.load(
    names_and_subjects={
        "GOD": [1],
    },
    folder="./mosaic-dataset" 
)

print(dataset[0].keys())

### 2. Visualize single trial beta values on an inflated brain

In [None]:
from mosaic.utils import visualize

visualize(
    betas=dataset[0]["betas"],
    ## set rois to None if you want to visualize all of the rois
    rois=[
        "L_FFC",
        "R_FFC",
        "L_V1",
        "R_V1"
    ],
    ## other modes are: 'white', 'midthickness', 'pial', 'inflated', 'very_inflated', 'flat', 'sphere'
    mode = "inflated",
    save_as = "plot.html",
)

### 3. Download brain optimized model weights and load the model

In [None]:
import mosaic

model, model_config = mosaic.from_pretrained(
    backbone_name="CNN8",
    framework="multihead",
    subjects="all",
    vertices="visual",
)

### 4. Run inference on the brain-optimized model

In [None]:
!wget -O face.jpg https://images.unsplash.com/photo-1542909168-82c3e7fdca5c

In [None]:
#visualize the image
from PIL import Image
im = Image.open("face.jpg").convert("RGB")
im

In [None]:
from mosaic.utils.inference import MosaicInference

inference = MosaicInference(
    model=model,
    batch_size=32,
    model_config=model_config,
    device="cpu"
)

results = inference.run(
    images = [
        Image.open("face.jpg").convert("RGB"),
    ],
    names_and_subjects={"NSD": "all", "GOD": [1,2]}
)

#inference returns vertex predictions for each of the subjects
for dataset in results.keys():
    for subjectID, prediction in results[dataset].items():
        print(f"{dataset} {subjectID} prediction shape: {prediction.shape}")

### 5. visualize model predictions on an inflated brain

In [None]:
#note responses to the face are highest in the ventral stream
inference.plot(
    image=Image.open("face.jpg").convert("RGB"),
    save_as="predicted_voxel_responses.html",
    dataset_name="NSD",
    subject_id=1,
    mode="inflated"
)