# Modules and Library

In [1]:
from __future__ import print_function, division
import os
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
from torchvision.transforms.transforms import Normalize, ToTensor
#imoprt labels.py
from labels import Labels

# 2. Classes

## 2.1 Data Loader

In [2]:
class CustomDataset(Dataset):
    def __init__(self, root="dataset", transform=None):
        self.imagespath  = glob.glob('dataset/images/*.png') #get the image's directory, I'm familiar with glob
        self.labelspath = glob.glob('dataset/labels/*.png')  #get the label's directory, I'm familiar with glob
        self.root = root
        
        # try to print the path
        print("Img Path: ", self.imagespath)
        print("Label Path: ", self.labelspath)
        self.transform = transform
    
    def __len__(self) -> int:        
        return len(self.images)
    
    def __getitem__(self,index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        #load image Dataset
        img_data = self.imagespath[index]
        image = Image.open(img_data)
        
        #load label Dataset
        label_data = self.labelspath[index]
        label = Image.open(label_data)
        
        #dictionary between image and label
        sample_data = {"image":image, "label":label}
        
        if self.transform:
            sample_data = self.transform(sample_data)
        return sample_data
            
    
    def shows(self, image, label):
        # Complete the implementation.
        pass

## 2.2 Horizontal Flip

In [3]:
class HorizontalFlip():
    def __init__(self)-> None:
        pass
    
    def __call__(self, sample) -> dict:
        image = sample["image"]
        label = sample["label"]
        transformedImage = transforms.RandomHorizontalFlip(p=1)(image) #do a random horizontal flip (probability 1, means all image will be flipped)
        transformedLabel = transforms.RandomHorizontalFlip(p=1)(label) #do a random horizontal flip (probability 1, means all label will be flipped)
        
        return {"image":transformedImage, "label":transformedLabel}
        

## 2.3 Random Crop

In [4]:
class RandomCrop():
    def __init__(self) -> None:
            self.output_size = (512, 512)
    
    def __call__(self,sample) -> dict:
        image = sample["image"]
        label = sample["label"]

        w, h = image.size               #get image size
        new_h, new_w = self.output_size # determine new image size
        
        left = np.random.randint(0, w - new_w)  #create horizontal random location for cropping
        top = np.random.randint(0, h - new_h)   #create vertical random location for cropping
        print("image height :",h,"image weight :",w)
        print("height crop location :",top,"weight crop location :",left)
        
        imagecropped = transforms.functional.crop(image, top,left,new_h,new_w) #crop image using the defined parameter before
        labelcropped = transforms.functional.crop(label, top,left,new_h,new_w) #crop label using the same defined parameter before

        return {"image": imagecropped, "label": labelcropped} #return the results

## 2.4. Normalize

In [5]:
class imageNormalized():
    def __init__(self) -> None:
        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)
        
    def __call__(self,sample) -> dict:
        image = sample["image"]
        label = sample["label"]
        trans = transforms.Compose([
            transforms.ToTensor(),
            Normalize(self.mean, self.std)
        ])
        imageNormalized = trans(image)
        labelTensored = transforms.ToTensor()(label)
        
        return {"image":imageNormalized, "label":labelTensored}

## Denormalize

In [6]:
class imageDenormalize():    
    def __init__(self) -> None:
        self.mean = (-0.485/0.229, -0.456/0.224, -0.406/0.255)
        self.std = (1/0.229, 1/0.224, 1/0.255)

    def __call__(self, sample) -> dict:
        image = sample["image"]
        label = sample["label"]
        trans = Normalize(self.mean, self.std)
        imagedenorm = trans(image)
        return {"image":imagedenorm, "label":label}

# 3. Calling The Function Class

In [7]:
alltransformation = transforms.Compose([imageNormalized(),imageDenormalize])
load = CustomDataset(transform = alltransformation)

# load = CustomDataset(transform = imageDenormalize())

trial = load[1]

# plt.imshow(transforms.ToPILImage()(trial["image"]))
# plt.figure()
# plt.imshow(transforms.ToPILImage()(trial["label"]))


Img Path:  ['dataset/images/78_iff_12.png', 'dataset/images/udu112_st55a.png', 'dataset/images/xz_77i.png', 'dataset/images/uff_987_stw.png', 'dataset/images/a_4564.png']
Label Path:  ['dataset/labels/78_iff_12.png', 'dataset/labels/udu112_st55a.png', 'dataset/labels/xz_77i.png', 'dataset/labels/uff_987_stw.png', 'dataset/labels/a_4564.png']


TypeError: __init__() takes 1 positional argument but 2 were given