[Paper] [Zhihu(η₯δΉ)]
Hang Guo*, Jinmin Li*, Tao Dai, Zhihao Ouyang, Xudong Ren, and Shu-Tao Xia
Check our paper collection of recent Awesome Mamba work in Low-Level Vision [here] π€.
(*) equal contribution
Abstract: Recent years have witnessed great progress in image restoration thanks to the advancements in modern deep neural networks e.g. Convolutional Neural Network and Transformer. However, existing restoration backbones are usually limited due to the inherent local reductive bias or quadratic computational complexity. Recently, Selective Structured State Space Model e.g., Mamba, have shown great potential for long-range dependencies modeling with linear complexity, but it is still under-explored in low-level computer vision. In this work, we introduce a simple but strong benchmark model, named MambaIR, for image restoration. In detail, we propose the Residual State Space Block as the core component, which employs convolution and channel attention to enhance capabilities of the vanilla Mamba. In this way, our MambaIR takes advantages of local patch recurrence prior as well as channel interaction to produce restoration-specific feature representation. Extensive experiments demonstrate the superiority of our method, for example, MambaIR outperforms Transformer-based baseline SwinIR by up to 0.36dB, using similar computational cost but with global receptive field.
βIf this work is helpful for you, please help star this repo. Thanks!π€
- 2024-2-23: arXiv paper available.
- 2024-2-27: This repo is released.
- 2024-3-01: Pretrained weights for SR and realDN is available. π
- 2024-3-08: The code for ERF visualization and model complexity analysis can be found at
./analysis/
π - 2024-3-19: We have updated the code for MambaIR-light.
- 2024-3-19: The FIRST Mamba-based Real-world SR Model is now available! Enjoy yourself π.
- 2024-05-24:πππWe have released a new repository to collect recent works of Mamba in low-level-vision, please see here if you are instersted ;D
- 2024-06-10: We have released the training and testing config files for Guassian Color Image Denosing, the pre-trained weights are coming soon π
- 2024-06-10: We have also updated the environments installation instruction here for fast building your own mamba environment for reproduce!
- 2024-07-01: π₯ π₯ π₯ Congratulations! Our MambaIR has been accepted by ECCV 2024οΌ
- 2024-07-04: π We have released the training and testing config files for JPEG compression artifact reduction tasks.
- 2024-07-04: The pretrained weight for Guassian Color Image Denosing as well as JPEG Compression Artifact Reduction are now availbale here. The performace of these models is futher improved than the reported one in the paper. And we will update the Arxiv version in the future. Enjoy these new models! π
- 2024-08-19: The previous #params&MACs calculation for Mamba model using the
thop
library has a bug, which was also discussed in #issue44. We have updated the new accurate calculation code which usesfvcore
and additionally registers the previous missing parameters. You can use this new code in./analysis/flops_param_fvcore.py
for complexity analysis. Note that the model complexity obtained from this code is lager than the reported one. We will release a new comparable MambaIR-light model soon, stay tunedοΌ - 2024-10-15: We have updated a new arXiv version of our MambaIR paper, in which we have fixed the results on lightSR tasks.
- 2024-10-15πππA brand new Mamba-base image restoration backbone MambaIRv2 is just around the corner, with significant performance and efficiency improvements. We will release the new paper and code soon~
- Build the repo
- arXiv version
- Release code
- Pretrained weights&log_files
- Add code for complexity analysis and ERF visualization
- Real-world SR
- Guassian Color Image Denosing
- Add Download Link for Visual Results on Common Benckmarks
- JPEG Compression Artifact Redection
- Futher Improvement...
Model | Task | Test_dataset | PSNR | SSIM | model_weights | log_files |
---|---|---|---|---|---|---|
MambaIR_SR2 | Classic SR x2 | Urban100 | 34.15 | 0.9446 | link | link |
MambaIR_SR3 | Classic SR x3 | Urban100 | 29.93 | 0.8841 | link | link |
MambaIR_SR4 | Classic SR x4 | Urban100 | 27.68 | 0.8287 | link | link |
MambaIR_light2 | Lightweight SR x2 | Urban100 | 32.92 | 0.9356 | link | link |
MambaIR_light3 | Lightweight SR x3 | Urban100 | 29.00 | 0.8689 | link | link |
MambaIR_light4 | Lightweight SR x4 | Urban100 | 26.75 | 0.8051 | link | link |
MambaIR_realDN | Real image Denoising | SIDD | 39.89 | 0.960 | link | link |
MambaIR_realSR | Real-world SR | RealSRSet | - | - | link | link |
MambaIR_guassian15 | Guassian Denosing | Urban100 | 35.17 | - | link | link |
MambaIR_guassian25 | Guassian Denosing | Urban100 | 32.99 | - | link | link |
MambaIR_guassian50 | Guassian Denosing | Urban100 | 30.07 | - | link | link |
MambaIR_JEPG10 | JPEG CAR | Classic5 | 30.27 | 0.8256 | link | link |
MambaIR_JPEG30 | JPEG CAR | Classic5 | 33.74 | 0.8965 | link | link |
MambaIR_JPEG40 | JPEG CAR | Classic5 | 34.53 | 0.9084 | link | link |
We achieve state-of-the-art performance on various image restoration tasks. Detailed results can be found in the paper.
This codebase was tested with the following environment configurations. It may work with other versions.
- Ubuntu 20.04
- CUDA 11.7
- Python 3.9
- PyTorch 2.0.1 + cu117
To use the selective scan with efficient hard-ware design, the mamba_ssm
library is needed to install with the folllowing command.
pip install causal_conv1d==1.0.0
pip install mamba_ssm==1.0.1
One can also create a new anaconda environment, and then install necessary python libraries with this requirement.txt and the following command:
conda install --yes --file requirements.txt
One can also reproduce the conda environment with the fllowing simple commands (cuda-11.7 is used, you can modify the yaml file for your cuda version):
cd ./MambaIR
conda env create -f environment.yaml
conda activate mambair
The datasets used in our training and testing are orgnized as follows:
Task | Training Set | Testing Set | Visual Results |
---|---|---|---|
image SR | DIV2K (800 training images) + Flickr2K (2650 images) [complete dataset DF2K download] | Set5 + Set14 + BSD100 + Urban100 + Manga109 [download] | Google Drive |
gaussian color image denoising | DIV2K (800 training images) + Flickr2K (2650 images) + BSD500 (400 training&testing images) + WED(4744 images) [complete dataset DFWB_RGB download] | CBSD68 + Kodak24 + McMaster + Urban100 [download] | Google Drive |
real image denoising | SIDD (320 training images) [complete dataset SIDD download] | SIDD + DND [download] | Google Drive |
grayscale JPEG compression artifact reduction | DIV2K (800 training images) + Flickr2K (2650 images) + BSD500 (400 training&testing images) + WED(4744 images) [complete dataset DFWB_CAR download] | Classic5 + LIVE1 [download] | Google Drive |
-
Please download the corresponding training datasets and put them in the folder datasets/DF2K. Download the testing datasets and put them in the folder datasets/SR.
-
Follow the instructions below to begin training our model.
# Claissc SR task, cropped input=64Γ64, 8 GPUs, batch size=4 per GPU
python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 basicsr/train.py -opt options/train/train_MambaIR_SR_x2.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 basicsr/train.py -opt options/train/train_MambaIR_SR_x3.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 basicsr/train.py -opt options/train/train_MambaIR_SR_x4.yml --launcher pytorch
# Lightweight SR task, cropped input=64Γ64, 2 GPUs, batch size=16 per GPU
python -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 basicsr/train.py -opt options/train/train_MambaIR_lightSR_x2.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 basicsr/train.py -opt options/train/train_MambaIR_lightSR_x3.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 basicsr/train.py -opt options/train/train_MambaIR_lightSR_x4.yml --launcher pytorch
- Run the script then you can find the generated experimental logs in the folder experiments.
-
Download the corresponding training datasets here and put them in the folder
./datasets/DFWB_RGB
. Download the testing datasets and put them in the folder./datasets/ColorDN
. -
Follow the instructions below to begin training:
# train on denosing15
python -m torch.distributed.launch --nproc_per_node=8 --master_port=2414 basicsr/train.py -opt options/train/train_MambaIR_ColorDN_level15.yml --launcher pytorch
# train on denosing25
python -m torch.distributed.launch --nproc_per_node=8 --master_port=2414 basicsr/train.py -opt options/train/train_MambaIR_ColorDN_level25.yml --launcher pytorch
# train on denosing50
python -m torch.distributed.launch --nproc_per_node=8 --master_port=2414 basicsr/train.py -opt options/train/train_MambaIR_ColorDN_level50.yml --launcher pytorch
- Run the script then you can find the generated experimental logs in the folder
./experiments
.
-
Download the corresponding training datasets here and put them in the folder
./datasets/DFWB_CAR
. Download the testing datasets and put them in the folder./datasets/JPEG_CAR
. -
Follow the instructions below to begin training:
# train on jpeg10
python -m torch.distributed.launch --nproc_per_node=8 --master_port=2414 basicsr/train.py -opt options/train/train_MambaIR_CAR_q10.yml --launcher pytorch
# train on jpeg30
python -m torch.distributed.launch --nproc_per_node=8 --master_port=2414 basicsr/train.py -opt options/train/train_MambaIR_CAR_q30.yml --launcher pytorch
# train on jpeg40
python -m torch.distributed.launch --nproc_per_node=8 --master_port=2414 basicsr/train.py -opt options/train/train_MambaIR_CAR_q40.yml --launcher pytorch
- Run the script then you can find the generated experimental logs in the folder
./experiments
.
- Please download the corresponding training datasets and put them in the folder datasets/SIDD. Note that we provide both training and validating files, which are already processed.
- Go to folder 'realDenoising'. Follow the instructions below to train our model.
# go to the folder
cd realDenoising
# set the new environment (BasicSRv1.2.0), which is the same with Restormer for training.
python setup.py develop --no_cuda_extgf
# train for RealDN task, 8 GPUs
python -m torch.distributed.launch --nproc_per_node=8 --master_port=2414 basicsr/train.py -opt options/train_MambaIR_RealDN.yml --launcher pytorch
Run the script then you can find the generated experimental logs in the folder realDenoising/experiments.
- Remember to go back to the original environment if you finish all the training or testing about real image denoising task. This is a friendly hint in order to prevent confusion in the training environment.
# Tips here. Go back to the original environment (BasicSRv1.3.5) after finishing all the training or testing about real image denoising.
cd ..
python setup.py develop
-
Please download the corresponding testing datasets and put them in the folder datasets/SR. Download the corresponding models and put them in the folder experiments/pretrained_models.
-
Follow the instructions below to begin testing our MambaIR model.
# test for image SR.
python basicsr/test.py -opt options/test/test_MambaIR_SR_x2.yml
python basicsr/test.py -opt options/test/test_MambaIR_SR_x3.yml
python basicsr/test.py -opt options/test/test_MambaIR_SR_x4.yml
# test for lightweight image SR.
python basicsr/test.py -opt options/test/test_MambaIR_lightSR_x2.yml
python basicsr/test.py -opt options/test/test_MambaIR_lightSR_x3.yml
python basicsr/test.py -opt options/test/test_MambaIR_lightSR_x4.yml
-
Please download the corresponding testing datasets and put them in the folder
datasets/ColorDN
. -
Download the corresponding models and put them in the folder
experiments/pretrained_models
. -
Follow the instructions below to begin testing our model.
# test on denosing15
python basicsr/test.py -opt options/test/test_MambaIR_ColorDN_level15.yml
# test on denosing25
python basicsr/test.py -opt options/test/test_MambaIR_ColorDN_level25.yml
# test on denosing50
python basicsr/test.py -opt options/test/test_MambaIR_ColorDN_level50.yml
-
Please download the corresponding testing datasets and put them in the folder
datasets/JPEG_CAR
. -
Download the corresponding models and put them in the folder
experiments/pretrained_models
. -
Follow the instructions below to begin testing our model.
# test on jpeg10
python basicsr/test.py -opt options/test/test_MambaIR_JPEG_q10.yml
# test on jpeg30
python basicsr/test.py -opt options/test/test_MambaIR_JPEG_q30.yml
# test on jpeg40
python basicsr/test.py -opt options/test/test_MambaIR_JPEG_q40.yml
-
Download the SIDD test and DND test. Place them in
datasets/RealDN
. Download the corresponding models and put them in the folderexperiments/pretrained_models
. -
Go to folder 'realDenoising'. Follow the instructions below to test our model. The output is in
realDenoising/results/Real_Denoising
.# go to the folder cd realDenoising # set the new environment (BasicSRv1.2.0), which is the same with Restormer for testing. python setup.py develop --no_cuda_ext # test MambaIR (training total iterations = 300K) on SSID python test_real_denoising_sidd.py # test MambaIR (training total iterations = 300K) on DND python test_real_denoising_dnd.py
-
Run the scripts below to reproduce PSNR/SSIM on SIDD.
run evaluate_sidd.m
-
For PSNR/SSIM scores on DND, you can upload the genetated DND mat files to the online server and get the results.
-
Remerber to go back to the original environment if you finish all the training or testing about real image denoising task. This is a friendly hint in order to prevent confusion in the training environment.
# Tips here. Go back to the original environment (BasicSRv1.3.5) after finishing all the training or testing about real image denoising. cd .. python setup.py develop
Please cite us if our work is useful for your research.
@inproceedings{guo2024mambair,
title={MambaIR: A Simple Baseline for Image Restoration with State-Space Model},
author={Guo, Hang and Li, Jinmin and Dai, Tao and Ouyang, Zhihao and Ren, Xudong and Xia, Shu-Tao},
booktitle={ECCV},
year={2024}
}
This project is released under the Apache 2.0 license.
This code is based on BasicSR, ART ,and VMamba. Thanks for their awesome work.
If you have any questions, feel free to approach me at cshguo@gmail.com