Skip to content

Commit

Permalink
Update to FeMaSR
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Jul 2, 2022
1 parent 3624dd1 commit 93ef60d
Show file tree
Hide file tree
Showing 14 changed files with 551 additions and 106 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ experiments/*
# results/*
tb_logger*/*
wandb/*
tmp/*
tmp*/*
*.sh
.vscode*
.github
Expand Down
51 changes: 18 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# QuanTexSR
# FeMaSR

This is the official PyTorch codes for the paper
[Blind Image Super Resolution with Semantic-Aware Quantized Texture Prior](https://arxiv.org/abs/2202.13142)
[Real-World Blind Super-Resolution via Feature Matching with Implicit High-Resolution Priors](https://arxiv.org/abs/2202.13142)
[Chaofeng Chen\*](https://chaofengc.github.io), [Xinyu Shi\*](https://github.com/Xinyu-Shi), [Yipeng Qin](http://yipengqin.github.io/), [Xiaoming Li](https://csxmli2016.github.io/), [Xiaoguang Han](https://mypage.cuhk.edu.cn/academics/hanxiaoguang/), [Tao Yang](https://github.com/yangxy), [Shihui Guo](http://guoshihui.net/)
(\* indicates equal contribution)

![framework_img](framework_overview.png)

### Update

- **2022.03.02**: Add onedrive download link for pretrained weights.
- **2022.07.02**
- Update codes of the new version FeMaSR
- Please find the old QuanTexSR in the `quantexsr` branch

Here are some example results on test images from [BSRGAN](https://github.com/cszn/BSRGAN) and [RealESRGAN](https://github.com/xinntao/Real-ESRGAN).

Expand Down Expand Up @@ -37,12 +39,12 @@ Here are some example results on test images from [BSRGAN](https://github.com/cs
- Other required packages in `requirements.txt`
```
# git clone this repository
git clone https://github.com/chaofengc/QuanTexSR.git
cd QuanTexSR
git clone https://github.com/chaofengc/FeMaSR.git
cd FeMaSR
# create new anaconda env
conda create -n quantexsr python=3.8
source activate quantexsr
conda create -n femasr python=3.8
source activate femasr
# install python dependencies
pip3 install -r requirements.txt
Expand All @@ -51,13 +53,9 @@ python setup.py develop

## Quick Inference

Download pretrained model (**only provide x4 model now**) from
- [BaiduNetdisk](https://pan.baidu.com/s/1H_9TIJUHEgAe75VToknbIA ), extract code `qtsr` .
- [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/chaofeng_chen_staff_main_ntu_edu_sg/EuqbHtP9-f9OjzLpyIftKH0Bp8WVlT-8FNX6-boTeqE47w)

Test the model with the following script
```
python inference_quantexsr.py -w ./path/to/model/weight -i ./path/to/test/image[or folder]
python inference_quantexsr.py -s 4 -i ./testset -o results_x4/
python inference_quantexsr.py -s 2 -i ./testset -o results_x2/
```

## Train the model
Expand All @@ -70,36 +68,23 @@ Please prepare the training and testing data follow descriptions in the main pap

#### Model preparation

Before training, you need to put the following pretrained models in `experiments/pretrained_models` and specify their path in the corresponding option file.

- HQ pretrain stage: pretrained semantic cluster codebook
- LQ stage (SR model training): pretrained semantic aware vqgan, pretrained PSNR oriented RRDB model
- lpips weight for validation

The above models can be downloaded from the BaiduNetDisk.
Before training, you need to
- Download the pretrained HRP model [here]()
- Put the pretrained models in `experiments/pretrained_models`
- Specify their path in the corresponding option file.

### Train SR model

```
python basicsr/train.py -opt options/train_QuanTexSR_LQ_stage.yml
python basicsr/train.py -opt options/train_FeMaSR_LQ_stage.yml
```

### Model pretrain

In case you want to pretrain your own VQGAN prior, we also provide the training instructions below.

#### Pretrain semantic codebook

The semantic-aware codebook is obtained with VGG19 features using a mini-batch version of K-means, optimized with Adam. This script will give three levels of codebooks from `relu3_4`, `relu4_4` and `relu5_4` features. We use `relu4_4` for this project.

```
python basicsr/train.py -opt options/train_QuanTexSR_semantic_cluster_stage.yml
```

#### Pretrain of semantic-aware VQGAN
In case you want to pretrain your own HRP model, we also provide the training option file:

```
python basicsr/train.py -opt options/train_QuanTexSR_HQ_pretrain_stage.yml
python basicsr/train.py -opt options/train_FeMaSR_HQ_pretrain_stage.yml
```

## Citation
Expand Down
2 changes: 2 additions & 0 deletions basicsr/archs/femasr_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def forward(self, z, gt_indices=None, current_iter=None):
q_latent_loss = torch.mean((z_q - z.detach())**2)

if self.LQ_stage and gt_indices is not None:
# codebook_loss = self.dist(z_q, z_q_gt.detach()).mean() \
# + self.beta * self.dist(z_q_gt.detach(), z)
codebook_loss = self.beta * self.dist(z_q_gt.detach(), z)
texture_loss = self.gram_loss(z, z_q_gt.detach())
codebook_loss = codebook_loss + texture_loss
Expand Down
25 changes: 2 additions & 23 deletions basicsr/data/bsrgan_train_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,11 @@
from basicsr.utils import FileClient, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY

import os
from .data_util import make_dataset

import cv2
import random

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]


def random_resize(img, scale_factor=1.):
return cv2.resize(img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)
Expand Down
24 changes: 24 additions & 0 deletions basicsr/data/data_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
import cv2
import numpy as np
import torch
import os
from os import path as osp
from torch.nn import functional as F

from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir


IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]


def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
"""Read a sequence of images from a given folder path.
Expand Down
24 changes: 1 addition & 23 deletions basicsr/data/paired_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,12 @@
from basicsr.utils import FileClient, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
from .data_util import make_dataset


def random_resize(img, scale_factor=1.):
return cv2.resize(img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)


@DATASET_REGISTRY.register()
class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
Expand Down
2 changes: 1 addition & 1 deletion basicsr/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def update_learning_rate(self, current_iter, warmup_iter=-1):
self._set_lr(warm_up_lr_l)

def get_current_learning_rate(self):
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
return [optim.param_groups[0]['lr'] for optim in self.optimizers]

@master_only
def save_network(self, net, net_label, current_iter, param_key='params'):
Expand Down
18 changes: 5 additions & 13 deletions basicsr/models/femasr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def init_training_settings(self):
# define network net_d
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
# self.print_network(self.net_d)
# load pretrained d models
load_path = self.opt['path'].get('pretrain_network_d', None)
# print(load_path)
Expand Down Expand Up @@ -118,23 +117,16 @@ def setup_optimizers(self):
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')

# optimizer g
optim_type = train_opt['optim_g'].pop('type')
if optim_type == 'Adam':
self.optimizer_g = torch.optim.Adam(optim_params,
**train_opt['optim_g'])
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
optim_class = getattr(torch.optim, optim_type)
self.optimizer_g = optim_class(optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)

# optimizer d
optim_type = train_opt['optim_d'].pop('type')
if optim_type == 'Adam':
self.optimizer_d = torch.optim.Adam(self.net_d.parameters(),
**train_opt['optim_d'])
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
optim_class = getattr(torch.optim, optim_type)
self.optimizer_d = optim_class(self.net_d.parameters(), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)

def feed_data(self, data):
Expand Down
Binary file modified framework_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
68 changes: 68 additions & 0 deletions generate_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import cv2
import numpy as np
import random
from tqdm import tqdm
from multiprocessing import Pool

from basicsr.data.bsrgan_util import degradation_bsrgan

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]

def degrade_img(hr_path, save_path):
img_gt = cv2.imread(hr_path).astype(np.float32) / 255.
img_gt = img_gt[:, :, [2, 1, 0]] # BGR to RGB
img_lq, img_gt = degradation_bsrgan(img_gt, sf=scale, use_crop=False)
img_lq = (img_lq[:, :, [2, 1, 0]] * 255).astype(np.uint8)
print(f'Save {save_path}')
cv2.imwrite(save_path, img_lq)


seed = 123
random.seed(seed)
np.random.seed(seed)

# scale = 2
scale = 4
hr_img_list = make_dataset('../datasets/HQ_sub')
pool = Pool(processes=40)

# hr_img_list = ['../datasets/HQ_sub_samename/DIV8K_train_HR_sub/div8k_1383_s021.png']

# scale = 2
# hr_img_list = ['../datasets/HQ_sub_samename/DIV8K_train_HR_sub/div8k_0903_s056.png']

# scale = 4
# hr_img_list = make_dataset('../datasets/LQ_sub_samename_X4')

for hr_path in hr_img_list:
save_path = hr_path.replace('HQ_sub', f'LQ_sub_X{scale}')
save_path = save_path.replace('HR', 'LR')
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
pool.apply_async(degrade_img(hr_path, save_path))

pool.close()
pool.join()

15 changes: 8 additions & 7 deletions options/train_FeMaSR_HQ_pretrain_stage.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# general settings
name: 004_FeMaSR_HQ_stage
name: 008_FeMaSR_HQ_stage
# name: debug_FeMaSR
model_type: FeMaSRModel
scale: 4
Expand All @@ -22,11 +22,12 @@ datasets:

# data loader
use_shuffle: true
num_worker_per_gpu: 2
batch_size_per_gpu: 12
batch_size_per_gpu: &bsz 12
num_worker_per_gpu: *bsz
dataset_enlarge_ratio: 1

prefetch_mode: ~
prefetch_mode: cpu
num_prefetch_queue: *bsz

val:
name: General_Image_Valid
Expand Down Expand Up @@ -58,8 +59,8 @@ network_d:
path:
# pretrain_network_g: ./experiments/pretrained_models/QuanTexSR/pretrain_semantic_vqgan_net_g_latest.pth
# pretrain_network_d: ~
pretrain_network_g: ./experiments/003_FeMaSR_HQ_stage/models/net_g_best_.pth
pretrain_network_d: ./experiments/003_FeMaSR_HQ_stage/models/net_d_best_.pth
pretrain_network_g: ./experiments/004_FeMaSR_HQ_stage/models/net_g_best_.pth
# pretrain_network_d: ./experiments/004_FeMaSR_HQ_stage/models/net_d_best_.pth
strict_load: false
# resume_state: ~

Expand Down Expand Up @@ -117,7 +118,7 @@ val:

key_metric: lpips
metrics:
psnr: # metric name, can be arbitrary
psnr: # metric name, not used in this codebase
type: psnr
crop_border: 4
test_y_channel: true
Expand Down
Loading

0 comments on commit 93ef60d

Please sign in to comment.