In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
from pathlib import Path
from glob import glob
from os import walk

import shutil
import random


import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models
import torchvision.transforms as transforms

import imgaug as ia
from imgaug import augmenters as iaa

In [3]:
classes = ['i', 'ii', 'iii', 'iv', 'ix', 'v', 'vi', 'vii', 'viii', 'x']
CLEAN_DATA_FOLDER = 'manual_clean'
TRAIN_DATA = f'{CLEAN_DATA_FOLDER}/train'
VALID_DATA = f'{CLEAN_DATA_FOLDER}/val'


for label in classes:
    # Get Val label folder
    val_label_folder = os.path.join(CLEAN_DATA_FOLDER, 'val', label)
    train_label_folder = os.path.join(CLEAN_DATA_FOLDER, 'train', label)
         

    for image in os.listdir(val_label_folder):
        shutil.copy(f'{val_label_folder}/{image}', train_label_folder)

# Delete val folder
shutil.rmtree(os.path.join(CLEAN_DATA_FOLDER, 'val'), ignore_errors=True)

In [4]:
for label in classes:
    # copy from
    train_copy_folder = os.path.join(CLEAN_DATA_FOLDER, 'train', label)
    
    images_list = glob(os.path.join(train_copy_folder, "*.png"))

    train_label_folder = os.path.join(CLEAN_DATA_FOLDER, 'data', 'train', label)
    val_label_folder = os.path.join(CLEAN_DATA_FOLDER, 'data', 'val', label)
    
    os.makedirs(train_label_folder, exist_ok=True)
    os.makedirs(val_label_folder, exist_ok=True)
    
    train_images_list = random.sample(images_list, int(len(images_list) * 0.8))
    val_images_list = [x for x in images_list if x not in train_images_list]
         

    for file in train_images_list:
        shutil.copy(file, train_label_folder)
    for file in val_images_list:
        shutil.copy(file, val_label_folder)
    
# Delete old train folder
shutil.rmtree(os.path.join(CLEAN_DATA_FOLDER, 'train'), ignore_errors=True)

In [5]:
seq1 = iaa.Sequential([
                    iaa.Sometimes(
                        0.5,
                        iaa.GaussianBlur(sigma=(0, 0.5))
                    ),
                    iaa.Sometimes(
                        0.5,
                        iaa.Affine(
                        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                        rotate=(-30, 30),
                        cval=255)
                    ),
                    iaa.Alpha((0.0, 1.0), iaa.AllChannelsHistogramEqualization())
                ], random_order=False)


seq2 = iaa.Sequential([
                    iaa.Fliplr(0.5),
                    iaa.Sometimes(
                        0.5,
                        iaa.GaussianBlur(sigma=(0, 0.5))
                    ),
                    iaa.Sometimes(
                        0.5,
                        iaa.Affine(
                        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                        rotate=(-30, 30),
                        cval=255)
                    ),
                    iaa.Alpha((0.0, 1.0), iaa.AllChannelsHistogramEqualization())
                ], random_order=False)


seq3 = iaa.Sequential([
                    iaa.Flipud(0.5),
                    iaa.Sometimes(
                        0.5,
                        iaa.GaussianBlur(sigma=(0, 0.5))
                    ),
                    iaa.Sometimes(
                        0.5,
                        iaa.Affine(
                        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                        rotate=(-30, 30),
                        cval=255)
                    ),
                    iaa.Alpha((0.0, 1.0), iaa.AllChannelsHistogramEqualization())
                ], random_order=False)



seq4 = iaa.Sequential([
                    iaa.Flipud(0.5),
                    iaa.Fliplr(0.5),
                    iaa.Sometimes(
                        0.5,
                        iaa.GaussianBlur(sigma=(0, 0.5))
                    ),
                    iaa.Sometimes(
                        0.5,
                        iaa.Affine(
                        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                        rotate=(-30, 30),
                        cval=255)
                    ),
                    iaa.Alpha((0.0, 1.0), iaa.AllChannelsHistogramEqualization())
                ], random_order=False)




  warn_deprecated(msg, stacklevel=3)


In [6]:
tf1=transforms.Compose([
    np.asarray,
    seq1.augment_image,
    np.copy,
    transforms.ToTensor()
])

tf2=transforms.Compose([
    np.asarray,
    seq2.augment_image,
    np.copy,
    transforms.ToTensor()
])

tf3=transforms.Compose([
    np.asarray,
    seq3.augment_image,
    np.copy,
    transforms.ToTensor()
])

tf4=transforms.Compose([
    np.asarray,
    seq4.augment_image,
    np.copy,
    transforms.ToTensor()
])

In [8]:
train_path = "manual_clean/data/train"
val_path = "manual_clean/data/val"
classes_transforms = {'i':tf4, 'ii':tf4, 'iii':tf4, 'iv':tf1, 'ix':tf3, 'v':tf2, 'vi':tf1, 'vii':tf1, 'viii':tf1, 'x':tf4}

In [9]:
# custom dataset that allows us to control which augmentation will be done on each class
class MyDataset(Dataset):
    def __init__(self, data, target=None, transforms=classes_transforms):
        self.data = data
        self.transforms = transforms
        self.classes = data.classes
        
    def __getitem__(self, index):
        convert_tensor = transforms.ToTensor()
        img, y = self.data[index]
        label = self.classes[y]
        if self.transforms is not None:
            img = self.transforms[label](img)
        return img, y
    
    def __len__(self):
        return len(self.data)

In [10]:
my_augmented_train = MyDataset(datasets.ImageFolder(train_path),transforms=classes_transforms)
my_augmented_train_dataloader = torch.utils.data.DataLoader(my_augmented_train, batch_size=1, shuffle=True)

my_augmented_val = MyDataset(datasets.ImageFolder(val_path),transforms=classes_transforms)
my_augmented_val_dataloader = torch.utils.data.DataLoader(my_augmented_val, batch_size=1, shuffle=True)

In [11]:
classes_names = my_augmented_train.classes
train_labels_dict  = {'i':[], 'ii':[], 'iii':[], 'iv':[], 'ix':[], 'v':[], 'vi':[], 'vii':[], 'viii':[], 'x':[]}
val_labels_dict  = {'i':[], 'ii':[], 'iii':[], 'iv':[], 'ix':[], 'v':[], 'vi':[], 'vii':[], 'viii':[], 'x':[]}

folders = [train_path, val_path]
for mypath in folders:
    for i, (dirpath, dirnames, filenames) in enumerate(walk(mypath)):
        label = dirpath.split('/')[-1]
        if label not in classes_names:
            continue
        if "train" in mypath:
            train_labels_dict[label].extend(filenames)
        else:
            val_labels_dict[label].extend(filenames)

train_counter_dict = {key: 800 - len(value) for key, value in train_labels_dict.items()}
val_counter_dict = {key: 200 - len(value) for key, value in val_labels_dict.items()}

In [12]:
# Generate images for train data
flg = True
while flg:
    for batch_idx, (X, y) in enumerate(my_augmented_train_dataloader):
        label = classes_names[y]
        if train_counter_dict[label] > 0:
            labeled_dir = os.path.join(train_path, label)
            os.makedirs(labeled_dir, exist_ok=True)
            filepath = os.path.join(labeled_dir, f"{train_counter_dict[label]}--{batch_idx}.png")
            torchvision.utils.save_image(X, filepath)
            train_counter_dict[label] -= 1
    flg = sum(train_counter_dict.values())

# Generate images for val data
flg = True
while flg:
    for batch_idx, (X, y) in enumerate(my_augmented_val_dataloader):
        label = classes_names[y]
        if val_counter_dict[label] > 0:
            labeled_dir = os.path.join(val_path, label)
            os.makedirs(labeled_dir, exist_ok=True)
            filepath = os.path.join(labeled_dir, f"{val_counter_dict[label]}--{batch_idx}.png")
            torchvision.utils.save_image(X, filepath)
            val_counter_dict[label] -= 1
    flg = sum(val_counter_dict.values())