In [1]:
!mkdir -p ~/work/data_augmentation/data
!wget "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar" -P ~/work/data_augmentation/data
!tar -xf ~/work/data_augmentation/data/images.tar -C ~/work/data_augmentation/data/

--2026-01-25 12:17:25--  http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
Resolving vision.stanford.edu (vision.stanford.edu)... 171.64.68.10
Connecting to vision.stanford.edu (vision.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 793579520 (757M) [application/x-tar]
Saving to: ‘/root/work/data_augmentation/data/images.tar.2’


2026-01-25 12:18:18 (14.5 MB/s) - ‘/root/work/data_augmentation/data/images.tar.2’ saved [793579520/793579520]



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import random
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

# 데이터셋 다운

In [3]:
dataset_dir = "~/work/data_augmentation/data/Images/"

# 기본 ImageFolder (Transform은 아래 함수에서 처리)
full_dataset = ImageFolder(root=dataset_dir)
class_names = full_dataset.classes
num_classes = len(class_names)

print(f"총 클래스 개수: {len(class_names)}")
print(f"첫 5개 클래스: {class_names[:5]}")

총 클래스 개수: 120
첫 5개 클래스: ['n02085620-Chihuahua', 'n02085782-Japanese_spaniel', 'n02085936-Maltese_dog', 'n02086079-Pekinese', 'n02086240-Shih-Tzu']


# 데이터셋 분할

In [4]:
total_size = len(full_dataset)
train_size = int(0.583 * total_size)
test_size = total_size - train_size
ds_train_raw, ds_test_raw = random_split(full_dataset, [train_size, test_size])

train_indices = set(ds_train_raw.indices)
test_indices = set(ds_test_raw.indices)

# 데이터셋 확인

In [5]:
duplicates = train_indices.intersection(test_indices)

print(f"훈련 데이터 개수: {len(train_indices)}")
print(f"테스트 데이터 개수: {len(test_indices)}")
print(f"중복된 인덱스 개수: {len(duplicates)}")

if len(duplicates) == 0:
    print("결과: 인덱스 수준에서 중복이 전혀 없습니다.")
else:
    print(f"경고: {len(duplicates)}개의 인덱스가 중복되었습니다.")

훈련 데이터 개수: 11998
테스트 데이터 개수: 8582
중복된 인덱스 개수: 0
결과: 인덱스 수준에서 중복이 전혀 없습니다.


# 이미지 시각화

In [6]:
def show_random_samples(dataset, class_names, num_samples=5):
    plt.figure(figsize=(18, 5))

    if isinstance(dataset, list):
        images, labels = dataset[0]
        max_idx = len(images)
        indices = random.sample(range(max_idx), min(num_samples, max_idx))

        for i, idx in enumerate(indices):
            img = images[idx].permute(1, 2, 0).cpu().numpy()

            img = (img * 0.5) + 0.5
            img = np.clip(img, 0, 1)

            lbl_data = labels[idx]
            lbl_idx = lbl_data.argmax().item() if lbl_data.dim() > 0 else int(lbl_data.item())

            plt.subplot(1, num_samples, i+1)
            plt.imshow(img)
            plt.title(f"Tensor: {class_names[lbl_idx].split('-')[-1]}", fontsize=10)
            plt.axis('off')

    else:
        max_idx = len(dataset)
        indices = random.sample(range(max_idx), min(num_samples, max_idx))

        for i, idx in enumerate(indices):
            img, lbl = dataset[idx]

            plt.subplot(1, num_samples, i+1)
            plt.imshow(img)
            plt.title(f"PIL: {class_names[lbl].split('-')[-1]}", fontsize=10)
            plt.axis('off')

    plt.tight_layout()
    plt.show()

# 기본 Transform 정의

In [7]:
basic_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [8]:
augment1 = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2)
])

# Mixup & CutMix 용 Transform 함수들

In [9]:
def onehot(label, num_classes=120):
    if not isinstance(label, torch.Tensor):
        label = torch.tensor(label)

    return F.one_hot(label, num_classes=num_classes).float()

In [10]:
def get_clip_box(image_shape):
    h, w = image_shape[1], image_shape[2]

    x = torch.randint(0, w, (1,)).item()
    y = torch.randint(0, h, (1,)).item()

    cut_rat = torch.sqrt(1. - torch.rand(1)).item()
    cut_w = int(w * cut_rat)
    cut_h = int(h * cut_rat)

    x_min = max(0, x - cut_w // 2)
    y_min = max(0, y - cut_h // 2)
    x_max = min(w, x + cut_w // 2)
    y_max = min(h, y + cut_h // 2)

    return x_min, y_min, x_max, y_max

In [11]:
def mix_2_images(image_a, image_b, x_min, y_min, x_max, y_max):
    mixed_img = image_a.clone()
    mixed_img[:, y_min:y_max, x_min:x_max] = image_b[:, y_min:y_max, x_min:x_max]

    return mixed_img

In [12]:
def mix_2_labels(label_a, label_b, x_min, y_min, x_max, y_max, img_size, num_classes):
    mixed_area = (x_max - x_min) * (y_max - y_min)
    total_area = img_size * img_size
    ratio = mixed_area / total_area

    if label_a.ndim == 0 or label_a.size(0) != num_classes:
        label_a = F.one_hot(label_a.to(torch.int64), num_classes=num_classes).float()
    if label_b.ndim == 0 or label_b.size(0) != num_classes:
        label_b = F.one_hot(label_b.to(torch.int64), num_classes=num_classes).float()

    mixed_label = (1 - ratio) * label_a + ratio * label_b

    return mixed_label

In [13]:
def cutmix(images, labels, prob=1.0, img_size=224, num_classes=120):
    current_batch_size = images.size(0)

    mixed_imgs = []
    mixed_labels = []

    for i in range(current_batch_size):
        image_a, label_a = images[i], labels[i]

        if torch.rand(1).item() < prob:
            j = torch.randint(0, current_batch_size, (1,)).item()
            image_b, label_b = images[j], labels[j]

            x_min, y_min, x_max, y_max = get_clip_box(image_a.shape)

            mixed_imgs.append(mix_2_images(image_a, image_b, x_min, y_min, x_max, y_max))
            mixed_labels.append(mix_2_labels(label_a, label_b, x_min, y_min, x_max, y_max, img_size, num_classes))

        else:
            mixed_imgs.append(image_a)
            mixed_labels.append(F.one_hot(label_a.to(torch.int64), num_classes=num_classes).float())

    mixed_imgs = torch.stack(mixed_imgs)
    mixed_labels = torch.stack(mixed_labels)

    return mixed_imgs, mixed_labels

In [14]:
def mixup_2_images(image_a, image_b, label_a, label_b, num_classes=120):
    ratio = torch.rand(1).item()

    if not isinstance(label_a, torch.Tensor) or label_a.ndim == 0:
        label_a = F.one_hot(torch.tensor(label_a).to(torch.int64), num_classes=num_classes).float()
    if not isinstance(label_b, torch.Tensor) or label_b.ndim == 0:
        label_b = F.one_hot(torch.tensor(label_b).to(torch.int64), num_classes=num_classes).float()

    mixed_image = (1 - ratio) * image_a + ratio * image_b
    mixed_label = (1 - ratio) * label_a + ratio * label_b

    return mixed_image, mixed_label

In [15]:
def mixup(images, labels, img_size=224, num_classes=120):
    current_batch_size = images.size(0)

    mixed_imgs = []
    mixed_labels = []

    for i in range(current_batch_size):
        image_a, label_a = images[i], labels[i]
        j = torch.randint(0, current_batch_size, (1,)).item()
        image_b, label_b = images[j], labels[j]

        mixed_img, mixed_label = mixup_2_images(image_a, image_b, label_a, label_b, num_classes)

        mixed_imgs.append(mixed_img)
        mixed_labels.append(mixed_label)

    mixed_imgs = torch.stack(mixed_imgs).view(current_batch_size, 3, img_size, img_size)
    mixed_labels = torch.stack(mixed_labels).view(current_batch_size, num_classes)

    return mixed_imgs, mixed_labels

# Transform 함수들 최종 적용

In [16]:
def apply_normalize_on_dataset(dataset, is_test=False, batch_size=16, with_aug=False, with_cutmix=False, with_mixup=False):
    data_list = []

    for img, lbl in dataset:
        if not is_test and with_aug:
            img = augment1(img)
        img_t = basic_transform(img)

        data_list.append((img_t, lbl))

    dataloader = DataLoader(data_list, batch_size=batch_size, shuffle=not is_test)

    processed_batches = []

    for imgs, lbls in dataloader:
        lbls_oh = onehot(lbls, num_classes=num_classes)

        if not is_test and with_cutmix:
            imgs, lbls_oh = cutmix(imgs, lbls_oh)
        elif not is_test and with_mixup:
            imgs, lbls_oh = mixup(imgs, lbls_oh)

        processed_batches.append((imgs, lbls_oh))

    return processed_batches

In [None]:
ds_test = apply_normalize_on_dataset(ds_test_raw, is_test=True)
ds_no_aug = apply_normalize_on_dataset(ds_train_raw, is_test=False, with_aug=False)
ds_aug = apply_normalize_on_dataset(ds_train_raw, is_test=False, with_aug=True)
ds_mixup = apply_normalize_on_dataset(ds_train_raw, is_test=False, with_aug=True, with_mixup=True)
ds_cutmix = apply_normalize_on_dataset(ds_train_raw, is_test=False, with_aug=True, with_cutmix=True)

# 변환 결과 확인

In [None]:
show_random_samples(ds_train_raw, class_names)
show_random_samples(ds_aug, class_names)
show_random_samples(ds_mixup, class_names)
show_random_samples(ds_cutmix, class_names)

# 학습 함수 정의

In [None]:
def implementation(model, train_loader, test_loader, epochs=3):
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    history = {'train_acc': [], 'val_acc': []}
    train_loss_list = []
    train_accuracy_list = []
    test_loss_list = []
    test_accuracy_list = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            _, targets = labels.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(targets).sum().item()

        train_loss = train_loss / train_size
        train_acc = 100.0 * train_correct / train_total

        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item() * images.size(0)
                _, predicted = outputs.max(1)
                _, targets = labels.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(targets).sum().item()

        test_loss = test_loss / test_size
        test_acc = 100. * test_correct / test_total

        train_loss_list.append(train_loss)
        test_loss_list.append(test_loss)
        train_accuracy_list.append(train_acc)
        test_accuracy_list.append(test_acc)

        print(f"Epoch {epoch+1:>2d} - Train Loss: {train_loss:.3f}, Test Loss: {test_loss:.3f}, Train Acc: {train_acc:.3f}%, Val Acc: {test_acc:.3f}%")

    return (train_loss_list, test_loss_list, train_accuracy_list, test_accuracy_list)

In [None]:
EPOCHS = 20

# 학습 시작

In [None]:
# No Augmentation
model_no_aug = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model_no_aug.fc = nn.Linear(model_no_aug.fc.in_features, num_classes)
result_no_aug = implementation(model_no_aug, ds_no_aug, ds_test, epochs=EPOCHS)

In [None]:
# Basic Augmentation
model_aug = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model_aug.fc = nn.Linear(model_aug.fc.in_features, num_classes)
result_basic_aug = implementation(model_aug, ds_aug, ds_test, epochs=EPOCHS)

In [None]:
# Basic Augmentation + MixUp
model_mixup = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model_mixup.fc = nn.Linear(model_mixup.fc.in_features, num_classes)
result_mixup = implementation(model_mixup, ds_mixup, ds_test, epochs=EPOCHS)

In [None]:
# Basic Augmentation + CutMix
model_cutmix = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model_cutmix.fc = nn.Linear(model_cutmix.fc.in_features, num_classes)
result_cutmix = implementation(model_cutmix, ds_cutmix, ds_test, epochs=EPOCHS)

# 결과 시각화

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(result_no_aug[0], '-o', label='No Augment')
plt.plot(result_basic_aug[0], '-o', label='Basic Augment')
plt.plot(result_mixup[0], '-o', label='Mixup')
plt.plot(result_cutmix[0], '-o', label='CutMix')
plt.title('Train Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(result_no_aug[1], '-o', label='No Augment')
plt.plot(result_basic_aug[1], '-o', label='Basic Augment')
plt.plot(result_mixup[1], '-o', label='Mixup')
plt.plot(result_cutmix[1], '-o', label='CutMix')
plt.title('Train Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(result_no_aug[2], '-o', label='No Augment')
plt.plot(result_basic_aug[2], '-o', label='Basic Augment')
plt.plot(result_mixup[2], '-o', label='Mixup')
plt.plot(result_cutmix[2], '-o', label='CutMix')
plt.title('Train Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(result_no_aug[3], '-o', label='No Augment')
plt.plot(result_basic_aug[3], '-o', label='Basic Augment')
plt.plot(result_mixup[3], '-o', label='Mixup')
plt.plot(result_cutmix[3],'-o', label='CutMix')
plt.title('Validation Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.show()