In [0]:
'''

    Standard Modules

'''

import os
import pickle as pkl

from functools import reduce

''' 

    External Modules

'''

import cv2
import numpy as np
import pandas as pd

from tqdm import tqdm, trange

''' 

    PyTorch Modules

'''

import torch
from torch.utils.data import Dataset


class Specimen_Dataset(Dataset):
    
    def __init__(self, dataset, set_value, transform):
        self.metadata = dataset
        self.set_value = set_value
        self.transform = transform
        
    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        
        category_id = None
        img = cv2.imread(self.metadata.loc[idx, ['file_name']])
        
        if self.set_value != 'test_submission':
            category_id = self.metadata.loc[idx, ['category_id_le_preprocessed']]
        
        sample = {
            'img' : img,
            'category_id' : category_id
        }
        
            
        if self.transform:
            return self.transform(sample)
        
        return sample
    
    
class Data_Pipeline:
    
    def __init__(self, *args):
        self.content_pipeline = list(args)
    
    
    def __call__(self, obj):
        
        return reduce(lambda x,y : y(x), [obj] + self.content_pipeline)
    

class Resizer:
    
    def __init__(self, output_size):
        self.output_size = output_size
        
    def __call__(self, sample):
        
        image = sample['img']
        image = cv2.resize(src=image, dsize=self.output_size)
        
        sample['img'] = image
        return sample
    
    
class Normalizer:
    
    def __init__(self, mean, std):
        self.std = std
        self.mean = mean
        
    def __call__(self, sample):
        
        img = sample['img']
        img = (img - self.mean) / self.std
        
        sample['img'] = img
        
        return sample
    
class ToTensor:
    
    def __init__(self, label_encoder):
        self.label_encoder = label_encoder
    
    def __cal__(self, sample):
        
        img, category_id = sample['img'], sample['category_id']
        
        img = torch.Tensor(img.transpose((2, 0, 1)))
        category_id = torch.Tensor(category_id)
        
        sample['img'] = img
        sample['category_id'] = category_id
        
        return sample
    

class Model_Helper:
    
    def __init__(self, model, optimizer, loss_func, loaders):
        self.model = model
        self.optimizer = optimizer
        self.loaders = loaders
    
    def __init_fine_tunning(self):
        pass
    
    def get_results(self):
        pass
    
    def test_model(self, loader_case):
        
        self.model.eval()
        with.torch.no_grad():
            for batch_index, batch in enumerate(self.loader_case[loader_case], 0):
                pass
    
    def train_model(self):
        
        for i in trange(10, desc='epochs'):
            for batch_index, batch in enumerate(self.loader_case['train'], 0):
                pass
    
    

        