In [1]:
import argparse
import os

import numpy as np
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.autograd as autograd
import torch.optim as optim
from tqdm import tqdm

import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models

from torch.utils.tensorboard import SummaryWriter

from dataset.dataset import AVADataset

from model.model import *

from dotenv import load_dotenv
load_dotenv()

True

In [2]:
IMAGE_PATH = os.getenv("AVA_IMAGE_PATH")
TRAIN_CSV = os.getenv("AVA_TRAIN_CSV")
TEST_CSV  = os.getenv("AVA_TEST_CSV")
VAL_CSV   = os.getenv("AVA_VAL_CSV")

BATCH_SIZE = 32
CONV_LEARNING_RATE = 5e-4
DENSE_LEARNING_RATE = 5e-3

DECAY = False

EPOCHS = 100
EARLY_STOPPING = 10

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter()

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

In [4]:
trainset = AVADataset(TRAIN_CSV, IMAGE_PATH, transform=train_transform)
valset   = AVADataset(VAL_CSV, IMAGE_PATH, transform=val_transform)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())
val_loader   = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count())

In [5]:
def train_model(model, optimizer, epoch, train_losses, avg_loss):
    # Train
    batch_losses = []
    pbar = tqdm(train_loader)
    for i, data in enumerate(pbar):
        images = data['image'].to(device)
        labels = data['annotations'].to(device).float()
        outputs = model(images)
        outputs = outputs.view(-1, 10, 1)
        
        optimizer.zero_grad()
        
        loss = emd_loss(labels, outputs)
        batch_losses.append(loss.item())
        
        loss.backward()
        optimizer.step()
        
        writer.add_scalar('batch train loss', loss.data[0], i + epoch * (len(trainset) // BATCH_SIZE + 1))
        
        # tqdm description
        pbar.set_description('Epoch: [{0}][{1}/{2}]\t' 'Batch Loss {loss:.4f}\t'.format(epoch, i, len(trainset) // BATCH_SIZE + 1, loss=loss.data[0]))
    
    avg_loss = np.sum(batch_losses) / (len(trainset) // BATCH_SIZE + 1)
    train_losses.append(avg_loss)
    print('Epoch: [{0}]\t' 'Mean Training EMD Loss {loss:.4f}\t'.format(epoch + 1, loss=avg_loss))

In [6]:
def validation(model, val_losses, avg_val_loss, avg_loss, epoch):
    # Validation
    batch_val_losses = []
    pbar = tqdm(val_loader)
    for data in pbar:
        images = data['image'].to(device)
        labels = data['annotations'].to(device).float()
        with torch.no_grad():
            outputs = model(images)
        outputs = outputs.view(-1, 10, 1)
        val_loss = emd_loss(labels, outputs)
        batch_val_losses.append(val_loss.item())
        
        # tqdm description
        pbar.set_description('Epoch: [{0}] Validation Loss {loss:.4f}\t'.format(epoch, loss=val_loss.data[0]))
        
    avg_val_loss = np.sum(batch_val_losses) / (len(valset) // BATCH_SIZE + 1)
    val_losses.append(avg_val_loss)
    print('Epoch: [{0}]\t' 'Mean Validation EMD Loss {loss:.4f}\t'.format(epoch + 1, loss=avg_val_loss))
    writer.add_scalars('epoch loss', {'train': avg_loss, 'val': avg_val_loss}, epoch + 1)

In [7]:
base_model = models.vgg16(weights = models.VGG16_Weights.DEFAULT)
model = NIMA(base_model).to(device)
optimizer = optim.SGD([
    {'params': model.features.parameters(), 'lr': CONV_LEARNING_RATE},
    {'params': model.classifier.parameters(), 'lr': DENSE_LEARNING_RATE}],
    momentum=0.9)

param_num = 0
for param in model.parameters():
    if param.requires_grad:
        param_num += param.numel()
print('Total trainable parameters: %d' % param_num)

Total trainable parameters: 14965578


In [8]:
count = 0
init_val_loss = float('inf')
train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    
    avg_loss = 0
    avg_val_loss = 0
    
    # train
    train_model(model, optimizer, epoch, train_losses, avg_loss)
    
    # validate
    validation(model, val_losses, avg_val_loss, avg_loss, epoch)
    
    if avg_val_loss < init_val_loss:
        init_val_loss = avg_val_loss
        
        # save model
        if not os.path.exists('checkpoints'):
            os.makedirs('checkpoints')
        torch.save(model.state_dict(), os.path.join('checkpoints', f'epoch_{epoch + 1}_val_loss_{avg_val_loss}.pth'))
        print('Model saved')
        count = 0
    elif avg_val_loss >= init_val_loss:
        count += 1
        if count == EARLY_STOPPING:
            print('Early stopping')
            break
    

  input = module(input)
Epoch: [0][7186/7187]	Batch Loss 0.0690	: 100%|██████████| 7187/7187 [22:06<00:00,  5.42it/s]


Epoch: [1]	Mean Training EMD Loss 0.0862	


Epoch: [0] Validation Loss 0.1070	: 100%|██████████| 397/397 [00:24<00:00, 16.12it/s]


Epoch: [1]	Mean Validation EMD Loss 0.0801	
Model saved


Epoch: [1][7186/7187]	Batch Loss 0.0622	: 100%|██████████| 7187/7187 [22:08<00:00,  5.41it/s]


Epoch: [2]	Mean Training EMD Loss 0.0783	


Epoch: [1] Validation Loss 0.1106	: 100%|██████████| 397/397 [00:24<00:00, 16.14it/s]


Epoch: [2]	Mean Validation EMD Loss 0.0774	


Epoch: [2][7186/7187]	Batch Loss 0.0659	: 100%|██████████| 7187/7187 [22:06<00:00,  5.42it/s]


Epoch: [3]	Mean Training EMD Loss 0.0765	


Epoch: [2] Validation Loss 0.0992	: 100%|██████████| 397/397 [00:24<00:00, 16.08it/s]


Epoch: [3]	Mean Validation EMD Loss 0.0759	


Epoch: [3][7186/7187]	Batch Loss 0.0757	: 100%|██████████| 7187/7187 [22:15<00:00,  5.38it/s]


Epoch: [4]	Mean Training EMD Loss 0.0755	


Epoch: [3] Validation Loss 0.0989	: 100%|██████████| 397/397 [00:24<00:00, 16.01it/s]


Epoch: [4]	Mean Validation EMD Loss 0.0764	


Epoch: [4][7186/7187]	Batch Loss 0.0605	: 100%|██████████| 7187/7187 [22:10<00:00,  5.40it/s]


Epoch: [5]	Mean Training EMD Loss 0.0747	


Epoch: [4] Validation Loss 0.1029	: 100%|██████████| 397/397 [00:24<00:00, 16.04it/s]


Epoch: [5]	Mean Validation EMD Loss 0.0744	


Epoch: [5][36/7187]	Batch Loss 0.0795	:   1%|          | 37/7187 [00:08<27:28,  4.34it/s]


KeyboardInterrupt: 