Controlling Vision-Language Models for Universal Image Restoration
Official PyTorch Implementation of DA-CLIP.
Project Page | Paper | Model Card π€
[2024.01.20] πππ Our DA-CLIP paper was accepted by ICLR 2024 πππ We further provide a more robust model in the model card.
[2023.10.25] Added dataset links for training and testing.
[2023.10.13] Added the Replicate demo and apiπ₯. Thanks to @chenxwh!!! We updated the Hugging Face demoπ₯ and online Colab demoπ₯. Thanks to @fffiloni and @camenduru !!! We also made a Model Card in Hugging Face π€ and provided more examples for testing.
[2023.10.09] The pretrained weights of DA-CLIP and the Universal IR model are released in link1 and link2, respectively. In addition, we also provide a Gradio app file for the case that you want to test your own images.
- OS: Ubuntu 20.04
- nvidia:
- cuda: 11.4
- python 3.8
We advise you first create a virtual environment with:
python3 -m venv .env
source .env/bin/activate
pip install -U pip
pip install -r requirements.txt
Get into the universal-image-restoration
directory and run:
import torch
from PIL import Image
import open_clip
checkpoint = 'pretrained/daclip_ViT-B-32.pt'
model, preprocess = open_clip.create_model_from_pretrained('daclip_ViT-B-32', pretrained=checkpoint)
tokenizer = open_clip.get_tokenizer('ViT-B-32')
image = preprocess(Image.open("haze_01.png")).unsqueeze(0)
degradations = ['motion-blurry','hazy','jpeg-compressed','low-light','noisy','raindrop','rainy','shadowed','snowy','uncompleted']
text = tokenizer(degradations)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = model.encode_text(text)
image_features, degra_features = model.encode_image(image, control=True)
degra_features /= degra_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * degra_features @ text_features.T).softmax(dim=-1)
index = torch.argmax(text_probs[0])
print(f"Task: {task_name}: {degradations[index]} - {text_probs[0][index]}")
Preparing the train and test datasets following our paper Dataset Construction section as:
#### for training dataset ####
#### (uncompleted means inpainting) ####
datasets/universal/train
|--motion-blurry
| |--LQ/*.png
| |--GT/*.png
|--hazy
|--jpeg-compressed
|--low-light
|--noisy
|--raindrop
|--rainy
|--shadowed
|--snowy
|--uncompleted
#### for testing dataset ####
#### (the same structure as train) ####
datasets/universal/val
...
#### for clean captions ####
datasets/universal/daclip_train.csv
datasets/universal/daclip_val.csv
Then get into the universal-image-restoration/config/daclip-sde
directory and modify the dataset paths in option files in options/train.yml
and options/tes.yml
.
You can add more tasks or datasets to both train
and val
directories and add the degradation word to distortion
.
Degradation | motion-blurry | hazy | jpeg-compressed* | low-light | noisy* (same to jpeg) |
---|---|---|---|---|---|
Datasets | Gopro | RESIDE-6k | DIV2K+Flickr2K | LOL | DIV2K+Flickr2K |
Degradation | raindrop | rainy | shadowed | snowy | uncompleted |
---|---|---|---|---|---|
Datasets | RainDrop | Rain100H | SRD | Snow100K | CelebaHQ-256 |
You should only extract the train datasets for training, and all validation datasets can be downloaded in the Google drive. For jpeg and noisy datasets, you can generate LQ images using this script.
See DA-CLIP.md for details.
The main code for training is in universal-image-restoration/config/daclip-sde
and the core network for DA-CLIP is in universal-image-restoration/open_clip/daclip_model.py
.
-
Put the pretrained DA-CLIP weights to
pretrained
directory and check thedaclip
path. -
You can then train the model following below bash scripts:
cd universal-image-restoration/config/daclip-sde
# For single GPU:
python3 train.py -opt=options/train.yml
# For distributed training, need to change the gpu_ids in option file
python3 -m torch.distributed.launch --nproc_per_node=2 --master_poer=4321 train.py -opt=options/train.yml --launcher pytorch
The models and training logs will save in log/universal-ir
.
You can print your log at time by running tail -f log/universal-ir/train_universal-ir_***.log -n 100
.
Model Name | Description | GoogleDrive | HuggingFace |
---|---|---|---|
DA-CLIP | Degradation-aware CLIP model | download | download |
Universal-IR | DA-CLIP based universal image restoration model | download | download |
DA-CLIP-mix | Degradation-aware CLIP model (add Gaussian blur + face inpainting and Gaussian blur + Rainy) | download | download |
Universal-IR-mix | DA-CLIP based universal image restoration model (add robust training and mix-degradations) | download | download |
To evalute our method on image restoration, please modify the benchmark path and model path and run
cd universal-image-restoration/config/universal-ir
python test.py -opt=options/test.yml
Here we provide an app.py file for testing your own images. Before that, you need to download the pretrained weights (DA-CLIP and UIR) and modify the model path in options/test.yml
. Then by simply running python app.py
, you can open http://localhost:7860
to test the model. (We also provide several images with different degradations in the images
dir). We also provide more examples from our test dataset in the google drive.
π In testing we found that the current pretrained model is still difficult to process some real-world images which might have distribution shifts with our training dataset (captured from different devices or with different resolutions or degradations). We regard it as a future work and will try to make our model more practical! We also encourage users who are interested in our work to train their own models with larger dataset and more degradation types.
π BTW, we also found that directly resizing input images will lead a poor performance for most tasks. We could try to add the resize step into the training but it always destroys the image quality due to interpolation.
π For the inpainting task our current model only supports face inpainting due to the dataset limitation. We provide our mask examples and you can use the generate_masked_face script to generate uncompleted faces.
Acknowledgment: Our DA-CLIP is based on IR-SDE and open_clip. Thanks for their code!
If you have any question, please contact: ziwei.luo@it.uu.se
If our code helps your research or work, please consider citing our paper. The following are BibTeX references:
@article{luo2023controlling,
title={Controlling Vision-Language Models for Universal Image Restoration},
author={Luo, Ziwei and Gustafsson, Fredrik K and Zhao, Zheng and Sj{\"o}lund, Jens and Sch{\"o}n, Thomas B},
journal={arXiv preprint arXiv:2310.01018},
year={2023}
}