In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import timm
import random
import numpy as np
import os
import time
import datetime

import albumentations as A
import ttach as tta
from albumentations.pytorch import ToTensorV2

# seed 고정
def seed_all(seed): 
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
    def __init__(self, root="./data/cifar10", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label
    
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

cuda:1


In [None]:
# 고정 하이퍼 파라미터
seed_number=0
epochs = 100
lr = 0.01

# test transform
transform_test = A.Compose([
        A.Resize(224,224),
        ToTensorV2()
])

testset = Cifar10SearchDataset(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=16)

model_names = ['resnet18']
augmentation_list = [A.Compose([        # albumentation 참고
                                    A.Resize(224,224),
                                    A.HorizontalFlip(p=0.5),
                                    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1),
                                    A.RandomGamma(gamma_limit=(90, 110)),
                                    A.OneOf([A.NoOp(), A.MultiplicativeNoise(), A.GaussNoise(), A.ISONoise()]),
                                    A.OneOf(
                                        [
                                            A.NoOp(p=0.8),
                                            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10),
                                            A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10)
                                        ],
                                        p=0.2,
                                    ),
                                    A.OneOf([A.ElasticTransform(), A.GridDistortion(), A.NoOp()]),
                                    ToTensorV2(),
                                ]),
                     
                     A.Compose([
                                    A.Resize(224,224),
                                    A.HorizontalFlip(p=0.5),
                                    A.OneOf([A.NoOp(), A.MultiplicativeNoise(), A.GaussNoise(), A.ISONoise()]),
                                    A.OneOf(
                                        [
                                            A.NoOp(p=0.8),
                                            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10),
                                            A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10)
                                        ],
                                        p=0.2,
                                    ),
                                    A.OneOf([A.ElasticTransform(), A.GridDistortion(), A.NoOp()]),
                                    ToTensorV2(),
                                ]),
                     
                     A.Compose([
                                    A.Resize(224,224),
                                    A.HorizontalFlip(p=0.5),
                                    A.OneOf(
                                        [
                                            A.NoOp(p=0.8),
                                            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10),
                                            A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10)
                                        ],
                                        p=0.2,
                                    ),
                                    A.OneOf(
                                        [
                                            A.MotionBlur(),
                                            A.OpticalDistortion(),
                                            A.GaussNoise()
                                        ],
                                        p=0.2,
                                    ),
                                    A.OneOf([A.ElasticTransform(), A.GridDistortion(), A.NoOp()]),
                                    ToTensorV2(),
                                ])
                    ]
optimizer_list = ['Adam', 'SGD', 'SGD_weight_decay']
batch_size_list = [64, 128]

# 실험 조건들 for문 돌면서 학습 및 모델 파라미터 저장
for model_name in model_names:
    for aug_idx, transform_train in enumerate(augmentation_list):
        for opt_idx, optimizer_type in enumerate(optimizer_list):
            for bat_idx, batch_size in enumerate(batch_size_list):
                
                # 시드 고정
                seed_all(seed_number)
                
                # 모델 불러오기
                model = timm.create_model(model_name, pretrained=True, num_classes=10)
                model = model.to(device)
                
                # 서로 다른 batch size, augmentation 적용한 trainset 불러오기
                trainset = Cifar10SearchDataset(root='./data', train=True, download=False, transform=transform_train)
                trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=16)

                criterion = nn.CrossEntropyLoss()
                
                # optimizer 선언
                if optimizer_type == 'Adam':
                    optimizer = optim.Adam(model.parameters(), lr=lr)
                elif optimizer_type == 'SGD':
                    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
                elif optimizer_type == 'SGD_weight_decay':
                    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
                scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

                # 모델 학습 시작
                start_time = time.time()
                for epoch in range(epochs):
                    temp_time = time.time()
                    model.train()
                    running_loss = 0.0
                    train_total = 0
                    for i, data in enumerate(trainloader, 0):
                        inputs, labels = data
                        inputs, labels = inputs.to(device).float(), labels.to(device)
                        optimizer.zero_grad()
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()
                        train_total += labels.size(0)
                        
                    # 모델 검증
                    model.eval()
                    correct = 0
                    total = 0
                    with torch.no_grad():
                        for data in testloader:
                            images, labels = data
                            images, labels = images.to(device).float(), labels.to(device)
                            outputs = model(images)
                            _, predicted = torch.max(outputs.data, 1)
                            total += labels.size(0)
                            correct += (predicted == labels).sum().item()

                    scheduler.step()
                    expected_time = (time.time()-temp_time)*(epochs-epoch) + (time.time()-start_time)
                    print(f"[{epoch + 1:03d}/{epochs}], Loss: {running_loss / train_total:4.3f}, Test accuracy: {100 * correct / total:4.2f} % --- \
    {str(datetime.timedelta(seconds=time.time()-start_time)).split('.')[0]} / {str(datetime.timedelta(seconds=expected_time)).split('.')[0]}")
                    
                test_acc = 100 * correct / total
                
                # Test Accuracy 출력
                print(f'Finished Training\nLast Accuracy of {aug_idx}_{opt_idx}_{batch_size:03d}: {test_acc:4.2f}!!')
                
                # 최종 학습한 모델의 가중치 각각 이름 다르게 저장 ({몇 번째 augmentation인지}_{몇 번째 optimizer인지}_{batch size}_{test accuracy})
                PATH = f'./models/{model_name}_aug{aug_idx}_opt{opt_idx}_{batch_size:03d}_{int(test_acc*100)}.pth'
                torch.save(model.state_dict(), PATH)