From 6fbae36668f2a0826b70ce808f8bae8b0c0e717a Mon Sep 17 00:00:00 2001 From: chaofengc Date: Sun, 6 Aug 2023 23:59:56 +0800 Subject: [PATCH] feat: :children_crossing: improve dataset api --- Makefile | 10 +++ README.md | 3 +- ResultsCalibra/calibration_summary.csv | 2 +- options/default_dataset_opt.yml | 35 +++++++- options/train/CNNIQA/train_CNNIQA.yml | 5 +- options/train/DBCNN/train_DBCNN_koniq10k.yml | 9 +- options/train/HyperNet/train_HyperNet.yml | 4 +- pyiqa/archs/ahiq_arch.py | 25 +----- pyiqa/archs/arch_util.py | 27 ++++++ pyiqa/archs/cnniqa_arch.py | 2 +- pyiqa/archs/dbcnn_arch.py | 4 +- pyiqa/archs/hypernet_arch.py | 2 +- pyiqa/archs/maniqa_arch.py | 21 +---- pyiqa/archs/nima_arch.py | 2 +- pyiqa/archs/topiq_arch.py | 23 +---- pyiqa/archs/tres_arch.py | 28 ++---- pyiqa/data/ava_dataset.py | 50 ++--------- pyiqa/data/bapps_dataset.py | 38 ++------ pyiqa/data/base_iqa_dataset.py | 88 +++++++++++++++++++ pyiqa/data/flive_dataset.py | 82 ------------------ pyiqa/data/general_fr_dataset.py | 67 +++++--------- pyiqa/data/general_nr_dataset.py | 65 +------------- pyiqa/data/livechallenge_dataset.py | 56 +----------- pyiqa/data/pieapp_dataset.py | 45 ++-------- pyiqa/data/pipal_dataset.py | 91 -------------------- pyiqa/default_model_configs.py | 2 +- tests/NR_benchmark_results.csv | 4 +- tests/test_datasets_general.py | 85 +----------------- tests/test_metric_general.py | 25 ++++-- 29 files changed, 266 insertions(+), 634 deletions(-) create mode 100644 pyiqa/data/base_iqa_dataset.py delete mode 100644 pyiqa/data/flive_dataset.py delete mode 100644 pyiqa/data/pipal_dataset.py diff --git a/Makefile b/Makefile index aa4c00f..f8e930f 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,16 @@ test: test_general: pytest tests/test_metric_general.py::test_cpu_gpu_consistency -v +test_gradient: + pytest tests/test_metric_general.py::test_gradient_backward -v + + +test_dataset: + pytest tests/test_datasets_general.py -v + +test_all: + pytest tests/ -v + clean: rm -rf __pycache__ rm -rf pyiqa/__pycache__ diff --git a/README.md b/README.md index a235863..f023c01 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ Basically, we use the largest existing datasets for training, and cross dataset | Aesthetic IQA | `nima`, `nima-vgg16-ava` | Notes: +- **Results of all retrained models are normalized to [0, 1] and change to higher better for convenience.** - Due to optimized training process, performance of some retrained approaches may be higher than original paper. - Results of KonIQ-10k, AVA are both tested with official split. - NIMA is only applicable to AVA dataset now. We use `inception_resnet_v2` for default `nima`. @@ -176,7 +177,7 @@ mkdir datasets && cd datasets ln -sf your/dataset/path datasetname # download meta info files and train split files -wget https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/data_info_files.tgz +wget https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/meta_info.tgz tar -xvf data_info_files.tgz ``` diff --git a/ResultsCalibra/calibration_summary.csv b/ResultsCalibra/calibration_summary.csv index 72abd18..b0e96ca 100644 --- a/ResultsCalibra/calibration_summary.csv +++ b/ResultsCalibra/calibration_summary.csv @@ -27,7 +27,7 @@ musiq-ava(ours),3.4084,5.6934,4.6968,5.1963,4.1955 musiq-koniq,12.494,75.332,73.429,75.188,36.938 musiq-koniq(ours),12.4773,75.7764,73.7459,75.4604,38.0248 musiq-paq2piq,46.035,72.66,73.625,74.361,69.006 -musiq-paq2piq(ours),46.0187,72.6657,73.7655,74.388,69.7218 +musiq-paq2piq(ours),46.0187,72.6657,73.7656,74.388,69.7218 musiq-spaq,17.685,70.492,78.74,79.015,49.105 musiq-spaq(ours),17.6804,70.6531,79.0364,79.3189,50.4526 niqe,15.7536,3.6549,3.2355,3.184,8.6352 diff --git a/options/default_dataset_opt.yml b/options/default_dataset_opt.yml index 48b5adf..e1ef2f4 100644 --- a/options/default_dataset_opt.yml +++ b/options/default_dataset_opt.yml @@ -4,7 +4,8 @@ csiq: dataroot_target: ./datasets/CSIQ/dst_imgs dataroot_ref: ./datasets/CSIQ/src_imgs meta_info_file: ./datasets/meta_info/meta_info_CSIQDataset.csv - dmos_max: 1 + mos_range: [0, 1] + lower_better: true tid2008: name: TID2008 @@ -12,6 +13,8 @@ tid2008: dataroot_target: ./datasets/tid2008/distorted_images dataroot_ref: ./datasets/tid2008/reference_images meta_info_file: ./datasets/meta_info/meta_info_TID2008Dataset.csv + mos_range: [0, 9] + lower_better: false tid2013: name: TID2013 @@ -19,25 +22,32 @@ tid2013: dataroot_target: ./datasets/tid2013/distorted_images dataroot_ref: ./datasets/tid2013/reference_images meta_info_file: ./datasets/meta_info/meta_info_TID2013Dataset.csv + mos_range: [0, 9] + lower_better: false live: name: LIVE type: GeneralFRDataset dataroot_target: './datasets/LIVEIQA_release2' meta_info_file: './datasets/meta_info/meta_info_LIVEIQADataset.csv' - dmos_max: 100 + mos_range: [1, 100] + lower_better: true livem: name: LIVEM type: GeneralFRDataset dataroot_target: './datasets/LIVEmultidistortiondatabase' meta_info_file: './datasets/meta_info/meta_info_LIVEMDDataset.csv' + mos_range: [1, 100] + lower_better: true livec: name: LIVEC type: LIVEChallengeDataset dataroot_target: ./datasets/LIVEC/ meta_info_file: ./datasets/meta_info/meta_info_LIVEChallengeDataset.csv + mos_range: [1, 100] + lower_better: false koniq10k: name: KonIQ10k @@ -46,6 +56,8 @@ koniq10k: meta_info_file: './datasets/meta_info/meta_info_KonIQ10kDataset.csv' split_file: './datasets/meta_info/koniq10k_official.pkl' phase: 'test' + mos_range: [0, 100] + lower_better: false koniq10k-1024: name: KonIQ10k @@ -54,6 +66,8 @@ koniq10k-1024: meta_info_file: './datasets/meta_info/meta_info_KonIQ10kDataset.csv' split_file: './datasets/meta_info/koniq10k_official.pkl' phase: 'test' + mos_range: [0, 100] + lower_better: false koniq10k++: name: KonIQ10k++ @@ -62,12 +76,16 @@ koniq10k++: meta_info_file: './datasets/meta_info/meta_info_KonIQ10k++Dataset.csv' split_file: './datasets/meta_info/koniq10k_official.pkl' phase: 'test' + mos_range: [1, 5] + lower_better: false kadid10k: name: KADID10k type: GeneralFRDataset dataroot_target: './datasets/kadid10k/images' meta_info_file: './datasets/meta_info/meta_info_KADID10kDataset.csv' + mos_range: [1, 5] + lower_better: false spaq: name: SPAQ @@ -76,6 +94,8 @@ spaq: meta_info_file: './datasets/meta_info/meta_info_SPAQDataset.csv' augment: resize: 448 + mos_range: [0, 100] + lower_better: false ava: name: AVA @@ -84,21 +104,28 @@ ava: meta_info_file: './datasets/meta_info/meta_info_AVADataset.csv' split_file: './datasets/meta_info/ava_official_ilgnet.pkl' split_index: 1 # use official split + mos_range: [1, 10] + lower_better: false pipal: name: PIPAL - type: PIPALDataset + type: GeneralFRDataset dataroot_target: './datasets/PIPAL/Dist_Imgs' dataroot_ref: './datasets/PIPAL/Train_Ref' meta_info_file: './datasets/meta_info/meta_info_PIPALDataset.csv' + split_file: './datasets/meta_info/pipal_official.pkl' + mos_range: [0, 1] + lower_better: false flive: name: FLIVE - type: FLIVEDataset + type: GeneralNRDataset dataroot_target: './datasets/FLIVE_Database/database' meta_info_file: './datasets/meta_info/meta_info_FLIVEDataset.csv' split_file: './datasets/meta_info/flive_official.pkl' phase: test + mos_range: [0, 100] + lower_better: false pieapp: name: PieAPPDataset diff --git a/options/train/CNNIQA/train_CNNIQA.yml b/options/train/CNNIQA/train_CNNIQA.yml index f1ec3ac..23d6fa2 100644 --- a/options/train/CNNIQA/train_CNNIQA.yml +++ b/options/train/CNNIQA/train_CNNIQA.yml @@ -13,6 +13,9 @@ datasets: dataroot_target: ./datasets/koniq10k/512x384 meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv split_file: ./datasets/meta_info/koniq10k_official.pkl + mos_range: [0, 100] + lower_better: false + mos_normalize: true augment: hflip: true @@ -62,7 +65,7 @@ train: # losses mos_loss_opt: - type: PLCCLoss + type: MSELoss loss_weight: !!float 1.0 # validation settings diff --git a/options/train/DBCNN/train_DBCNN_koniq10k.yml b/options/train/DBCNN/train_DBCNN_koniq10k.yml index 85516ec..a887bf3 100644 --- a/options/train/DBCNN/train_DBCNN_koniq10k.yml +++ b/options/train/DBCNN/train_DBCNN_koniq10k.yml @@ -13,6 +13,9 @@ datasets: dataroot_target: ./datasets/koniq10k/512x384 meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv split_file: ./datasets/meta_info/koniq10k_official.pkl + mos_range: [0, 100] + lower_better: false + mos_normalize: true augment: hflip: true @@ -75,11 +78,7 @@ train: mos_loss_opt: type: MSELoss loss_weight: !!float 1.0 - - metric_loss_opt: - type: PLCCLoss - loss_weight: !!float 1.0 - + # validation settings val: val_freq: !!float 800 diff --git a/options/train/HyperNet/train_HyperNet.yml b/options/train/HyperNet/train_HyperNet.yml index 6c48617..ffd6db4 100644 --- a/options/train/HyperNet/train_HyperNet.yml +++ b/options/train/HyperNet/train_HyperNet.yml @@ -14,7 +14,9 @@ datasets: dataroot_target: ./datasets/koniq10k/512x384 meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv split_file: ./datasets/meta_info/koniq10k_official.pkl - mos_max: 100 + mos_range: [0, 100] + lower_better: false + mos_normalize: true augment: hflip: true diff --git a/pyiqa/archs/ahiq_arch.py b/pyiqa/archs/ahiq_arch.py index 4b4bba1..44129fc 100644 --- a/pyiqa/archs/ahiq_arch.py +++ b/pyiqa/archs/ahiq_arch.py @@ -1,4 +1,3 @@ -from pyexpat import model import torch import torch.nn as nn import torch.nn.functional as F @@ -7,10 +6,10 @@ import timm from timm.models.vision_transformer import Block -from timm.models.resnet import BasicBlock, Bottleneck +from timm.models.resnet import Bottleneck from pyiqa.utils.registry import ARCH_REGISTRY -from pyiqa.archs.arch_util import load_pretrained_network, default_init_weights, to_2tuple, ExactPadding2d, load_file_from_url +from pyiqa.archs.arch_util import load_pretrained_network, to_2tuple, load_file_from_url, random_crop default_model_urls = { @@ -18,22 +17,6 @@ } -def random_crop(x, y, crop_size, crop_num): - b, c, h, w = x.shape - ch, cw = to_2tuple(crop_size) - - crops_x = [] - crops_y = [] - for i in range(crop_num): - sh = np.random.randint(0, h - ch) - sw = np.random.randint(0, w - cw) - crops_x.append(x[..., sh: sh + ch, sw: sw + cw]) - crops_y.append(y[..., sh: sh + ch, sw: sw + cw]) - crops_x = torch.stack(crops_x, dim=1) - crops_y = torch.stack(crops_y, dim=1) - return crops_x.reshape(b * crop_num, c, ch, cw), crops_y.reshape(b * crop_num, c, ch, cw) - - class SaveOutput: def __init__(self): self.outputs = {} @@ -51,7 +34,7 @@ def clear(self, device): class DeformFusion(nn.Module): def __init__(self, patch_size=8, in_channels=768 * 5, cnn_channels=256 * 3, out_channels=256 * 3): super().__init__() - #in_channels, out_channels, kernel_size, stride, padding + # in_channels, out_channels, kernel_size, stride, padding self.d_hidn = 512 if patch_size == 8: stride = 1 @@ -227,7 +210,7 @@ def forward(self, x, y): bsz = x.shape[0] if self.crops > 1 and not self.training: - x, y = random_crop(x, y, self.crop_size, self.crops) + x, y = random_crop([x, y], self.crop_size, self.crops) score = self.regress_score(x, y) score = score.reshape(bsz, self.crops, 1) score = score.mean(dim=1) diff --git a/pyiqa/archs/arch_util.py b/pyiqa/archs/arch_util.py index b975eb6..a877f78 100644 --- a/pyiqa/archs/arch_util.py +++ b/pyiqa/archs/arch_util.py @@ -34,6 +34,33 @@ def dist_to_mos(dist_score: torch.Tensor) -> torch.Tensor: return mos_score +def random_crop(input_list, crop_size, crop_num): + if not isinstance(input_list, collections.abc.Sequence): + input_list = [input_list] + + b, c, h, w = input_list[0].shape + ch, cw = to_2tuple(crop_size) + + if min(h, w) <= crop_size: + scale_factor = (crop_size + 1) / min(h, w) + input_list = [F.interpolate(x, scale_factor=scale_factor, mode='bilinear') for x in input_list] + b, c, h, w = input_list[0].shape + + crops_list = [[] for i in range(len(input_list))] + for i in range(crop_num): + sh = np.random.randint(0, h - ch + 1) + sw = np.random.randint(0, w - cw + 1) + for j in range(len(input_list)): + crops_list[j].append(input_list[j][..., sh: sh + ch, sw: sw + cw]) + + for i in range(len(crops_list)): + crops_list[i] = torch.stack(crops_list[i], dim=1).reshape(b * crop_num, c, ch, cw) + + if len(crops_list) == 1: + crops_list = crops_list[0] + return crops_list + + # -------------------------------------------- # Common utils # -------------------------------------------- diff --git a/pyiqa/archs/cnniqa_arch.py b/pyiqa/archs/cnniqa_arch.py index 4307ecd..a4a2635 100644 --- a/pyiqa/archs/cnniqa_arch.py +++ b/pyiqa/archs/cnniqa_arch.py @@ -17,7 +17,7 @@ default_model_urls = { - 'koniq10k': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/CNNIQA_koniq10k-fd89516f.pth' + 'koniq10k': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/CNNIQA_koniq10k-e6f14c91.pth' } diff --git a/pyiqa/archs/dbcnn_arch.py b/pyiqa/archs/dbcnn_arch.py index 9def694..4e68190 100644 --- a/pyiqa/archs/dbcnn_arch.py +++ b/pyiqa/archs/dbcnn_arch.py @@ -21,6 +21,7 @@ 'livec': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_LIVEC-83f6dad3.pth', 'livem': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_LIVEM-698474e3.pth', 'koniq': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_KonIQ10k-254e8241.pth', + 'scnn': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/DBCNN_scnn-7ea73d75.pth', } @@ -117,8 +118,7 @@ def __init__( self.features1 = torchvision.models.vgg16(weights='IMAGENET1K_V1').features self.features1 = nn.Sequential(*list(self.features1.children())[:-1]) scnn = SCNN(use_bn=use_bn) - if pretrained_scnn_path is not None: - load_pretrained_network(scnn, pretrained_scnn_path) + load_pretrained_network(scnn, default_model_urls['scnn']) self.features2 = scnn.features diff --git a/pyiqa/archs/hypernet_arch.py b/pyiqa/archs/hypernet_arch.py index 754da15..d9d9835 100644 --- a/pyiqa/archs/hypernet_arch.py +++ b/pyiqa/archs/hypernet_arch.py @@ -14,7 +14,7 @@ default_model_urls = { - 'resnet50-koniq': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/HyperIQA-resnet50-koniq10k-48579ec9.pth', + 'resnet50-koniq': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/HyperIQA-resnet50-koniq10k-c96c41b1.pth', } diff --git a/pyiqa/archs/maniqa_arch.py b/pyiqa/archs/maniqa_arch.py index 2c0970f..3b12558 100644 --- a/pyiqa/archs/maniqa_arch.py +++ b/pyiqa/archs/maniqa_arch.py @@ -19,7 +19,7 @@ from einops import rearrange from pyiqa.utils.registry import ARCH_REGISTRY -from pyiqa.archs.arch_util import load_pretrained_network +from pyiqa.archs.arch_util import load_pretrained_network, random_crop default_model_urls = { 'pipal': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/MANIQA_PIPAL-ae6d356b.pth', @@ -28,17 +28,6 @@ } -def random_crop(x, sample_size=224, sample_num=8): - b, c, h, w = x.shape - th = tw = sample_size - cropped_x = [] - for s in range(sample_num): - i = torch.randint(0, h - th + 1, size=(1, )).item() - j = torch.randint(0, w - tw + 1, size=(1, )).item() - cropped_x.append(x[:, :, i:i + th, j:j + tw]) - cropped_x = torch.stack(cropped_x, dim=1) - return cropped_x - class TABlock(nn.Module): def __init__(self, dim, drop=0.1): @@ -169,14 +158,12 @@ def extract_feature(self, save_output): def forward(self, x): x = (x - self.default_mean.to(x)) / self.default_std.to(x) + bsz = x.shape[0] if self.training: - x_patches = random_crop(x, sample_size=224, sample_num=1) + x = random_crop(x, crop_size=224, crop_num=1) else: - x_patches = random_crop(x, sample_size=224, sample_num=self.test_sample) - - bsz, num_patches, c, psz, psz = x_patches.shape - x = x_patches.reshape(bsz * num_patches, c, psz, psz) + x = random_crop(x, crop_size=224, crop_num=self.test_sample) _x = self.vit(x) x = self.extract_feature(self.save_output) diff --git a/pyiqa/archs/nima_arch.py b/pyiqa/archs/nima_arch.py index 04a7bd1..b51e4e8 100644 --- a/pyiqa/archs/nima_arch.py +++ b/pyiqa/archs/nima_arch.py @@ -68,7 +68,7 @@ def __init__( self.default_std = torch.Tensor(default_std).view(1, 3, 1, 1) if pretrained and pretrained_model_path is None: - url_key = f'{base_model_name}-{pretrained}' + url_key = f'{base_model_name}-ava' load_pretrained_network(self, default_model_urls[url_key], True, weight_keys='params') elif pretrained_model_path is not None: load_pretrained_network(self, pretrained_model_path, True, weight_keys='params') diff --git a/pyiqa/archs/topiq_arch.py b/pyiqa/archs/topiq_arch.py index ea06956..a956e80 100644 --- a/pyiqa/archs/topiq_arch.py +++ b/pyiqa/archs/topiq_arch.py @@ -17,7 +17,7 @@ import timm from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from pyiqa.utils.registry import ARCH_REGISTRY -from pyiqa.archs.arch_util import dist_to_mos, load_pretrained_network, to_2tuple +from pyiqa.archs.arch_util import dist_to_mos, load_pretrained_network, random_crop import copy from .clip_model import load @@ -141,27 +141,6 @@ def forward(self, tgt, memory): return output -def random_crop(input_list, crop_size, crop_num): - b, c, h, w = input_list[0].shape - ch, cw = to_2tuple(crop_size) - - if min(h, w) <= crop_size: - scale_factor = (crop_size + 1) / min(h, w) - input_list = [F.interpolate(x, scale_factor=scale_factor, mode='bilinear') for x in input_list] - b, c, h, w = input_list[0].shape - - crops_list = [[] for i in range(len(input_list))] - for i in range(crop_num): - sh = np.random.randint(0, h - ch + 1) - sw = np.random.randint(0, w - cw + 1) - for j in range(len(input_list)): - crops_list[j].append(input_list[j][..., sh: sh + ch, sw: sw + cw]) - - for i in range(len(crops_list)): - crops_list[i] = torch.stack(crops_list[i], dim=1).reshape(b * crop_num, c, ch, cw) - - return crops_list - class GatedConv(nn.Module): def __init__(self, weightdim, ksz=3): diff --git a/pyiqa/archs/tres_arch.py b/pyiqa/archs/tres_arch.py index 94049d5..7a05891 100644 --- a/pyiqa/archs/tres_arch.py +++ b/pyiqa/archs/tres_arch.py @@ -20,6 +20,7 @@ from .arch_util import load_pretrained_network from pyiqa.utils.registry import ARCH_REGISTRY +from .arch_util import random_crop default_model_urls = { @@ -28,23 +29,6 @@ } -def random_crop(x, sample_size=224, sample_num=8): - b, c, h, w = x.shape - if min(h, w) <= sample_size: - scale_factor = (sample_size + 1) / min(h, w) - x = F.interpolate(x, scale_factor=scale_factor, mode='bicubic') - b, c, h, w = x.shape - - th = tw = sample_size - cropped_x = [] - for s in range(sample_num): - i = torch.randint(0, h - th + 1, size=(1, )).item() - j = torch.randint(0, w - tw + 1, size=(1, )).item() - cropped_x.append(x[:, :, i:i + th, j:j + tw]) - cropped_x = torch.stack(cropped_x, dim=1) - return cropped_x - - def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": @@ -337,14 +321,14 @@ def forward_backbone(self, model, x): def forward(self, x): x = (x - self.default_mean.to(x)) / self.default_std.to(x) + bsz = x.shape[0] if self.training: - x = random_crop(x, sample_size=224, sample_num=1) + x = random_crop(x, 224, 1) + num_patches = 1 else: - x = random_crop(x, sample_size=224, sample_num=self.test_sample) - - bsz, num_patches, c, psz, psz = x.shape - x = x.reshape(bsz * num_patches, c, psz, psz) + x = random_crop(x, 224, self.test_sample) + num_patches = self.test_sample self.pos_enc_1 = self.position_embedding(torch.ones(1, self.dim_modelt, 7, 7).to(x)) self.pos_enc = self.pos_enc_1.repeat(x.shape[0], 1, 1, 1).contiguous() diff --git a/pyiqa/data/ava_dataset.py b/pyiqa/data/ava_dataset.py index c3429ea..c572c3f 100644 --- a/pyiqa/data/ava_dataset.py +++ b/pyiqa/data/ava_dataset.py @@ -1,25 +1,22 @@ import numpy as np import pickle from PIL import Image -import cv2 import os -import random -import itertools import torch from torch.utils import data as data -import torchvision.transforms as tf -from pyiqa.data.transforms import transform_mapping from pyiqa.utils.registry import DATASET_REGISTRY import pandas as pd +from .base_iqa_dataset import BaseIQADataset + # avoid possible image read error in AVA dataset from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True @DATASET_REGISTRY.register() -class AVADataset(data.Dataset): +class AVADataset(BaseIQADataset): """AVA dataset, proposed by Murray, Naila, Luca Marchesotti, and Florent Perronnin. @@ -31,14 +28,12 @@ class AVADataset(data.Dataset): phase (str): 'train' or 'val'. """ - def __init__(self, opt): - super(AVADataset, self).__init__() - self.opt = opt - + def init_path_mos(self, opt): target_img_folder = opt['dataroot_target'] self.dataroot = target_img_folder self.paths_mos = pd.read_csv(opt['meta_info_file']).values.tolist() - + + def get_split(self, opt): # read train/val/test splits split_file_path = opt.get('split_file', None) if split_file_path: @@ -47,7 +42,7 @@ def __init__(self, opt): split_dict = pickle.load(f) # use val_num for validation - val_num = 2000 + val_num = opt.get('val_num', 2000) train_split = split_dict[split_index]['train'] val_split = split_dict[split_index]['val'] train_split = train_split + val_split[:-val_num] @@ -55,34 +50,10 @@ def __init__(self, opt): split_dict[split_index]['train'] = train_split split_dict[split_index]['val'] = val_split - if opt.get('override_phase', None) is None: - splits = split_dict[split_index][opt['phase']] - else: - splits = split_dict[split_index][opt['override_phase']] - + splits = split_dict[split_index][self.phase] self.paths_mos = [self.paths_mos[i] for i in splits] - - self.mean_mos = np.array([item[1] for item in self.paths_mos]).mean() - - # self.paths_mos.sort(key=lambda x: x[1]) - # n = 32 - # n = 4 - # tmp_list = [self.paths_mos[i: i + n] for i in range(0, len(self.paths_mos), n)] - # random.shuffle(tmp_list) - # self.paths_mos = list(itertools.chain.from_iterable(tmp_list)) - - transform_list = [] - augment_dict = opt.get('augment', None) - if augment_dict is not None: - for k, v in augment_dict.items(): - transform_list += transform_mapping(k, v) - img_range = opt.get('img_range', 1.0) - transform_list += [ - tf.ToTensor(), - tf.Lambda(lambda x: x * img_range), - ] - self.trans = tf.Compose(transform_list) + self.mean_mos = np.array([item[1] for item in self.paths_mos]).mean() def __getitem__(self, index): @@ -104,6 +75,3 @@ def __getitem__(self, index): return {'img': tmp_tensor, 'mos_label': mos_label_tensor, 'mos_dist': mos_dist_tensor, 'org_size': torch.tensor([height, width]), 'img_path': img_path, 'mean_mos': torch.tensor(self.mean_mos)} else: return {'img': img_tensor, 'img2': img_tensor2, 'mos_label': mos_label_tensor, 'mos_dist': mos_dist_tensor, 'org_size': torch.tensor([height, width]), 'img_path': img_path, 'mean_mos': torch.tensor(self.mean_mos)} - - def __len__(self): - return len(self.paths_mos) diff --git a/pyiqa/data/bapps_dataset.py b/pyiqa/data/bapps_dataset.py index 0cf7161..a056a65 100644 --- a/pyiqa/data/bapps_dataset.py +++ b/pyiqa/data/bapps_dataset.py @@ -1,23 +1,18 @@ -import numpy as np import pickle from PIL import Image import os import torch from torch.utils import data as data -import torchvision.transforms as tf -from torchvision.transforms.functional import normalize -from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping, augment, PairedToTensor -from pyiqa.utils import FileClient, imfrombytes, img2tensor from pyiqa.utils.registry import DATASET_REGISTRY +from .base_iqa_dataset import BaseIQADataset import pandas as pd @DATASET_REGISTRY.register() -class BAPPSDataset(data.Dataset): +class BAPPSDataset(BaseIQADataset): """The BAPPS Dataset introduced by: Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver @@ -33,24 +28,21 @@ class BAPPSDataset(data.Dataset): - jnd: load jnd pair data """ - def __init__(self, opt): - super(BAPPSDataset, self).__init__() - self.opt = opt - + def init_path_mos(self, opt): if opt.get('override_phase', None) is None: self.phase = opt['phase'] else: self.phase = opt['override_phase'] self.dataset_mode = opt.get('mode', '2afc') - val_types = opt.get('val_types', None) target_img_folder = opt['dataroot_target'] self.dataroot = target_img_folder - ref_img_folder = opt.get('dataroot_ref', None) self.paths_mos = pd.read_csv(opt['meta_info_file']).values.tolist() - + + def get_split(self, opt): + val_types = opt.get('val_types', None) # read train/val/test splits split_file_path = opt.get('split_file', None) if split_file_path: @@ -72,20 +64,7 @@ def __init__(self, opt): if vt in item[1]: tmp_paths_mos.append(item) self.paths_mos = tmp_paths_mos - - # TODO: paired transform - transform_list = [] - augment_dict = opt.get('augment', None) - if augment_dict is not None: - for k, v in augment_dict.items(): - transform_list += transform_mapping(k, v) - - img_range = opt.get('img_range', 1.0) - transform_list += [ - PairedToTensor(), - ] - self.trans = tf.Compose(transform_list) - + def __getitem__(self, index): is_jnd_data = self.paths_mos[index][0] == 'jnd' distA_path = os.path.join(self.dataroot, self.paths_mos[index][1]) @@ -117,6 +96,3 @@ def __getitem__(self, index): 'mos_label': mos_label_tensor, 'distB_path': distB_path, 'distA_path': distA_path} - - def __len__(self): - return len(self.paths_mos) diff --git a/pyiqa/data/base_iqa_dataset.py b/pyiqa/data/base_iqa_dataset.py new file mode 100644 index 0000000..e33ddcb --- /dev/null +++ b/pyiqa/data/base_iqa_dataset.py @@ -0,0 +1,88 @@ +import pickle + +from torch.utils import data as data +import torchvision.transforms as tf + +from pyiqa.data.data_util import read_meta_info_file +from pyiqa.data.transforms import transform_mapping, PairedToTensor +from pyiqa.utils import get_root_logger + + +class BaseIQADataset(data.Dataset): + """General No Reference dataset with meta info file. + + Args: + opt (dict): Config for train datasets with the following keys: + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + self.opt = opt + self.logger = get_root_logger() + + if opt.get('override_phase', None) is None: + self.phase = opt['phase'] + else: + self.phase = opt['override_phase'] + + # initialize datasets + self.init_path_mos(opt) + + # mos normalization + self.mos_normalize(opt) + + # read train/val/test splits + self.get_split(opt) + + # get transforms + self.get_transforms(opt) + + def init_path_mos(self, opt): + target_img_folder = opt['dataroot_target'] + self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file']) + + def get_split(self, opt): + # read train/val/test splits + split_file_path = opt.get('split_file', None) + if split_file_path: + split_index = opt.get('split_index', 1) + with open(opt['split_file'], 'rb') as f: + split_dict = pickle.load(f) + splits = split_dict[split_index][self.phase] + self.paths_mos = [self.paths_mos[i] for i in splits] + + def mos_normalize(self, opt): + mos_range = opt.get('mos_range', None) + mos_lower_better = opt.get('lower_better', None) + mos_normalize = opt.get('mos_normalize', False) + + if mos_normalize: + assert mos_range is not None and mos_lower_better is not None, 'mos_range and mos_lower_better should be provided when mos_normalize is True' + + def normalize(mos_label): + mos_label = (mos_label - mos_range[0]) / (mos_range[1] - mos_range[0]) + if mos_lower_better: + mos_label = 1 - mos_label + return mos_label + + self.paths_mos = [(p, normalize(m)) for p, m in self.paths_mos] + self.logger.info(f'mos_label is normalized from {mos_range}, lower_better[{mos_lower_better}] to [0, 1], higher better.') + + def get_transforms(self, opt): + transform_list = [] + augment_dict = opt.get('augment', None) + if augment_dict is not None: + for k, v in augment_dict.items(): + transform_list += transform_mapping(k, v) + + self.img_range = opt.get('img_range', 1.0) + transform_list += [ + PairedToTensor(), + ] + self.trans = tf.Compose(transform_list) + + def __getitem__(self, index): + pass + + def __len__(self): + return len(self.paths_mos) diff --git a/pyiqa/data/flive_dataset.py b/pyiqa/data/flive_dataset.py deleted file mode 100644 index db82cf5..0000000 --- a/pyiqa/data/flive_dataset.py +++ /dev/null @@ -1,82 +0,0 @@ -import numpy as np -import pickle -from PIL import Image -import cv2 - -import torch -from torch.utils import data as data -import torchvision.transforms as tf -from torchvision.transforms.functional import normalize - -from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping -from pyiqa.utils import FileClient, imfrombytes, img2tensor -from pyiqa.utils.registry import DATASET_REGISTRY - - -@DATASET_REGISTRY.register() -class FLIVEDataset(data.Dataset): - """General No Reference dataset with meta info file. - - Args: - opt (dict): Config for train datasets with the following keys: - phase (str): 'train' or 'val'. - """ - - def __init__(self, opt): - super(FLIVEDataset, self).__init__() - self.opt = opt - - target_img_folder = opt['dataroot_target'] - self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file']) - - # read train/val/test splits - split_file_path = opt.get('split_file', None) - if split_file_path: - split_index = opt.get('split_index', 1) - with open(opt['split_file'], 'rb') as f: - split_dict = pickle.load(f) - if opt.get('override_phase', None) is None: - splits = split_dict[split_index][opt['phase']] - else: - splits = split_dict[split_index][opt['override_phase']] - - self.paths_mos = [self.paths_mos[i] for i in splits] - - dmos_max = opt.get('dmos_max', 0.) - if dmos_max: - self.use_dmos = True - self.dmos_max = opt.get('dmos_max') - else: - self.use_dmos = False - self.mos_max = opt.get('mos_max', 1.) - - transform_list = [] - augment_dict = opt.get('augment', None) - if augment_dict is not None: - for k, v in augment_dict.items(): - transform_list += transform_mapping(k, v) - - self.img_range = opt.get('img_range', 1.0) - transform_list += [ - tf.ToTensor(), - ] - self.trans = tf.Compose(transform_list) - - def __getitem__(self, index): - - img_path = self.paths_mos[index][0] - mos_label = self.paths_mos[index][1] - img_pil = Image.open(img_path).convert('RGB') - - img_tensor = self.trans(img_pil) * self.img_range - if self.use_dmos: - mos_label = self.dmos_max - mos_label - else: - mos_label = mos_label / self.mos_max - mos_label_tensor = torch.Tensor([mos_label]) - - return {'img': img_tensor, 'mos_label': mos_label_tensor, 'img_path': img_path} - - def __len__(self): - return len(self.paths_mos) diff --git a/pyiqa/data/general_fr_dataset.py b/pyiqa/data/general_fr_dataset.py index 4906e73..d73ee27 100644 --- a/pyiqa/data/general_fr_dataset.py +++ b/pyiqa/data/general_fr_dataset.py @@ -1,58 +1,26 @@ -import numpy as np -import pickle from PIL import Image import torch from torch.utils import data as data import torchvision.transforms as tf -from torchvision.transforms.functional import normalize from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping, augment, PairedToTensor -from pyiqa.utils import FileClient, imfrombytes, img2tensor +from pyiqa.data.transforms import transform_mapping, PairedToTensor from pyiqa.utils.registry import DATASET_REGISTRY +from .base_iqa_dataset import BaseIQADataset @DATASET_REGISTRY.register() -class GeneralFRDataset(data.Dataset): +class GeneralFRDataset(BaseIQADataset): """General Full Reference dataset with meta info file. - - Args: - opt (dict): Config for train datasets with the following keys: - phase (str): 'train' or 'val'. """ - - def __init__(self, opt): - super(GeneralFRDataset, self).__init__() - self.opt = opt - - if opt.get('override_phase', None) is None: - self.phase = opt['phase'] - else: - self.phase = opt['override_phase'] - + + def init_path_mos(self, opt): target_img_folder = opt['dataroot_target'] ref_img_folder = opt.get('dataroot_ref', None) self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file'], mode='fr', ref_dir=ref_img_folder) - # read train/val/test splits - split_file_path = opt.get('split_file', None) - if split_file_path: - split_index = opt.get('split_index', 1) - with open(opt['split_file'], 'rb') as f: - split_dict = pickle.load(f) - splits = split_dict[split_index][self.phase] - self.paths_mos = [self.paths_mos[i] for i in splits] - - dmos_max = opt.get('dmos_max', 0.) - if dmos_max: - self.use_dmos = True - self.dmos_max = opt.get('dmos_max') - else: - self.use_dmos = False - - self.mos_max = opt.get('mos_max', 1.) - + def get_transforms(self, opt): # do paired transform first and then do common transform paired_transform_list = [] augment_dict = opt.get('augment', None) @@ -67,7 +35,23 @@ def __init__(self, opt): PairedToTensor(), ] self.common_trans = tf.Compose(common_transform_list) + + def mos_normalize(self, opt): + mos_range = opt.get('mos_range', None) + mos_lower_better = opt.get('lower_better', None) + mos_normalize = opt.get('mos_normalize', False) + if mos_normalize: + assert mos_range is not None and mos_lower_better is not None, 'mos_range and mos_lower_better should be provided when mos_normalize is True' + + def normalize(mos_label): + mos_label = (mos_label - mos_range[0]) / (mos_range[1] - mos_range[0]) + if mos_lower_better: + mos_label = 1 - mos_label + return mos_label + + self.paths_mos = [item[:2] + [normalize(item[2])] for item in self.paths_mos] + self.logger.info(f'mos_label is normalized from {mos_range}, lower_better[{mos_lower_better}] to [0, 1], higher better.') def __getitem__(self, index): @@ -81,13 +65,6 @@ def __getitem__(self, index): img_tensor = self.common_trans(img_pil) * self.img_range ref_tensor = self.common_trans(ref_pil) * self.img_range - if self.use_dmos: - mos_label = (self.dmos_max - mos_label) / self.dmos_max - else: - mos_label /= self.mos_max mos_label_tensor = torch.Tensor([mos_label]) return {'img': img_tensor, 'ref_img': ref_tensor, 'mos_label': mos_label_tensor, 'img_path': img_path, 'ref_img_path': ref_path} - - def __len__(self): - return len(self.paths_mos) diff --git a/pyiqa/data/general_nr_dataset.py b/pyiqa/data/general_nr_dataset.py index c016437..7ca0c7f 100644 --- a/pyiqa/data/general_nr_dataset.py +++ b/pyiqa/data/general_nr_dataset.py @@ -1,69 +1,19 @@ -import numpy as np -import pickle from PIL import Image -import cv2 - import torch from torch.utils import data as data -import torchvision.transforms as tf -from torchvision.transforms.functional import normalize from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping, augment, PairedToTensor -from pyiqa.utils import FileClient, imfrombytes, img2tensor from pyiqa.utils.registry import DATASET_REGISTRY - +from .base_iqa_dataset import BaseIQADataset @DATASET_REGISTRY.register() -class GeneralNRDataset(data.Dataset): +class GeneralNRDataset(BaseIQADataset): """General No Reference dataset with meta info file. - - Args: - opt (dict): Config for train datasets with the following keys: - phase (str): 'train' or 'val'. """ - - def __init__(self, opt): - super(GeneralNRDataset, self).__init__() - self.opt = opt - - if opt.get('override_phase', None) is None: - self.phase = opt['phase'] - else: - self.phase = opt['override_phase'] - + def init_path_mos(self, opt): target_img_folder = opt['dataroot_target'] self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file']) - # read train/val/test splits - split_file_path = opt.get('split_file', None) - if split_file_path: - split_index = opt.get('split_index', 1) - with open(opt['split_file'], 'rb') as f: - split_dict = pickle.load(f) - splits = split_dict[split_index][self.phase] - self.paths_mos = [self.paths_mos[i] for i in splits] - - dmos_max = opt.get('dmos_max', 0.) - if dmos_max: - self.use_dmos = True - self.dmos_max = opt.get('dmos_max') - else: - self.use_dmos = False - self.mos_max = opt.get('mos_max', 1.) - - transform_list = [] - augment_dict = opt.get('augment', None) - if augment_dict is not None: - for k, v in augment_dict.items(): - transform_list += transform_mapping(k, v) - - self.img_range = opt.get('img_range', 1.0) - transform_list += [ - PairedToTensor(), - ] - self.trans = tf.Compose(transform_list) - def __getitem__(self, index): img_path = self.paths_mos[index][0] @@ -71,13 +21,6 @@ def __getitem__(self, index): img_pil = Image.open(img_path).convert('RGB') img_tensor = self.trans(img_pil) * self.img_range - if self.use_dmos: - mos_label = self.dmos_max - mos_label - else: - mos_label /= self.mos_max mos_label_tensor = torch.Tensor([mos_label]) - + return {'img': img_tensor, 'mos_label': mos_label_tensor, 'img_path': img_path} - - def __len__(self): - return len(self.paths_mos) diff --git a/pyiqa/data/livechallenge_dataset.py b/pyiqa/data/livechallenge_dataset.py index 55c22de..39949e5 100644 --- a/pyiqa/data/livechallenge_dataset.py +++ b/pyiqa/data/livechallenge_dataset.py @@ -1,21 +1,12 @@ -import numpy as np -import pickle -from PIL import Image import os -import torch -from torch.utils import data as data -import torchvision.transforms as tf -from torchvision.transforms.functional import normalize - from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping, augment -from pyiqa.utils import FileClient, imfrombytes, img2tensor from pyiqa.utils.registry import DATASET_REGISTRY +from .general_nr_dataset import GeneralNRDataset @DATASET_REGISTRY.register() -class LIVEChallengeDataset(data.Dataset): +class LIVEChallengeDataset(GeneralNRDataset): """The LIVE Challenge Dataset introduced by D. Ghadiyaram and A.C. Bovik, @@ -28,47 +19,8 @@ class LIVEChallengeDataset(data.Dataset): phase (str): 'train' or 'val'. """ - def __init__(self, opt): - super(LIVEChallengeDataset, self).__init__() - self.opt = opt - + def init_path_mos(self, opt): target_img_folder = os.path.join(opt['dataroot_target'], 'Images') self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file']) # remove first 7 training images as previous works - self.paths_mos = self.paths_mos[7:] - - # read train/val/test splits - split_file_path = opt.get('split_file', None) - if split_file_path: - split_index = opt.get('split_index', 1) - with open(opt['split_file'], 'rb') as f: - split_dict = pickle.load(f) - splits = split_dict[split_index][opt['phase']] - self.paths_mos = [self.paths_mos[i] for i in splits] - - transform_list = [] - augment_dict = opt.get('augment', None) - if augment_dict is not None: - for k, v in augment_dict.items(): - transform_list += transform_mapping(k, v) - - img_range = opt.get('img_range', 1.0) - transform_list += [ - tf.ToTensor(), - tf.Lambda(lambda x: x * img_range), - ] - self.trans = tf.Compose(transform_list) - - def __getitem__(self, index): - - img_path = self.paths_mos[index][0] - mos_label = self.paths_mos[index][1] - img_pil = Image.open(img_path) - - img_tensor = self.trans(img_pil) - mos_label_tensor = torch.Tensor([mos_label]) - - return {'img': img_tensor, 'mos_label': mos_label_tensor, 'img_path': img_path} - - def __len__(self): - return len(self.paths_mos) + self.paths_mos = self.paths_mos[7:] diff --git a/pyiqa/data/pieapp_dataset.py b/pyiqa/data/pieapp_dataset.py index 122f8bf..e16dabe 100644 --- a/pyiqa/data/pieapp_dataset.py +++ b/pyiqa/data/pieapp_dataset.py @@ -1,23 +1,17 @@ -import numpy as np import pickle from PIL import Image import os import torch from torch.utils import data as data -import torchvision.transforms as tf -from torchvision.transforms.functional import normalize - -from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping, augment, PairedToTensor -from pyiqa.utils import FileClient, imfrombytes, img2tensor from pyiqa.utils.registry import DATASET_REGISTRY import pandas as pd +from .general_fr_dataset import GeneralFRDataset @DATASET_REGISTRY.register() -class PieAPPDataset(data.Dataset): +class PieAPPDataset(GeneralFRDataset): """The PieAPP Dataset introduced by: Prashnani, Ekta and Cai, Hong and Mostofi, Yasamin and Sen, Pradeep @@ -29,25 +23,15 @@ class PieAPPDataset(data.Dataset): opt (dict): Config for train datasets with the following keys: phase (str): 'train' or 'val'. """ - - def __init__(self, opt): - super(PieAPPDataset, self).__init__() - self.opt = opt - - target_img_folder = opt['dataroot_target'] - self.dataroot = target_img_folder - - if opt.get('override_phase', None) is None: - self.phase = opt['phase'] - else: - self.phase = opt['override_phase'] - + def init_path_mos(self, opt): + self.dataroot = opt['dataroot_target'] if self.phase == "test": metadata = pd.read_csv(opt['meta_info_file'], usecols=['ref_img_path', 'dist_imgB_path', 'per_img score for dist_imgB']) else: metadata = pd.read_csv(opt['meta_info_file']) self.paths_mos = metadata.values.tolist() + def get_split(self, opt): # read train/val/test splits split_file_path = opt.get('split_file', None) if split_file_path: @@ -63,22 +47,6 @@ def __init__(self, opt): [temp.append(item) for item in self.paths_mos if not item in temp] self.paths_mos = temp - # do paired transform first and then do common transform - paired_transform_list = [] - augment_dict = opt.get('augment', None) - if augment_dict is not None: - for k, v in augment_dict.items(): - paired_transform_list += transform_mapping(k, v) - self.paired_trans = tf.Compose(paired_transform_list) - - common_transform_list = [] - self.img_range = opt.get('img_range', 1.0) - common_transform_list += [ - PairedToTensor(), - ] - self.common_trans = tf.Compose(common_transform_list) - - def __getitem__(self, index): ref_path = os.path.join(self.dataroot, self.paths_mos[index][0]) @@ -120,6 +88,3 @@ def __getitem__(self, index): return {'distB_img': distB_tensor, 'ref_img': ref_tensor, 'distA_img': distA_tensor, 'mos_label': mos_label_tensor, 'distB_per_img_score': distB_score, 'distB_path': distB_path, 'ref_img_path': ref_path, 'distA_path': distA_path} - - def __len__(self): - return len(self.paths_mos) diff --git a/pyiqa/data/pipal_dataset.py b/pyiqa/data/pipal_dataset.py deleted file mode 100644 index 20a17c8..0000000 --- a/pyiqa/data/pipal_dataset.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np -import pickle -from PIL import Image - -import torch -from torch.utils import data as data -import torchvision.transforms as tf -from torchvision.transforms.functional import normalize - -from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping, augment, PairedToTensor -from pyiqa.utils import FileClient, imfrombytes, img2tensor -from pyiqa.utils.registry import DATASET_REGISTRY - - -@DATASET_REGISTRY.register() -class PIPALDataset(data.Dataset): - """General Full Reference dataset with meta info file. - - Args: - opt (dict): Config for train datasets with the following keys: - phase (str): 'train' or 'val'. - """ - - def __init__(self, opt): - super(PIPALDataset, self).__init__() - self.opt = opt - - if opt.get('override_phase', None) is None: - self.phase = opt['phase'] - else: - self.phase = opt['override_phase'] - - target_img_folder = opt['dataroot_target'] - ref_img_folder = opt.get('dataroot_ref', None) - self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file'], mode='fr', ref_dir=ref_img_folder) - - # read train/val/test splits - split_file_path = opt.get('split_file', None) - if split_file_path: - split_index = opt.get('split_index', 1) - with open(opt['split_file'], 'rb') as f: - split_dict = pickle.load(f) - splits = split_dict[split_index][self.phase] - - self.paths_mos = [self.paths_mos[i] for i in splits] - - dmos_max = opt.get('dmos_max', 0.) - if dmos_max: - self.use_dmos = True - self.dmos_max = opt.get('dmos_max') - else: - self.use_dmos = False - - # do paired transform first and then do common transform - paired_transform_list = [] - augment_dict = opt.get('augment', None) - if augment_dict is not None: - for k, v in augment_dict.items(): - paired_transform_list += transform_mapping(k, v) - self.paired_trans = tf.Compose(paired_transform_list) - - common_transform_list = [] - self.img_range = opt.get('img_range', 1.0) - common_transform_list += [ - PairedToTensor(), - ] - self.common_trans = tf.Compose(common_transform_list) - - - def __getitem__(self, index): - - ref_path = self.paths_mos[index][0] - img_path = self.paths_mos[index][1] - mos_label = self.paths_mos[index][2] - - img_pil = Image.open(img_path).convert('RGB') - ref_pil = Image.open(ref_path).convert('RGB') - - img_pil, ref_pil = self.paired_trans([img_pil, ref_pil]) - - img_tensor = self.common_trans(img_pil) * self.img_range - ref_tensor = self.common_trans(ref_pil) * self.img_range - if self.use_dmos: - mos_label = self.dmos_max - mos_label - mos_label_tensor = torch.Tensor([mos_label]) - - return {'img': img_tensor, 'ref_img': ref_tensor, 'mos_label': mos_label_tensor, 'img_path': img_path, 'ref_img_path': ref_path} - - def __len__(self): - return len(self.paths_mos) diff --git a/pyiqa/default_model_configs.py b/pyiqa/default_model_configs.py index cf9fe0e..cff0815 100644 --- a/pyiqa/default_model_configs.py +++ b/pyiqa/default_model_configs.py @@ -348,7 +348,7 @@ 'metric_opts': { 'type': 'CLIPScore', }, - 'metric_mode': 'NR', # Caption image similarity + 'metric_mode': 'NR', # Caption image similarity }, 'entropy': { 'metric_opts': { diff --git a/tests/NR_benchmark_results.csv b/tests/NR_benchmark_results.csv index 7c4100f..c1ea341 100644 --- a/tests/NR_benchmark_results.csv +++ b/tests/NR_benchmark_results.csv @@ -6,7 +6,7 @@ nrqm,0.4122/0.3012/0.2013,0.4756/0.3715/0.2517,0.4591/0.3317/0.2285,,, pi,0.5201/0.4615/0.3139,0.4688/0.4573/0.3132,0.4512/0.3444/0.2373,,, nima,0.4993/0.5071/0.348,0.7156/0.6662/0.4816,,,, paq2piq,0.7542/0.7188/0.5302,0.7062/0.643/0.4622,0.5776/0.4011/0.2838,,, -cnniqa,0.6338/0.6126/0.4285,0.7971/0.7603/0.562,0.3932/0.162/0.1063,,, +cnniqa,0.6372/0.6089/0.4257,0.7934/0.7551/0.558,0.398/0.1769/0.117,,, dbcnn,0.7864/0.7642/0.5667,0.9269/0.9122/0.7451,0.5454/0.4348/0.3057,,, musiq-ava,0.6001/0.5954/0.4235,0.589/0.5273/0.3714,,,, musiq-koniq,0.8295/0.7889/0.5986,0.8958/0.8654/0.6817,0.6814/0.575/0.4131,0.5128/0.4978/0.3437,0.8626/0.8676/0.6649, @@ -19,5 +19,5 @@ clipiqa+_rn50_512,0.8181/0.818/0.6231,0.9012/0.8847/0.7033,0.6577/0.5949/0.4241, clipiqa+_vitL14_512,0.7679/0.7729/0.5733,0.8747/0.861/0.6721,0.6063/0.5259/0.3709,,, tres-koniq,0.8118/0.7771/0.5808,,,0.513/0.4919/0.3391,0.8624/0.8619/0.66, tres-flive,0.7213/0.7336/0.5373,0.7507/0.7068/0.516,,,0.6137/0.7269/0.533, -hyperiqa,0.7688/0.744/0.5448,0.9196/0.8991/0.7283,0.5646/0.4502/0.3142,,, +hyperiqa,0.7779/0.7546/0.5562,0.9233/0.904/0.7336,0.5627/0.4537/0.3177,,, topiq_nr,0.8261/0.8106/0.6165,0.9436/0.9299/0.7727,0.5625/0.4452/0.3143,0.6289/0.5819/0.4105,0.8744/0.8704/0.6716, diff --git a/tests/test_datasets_general.py b/tests/test_datasets_general.py index c7e23c4..90caacc 100644 --- a/tests/test_datasets_general.py +++ b/tests/test_datasets_general.py @@ -2,89 +2,10 @@ import torch import pytest import os +import yaml -options = { - 'BAPPS': { - 'type': 'BAPPSDataset', - 'dataroot_target': './datasets/PerceptualSimilarity/dataset', - 'meta_info_file': './datasets/meta_info/meta_info_BAPPSDataset.csv', - }, - 'PieAPP': { - 'type': 'PieAPPDataset', - 'dataroot_target': './datasets/PieAPP_dataset_CVPR_2018/', - 'meta_info_file': './datasets/meta_info/meta_info_PieAPPDataset.csv', - }, - 'FLIVE': { - 'type': 'GeneralNRDataset', - 'dataroot_target': './datasets/FLIVE_Database/database', - 'meta_info_file': './datasets/meta_info/meta_info_FLIVEDataset.csv', - }, - 'PIPAL': { - 'type': 'PIPALDataset', - 'dataroot_target': './datasets/PIPAL/Dist_Imgs', - 'dataroot_ref': './datasets/PIPAL/Train_Ref', - 'meta_info_file': './datasets/meta_info/meta_info_PIPALDataset.csv', - 'split_file': './datasets/meta_info/pipal_official.pkl' - }, - 'KonIQ10k++': { - 'type': 'GeneralNRDataset', - 'dataroot_target': './datasets/koniq10k/512x384', - 'meta_info_file': './datasets/meta_info/meta_info_KonIQ10k++Dataset.csv', - }, - 'AVA': { - 'type': 'AVADataset', - 'dataroot_target': './datasets/AVA_dataset/ava_images/', - 'meta_info_file': './datasets/meta_info/meta_info_AVADataset.csv', - }, - 'SPAQ': { - 'type': 'GeneralNRDataset', - 'dataroot_target': './datasets/SPAQ/TestImage', - 'meta_info_file': './datasets/meta_info/meta_info_SPAQDataset.csv', - }, - 'KADID10k': { - 'type': 'GeneralFRDataset', - 'dataroot_target': './datasets/kadid10k/images', - 'meta_info_file': './datasets/meta_info/meta_info_KADID10kDataset.csv', - }, - 'KonIQ10k': { - 'type': 'GeneralNRDataset', - 'dataroot_target': './datasets/koniq10k/512x384', - 'meta_info_file': './datasets/meta_info/meta_info_KonIQ10kDataset.csv', - }, - 'LIVEC': { - 'type': 'LIVEChallengeDataset', - 'dataroot_target': './datasets/LIVEC', - 'meta_info_file': './datasets/meta_info/meta_info_LIVEChallengeDataset.csv', - }, - 'LIVEM': { - 'type': 'GeneralFRDataset', - 'dataroot_target': './datasets/LIVEmultidistortiondatabase', - 'meta_info_file': './datasets/meta_info/meta_info_LIVEMDDataset.csv', - }, - 'LIVE': { - 'type': 'GeneralFRDataset', - 'dataroot_target': './datasets/LIVEIQA_release2', - 'meta_info_file': './datasets/meta_info/meta_info_LIVEIQADataset.csv', - }, - 'TID2013': { - 'type': 'GeneralFRDataset', - 'dataroot_target': './datasets/tid2013/distorted_images', - 'dataroot_ref': './datasets/tid2013/reference_images', - 'meta_info_file': './datasets/meta_info/meta_info_TID2013Dataset.csv', - }, - 'TID2008': { - 'type': 'GeneralFRDataset', - 'dataroot_target': './datasets/tid2008/distorted_images', - 'dataroot_ref': './datasets/tid2008/reference_images', - 'meta_info_file': './datasets/meta_info/meta_info_TID2008Dataset.csv', - }, - 'CSIQ': { - 'type': 'GeneralFRDataset', - 'dataroot_target': './datasets/CSIQ/dst_imgs', - 'dataroot_ref': './datasets/CSIQ/src_imgs', - 'meta_info_file': './datasets/meta_info/meta_info_CSIQDataset.csv', - }, -} +with open('./options/default_dataset_opt.yml') as f: + options = yaml.safe_load(f) common_opt = { 'name': 'test', diff --git a/tests/test_metric_general.py b/tests/test_metric_general.py index 7f479b7..c3596fa 100644 --- a/tests/test_metric_general.py +++ b/tests/test_metric_general.py @@ -96,13 +96,20 @@ def test_cpu_gpu_consistency(metric_name): 2. fid requires directory inputs; 3. vsi will output NaN with random input. """ - x_cpu = torch.rand(1, 3, 256, 256) + x_cpu = torch.rand(1, 3, 224, 224) x_gpu = x_cpu.cuda() - y_cpu = torch.rand(1, 3, 256, 256) + y_cpu = torch.rand(1, 3, 224, 224) y_gpu = y_cpu.cuda() metric_cpu = pyiqa.create_metric(metric_name, device='cpu') metric_gpu = pyiqa.create_metric(metric_name, device='cuda') + metric_cpu.eval() + metric_gpu.eval() + + if hasattr(metric_cpu.net, 'test_sample'): + metric_cpu.net.test_sample = 1 + if hasattr(metric_gpu.net, 'test_sample'): + metric_gpu.net.test_sample = 1 score_cpu = metric_cpu(x_cpu, y_cpu) score_gpu = metric_gpu(x_gpu, y_gpu) @@ -113,19 +120,25 @@ def test_cpu_gpu_consistency(metric_name): @pytest.mark.parametrize( ("metric_name"), - [(k) for k in pyiqa.list_models() if k not in ['pi', 'nrqm', 'fid', 'mad', 'vsi']] + [(k) for k in pyiqa.list_models() if k not in ['pi', 'nrqm', 'fid', 'mad', 'vsi', 'clipscore', 'entropy']] ) def test_gradient_backward(metric_name, device): """Test if the metric can be used in a gradient descent process. pi, nrqm and fid are not tested because they are not differentiable. mad and vsi give NaN with random input. """ - x = torch.randn(2, 3, 224, 224).to(device) - y = torch.randn(2, 3, 224, 224).to(device) + size = (2, 3, 224, 224) + if 'swin' in metric_name: + size = (2, 3, 384, 384) + + x = torch.randn(*size).to(device) + y = torch.randn(*size).to(device) x.requires_grad_() metric = pyiqa.create_metric(metric_name, as_loss=True, device=device) - metric.train() + metric.eval() + if hasattr(metric.net, 'test_sample'): + metric.net.test_sample = 1 score = metric(x, y) if isinstance(score, tuple):