In [4]:
%run unet_encoding.ipynb

import os
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import functional as tf

class train_dataset(Dataset):
    def __init__(self, transform=None):
        self.image_path = './Fashionista_Revised/training/images'
        self.image_file = [os.path.join(self.image_path, file) for file in os.listdir(self.image_path)]
        
        self.annotation_path = './Fashionista_Revised/training/annotations'
        self.annotation_file = [os.path.join(self.annotation_path, file) for file in os.listdir(self.annotation_path)]
        
        self.transform = transform
            
    def __getitem__(self, index):
        image = cv2.imread(self.image_file[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.pad(image, ((4, 4), (8, 8), (0, 0)), 'constant')
        
        annotation = cv2.imread(self.annotation_file[index])
        annotation = np.pad(annotation, ((4, 4), (8, 8), (0, 0)), 'constant')
        annotation = cv2.cvtColor(annotation, cv2.COLOR_BGR2GRAY).astype(np.int32)
                
        image = tf.to_pil_image(image, mode='RGB')
        annotation = tf.to_pil_image(annotation, mode='I')
        
        if self.transform:
            image = self.transform(image)
            annotation = self.transform(annotation)
        
        image = tf.to_tensor(image)
        image = tf.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        annotation = np.array(annotation)
        
        return image, annotation
        
    def __len__(self):
        return len(self.image_file)

class test_dataset(Dataset):
    def __init__(self, transform=None):
        self.image_path = './Fashionista_Revised/validation/images'
        self.image_file = [os.path.join(self.image_path, file) for file in os.listdir(self.image_path)]
        
        self.annotation_path = './Fashionista_Revised/validation/annotations'
        self.annotation_file = [os.path.join(self.annotation_path, file) for file in os.listdir(self.annotation_path)]
        
        self.transform = transform
            
    def __getitem__(self, index):
        image = cv2.imread(self.image_file[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.pad(image, ((4, 4), (8, 8), (0, 0)), 'constant')
        
        annotation = cv2.imread(self.annotation_file[index])
        annotation = np.pad(annotation, ((4, 4), (8, 8), (0, 0)), 'constant')
        annotation = cv2.cvtColor(annotation, cv2.COLOR_BGR2GRAY).astype(np.int32)
        
        image = tf.to_pil_image(image, mode='RGB')
        annotation = tf.to_pil_image(annotation, mode='I')
        
        if self.transform:
            image = self.transform(image)
            annotation = self.transform(annotation)
            
        image = tf.to_tensor(image)
        image = tf.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        annotation = np.array(annotation)
        
        return image, annotation
        
    def __len__(self):
        return len(self.image_file)