In [9]:
from typing import Tuple
import torch.utils.data as data
import torchvision.transforms as T
from pathlib import Path
import numpy as np
from PIL import Image
import torch

from max import locate_data, load_data
from bence import compare_images, augment_images_plain,augment_five_crop
from bence import augment_perspective

In [13]:

class SemanticSegmentationDataset(data.Dataset):
    def __init__(self, data_path: str):
        self.data_path = Path(data_path)
        (
            self.class_labels,
            self.X_train_paths,
            self.y_train_paths,
            self.X_test_paths,
            self.y_test_paths,
            self.X_val_paths,
            self.y_val_paths,
        ) = locate_data(self.data_path)


        # load original training set
        self.X_train_orig  = self.load_data(self.X_train_paths)
        self.y_train_orig = self.load_data(self.X_train_paths)

        print(type(self.X_train_orig))
        print(self.X_train_orig.shape)
        # flip all images and add to training set
        X_train_flipped, y_train_flipped = self.flip_images(self.X_train_orig, self.y_train_orig)
        self.X_train = np.concatenate((self.X_train_orig, X_train_flipped))
        self.y_train = np.concatenate((self.y_train_orig, y_train_flipped))

        # five crop images with no_augmentations = 10 and add to training set
        X_train_crop, y_train_crop = self.five_crop_images(self.X_train_orig, self.y_train_orig, no_augmentations=10)
        self.X_train = np.concatenate((self.X_train, X_train_crop))
        self.y_train = np.concatenate((self.y_train, y_train_crop))

        # augment_perspective no_aumentaions = 20 and add to training set
        X_train_persp, y_train_persp = self.augment_perspective(self.X_train_orig, self.y_train_orig, no_augmentations=20)
        self.X_train = np.concatenate((self.X_train, X_train_persp))
        self.y_train = np.concatenate((self.y_train, y_train_persp))

        # T.ColorJitter(brightness=(1,2),hue=(-0.2,0.5)) with augment_images_plain 5 times on 10 images and add to training set
        X_train_jittered, y_train_jittered = self.augment_images_plain(self.X_train_orig, self.y_train_orig, 
                                                                       T.ColorJitter(brightness=(1,2),hue=(-0.2,0.5)),
                                                                       no_augmentations=5)
        self.X_train = np.concatenate((self.X_train, X_train_jittered))
        self.y_train = np.concatenate((self.y_train, y_train_jittered))

        X_train_jittered, y_train_jittered = self.augment_images_plain(self.X_train_orig, self.y_train_orig, 
                                                                       T.ColorJitter(brightness=(1,2),hue=(-0.2,0.5)),
                                                                       no_augmentations=5)
        self.X_train = np.concatenate((self.X_train, X_train_jittered))
        self.y_train = np.concatenate((self.y_train, y_train_jittered))
        X_train_jittered, y_train_jittered = self.augment_images_plain(self.X_train_orig, self.y_train_orig, 
                                                                       T.ColorJitter(brightness=(1,2),hue=(-0.2,0.5)),
                                                                       no_augmentations=5)
        self.X_train = np.concatenate((self.X_train, X_train_jittered))
        self.y_train = np.concatenate((self.y_train, y_train_jittered))


    def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
        X = Image.open(self.X_train[index])
        y = Image.open(self.y_train[index])
        X = self.transforms(X)
        y = np.array(y, dtype=np.int64)
        return X, y

    def __len__(self) -> int:
        return len(self.X_train)

    def load_data(self,file_paths: list[Path]) -> np.ndarray:
        """Load images into memory with uint8 dtype (format PyTorch likes"""
        array= np.array([np.array(Image.open(x), dtype=np.uint8) for x in file_paths])
        tensor = torch.from_numpy(array).permute(0,3,1,2)
        return tensor
    
    def flip_images(self,X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Flip images horizontally"""
        X_flipped = torch.flip(X, dims=[3])
        y_flipped = torch.flip(y, dims=[3])
        return X_flipped, y_flipped
    
    def five_crop_images(self,X: torch.Tensor, y: torch.Tensor, no_augmentations = 10,crop_size = (360, 480)) -> Tuple[torch.Tensor, torch.Tensor]:
        """Five crop images"""
        five_cropper = T.FiveCrop(size=crop_size)
        resizer = T.Resize((720, 960))
        samples = np.random.randint(X.shape[0], size = no_augmentations)
        X_cropped = five_cropper(X[samples])
        y_cropped = five_cropper(y[samples])
        X_cropped = torch.cat([resizer(x) for x in X_cropped])
        y_cropped = torch.cat([resizer(x) for x in y_cropped])
        return X_cropped, y_cropped
    
    def augment_perspective(self,X: torch.Tensor, y: torch.Tensor, distortions_scale= 0.6,p = 1.0, no_augmentations = 20) -> Tuple[torch.Tensor, torch.Tensor]:
        """Augment perspective"""
        samples = np.random.randint(X.shape[0], size = no_augmentations)
        perspective = T.RandomPerspective(distortion_scale=distortions_scale, p=1.0)
        joint_tensor = torch.cat([X[samples], y[samples]])
        joint_tensor = perspective(joint_tensor)
        X_persp = joint_tensor[:no_augmentations]
        y_persp = joint_tensor[no_augmentations:]
        return X_persp, y_persp
    
    def augment_images_plain(self,X: torch.Tensor, y: torch.Tensor, transform: T, no_augmentations = 5) -> Tuple[torch.Tensor, torch.Tensor]:
        """Augment images with plain augmentation"""
        samples = np.random.randint(X.shape[0], size = no_augmentations)
        X_augmented = transform(X[samples])
        y_augmented = y[samples]

        return X_augmented, y_augmented
    


    



In [14]:
DATA_DIR = Path("../data/CamVid/")
dataset = SemanticSegmentationDataset(DATA_DIR)

<class 'torch.Tensor'>
torch.Size([369, 3, 720, 960])


NameError: name 'X_tensor' is not defined