In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
import torch
import torchvision
from torchvision import utils
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
torch.manual_seed(777)
if device == "cuda":
    torch.cuda.manual_seed_all(777)

In [None]:
my_transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

train_data = torchvision.datasets.CIFAR10(root='./data',
                                          train=True,
                                          download=True,
                                          transform = my_transform)
val_data = torchvision.datasets.CIFAR10(root='./data',
                                          train=False,
                                          download=True,
                                          transform = my_transform)
train_loader = DataLoader(train_data, batch_size=512, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=4,num_workers=2)


In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

def imshow(img):
    img = img/2+0.5
    plt.imshow(img.permute(1,2,0).numpy())
    plt.show()

dataiter = iter(train_loader)
images, labels = dataiter.next()
x_grid = utils.make_grid(images[:512], nrow=32, padding=2)

imshow(x_grid)


In [None]:
from ipywidgets import interact

@interact(idx=(0,train_data.data.shape[0]-1))
def showImg(idx):
    plt.imshow(train_data.data[idx])
    plt.show()

In [None]:
import torchvision.models.vgg as vgg

In [None]:
cfg = [32,32,'M',64,64,128,128,128,'M',256,256,256,512,512,512,'M']

In [None]:
class VGG(nn.Module):
    def __init__(self, features, num_classes=1000,init_weights=True):
        super().__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*4*4, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096,4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096,num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x=self.features(x)
        x=x.view(x.size(0), -1)
        x=self.classifier(x)
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [None]:
vgg16=VGG(vgg.make_layers(cfg),10,True).to(device)

In [None]:
a = torch.Tensor(1,3,32,32).to(device)
out = vgg16(a)
print(out)

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(vgg16.parameters(), lr = 0.005, momentum=0.9)

lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
# 5번의 step마다 learning rate에 gamma값을 곱해준다.

In [None]:
import copy
from tqdm import tqdm
import time
epochs = 50

best_loss = float('inf')
best_model_wts = copy.deepcopy(vgg16.state_dict())

for epoch in range(epochs):
    running_loss = 0.0
    lr_sche.step()
    for i, data in tqdm(enumerate(train_loader, 0)):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = vgg16(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    if( running_loss < best_loss ):
        best_loss = running_loss
        best_model_wts = copy.deepcopy(vgg16.state_dict())
        torch.save(best_model_wts,'./weight/weights.pth')
        

In [None]:
new_model = VGG(vgg.make_layers(cfg),10,True)
new_model.to(device)
new_model.load_state_dict(torch.load('./weight/weights.pth'))

In [None]:
val_iter = iter(val_loader)
images, labels = val_iter.next()
outputs = new_model(images.to(device))
outputs = torch.argmax(outputs, dim=1)
list_answers = [ outputs == labels.to(device).view_as(outputs) ]
for i in range(outputs.size(0)):
    print( classes[labels[i]], classes[outputs[i]])
