Compressed Sensing MRI based on Generative Adversarial Network
Morteza Mardani, Enhao Gong
Deep Generative Adversarial Networks for Compressed Sensing Automates MRI Arxiv Paper: https://arxiv.org/abs/1706.00051
The basic code base is derived from super resolution github repo. https://github.com/david-gpu/srez
https://github.com/david-gpu/srez https://arxiv.org/abs/1609.04802 https://github.com/carpedm20/DCGAN-tensorflow http://wiseodd.github.io/techblog/2017/03/02/least-squares-gan/ http://wiseodd.github.io/techblog/2017/02/04/wasserstein-gan/
Deep Generative Adversarial Networks for CS MRI
Magnetic resonance imaging (MRI) suffers from aliasing artifacts when it is highly undersampled for fast imaging. Conventional CS MRI reconstruction uses regularized iterative reconstruction based on pre-defined sparsity transform, which usually include time-consuming iterative optimization and may result in undesired artifacts such as oversmoothing. Here we propose a novel CS framework that permeates benefits from deep learning and generative adversarial networks (GAN) to modeling a manifold of MR images from historical patients. Extensive evaluations on a large MRI datasets of pediatric pateints show it results in superior perforamnce, retrieves image with improved quality and finer details relative to conventional CS and pixel-wise deep learning schemes.
1D undersampling is generated using the variable density distribution.
R_factor defines the desired reduction factor, which controls how many samples to randomly pick.
R_alpha defines the decay of VD distribution with formula
R_seed defines the random seed used, for negative values there is no fixed undersampling pattern.
sampling_pattern can be a path to .mat file for specific 1D/2D undersampling mask
Several models are explored including ResNet-ish models from super-resolution paper and encoder-decoder models.
Currently we are using 4*(Conv-BN-RELU-POOL)+2*(CONV-BN-RELU)+CONV+MEAN+softmax (for logloss)
descriminator related loss
We have been exploring different loss functions for GAN, including:
- LS loss (better than log-loss, use as default, easy to tune and optimize)
- Cycle-GAN/WGAN loss (todo)
Loss is a mixed combination with: 1) Data consistency loss, 2) pixel-wise MSE/L1/L2 loss and 3) LS-GAN loss
FLAGS.gene_log_factor = 0 # log loss vs least-square loss
FLAGS.gene_dc_factor = 0.9 # data-consistency (kspace) loss vs generator loss
FLAGS.gene_mse_factor = 0.001 # simple MSE loss originally for forward-passing model vs GAN loss
GAN loss = generator loss + discriminator loss
gene_fool_loss = FLAGS.gene_log_factor * gene_log_loss + (1-FLAGS.gene_log_factor) * gene_LS_loss
gene_non_mse_loss = FLAGS.gene_dc_factor * gene_dc_loss + (1-FLAGS.gene_dc_factor) * gene_fool_loss
gene_loss = FLAGS.gene_mse_factor * gene_mse_loss + (1- FLAGS.gene_mse_factor) * gene_non_mse_loss
Dataset is parsed to pngs saved in specific folders
- Phantom dataset
- Knee dataset
- DCE dataset
Multiple results are exported while traning
- loss changes
- test results (zero-fill VS Recon VS Ref) saved to png for each epoch
- some of the train results are exported to png, WARNING, there was a memory bug before when we try to export all train results
- layers (some layers are skipped) of generator and detectors are exported into json after each epoch.
srez_main.pyhandles commendline arguments
srez_input.pyloads from input files, conducting undersampling or loading/extracting the undersampled patterns,
srez_model.pydefines Generator and Discriminator models
srez_train.pytrains model in batch and export checkpoints and results for progress summaries
(currently working on t2)
python srez_main.py --dataset_input /home/enhaog/GANCS/srez/dataset_MRI/phantom --batch_size 8 --run train --summary_period 123 --sample_size 256 --train_time 10 --train_dir train_save_all --R_factor 4 --R_alpha 3 --R_seed 0
(currently working on t2 for DCE)
python srez_main.py --run train --dataset_input /home/enhaog/GANCS/srez/dataset_MRI/abdominal_DCE --sample_size 200 --sample_size_y 100 --sampling_pattern /home/enhaog/GANCS/srez/dataset_MRI/sampling_pattern_DCE/mask_2dvardesnity_radiaview_4fold.mat --batch_size 4 --summary_period 125 --sample_test 32 --sample_train 10000 --train_time 200 --train_dir train_DCE_test