In [None]:
import os, sys
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision import dataset
from torch.utils.data import DataLoader

In [None]:
data_path = './data/'
save_path = './parameters/'
epochs = 30
batch_size = 64
learning_rate = 0.001

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
train_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

test_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

In [None]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

train_data = datasets.CIFAR10(
    root=data_path,
    train=True,
    download=False,
    transform=train_transform
)
test_data = datasets.CIFAR10(
    root=data_path,
    train=False,
    download=False,
    transform=test_transform
)

In [None]:
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.fc = nn.Linear(8*8*32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        flatten = out.view(out.size(0), -1)
        score = self.fc(flatten)
        return score

In [None]:
# 1-------------------------------------
# model = ConvNet()
# 2-------------------------------------
# model = models.resnet18()
# 3-------------------------------------
model_weight = torch.load(os.path.join(save_path, 'resnet18.pth'))
model = models.resnet18()
model.load_state_dict(model_weight)

num_classes = 10
model.fc = nn.Linear(512, num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
def train(epoch):
    print('\nEpoch: %d'%epoch)
    # model train mode로 전환
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    total = 0
    for (inputs, labels) in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        _, pred = torch.max(outputs, 1)
        total += outputs.size(0)
        running_acc += (pred == labels).sum().item()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    total_loss = running_loss / len(trainloader)
    total_acc = 100 * running_acc / total
    print(f'Train epoch : {epoch} loss : {total_loss} Acc : {total_acc}%')

In [None]:
def test(epoch):
    print('\nEpoch: %d'%epoch)
    # model eval mode로 전환
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    total = 0
    with torch.no_grad():
        for (inputs, labels) in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, pred = torch.max(outputs, 1)
            total += outputs.size(0)
            running_acc += (pred == labels).sum().item()
            loss = criterion(outputs, labels)
            running_loss += loss.item()
        total_loss = running_loss / len(testloader)
        total_acc = 100 * running_acc / total
        print(f'Test epoch : {epoch} loss : {total_loss} Acc : {total_acc}%')

In [None]:
for epoch in range(epochs):
    train(epoch)
    test(epoch)
    path = os.path.join(save_path, f'epoch_{epoch}.pth')
    torch.save(model.state_dict(), path)