MEMAE: Microstructure Informed Mamba Vision Masked Autoencoder for Personalized Brain Injury Detection from Diffusion MRI
This repository contains the official implementation for MEMAE, a model designed for personalized brain injury detection from diffusion MRI data.
Follow these steps to set up the necessary conda environment and install dependencies.
-
Create and activate the conda environment:
conda create -n memae python=3.10.13 conda activate memae
-
Install PyTorch and CUDA:
# Install CUDA Toolkit conda install cudatoolkit==11.8 -c nvidia # Install PyTorch (v2.1.1 for cu118) pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url [https://download.pytorch.org/whl/cu118](https://download.pytorch.org/whl/cu118)
-
Install Mamba and other dependencies:
# Install CUDA compiler (needed for Mamba) conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc # Install packaging conda install packaging # Install Mamba (SSM) pip install mamba-ssm
For more details on the Mamba architecture, visit state-spaces/mamba.
- Module:
data_set/ - Description: This step involves standardizing the resolution and dimensions of all input images. Data is also normalized to prepare it for model training.
To begin training the MEMAE model, run the main training script.
- Command:
(Note: The arguments
python train.py -pdir /MEMAE /parameter/par.yml -gpu 0
-pdir /MEMAE /parameter/par.ymlare based on your input. Please adjust paths and arguments as needed.)
To run inference on the test set using a trained model.
- Command:
python test.py
This module is used for the creation and utilization of the prior knowledge base.
- Script:
jkzxd.py