In [None]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models

from imutils import paths
from pathlib import Path
import os, sys
import time
import copy

import pandas as pd
import matplotlib.pylab as plt
import numpy as np

# Local modules

from cub_tools.train import train_model
from cub_tools.visualize import imshow, visualize_model

In [None]:
# Script runtime options
root_dir = '../data'
data_dir = os.path.join(root_dir,'images')

In [None]:
# Get image generator object of bird images
imageFolder = torchvision.datasets.ImageFolder(data_dir)

In [None]:
imageFolder

In [None]:
image_fnames = pd.read_csv(filepath_or_buffer=os.path.join(root_dir,'images.txt'), header=None, delimiter=' ', names=['Img ID', 'file path'])
image_fnames['is training image?'] = pd.read_csv(filepath_or_buffer=os.path.join(root_dir,'train_test_split.txt'), header=None, delimiter=' ', names=['Img ID','is training image?'])['is training image?']

In [None]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [None]:
# Setup data loaders with augmentation transforms
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

In [None]:
image_datasets['train'].classes

In [None]:
# Setup the device to run the computations
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device::', device)

In [None]:
# Check the augmentations
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

# Load a pre-trained model on ImageNet and train

In [None]:
# Setup the model and optimiser

model_ft = models.resnet152(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, len(class_names))

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [None]:
model_ft = train_model(model=model_ft, criterion=criterion, optimizer=optimizer_ft, scheduler=exp_lr_scheduler, 
                       device=device, dataloaders=dataloaders, dataset_sizes=dataset_sizes, num_epochs=40)

In [None]:
save_model_full(model=model_ft, PATH='models/classification/caltech_birds_resnet152_full.pth')
save_model_dict(model=model_ft, PATH='models/classification/caltech_birds_resnet152_dict.pth')

In [None]:
visualize_model(model=model_ft, class_names=class_names, device=device, dataloaders=dataloaders)

# Load Pretrained Model from File

In [None]:
model_ft = models.resnet152(pretrained=True)
model_ft.fc = nn.Linear(model_ft.fc.in_features, len(class_names))

In [None]:
model_ft.load_state_dict(torch.load('models/classification/caltech_birds_resnet152_dict.pth'))
model_ft.eval()

In [None]:
visualize_model(model=model_ft, class_names=class_names, device=device, dataloaders=dataloaders)