In [None]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.autograd import Variable

import torchvision
from torchvision import datasets, models, transforms

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

In [None]:
data_transforms = {'train':
                    torchvision.transforms.Compose([transforms.Resize((224, 224)),
                                                    transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
                   'val': 
                    torchvision.transforms.Compose([transforms.Resize((224, 224)),
#                                                     transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
#                                                     transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])}

In [None]:
image_datasets = {'train':
                  datasets.ImageFolder(r'C:\Users\RS_Vulcan\Documents\vada_pav_data\train', data_transforms['train']),
                  'val':
                  datasets.ImageFolder(r'C:\Users\RS_Vulcan\Documents\vada_pav_data\valid', data_transforms['val'])}

In [None]:
dataloaders = {'train':
               torch.utils.data.DataLoader(image_datasets['train'],
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=4),
               'val':
               torch.utils.data.DataLoader(image_datasets['val'],
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=4)}

class_names = image_datasets['train'].classes
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## transfer learning model
# model = models.resnet18(pretrained=True).to(device)
 
# for param in model.parameters():
#     param.requires_grad = False
 
# model_ft.fc = nn.Sequential(
#     nn.Linear(2048, 128),
#     nn.ReLU(inplace=True),
#     nn.Linear(128, 2)).to(device)


model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ft.fc.parameters())

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
def train_model(model, criterion, optimizer, num_epochs=3):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)
 
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
 
            running_loss = 0.0
            running_corrects = 0
 
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
 
                outputs = model(inputs)
                loss = criterion(outputs, labels)
 
                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
 
                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
 
            epoch_loss = running_loss / len(image_datasets[phase])
            epoch_acc = running_corrects.double() / len(image_datasets[phase])
 
            print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                        epoch_loss,
                                                        epoch_acc))
    return model
 
model_trained = train_model(model_ft, criterion, optimizer, num_epochs=50)

In [None]:
#save model
torch.save(model_trained.state_dict(), r'C:\Users\RS_Vulcan\Documents\vada_pav_data\models\pytorch_res18_weights_2.h5')

In [None]:
## model loaded above already
# model_ft = models.resnet18(pretrained=True)
# num_ftrs = model_ft.fc.in_features
# model_ft.fc = nn.Linear(num_ftrs, 2)

# model_ft = model_ft.to(device)

model_ft.load_state_dict(torch.load(r'C:\Users\RS_Vulcan\Documents\vada_pav_data\models\pytorch_res18_weights_2.h5'))

In [None]:
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure(figsize=(40, 30))
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(10, 5, j+1)   #ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])
                fig = plt.figure(figsize=(40, 30))

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)
        plt.show()

In [None]:
## change to model_trained if visualizing after training, else model_ft
visualize_model(model_ft)