In [1]:
# default_exp data.dataset_factory

In [2]:
from nbdev.export import *

## Utility Functions for Loading Data from OmegaConf config

In [3]:
# export
import albumentations as A
import torchvision.transforms as T
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from src import _logger
from src.data.datasets import CassavaDataset, load_data

In [4]:
# export
def create_transform(cfg: DictConfig, verbose=False):
    "creates transoformations to be used in datasets"
    train_augs_initial = [instantiate(t) for t in cfg.train.before_mix]
    train_augs_final   = [instantiate(t) for t in cfg.train.after_mix]
    valid_augs = [instantiate(t) for t in cfg.valid]
    
    if cfg.backend == "torchvision":
        compose_func = T.Compose
    elif cfg.backend == "albumentations":
        compose_func = A.Compose
    
    train_augs_initial = compose_func(train_augs_initial)
    train_augs_final   = compose_func(train_augs_final)
    valid_augs = compose_func(valid_augs)
    return train_augs_initial, train_augs_final, valid_augs

## Dataset Mapper -

In [5]:
# export
class DatasetMapper:
    "A convenince class for CassavaImageClassification task"
    def __init__(self, cfg: DictConfig):
        "Note: `cfg` has to be the global hydra config"
        self.dset_cfg = cfg.data.dataset
        self.tfms_cfg = cfg.augmentations
        self.fold = self.dset_cfg.fold
        
    def generate_datasets(self):
        _logger.info(f"Generating Datasets for FOLD :{self.fold}")
        
        # loads the data correspoind to the current fold
        # and do some data preprocessing
        self.data = load_data(self.dset_cfg.csv, self.dset_cfg.image_dir, self.fold, shuffle=True)
        
        self.train_data = self.data.loc[self.data["is_valid"] == False]
        self.valid_data = self.data.loc[self.data["is_valid"] == True]
        
        self.train_data = self.train_data.sample(frac=1).reset_index(inplace=False, drop=True)
        self.valid_data = self.valid_data.sample(frac=1).reset_index(inplace=False, drop=True)
        
        self.augs_initial, self.augs_final, self.augs_valid = create_transform(self.tfms_cfg)
        
        if self.tfms_cfg.backend == "torchvision":
            
            self.train_ds = CassavaDataset.from_torchvision_tfms(
                self.train_data,
                fn_col="filePath",
                label_col="label",
                transform=self.augs_initial)
            
            self.valid_ds = CassavaDataset.from_torchvision_tfms(
                self.valid_data,
                fn_col="filePath",
                label_col="label",
                transform=self.augs_valid)
            
            self.test_ds = self.valid_ds
            _logger.info(f"Train Dataset has {len(self.train_ds)}, Validation Dataset has {len(self.valid_ds)} instances.")
        
        elif self.tfm_config.backend == "albumentations":
            self.train_ds = CassavaDataset.from_albu_tfms(
                self.train_data,
                fn_col="filePath",
                label_col="label",
                transform=self.augs_initial
            )

            self.valid_ds = CassavaDataset.from_albu_tfms(
                self.train_data,
                fn_col="filePath",
                label_col="label",
                transform=self.augs_initial
            )

            self.test_ds = self.valid_ds
            _logger.info(f"Train Dataset has {len(self.train_ds)}, Validation Dataset has {len(self.train_ds)} instances.")
             
        else:
            raise NameError
            
    def get_train_dataset(self):
        "returns the train dataset"
        return self.train_ds
    
    def get_valid_dataset(self):
        "returns the validation dataset"
        return self.valid_ds
    
    def get_test_dataset(self):
        "return the test dataset"
        return self.test_ds
    
    def get_transforms(self):
        "returns the transformations to be applied after mixmethod"
        return self.augs_final

In [6]:
notebook2script()

Converted 00_core.ipynb.
Converted 01a_data.datasets.ipynb.
Converted 01b_data.datasests_factory.ipynb.
Converted 01c_data.mixmethods.ipynb.
Converted 02_losses.ipynb.
Converted 03a_optimizers.ipynb.
Converted 03b_schedulers.ipynb.
Converted 04a_models.utils.ipynb.
Converted 04b_models.layers.ipynb.
Converted 04c_models.classifiers.ipynb.
Converted 04d_models.builder.ipynb.
Converted 04e_models.task.ipynb.
Converted 05_callbacks.ipynb.
Converted index.ipynb.
