In [1]:
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import time
import os
from tqdm import tqdm

In [2]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Original Dataset Load

In [3]:
# 데이터셋이 위치한 경로
data_dir = '/content/drive/MyDrive/Parrot DL Project'

In [4]:
# 이미지 변환
data_transforms = {
    'train': transforms.Compose([ # 원본 이미지 그대로
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [5]:
# 이미지 데이터셋 불러오기
full_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
# 클래스명
class_names = full_dataset.classes

In [6]:
print("Total Data:",len(full_dataset))

Total Data: 42388


## Augmented Dataset Load

In [14]:
# 이미 augmented_dataset이 있다는 가정하에 실행시키는 부분

# augmented_dataset = datasets.ImageFolder(os.path.join(data_dir, 'augmented'),
#                                           data_transforms['train'])
# class_names_aug = augmented_dataset.classes

In [16]:
# print("Total Data:",len(augmented_dataset))

Total Data: 54


In [20]:
import numpy as np
import os
import cv2
from PIL import Image
import albumentations as A
import matplotlib.pyplot as plt

def delete_augmented(data_dir, num_class, class_names):
  prefixes = ['augmented_']
  for i in range(num_class):
    # file_path = data_dir + '/train/' + class_names[i] # 기존 파일에 생성한 경우
    file_path = data_dir + '/augmented/' + class_names[i]
    file_names = os.listdir(file_path) # 각 클래스별 폴더 내의 파일들의 이름

    print("current_directory:", class_names[i]) # 진행 상황을 확인
    for file_name in file_names:
      for prefix in prefixes:
        if file_name.startswith(prefix): # 해당 prefix로 시작하는 파일이면
          file_path_to_delete = os.path.join(file_path, file_name)
          os.remove(file_path_to_delete) # 삭제하기

albumentations_transform = A.Compose([
    #이미지 조정은 전체 dataset을 구성할 때 적용하므로 생략
    #A.Resize(256, 256),
    #A.RandomCrop(224, 224),

    # 1. shift/flip - # 90도 이내에서 회전 및 shift, scale 진행
    A.ShiftScaleRotate(shift_limit=0.2, scale_limit=(0.3, 0.7), rotate_limit=90, p=1, border_mode=cv2.BORDER_REPLICATE),

    # 2. noise
    A.OneOf([A.Blur(p=1, blur_limit=(3, 7)), # 흐리게 만들기
             A.CoarseDropout(p=1, max_holes=8, max_height=8, max_width=8, min_holes=8, min_height=8, min_width=8), # 검은색 직사각형 삽입
             A.Downscale(p=1, scale_min=0.25, scale_max=0.25, interpolation=0), # 화질 낮추기
             A.GaussNoise(p=1, mean=0, var_limit=(10.0, 50.0)), # 가우시안 노이즈 추가
    ], p=1),
])

def augmentation(data_dir, num_class, class_names, new_per_origin):
  # new_per_origin: 각 원본 이미지당 생성할 이미지의 개수
  total = 0
  for i in range(num_class):
    file_path = data_dir + '/train/' + class_names[i] # 각 클래스별 폴더 경로 지정
    file_names = os.listdir(file_path) # 각 클래스별 폴더 내의 파일들의 이름
    total_origin_image_num = len(file_names)

    new_file_path = data_dir + '/augmented/' + class_names[i]
    if not os.path.exists(new_file_path):
      os.makedirs(new_file_path)

    print("current_directory:", class_names[i]) # 진행 상황을 확인

    augment_cnt = 0
    images_to_use = total_origin_image_num  # 해당 폴더에서 Augmentation에 사용할 이미지 개수
    for idx in range(images_to_use):
      file_name = file_names[idx]
      origin_image_path = file_path + '/' + file_name # 해당 파일의 경로를 설정
      origin_image = plt.imread(origin_image_path) # 이미지로 불러오기
      for _ in range(new_per_origin):
        # albumentations_transform으로 변환된 이미지 생성
        transformed = albumentations_transform(image=origin_image)
        transformed_image = transformed['image']
        transformed_image = Image.fromarray(np.uint8(transformed_image))

        # transformed_image.save(file_path + '/augmented_' + str(augment_cnt) + '.jpg') # 기존 파일에 저장한 경우
        transformed_image.save(new_file_path + '/augmented_' + str(augment_cnt) + '.jpg') # 새로 분리해서 저장
        augment_cnt += 1
    total += augment_cnt
  print("Augmented Data:", total)

In [18]:
delete_augmented('/content/drive/MyDrive/Parrot DL Project', 200, class_names)

current_directory: 001.Black_footed_Albatross


In [21]:
augmentation('/content/drive/MyDrive/Parrot DL Project', 200, class_names, 1) # 200개의 클래스에 대해 각 이미지당 1개씩 augmented image 생성

current_directory: 001.Black_footed_Albatross
current_directory: 002.Laysan_Albatross
current_directory: 003.Sooty_Albatross
current_directory: 004.Groove_billed_Ani
current_directory: 005.Crested_Auklet
current_directory: 006.Least_Auklet
current_directory: 007.Parakeet_Auklet
current_directory: 008.Rhinoceros_Auklet
current_directory: 009.Brewer_Blackbird
current_directory: 010.Red_winged_Blackbird
current_directory: 011.Rusty_Blackbird
current_directory: 012.Yellow_headed_Blackbird
current_directory: 013.Bobolink
current_directory: 014.Indigo_Bunting
current_directory: 015.Lazuli_Bunting
current_directory: 016.Painted_Bunting
current_directory: 017.Cardinal
current_directory: 018.Spotted_Catbird
current_directory: 019.Gray_Catbird
current_directory: 020.Yellow_breasted_Chat
current_directory: 021.Eastern_Towhee
current_directory: 022.Chuck_will_Widow
current_directory: 023.Brandt_Cormorant
current_directory: 024.Red_faced_Cormorant
current_directory: 025.Pelagic_Cormorant
current_di

In [22]:
# original dataset
full_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
class_names = full_dataset.classes

# augmented dataset
augmented_dataset = datasets.ImageFolder(os.path.join(data_dir, 'augmented'),
                                          data_transforms['train'])
class_names_aug = augmented_dataset.classes

print(len(full_dataset), len(augmented_dataset)) # 2배로 증가한 것을 확인할 수 있다

10597 10597
