Skip to content

Commit

Permalink
feat: 🚸 improve dataset api
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Aug 6, 2023
1 parent 0d00073 commit 6fbae36
Show file tree
Hide file tree
Showing 29 changed files with 266 additions and 634 deletions.
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ test:
test_general:
pytest tests/test_metric_general.py::test_cpu_gpu_consistency -v

test_gradient:
pytest tests/test_metric_general.py::test_gradient_backward -v


test_dataset:
pytest tests/test_datasets_general.py -v

test_all:
pytest tests/ -v

clean:
rm -rf __pycache__
rm -rf pyiqa/__pycache__
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Basically, we use the largest existing datasets for training, and cross dataset
| Aesthetic IQA | `nima`, `nima-vgg16-ava` |

Notes:
- **Results of all retrained models are normalized to [0, 1] and change to higher better for convenience.**
- Due to optimized training process, performance of some retrained approaches may be higher than original paper.
- Results of KonIQ-10k, AVA are both tested with official split.
- NIMA is only applicable to AVA dataset now. We use `inception_resnet_v2` for default `nima`.
Expand Down Expand Up @@ -176,7 +177,7 @@ mkdir datasets && cd datasets
ln -sf your/dataset/path datasetname
# download meta info files and train split files
wget https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/data_info_files.tgz
wget https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/meta_info.tgz
tar -xvf data_info_files.tgz
```

Expand Down
2 changes: 1 addition & 1 deletion ResultsCalibra/calibration_summary.csv
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ musiq-ava(ours),3.4084,5.6934,4.6968,5.1963,4.1955
musiq-koniq,12.494,75.332,73.429,75.188,36.938
musiq-koniq(ours),12.4773,75.7764,73.7459,75.4604,38.0248
musiq-paq2piq,46.035,72.66,73.625,74.361,69.006
musiq-paq2piq(ours),46.0187,72.6657,73.7655,74.388,69.7218
musiq-paq2piq(ours),46.0187,72.6657,73.7656,74.388,69.7218
musiq-spaq,17.685,70.492,78.74,79.015,49.105
musiq-spaq(ours),17.6804,70.6531,79.0364,79.3189,50.4526
niqe,15.7536,3.6549,3.2355,3.184,8.6352
Expand Down
35 changes: 31 additions & 4 deletions options/default_dataset_opt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,50 @@ csiq:
dataroot_target: ./datasets/CSIQ/dst_imgs
dataroot_ref: ./datasets/CSIQ/src_imgs
meta_info_file: ./datasets/meta_info/meta_info_CSIQDataset.csv
dmos_max: 1
mos_range: [0, 1]
lower_better: true

tid2008:
name: TID2008
type: GeneralFRDataset
dataroot_target: ./datasets/tid2008/distorted_images
dataroot_ref: ./datasets/tid2008/reference_images
meta_info_file: ./datasets/meta_info/meta_info_TID2008Dataset.csv
mos_range: [0, 9]
lower_better: false

tid2013:
name: TID2013
type: GeneralFRDataset
dataroot_target: ./datasets/tid2013/distorted_images
dataroot_ref: ./datasets/tid2013/reference_images
meta_info_file: ./datasets/meta_info/meta_info_TID2013Dataset.csv
mos_range: [0, 9]
lower_better: false

live:
name: LIVE
type: GeneralFRDataset
dataroot_target: './datasets/LIVEIQA_release2'
meta_info_file: './datasets/meta_info/meta_info_LIVEIQADataset.csv'
dmos_max: 100
mos_range: [1, 100]
lower_better: true

livem:
name: LIVEM
type: GeneralFRDataset
dataroot_target: './datasets/LIVEmultidistortiondatabase'
meta_info_file: './datasets/meta_info/meta_info_LIVEMDDataset.csv'
mos_range: [1, 100]
lower_better: true

livec:
name: LIVEC
type: LIVEChallengeDataset
dataroot_target: ./datasets/LIVEC/
meta_info_file: ./datasets/meta_info/meta_info_LIVEChallengeDataset.csv
mos_range: [1, 100]
lower_better: false

koniq10k:
name: KonIQ10k
Expand All @@ -46,6 +56,8 @@ koniq10k:
meta_info_file: './datasets/meta_info/meta_info_KonIQ10kDataset.csv'
split_file: './datasets/meta_info/koniq10k_official.pkl'
phase: 'test'
mos_range: [0, 100]
lower_better: false

koniq10k-1024:
name: KonIQ10k
Expand All @@ -54,6 +66,8 @@ koniq10k-1024:
meta_info_file: './datasets/meta_info/meta_info_KonIQ10kDataset.csv'
split_file: './datasets/meta_info/koniq10k_official.pkl'
phase: 'test'
mos_range: [0, 100]
lower_better: false

koniq10k++:
name: KonIQ10k++
Expand All @@ -62,12 +76,16 @@ koniq10k++:
meta_info_file: './datasets/meta_info/meta_info_KonIQ10k++Dataset.csv'
split_file: './datasets/meta_info/koniq10k_official.pkl'
phase: 'test'
mos_range: [1, 5]
lower_better: false

kadid10k:
name: KADID10k
type: GeneralFRDataset
dataroot_target: './datasets/kadid10k/images'
meta_info_file: './datasets/meta_info/meta_info_KADID10kDataset.csv'
mos_range: [1, 5]
lower_better: false

spaq:
name: SPAQ
Expand All @@ -76,6 +94,8 @@ spaq:
meta_info_file: './datasets/meta_info/meta_info_SPAQDataset.csv'
augment:
resize: 448
mos_range: [0, 100]
lower_better: false

ava:
name: AVA
Expand All @@ -84,21 +104,28 @@ ava:
meta_info_file: './datasets/meta_info/meta_info_AVADataset.csv'
split_file: './datasets/meta_info/ava_official_ilgnet.pkl'
split_index: 1 # use official split
mos_range: [1, 10]
lower_better: false

pipal:
name: PIPAL
type: PIPALDataset
type: GeneralFRDataset
dataroot_target: './datasets/PIPAL/Dist_Imgs'
dataroot_ref: './datasets/PIPAL/Train_Ref'
meta_info_file: './datasets/meta_info/meta_info_PIPALDataset.csv'
split_file: './datasets/meta_info/pipal_official.pkl'
mos_range: [0, 1]
lower_better: false

flive:
name: FLIVE
type: FLIVEDataset
type: GeneralNRDataset
dataroot_target: './datasets/FLIVE_Database/database'
meta_info_file: './datasets/meta_info/meta_info_FLIVEDataset.csv'
split_file: './datasets/meta_info/flive_official.pkl'
phase: test
mos_range: [0, 100]
lower_better: false

pieapp:
name: PieAPPDataset
Expand Down
5 changes: 4 additions & 1 deletion options/train/CNNIQA/train_CNNIQA.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ datasets:
dataroot_target: ./datasets/koniq10k/512x384
meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv
split_file: ./datasets/meta_info/koniq10k_official.pkl
mos_range: [0, 100]
lower_better: false
mos_normalize: true

augment:
hflip: true
Expand Down Expand Up @@ -62,7 +65,7 @@ train:

# losses
mos_loss_opt:
type: PLCCLoss
type: MSELoss
loss_weight: !!float 1.0

# validation settings
Expand Down
9 changes: 4 additions & 5 deletions options/train/DBCNN/train_DBCNN_koniq10k.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ datasets:
dataroot_target: ./datasets/koniq10k/512x384
meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv
split_file: ./datasets/meta_info/koniq10k_official.pkl
mos_range: [0, 100]
lower_better: false
mos_normalize: true

augment:
hflip: true
Expand Down Expand Up @@ -75,11 +78,7 @@ train:
mos_loss_opt:
type: MSELoss
loss_weight: !!float 1.0

metric_loss_opt:
type: PLCCLoss
loss_weight: !!float 1.0


# validation settings
val:
val_freq: !!float 800
Expand Down
4 changes: 3 additions & 1 deletion options/train/HyperNet/train_HyperNet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ datasets:
dataroot_target: ./datasets/koniq10k/512x384
meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv
split_file: ./datasets/meta_info/koniq10k_official.pkl
mos_max: 100
mos_range: [0, 100]
lower_better: false
mos_normalize: true

augment:
hflip: true
Expand Down
25 changes: 4 additions & 21 deletions pyiqa/archs/ahiq_arch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pyexpat import model
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -7,33 +6,17 @@

import timm
from timm.models.vision_transformer import Block
from timm.models.resnet import BasicBlock, Bottleneck
from timm.models.resnet import Bottleneck

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network, default_init_weights, to_2tuple, ExactPadding2d, load_file_from_url
from pyiqa.archs.arch_util import load_pretrained_network, to_2tuple, load_file_from_url, random_crop


default_model_urls = {
'pipal': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/AHIQ_vit_p8_epoch33-da3ea303.pth'
}


def random_crop(x, y, crop_size, crop_num):
b, c, h, w = x.shape
ch, cw = to_2tuple(crop_size)

crops_x = []
crops_y = []
for i in range(crop_num):
sh = np.random.randint(0, h - ch)
sw = np.random.randint(0, w - cw)
crops_x.append(x[..., sh: sh + ch, sw: sw + cw])
crops_y.append(y[..., sh: sh + ch, sw: sw + cw])
crops_x = torch.stack(crops_x, dim=1)
crops_y = torch.stack(crops_y, dim=1)
return crops_x.reshape(b * crop_num, c, ch, cw), crops_y.reshape(b * crop_num, c, ch, cw)


class SaveOutput:
def __init__(self):
self.outputs = {}
Expand All @@ -51,7 +34,7 @@ def clear(self, device):
class DeformFusion(nn.Module):
def __init__(self, patch_size=8, in_channels=768 * 5, cnn_channels=256 * 3, out_channels=256 * 3):
super().__init__()
#in_channels, out_channels, kernel_size, stride, padding
# in_channels, out_channels, kernel_size, stride, padding
self.d_hidn = 512
if patch_size == 8:
stride = 1
Expand Down Expand Up @@ -227,7 +210,7 @@ def forward(self, x, y):
bsz = x.shape[0]

if self.crops > 1 and not self.training:
x, y = random_crop(x, y, self.crop_size, self.crops)
x, y = random_crop([x, y], self.crop_size, self.crops)
score = self.regress_score(x, y)
score = score.reshape(bsz, self.crops, 1)
score = score.mean(dim=1)
Expand Down
27 changes: 27 additions & 0 deletions pyiqa/archs/arch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,33 @@ def dist_to_mos(dist_score: torch.Tensor) -> torch.Tensor:
return mos_score


def random_crop(input_list, crop_size, crop_num):
if not isinstance(input_list, collections.abc.Sequence):
input_list = [input_list]

b, c, h, w = input_list[0].shape
ch, cw = to_2tuple(crop_size)

if min(h, w) <= crop_size:
scale_factor = (crop_size + 1) / min(h, w)
input_list = [F.interpolate(x, scale_factor=scale_factor, mode='bilinear') for x in input_list]
b, c, h, w = input_list[0].shape

crops_list = [[] for i in range(len(input_list))]
for i in range(crop_num):
sh = np.random.randint(0, h - ch + 1)
sw = np.random.randint(0, w - cw + 1)
for j in range(len(input_list)):
crops_list[j].append(input_list[j][..., sh: sh + ch, sw: sw + cw])

for i in range(len(crops_list)):
crops_list[i] = torch.stack(crops_list[i], dim=1).reshape(b * crop_num, c, ch, cw)

if len(crops_list) == 1:
crops_list = crops_list[0]
return crops_list


# --------------------------------------------
# Common utils
# --------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pyiqa/archs/cnniqa_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


default_model_urls = {
'koniq10k': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/CNNIQA_koniq10k-fd89516f.pth'
'koniq10k': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/CNNIQA_koniq10k-e6f14c91.pth'
}


Expand Down
4 changes: 2 additions & 2 deletions pyiqa/archs/dbcnn_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'livec': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_LIVEC-83f6dad3.pth',
'livem': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_LIVEM-698474e3.pth',
'koniq': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_KonIQ10k-254e8241.pth',
'scnn': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_scnn-7ea73d75.pth',
}


Expand Down Expand Up @@ -117,8 +118,7 @@ def __init__(
self.features1 = torchvision.models.vgg16(weights='IMAGENET1K_V1').features
self.features1 = nn.Sequential(*list(self.features1.children())[:-1])
scnn = SCNN(use_bn=use_bn)
if pretrained_scnn_path is not None:
load_pretrained_network(scnn, pretrained_scnn_path)
load_pretrained_network(scnn, default_model_urls['scnn'])

self.features2 = scnn.features

Expand Down
2 changes: 1 addition & 1 deletion pyiqa/archs/hypernet_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


default_model_urls = {
'resnet50-koniq': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/HyperIQA-resnet50-koniq10k-48579ec9.pth',
'resnet50-koniq': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/HyperIQA-resnet50-koniq10k-c96c41b1.pth',
}


Expand Down
21 changes: 4 additions & 17 deletions pyiqa/archs/maniqa_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from einops import rearrange

from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network
from pyiqa.archs.arch_util import load_pretrained_network, random_crop

default_model_urls = {
'pipal': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/MANIQA_PIPAL-ae6d356b.pth',
Expand All @@ -28,17 +28,6 @@
}


def random_crop(x, sample_size=224, sample_num=8):
b, c, h, w = x.shape
th = tw = sample_size
cropped_x = []
for s in range(sample_num):
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
cropped_x.append(x[:, :, i:i + th, j:j + tw])
cropped_x = torch.stack(cropped_x, dim=1)
return cropped_x


class TABlock(nn.Module):
def __init__(self, dim, drop=0.1):
Expand Down Expand Up @@ -169,14 +158,12 @@ def extract_feature(self, save_output):
def forward(self, x):

x = (x - self.default_mean.to(x)) / self.default_std.to(x)
bsz = x.shape[0]

if self.training:
x_patches = random_crop(x, sample_size=224, sample_num=1)
x = random_crop(x, crop_size=224, crop_num=1)
else:
x_patches = random_crop(x, sample_size=224, sample_num=self.test_sample)

bsz, num_patches, c, psz, psz = x_patches.shape
x = x_patches.reshape(bsz * num_patches, c, psz, psz)
x = random_crop(x, crop_size=224, crop_num=self.test_sample)

_x = self.vit(x)
x = self.extract_feature(self.save_output)
Expand Down

0 comments on commit 6fbae36

Please sign in to comment.