In [1]:
import sys
sys.path.append('src')

In [451]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Imports

In [230]:
#export
from pathlib import Path
from functools import lru_cache, partial

# from tqdm.auto import tqdm
from tqdm.notebook import tqdm

import os
import cv2
import yaml
from PIL import Image
import numpy as np
import albumentations as albu
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import ConcatDataset as ConcatDataset

from . import nb_utils

In [201]:
from pprint import pprint

# Code

## Datasets

In [285]:
#export
    
class Dataset:
    def __init__(self, root, pattern):
        self.pattern = pattern
        self.root = Path(root)
        files = list(self.root.rglob(self.pattern))
        assert len(files) > 0, 'There is no matching files'
        self.files = sorted(files)
        
    def load_item(self, idx):
        raise NotImplementedError
    
    def __getitem__(self, idx):
        item = self.load_item(idx)
        return item
    
    def __len__(self):
        return len(self.files)
        
#     def __add__(self, other):
#         return ConcatDataset([self, other])

    
class ImageDataset(Dataset):
    def load_item(self, idx):
        img_path = self.files[idx]
        #img = cv2.imread(str(img_path))
        #img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
        img = Image.open(str(img_path))
        return img
    
class PairDataset:
    def __init__(self, ds1, ds2):
        self.ds1, self.ds2 = ds1, ds2
        assert len(self.ds1) == len(self.ds2)
    
    def __getitem__(self, idx):
        return self.ds1.__getitem__(idx), self.ds2.__getitem__(idx) 
    
    def __len__(self):
        return len(self.ds1)
    
class TransformDataset:
    def __init__(self, dataset, transforms, is_masked=False):
        self.dataset = dataset
        self.transforms = albu.Compose([]) if transforms is None else transforms
        self.is_masked = is_masked
    
    def __getitem__(self, idx):
        item = self.dataset.__getitem__(idx)
        if self.is_masked:
            img, mask = item
            augmented = self.transforms(image=img, mask=mask)
            return augmented["image"], augmented["mask"]
        else:
            return self.transforms(image=item[0], mask=None)['image']
    
    def __len__(self):
        return len(self.dataset)
    
class MultiplyDataset:
    def __init__(self, dataset, rate):
        _dataset = ConcatDataset([dataset])
        for i in range(rate-1):
            _dataset += ConcatDataset([dataset])
        self.dataset = _dataset
        
    def __getitem__(self, idx):
        return self.dataset.__getitem__(idx)
    
    def __len__(self):
        return len(self.dataset)
    
class CachingDataset:
    def __init__(self, dataset):
        self.dataset = dataset
            
    @lru_cache(maxsize=None)
    def __getitem__(self, idx):
        return self.dataset.__getitem__(idx)
    
    def __len__(self):
        return len(self.dataset)

    
class PreloadingDataset:
    def __init__(self, dataset, num_proc=False):
        self.dataset = dataset
        self.num_proc = num_proc
        if self.num_proc:
            self.data = nb_utils.mp_func_gen(self.preload_data, range(len(self.dataset)) , self.num_proc)
        else:
            self.data = self.preload_data((self.dataset, range(len(self.dataset))))
        
    def preload_data(self, args):
        idxs = args
        data = []
        for i in idxs:
            r = self.dataset.__getitem__(i)
            data.append(r)
        return data
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)
    
    
class GpuPreloadingDataset:
    def __init__(self, dataset, devices):
        self.dataset = dataset
        self.devices = devices
        self.data = self.preload_data()
        
    def preload_data(self):
        data = []
        for i in range(len(self.dataset)):
            item, idx = self.dataset.__getitem__(i)
            item = item.to(self.devices[0])
            data.append((item, idx))
        return data
    
    def __getitem__(self, idx):
        return self.data[idx]
   
    def __len__(self):
        return len(self.dataset)

## Dataset catalog

In [286]:
#export
class DatasetCatalog():
    DATA_DIR = "/tmp/"
    DATA_DIR_MNT = "/mnt/tmp"
    
    DATASETS = {
        "default": {
            'factory':'default',
            "root": "def_root",
        }
    }
    @staticmethod 
    def create_factory_dict(data_dir, dataset_attrs):
        #{factory:Dataset, args:args}
        raise NotImplementedError
    
    @classmethod 
    def get(cls, name):
        try:
            attrs = cls.DATASETS[name]
        except:
            print(cls.DATASETS)
            raise RuntimeError("Dataset not available: {}".format(name))
            
        if os.path.exists(cls.DATA_DIR):
            data_dir = cls.DATA_DIR
        elif os.path.exists(cls.DATA_DIR_MNT):
            data_dir = cls.DATA_DIR_MNT
            
        return cls.create_factory_dict(data_dir, attrs)
        

## Builders

In [405]:
#export

# dataset_factories = {'termit':TermitDataset}
# transform_factories = {'TRAIN':{'factory':TransformDataset_Partial_HARD, 'transform_getter':get_aug}}
# extend_factories = {'GPU_PRELOAD':GpuPreloadingDataset_Partial_GPU0}
# dataset_types = ['TRAIN', 'VALID', 'TEST']
# datasets = {'TRAIN': dataset1, 'VALID': ...}
    
def extend_dataset(ds, data_field, extend_factories):
    for k, factory in extend_factories.items():
        field_val = data_field.get(k, None) 
        if field_val:
            args = {}
            if isinstance(field_val, dict): args.update(field_val)
            ds = factory(ds, **args)
    return ds

def extend_all_datasets(cfg, datasets, extend_factories):
    extended_datasets = {}
    for kind, ds in datasets.items():
        extended_datasets[kind] = extend_dataset(ds, cfg.DATA[kind], extend_factories)
    return extended_datasets

class DatasetBuilder:
    def __init__(self, cfg,
                       catalog,
                       dataset_factories,
                       transform_factory,
                       dataset_types=['TRAIN', 'VALID', 'TEST']):
        nb_utils.store_attr(self, locals())
        
    def build_datasets(self):
        transformers = self._build_transformers()
        converted_datasets = {}
        for dataset_type in self.dataset_types:
            data_field = self.cfg.DATA[dataset_type]
            datasets_strings = data_field.DATASETS

            if datasets_strings:
                datasets = [self._create_dataset_fact(ds) for ds in datasets_strings]
                ds = ConcatDataset(datasets) if len(datasets)>1 else datasets[0] 
                ds = transformers[dataset_type](ds)
                converted_datasets[dataset_type] = ds
        return converted_datasets
    
    def _create_dataset_fact(self, ds):
        dataset_attrs = self.catalog.get(ds)
        factory = self.dataset_factories[dataset_attrs['factory']]
        return factory(**dataset_attrs['args'])
    
    def _build_transformers(self):
        transformers = {}
        for dataset_type in self.dataset_types:
            aug_type = self.cfg.TRANSFORMERS[dataset_type]['AUG']
            args={
                'aug_type':aug_type,
                'size':self.cfg.TRANSFORMERS.CROP_SIZE
            }
            transform_getter = self.transform_factory[dataset_type]['transform_getter'](**args)
            transformer = partial(self.transform_factory[dataset_type]['factory'], transforms=transform_getter)
            transformers[dataset_type] = transformer
        return transformers

In [None]:
#export
def build_dataloader(cfg, dataset, sampler=None, batch_size=1, num_workers=0, drop_last=False, pin=False):
    collate_fn=None
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(sampler is None),
        num_workers=num_workers,
        pin_memory=pin,
        drop_last=drop_last,
        collate_fn=collate_fn,
        sampler=sampler,
    )
    return data_loader

In [None]:

if cfg.PARALLEL.DDP:
sampler = torch.utils.data.distributed.DistributedSampler(dataset,
                                                            num_replicas=cfg.PARALLEL.WORLD_SIZE,
                                                            rank=cfg.PARALLEL.LOCAL_RANK,
                                                            shuffle=True)#!!!

# Tests

## test datasets

In [4]:
imgs_path = './test_data/validation_1_1/'

### 1

In [5]:
d1 = Dataset(imgs_path, 'aimg*.png')
assert len(d1) == 4612
try:
    d1[0]
except NotImplementedError:
    pass
except Exception as e:
    raise e

### 2

In [6]:
d2 = ImageDataset(imgs_path, 'aimg_*')
assert len(d2) == 4612
d3 = ImageDataset(imgs_path, 'mask_*')
assert len(d3) == 4612
d2[0].shape, d3[0].shape

((4096, 4096, 3), (64, 64, 3))

### 3

In [7]:
d4 = PairDataset(d2, d3)
assert len(d4) == 4612
i,ii = d4[0]
j, jj = d2[0], d3[0]
np.allclose(i,j), np.allclose(ii,jj)

(True, True)

### transforms dataset

In [9]:
transforms = albu.Compose([albu.CenterCrop(50, 50)])
d5 = TransformDataset(d4, transforms=transforms, is_masked=True)
i = d5[0]
i[0].shape, i[1].shape

((50, 50, 3), (50, 50, 3))

### multiply

In [10]:
mult = 2
d6 = MultiplyDataset(d2, mult)
assert len(d6) // mult == len(d2)

### cache

In [10]:
d7 = CachingDataset(d2)

In [11]:
%%timeit -r 10 -n 100
d7[0]

The slowest run took 26989.51 times longer than the fastest. This could mean that an intermediate result is being cached.
339 µs ± 1.02 ms per loop (mean ± std. dev. of 10 runs, 100 loops each)


In [13]:
%%timeit -r 1 -n 5
d2[1]

351 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 5 loops each)


### preloading

In [7]:
_d8 = ImageDataset(imgs_path, 'aimg_9*.png')
d8 = PreloadingDataset(_d8, num_proc=8)
assert len(_d8) == len(d8)

111

In [11]:
%%timeit -r 10 -n 100
d8[18]

140 ns ± 32.6 ns per loop (mean ± std. dev. of 10 runs, 100 loops each)


## test catalog

In [424]:
class MyDatasetCatalog(DatasetCatalog):
    DATA_DIR = "test_data/"
    DATA_DIR_MNT = "/mnt/input/term"
    
    DATASETS = {
        "test_data": {
                        'factory':'factory_test',
                        'path_args':{
                                        "root": "validation_1_1",
                                    },
                        'kwargs':{
                                        "pattern": 'aimg*.png'
                                }
            
        },
        "test_data_masks": {
            'factory':'factory_test_masks',
            'path_args':{
                    "root": "validation_1_1",
            },
            'kwargs':{
                    "pattern": 'aimg*.png'
            }
        },
        "test_data_joined": {
            'factory':'factory_test_joined',
            'path_args':{
                    "root1": "validation_1_1",
                    "root2": "validation_1_1",
            },
            'kwargs':{
                    "pattern1": 'aimg*.png',
                    "pattern2": 'mask*.png'
            }
        }
    }
    
    @staticmethod
    def get(name): return super(MyDatasetCatalog, MyDatasetCatalog).get(name)
    
    @staticmethod
    def create_factory_dict(data_dir, dataset_attrs):
        factory = dataset_attrs['factory']
        allowed_facts = [v['factory']  for v in MyDatasetCatalog.DATASETS.values()]
        if factory not in allowed_facts: raise RuntimeError(f' Uknnown factory type: {factory}' )
        
        path_args = {k:os.path.join(data_dir, v) for k, v in dataset_attrs['path_args'].items()}
        return dict(factory=factory, args={**path_args, **dataset_attrs['kwargs']})

In [425]:
test_fact_args = MyDatasetCatalog.get(name='test_data_masks')
test_fact_args

{'factory': 'factory_test_masks',
 'args': {'root': 'test_data/validation_1_1', 'pattern': 'aimg*.png'}}

In [426]:
test_fact_args = MyDatasetCatalog.get(name='test_data_joined')
test_fact_args

{'factory': 'factory_test_joined',
 'args': {'root1': 'test_data/validation_1_1',
  'root2': 'test_data/validation_1_1',
  'pattern1': 'aimg*.png',
  'pattern2': 'mask*.png'}}

## builders

In [427]:
from nb_configer import cfg

In [440]:
yaml_str = '''
    DATA:
      TRAIN:
        DATASETS: ['test_data_joined', 'test_data_joined']
        GPU_PRELOAD: False
        PRELOAD: True
        CACHE: False
      VALID:
        DATASETS: ['test_data']
      TEST:
        DATASETS: ['test_data']

    TRANSFORMERS:
      TRAIN:
        AUG: 'test'
      VALID:
        AUG: 'val'
      TEST:
        AUG: 'test'

      CROP_SIZE: 64

    TRAIN:
      NUM_WORKERS: 0
      BATCH_SIZE: 128

    VALID:
      NUM_WORKERS: 4
      BATCH_SIZE: 1
    '''
yd = yaml.safe_load(yaml_str)
with open('/tmp/t.yaml', 'w') as f:
    yaml.safe_dump(yd, f)
cfg.merge_from_file('/tmp/t.yaml')

In [441]:
pprint(cfg.DATA)

{'TEST': {'CACHE': False,
          'DATASETS': ('test_data',),
          'GPU_PRELOAD': False,
          'MULTIPLY': 1,
          'PRELOAD': False},
 'TRAIN': {'CACHE': False,
           'DATASETS': ('test_data_joined', 'test_data_joined'),
           'GPU_PRELOAD': False,
           'MULTIPLY': 1,
           'PRELOAD': True},
 'VALID': {'CACHE': False,
           'DATASETS': ('test_data',),
           'GPU_PRELOAD': False,
           'MULTIPLY': 1,
           'PRELOAD': False}}


In [442]:
def test_trans_get(aug_type, size):
    return albu.Compose([albu.CenterCrop(size, size)])

In [443]:
class PairImageDataset(PairDataset):
    def __init__(self, root1, pattern1, root2, pattern2):
        self.ds1 = ImageDataset(root1, pattern1)
        self.ds2 = ImageDataset(root2, pattern2)
        assert len(self.ds1) == len(self.ds2)

In [444]:
dataset_factories = {'factory_test': ImageDataset, 'factory_test_joined': PairImageDataset}
transform_factory = {
    'TRAIN':{'factory':partial(TransformDataset, is_masked=True), 'transform_getter':test_trans_get},
    'TEST':{'factory':TransformDataset, 'transform_getter':test_trans_get},
    'VALID':{'factory':TransformDataset, 'transform_getter':test_trans_get},
}
extend_factories = {
    'GPU_PRELOAD':GpuPreloadingDataset,
    'PRELOAD':partial(PreloadingDataset, num_proc=8),
    'CACHE':CachingDataset,
}

In [445]:
builder = DatasetBuilder(cfg, MyDatasetCatalog, dataset_factories=dataset_factories, transform_factory=transform_factory)

In [446]:
datasets = builder.build_datasets()
datasets

{'TRAIN': <__main__.TransformDataset at 0x7f95a16b7da0>,
 'VALID': <__main__.TransformDataset at 0x7f95a16b75c0>,
 'TEST': <__main__.TransformDataset at 0x7f95a16b7f28>}

In [447]:
tds = datasets['TRAIN']

In [448]:
len(tds)
tds[0][0].shape, tds[0][1].shape

((64, 64, 3), (64, 64, 3))

In [449]:
%%timeit -r 2 -n 5
tds[17]

340 ms ± 2.67 ms per loop (mean ± std. dev. of 2 runs, 5 loops each)


In [None]:
datasets = extend_all_datasets(datasets, cfg, extend_factories)

In [None]:
%%timeit -r 2 -n 5
datasets['TRAIN'][17]