<h1>Data Augmentation for Gastrointestinal Tract Image</h1>

In [1]:
import sys

sys_dir = './'

sys.path.append(sys_dir)

In [2]:
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
import random
import gc

img_size = 224

def augment_and_add_to_dataset(original_dataset, augmentation_transforms_list, output_dir, num_augmented_images_per_original=1):
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for class_name in original_dataset.classes:
        class_path = os.path.join(output_dir, class_name)
        if not os.path.exists(class_path):
            os.makedirs(class_path)
        
        count_file = 1

        class_images = [img_path for img_path, label in original_dataset.imgs if original_dataset.classes[label] == class_name]
        for aug_list_idx in range(len(augmentation_transforms_list)):
            
            aug_abnormal_img_list = []
            aug_normal_img_list = []
            
            for i, image_path in enumerate(class_images):
                original_image = original_dataset.loader(image_path)

                # ทำหลาย transform และบันทึกไฟล์
                for j in range(num_augmented_images_per_original):
                    augmented_image = original_image
                    augmented_image = augmentation_transforms_list[aug_list_idx](augmented_image)
                        
                    augmented_image_path = os.path.join(class_path, f'GIT_{class_name}_({count_file}).jpg')
                    if aug_list_idx == 0:
                        save_image(augmented_image, augmented_image_path)
                        count_file = count_file + 1
                    else:
                        if class_name == "Abnormal":
                            aug_abnormal_img_list.append(augmented_image)
                        else:
                            aug_normal_img_list.append(augmented_image)
            
            if aug_list_idx != 0:
                
                random.seed(254)
                rand_indices = random.sample(range(len(aug_abnormal_img_list) if class_name == "Abnormal" else len(aug_normal_img_list)), len(aug_abnormal_img_list) if class_name == "Abnormal" else len(aug_normal_img_list))
                print(f"Random Index ({aug_list_idx}) : {rand_indices[:5]}")
                
                for img in range(1084):
                    
                    augmented_image_path = os.path.join(class_path, f'GIT_{class_name}_({count_file}).jpg')
                    save_image(aug_abnormal_img_list[rand_indices[img]] if class_name == "Abnormal" else aug_normal_img_list[rand_indices[img]], augmented_image_path)
                    count_file = count_file + 1

        del aug_abnormal_img_list, aug_normal_img_list
        gc.collect()
                    

# กำหนดตำแหน่งโฟลเดอร์ของ Dataset
data_dir = sys_dir + "FinalDataset_Gastrointestinal_Tract_V1"

# กำหนดการทำ image augmentation ด้วยหลาย transforms
augmentation_transforms_list = []

# Original Image with Resize
augmentation_transforms_list.append(transforms.Compose([
    transforms.Resize((img_size,img_size)),transforms.ToTensor(),
]))
# 90 degrees
augmentation_transforms_list.append(transforms.Compose([
    transforms.Resize((img_size,img_size)),transforms.RandomRotation(degrees=(90, 90)),transforms.ToTensor(),
]))
# 180 degrees
augmentation_transforms_list.append(transforms.Compose([
    transforms.Resize((img_size,img_size)),transforms.RandomRotation(degrees=(180, 180)),transforms.ToTensor(),
]))
# 270 degrees
augmentation_transforms_list.append(transforms.Compose([
    transforms.Resize((img_size,img_size)),transforms.RandomRotation(degrees=(270, 270)),transforms.ToTensor(),
]))
# FlipV
augmentation_transforms_list.append(transforms.Compose([
    transforms.Resize((img_size,img_size)),transforms.RandomVerticalFlip(p=1.0),transforms.ToTensor(),
]))
# FlipH
augmentation_transforms_list.append(transforms.Compose([
    transforms.Resize((img_size,img_size)),transforms.RandomHorizontalFlip(p=1.0),transforms.ToTensor(),
]))

# โหลด Dataset โดยใช้ ImageFolder
original_dataset = ImageFolder(root=data_dir, transform=None)

# ทำ image augmentation และเพิ่มข้อมูลใน Dataset
output_dir = './Dataset/AugmentationDataset_GIT_V1_(21680)'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

augment_and_add_to_dataset(original_dataset, augmentation_transforms_list, output_dir)


Random Index (1) : [3070, 4104, 2055, 3359, 2784]
Random Index (2) : [3070, 4104, 2055, 3359, 2784]
Random Index (3) : [3070, 4104, 2055, 3359, 2784]
Random Index (4) : [3070, 4104, 2055, 3359, 2784]
Random Index (5) : [3070, 4104, 2055, 3359, 2784]
Random Index (1) : [3070, 4104, 2055, 3359, 2784]
Random Index (2) : [3070, 4104, 2055, 3359, 2784]
Random Index (3) : [3070, 4104, 2055, 3359, 2784]
Random Index (4) : [3070, 4104, 2055, 3359, 2784]
Random Index (5) : [3070, 4104, 2055, 3359, 2784]
