In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms
from torchvision.models import vgg11
from torch.utils.data import Dataset, DataLoader,random_split
import numpy as np

from dataset import TrainDataset, TestDataset, img_transform, TrainDatasetAgeAugmentation
from avgMeter import AverageMeter

import copy
import random
import os 

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)



In [2]:
transform = img_transform()


In [3]:
train_root = '/opt/ml/input/data/train/images'
test_root = '/opt/ml/input/data/eval'


In [4]:
data = TrainDataset(train_root, input_size = 224, transform = transform)

In [5]:
train, val = random_split(data, [int(len(data)*0.8), len(data) - int(len(data)*0.8)])

In [6]:
batch_size = 128
lr = 0.001
num_epochs = 60
model_name = 'vgg11_batchsize' + str(batch_size) + '_lr' + str(lr).split('.')[1] + '_epoch' + str(num_epochs) + '_CenterCrop' + '_scheduler' + '_nofreeze'
log_dir = '/opt/ml/code/log/' + model_name + '.txt' 
save_dir = '/opt/ml/code/trained_models/' + model_name + '.pt'

In [7]:
print(model_name)

vgg11_batchsize128_lr001_epoch60_CenterCrop_scheduler_nofreeze


In [8]:
print(save_dir)

/opt/ml/code/trained_models/vgg11_batchsize128_lr001_epoch60_CenterCrop_scheduler_nofreeze.pt


In [9]:
train_loader = DataLoader(train, batch_size=batch_size, num_workers = 4, pin_memory=True, shuffle=True)
valid_loader = DataLoader(val, batch_size=batch_size, num_workers = 4, pin_memory=True, shuffle=False)

In [10]:
model = vgg11(True)
model.classifier[6] = nn.Linear(in_features=4096, out_features=18, bias=True)
model.cuda()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [11]:
# for param in model.features.parameters():
#     param.requires_grad = False

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scheduler = ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.1, patience = 5)

In [None]:
best_val_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())

with open(log_dir, 'w') as log:
    for epoch in range(num_epochs):
        for iter, (img, label) in enumerate(train_loader):
            optimizer.zero_grad()

            img, label = img.float().cuda(), label.cuda()

            pred_logit = model(img)

            loss = criterion(pred_logit, label)

            loss.backward()
            optimizer.step()

            pred_label = pred_logit.argmax(-1)
            acc = (pred_label == label).sum().float() / img.size(0)

            train_loss = loss.item()
            train_acc = acc
            
            
        valid_loss, valid_acc = AverageMeter(), AverageMeter()
                
        for img, label in valid_loader:
            img, label = img.float().cuda(), label.cuda()

            with torch.no_grad():
                pred_logit = model(img)


            loss = criterion(pred_logit, label)

            pred_label = pred_logit.argmax(-1)
            acc = (pred_label == label).sum().float() / img.size(0)

            valid_loss.update(loss.item(), len(img))
            valid_acc.update(acc, len(img))
        
        valid_loss = valid_loss.avg
        valid_acc = valid_acc.avg
        
        print("epoch [%3d/%3d] | Train Loss %.4f | Train Acc %.4f | Valid Loss %.4f | Valid Acc %.4f" %
            (epoch+1, num_epochs, train_loss, train_acc, valid_loss, valid_acc))
        
        if valid_acc > best_val_acc:
            best_val_loss = valid_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            
        scheduler.step(valid_loss)
        
        # Train Log Writing
#         log.write('%d,%.4f,%.4f,%.4f,%.4f\n'%(iter, train_loss, train_acc, valid_loss, valid_acc))
        log.write("epoch [%3d/%3d] | Train Loss %.4f | Train Acc %.4f | Valid Loss %.4f | Valid Acc %.4f \n" %
            (epoch+1, num_epochs, train_loss, train_acc, valid_loss, valid_acc))        

epoch [  1/ 60] | Train Loss 0.6847 | Train Acc 0.8125 | Valid Loss 1.0994 | Valid Acc 0.6312
epoch [  2/ 60] | Train Loss 0.3754 | Train Acc 0.8750 | Valid Loss 0.6133 | Valid Acc 0.8003
epoch [  3/ 60] | Train Loss 0.2901 | Train Acc 0.9375 | Valid Loss 0.3997 | Valid Acc 0.8698


In [None]:
model.__class__.__name__

In [None]:
model.load_state_dict(best_model_wts)
torch.save(model, save_dir)