In [13]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import os

# 1. CIFAR-10 데이터셋 다운로드 및 기본 설정
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 2. 데이터 증강: 이미지당 5번 증강 후 저장
augmented_data_path = "./augmented_data"
os.makedirs(augmented_data_path, exist_ok=True)

transform_augment = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

augmentation_times = 5

print("Augmenting data...")
for i, (img, label) in enumerate(trainset):
    img_pil = transforms.ToPILImage()(img)  # 원본 이미지를 PIL 이미지로 변환
    for j in range(augmentation_times):
        augmented_img = transform_augment(img_pil)
        save_image(augmented_img, os.path.join(augmented_data_path, f"{i}_{j}_augmented.png"))
    if i >= 9999:  # CIFAR-10 훈련 데이터 10,000장 기준
        break
print("Data augmentation complete.")

# 3. 기본 CNN 모델 정의
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = SimpleCNN()

# 4. 손실 함수 및 옵티마이저 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 5. 모델 훈련
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 10  # 훈련 에포크 수
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # 순전파
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 역전파 및 최적화
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(trainloader):.4f}")

# 6. 모델 평가
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on test dataset: {100 * correct / total:.2f}%")


Files already downloaded and verified
Files already downloaded and verified
Augmenting data...
Data augmentation complete.
Epoch [1/10], Loss: 1.5336
Epoch [2/10], Loss: 1.2464
Epoch [3/10], Loss: 1.1459
Epoch [4/10], Loss: 1.0839
Epoch [5/10], Loss: 1.0506
Epoch [6/10], Loss: 1.0149
Epoch [7/10], Loss: 0.9830
Epoch [8/10], Loss: 0.9539
Epoch [9/10], Loss: 0.9378
Epoch [10/10], Loss: 0.9167
Accuracy on test dataset: 68.24%
