<a href="https://colab.research.google.com/github/Achillesy/Fetal_Functional_MRI_Segmentation/blob/master/fmri_vnet_interface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Automated Brain Masking of Fetal Functional MRI with Open Data

## Preprocessing, downloading resources from my github repository

https://github.com/Achillesy/Fetal_Functional_MRI_Segmentation
![xuchu](https://avatars.githubusercontent.com/u/5572441?s=16) 

In [None]:
!wget https://github.com/Achillesy/Fetal_Functional_MRI_Segmentation/releases/download/v1.0.0/fold4_train_metric_vnet.pth

## Install ![monai](https://monai.io/assets/img/MONAI-logo_color.png) and check environment variables

In [None]:
!pip install monai

from monai.config import print_config
print_config()


In [None]:
import os
import numpy as np
import nibabel as nib
from glob import glob

import torch
from types import SimpleNamespace
from google.colab import files

cfg = SimpleNamespace(**{})
cfg.pixdim = (3.5, 3.5, 3.5)
cfg.roi_size = [64, 64, 64]
cfg.sw_batch_size = 4
cfg.file_pth = "fold4_train_metric_vnet.pth"

cfg.mri_dir = "mri"
cfg.mask_dir = "mask"
os.makedirs(cfg.mri_dir, exist_ok=True)
os.makedirs(cfg.mask_dir, exist_ok=True)

cfg.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


## Load VNet model and parameters

In [None]:
from monai.networks.nets import VNet

model = VNet(
  spatial_dims=3,
  in_channels=1,
  out_channels=2,
  act=("elu", {"inplace": True}),
  dropout_dim=3,
  bias=False,
).to(cfg.device)
model.load_state_dict(torch.load(cfg.file_pth, cfg.device))


In [None]:
from monai.transforms import (
  AsDiscreted,
  Compose,
  EnsureChannelFirstd,
  Invertd,
  LoadImaged,
  NormalizeIntensityd,
  Orientationd,
  SaveImaged,
  Spacingd,
)

test_transforms = Compose(
  [
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    Orientationd(keys=["image"], axcodes="RAS"),
    Spacingd(keys=["image"], pixdim=cfg.pixdim, mode="bilinear"),
    NormalizeIntensityd(keys="image", nonzero=True),
  ]
)

post_transforms = Compose(
  [
    Invertd(
      keys="pred",
      transform=test_transforms,
      orig_keys="image",
      meta_keys="pred_meta_dict",
      orig_meta_keys="image_meta_dict",
      meta_key_postfix="meta_dict",
      nearest_interp=False,
      to_tensor=True,
    ),
    AsDiscreted(keys="pred", argmax=True),
    SaveImaged(
      keys="pred",
      meta_keys="pred_meta_dict",
      output_dir=cfg.mask_dir,
      output_postfix="vnet",
      resample=False,
    ),
  ]
)

## Please upload your Fetal Functional MRI files

![upload_guide](https://github.com/Achillesy/Fetal_Functional_MRI_Segmentation/blob/master/figures/upload_guide.png?raw=1)
1. Click the **Files** icon on the left
2. Click the **Upload to session storage** icon above
3. Your uploaded files will be displayed here
----
After double-checking your uploaded files, by <font color="green">pressing the **Enter** key in the input box below</font>, the fMRI mask will be automatically generated in a short time.

In [None]:
input()


## Split into multiple 3D MRI files

In [None]:
frmi_files = glob("*.nii.gz")
for fmri_data in frmi_files:
  fmri_data_name = os.path.basename(fmri_data).replace(".nii.gz", "")
  image = nib.load(fmri_data)
  data = image.get_fdata()
  if len(data.shape) != 4:
    raise ValueError("Invalid shape of fMRI file format. Expected 4D shape: [x, y, z, t]")
  channel_list = np.split(data, data.shape[-1], axis=-1)
  for i, channel in enumerate(channel_list):
    channel_image = nib.Nifti1Image(channel, image.affine)
    channel_file_name = os.path.join(cfg.mri_dir, f"{fmri_data_name}_{i+1}.nii.gz")
    nib.save(channel_image, channel_file_name)


In [None]:
test_files = []
rmi_files = glob(os.path.join(cfg.mri_dir, "*.nii.gz"))
for f_file in rmi_files:
  test_files.append({"image": f_file})
print(test_files)


## Generate MRI mask

In [None]:
from monai.inferers import sliding_window_inference
from monai.data import DataLoader, Dataset, decollate_batch

test_ds = Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1)


In [None]:
with torch.no_grad():
  for test_data in test_loader:
    test_inputs = test_data["image"].to(cfg.device)
    test_data["pred"] = sliding_window_inference(
      test_inputs, cfg.roi_size, cfg.sw_batch_size, model
    )
    test_data = [post_transforms(i) for i in decollate_batch(test_data)]

## Combined into fMRI mask

In [None]:
for fmri_data in frmi_files:
  image = nib.load(fmri_data)
  data = image.get_fdata()

  fmri_data_name = os.path.basename(fmri_data).replace(".nii.gz", "")
  mask_data = np.zeros_like(data)
  for i in range(data.shape[-1]):
    i_mask_file = os.path.join(cfg.mask_dir,  f"{fmri_data_name}_{i+1}", f"{fmri_data_name}_{i+1}_vnet.nii.gz")
    i_mask_data = nib.load(i_mask_file).get_fdata()
    mask_data[:,:,:,i] = i_mask_data
  fmri_mask = nib.Nifti1Image(mask_data, affine=image.affine, header=image.header)
  mask_data_name = f"{fmri_data_name}_vnet.nii.gz"
  fmri_mask.to_filename(mask_data_name)
  # files.download(mask_data_name)

Please <font color="red">right-click</font> the generated mask files (**your_frmi_name_vnet.nii.gz**) to open the download link.

## Clean up temporary files

In [None]:
!rm -rf {cfg.mri_dir}
!rm -rf {cfg.mask_dir}
