# VGG

In [1]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
import torch.nn as nn
from torchvision.models.vgg import vgg16

model = vgg16(pretrained=True)

fc = nn.Sequential(
    nn.Linear(512*7*7, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 10)
)

model.classifier = fc
model.to(device)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 192MB/s]


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [5]:
import tqdm
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize, RandomHorizontalFlip, RandomCrop, Normalize
from torch.utils.data.dataloader import DataLoader
from torch.optim import Adam

transformers = Compose([
    Resize(224),
    RandomCrop((224, 224), padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
])

train_data = CIFAR10(root='./', train=True, download=True, transform=transformers)
test_data = CIFAR10(root='./', train=False, download=True, transform=transformers)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

100%|██████████| 170M/170M [02:01<00:00, 1.40MB/s]


In [None]:
# 학습
learning_rate = 1e-3
optim = Adam(model.parameters(), lr=learning_rate)

for epoch in range(50):
  for data, label in tqdm.tqdm(train_loader):
    optim.zero_grad()
    preds = model(data.to(device))

    loss = nn.CrossEntropyLoss()(preds, label.to(device))
    loss.backward()
    optim.step()

  if (epoch+1) % 10 == 0:
    print(f'epoch: {epoch+1}, loss: {loss.item():.4f}')

torch.save(model.state_dict(), 'CIFA10_pretrained.pt')

100%|██████████| 1563/1563 [11:48<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:46<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:45<00:00,  2.22it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:45<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:46<00:00,  2.21it/s]


epoch: 10, loss: 2.3023


100%|██████████| 1563/1563 [11:46<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
100%|██████████| 1563/1563 [11:47<00:00,  2.21it/s]
 85%|████████▌ | 1335/1563 [10:04<01:43,  2.21it/s]

In [None]:
# 평가
model.load_state_dict(torch.load("CIFAR10_pretrained.pt", map_location=device))

cnt_corr = 0

with torch.no_grad():
  for data, label in test_loader:
    output = model(data.to(device))
    preds = output.data.max(1)[1]
    corr = preds.eq(label.to(device).data).sum().item ()
    cnt_corr += corr

print(f'accuracy: {cnt_corr / len(test_data):.4f}')