diff --git a/options/default_dataset_opt.yml b/options/default_dataset_opt.yml index 475afc4..e26e96a 100644 --- a/options/default_dataset_opt.yml +++ b/options/default_dataset_opt.yml @@ -79,14 +79,14 @@ ava: pipal: name: PIPAL - type: GeneralFRDataset + type: PIPALDataset dataroot_target: './datasets/PIPAL/Dist_Imgs' dataroot_ref: './datasets/PIPAL/Train_Ref' meta_info_file: './datasets/meta_info/meta_info_PIPALDataset.csv' flive: name: FLIVE - type: GeneralNRDataset + type: FLIVEDataset dataroot_target: './datasets/FLIVE_Database/database' meta_info_file: './datasets/meta_info/meta_info_FLIVEDataset.csv' split_file: './datasets/train_split_info/flive_official.pkl' diff --git a/pyiqa/data/flive_dataset.py b/pyiqa/data/flive_dataset.py index ff51fb8..db82cf5 100644 --- a/pyiqa/data/flive_dataset.py +++ b/pyiqa/data/flive_dataset.py @@ -9,7 +9,7 @@ from torchvision.transforms.functional import normalize from pyiqa.data.data_util import read_meta_info_file -from pyiqa.data.transforms import transform_mapping +from pyiqa.data.transforms import transform_mapping from pyiqa.utils import FileClient, imfrombytes, img2tensor from pyiqa.utils.registry import DATASET_REGISTRY @@ -41,12 +41,7 @@ def __init__(self, opt): else: splits = split_dict[split_index][opt['override_phase']] - if opt['phase'] == 'train': - self.paths_mos = [self.paths_mos[i] for i in splits] - else: - # remove patches during validation and test - self.paths_mos = [self.paths_mos[i] for i in splits] - self.paths_mos = [[p, m] for p, m in self.paths_mos if not 'patches/' in p] + self.paths_mos = [self.paths_mos[i] for i in splits] dmos_max = opt.get('dmos_max', 0.) if dmos_max: