In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import time
import os
import copy

from PIL import Image
import sys
import pandas as pd
from skimage import io, transform
from sklearn.metrics import confusion_matrix

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
class HabDataset(Dataset):
    """HAB Images dataset."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
        """
        self.images_data_frame = pd.read_csv(csv_file)
        self.class_names = sorted(self.images_data_frame.iloc[:, 1].unique())
        self.class_weights = torch.FloatTensor(1 / self.images_data_frame.iloc[:, 1].value_counts(normalize=True).sort_index())
        self.transform = transform

    def __len__(self):
        return len(self.images_data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = self.images_data_frame.iloc[idx, 0]
        
        image = Image.open(img_path)
        label_name = self.images_data_frame.iloc[idx, 1]
        label = self.class_names.index(label_name)
         
        if self.transform:
             image = self.transform(image)

        return image, label

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [4]:
data_dir = "/data6/SuryaKrishnan/raw_data"

image_datasets = {x: HabDataset(os.path.join(data_dir, (x + '.csv')),
                                          data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].class_names
print(class_names)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

['Acantharea', 'Aggregate', 'Akashiwo', 'BadImageSegmentation', 'Bubble', 'Ceratium falcatiforme fusus pair', 'Ceratium falcatiforme fusus single', 'Ceratium furca pair', 'Ceratium furca side', 'Ceratium furca single', 'Ceratium other pair', 'Ceratium other single', 'Chaetoceros socialis', 'Chattonella', 'Ciliates', 'Cochlodinium Alexandrium Gonyaulax Gymnodinium chain', 'Cochlodinium Alexandrium Gonyaulax Gymnodinium pair', 'Curved diatom chain chaetoceros', 'Curved diatom chain guinardia', 'Dinophysis pair', 'Dinophysis single', 'Eucampia chain', 'Eucampia pair', 'Gymnodinium', 'Gyrodinium', 'Kelp Fragment', 'Licmophora', 'Lingulodinium', 'Marine Lashes', 'Nauplii', 'Phaeocystis', 'Polykrikos', 'Prorocentrum', 'Prorocentrum gracile', 'Protoperidinium', 'Protoperidinium feeding', 'Protoperidinium flipped', 'Pseudo nitzschia chain', 'Rhizosolenia or Proboscia', 'Sand', 'Straight diatom chains chaetoceros', 'Straight diatom chains hemiaulus', 'Straight diatom chains leptocylindrus epiph

In [5]:
# Function definition to train our model

def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = sys.maxsize
    best_acc = 0.0
    best_mean_class_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            running_preds = np.array([])
            running_labels = np.array([])

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device) 

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    running_preds = np.append(running_preds, preds.cpu())
                    running_labels = np.append(running_labels, labels.cpu())

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics to keep a tab on Loss and number of correct predictions
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)


            # Calculating Model Performance Statistics 
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]        
            conf_matrix = confusion_matrix(running_labels, running_preds)
            epoch_mean_class_acc = np.array(conf_matrix.diagonal() / conf_matrix.sum(axis=1)).mean()

            # Logging Model Statistics to TensorBoard
            if phase == "train":
                scheduler.step()

            print('{} Loss: {:.4f} Acc: {:.4f} Mean Class Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, epoch_mean_class_acc))

            # Storing best model's weights
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_acc = epoch_acc
                best_mean_class_acc = epoch_mean_class_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:4f}'.format(best_loss))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('Best val Mean Class Acc: {:4f}'.format(best_mean_class_acc))

    stats_dict={"Loss" : best_loss, "Accuracy" : best_acc, "Mean Class Accuracy" : best_mean_class_acc}

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, stats_dict

In [6]:
model_conv = torchvision.models.resnet18(pretrained=True)

# freeze weights
for param in model_conv.parameters():
    param.requires_grad = False

# Adding custom fully connected layer
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 51)

model_conv = model_conv.to(device)

In [7]:
# setting hyper params for model training

params_dict = {"lr_val":0.01, "momentum_val":0.9, "step_val":7, "gamma_val":0.1, "total_epochs":25, 
    "weights":image_datasets['train'].class_weights.to(device)}

# Handling Class Imbalance through Weighted Cross-Entropy Loss Function
criterion = nn.CrossEntropyLoss(weight = params_dict["weights"])

# Selecting the Optimizer and Scheduler to use
optimizer_conv = optim.SGD(model_conv.parameters(), lr=params_dict["lr_val"], momentum=params_dict["momentum_val"])
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=params_dict["step_val"], gamma=params_dict["gamma_val"])

In [8]:
# Train the model
model_conv, stats_dict = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=params_dict["total_epochs"])

Epoch 0/24
----------
train Loss: 2.9780 Acc: 0.5333 Mean Class Acc: 0.4321
val Loss: 2.4379 Acc: 0.6188 Mean Class Acc: 0.5246

Epoch 1/24
----------
train Loss: 1.9093 Acc: 0.6684 Mean Class Acc: 0.6111
val Loss: 1.8971 Acc: 0.6949 Mean Class Acc: 0.6145

Epoch 2/24
----------
train Loss: 1.6794 Acc: 0.6928 Mean Class Acc: 0.6526
val Loss: 2.1372 Acc: 0.6639 Mean Class Acc: 0.6011

Epoch 3/24
----------
train Loss: 1.5537 Acc: 0.7147 Mean Class Acc: 0.6753
val Loss: 2.5204 Acc: 0.7079 Mean Class Acc: 0.5846

Epoch 4/24
----------
train Loss: 1.4470 Acc: 0.7279 Mean Class Acc: 0.7000
val Loss: 2.0480 Acc: 0.7478 Mean Class Acc: 0.6174

Epoch 5/24
----------
train Loss: 1.3110 Acc: 0.7385 Mean Class Acc: 0.7190
val Loss: 2.3484 Acc: 0.7127 Mean Class Acc: 0.6143

Epoch 6/24
----------
train Loss: 1.2642 Acc: 0.7469 Mean Class Acc: 0.7302
val Loss: 2.3727 Acc: 0.6948 Mean Class Acc: 0.6200

Epoch 7/24
----------
train Loss: 1.2221 Acc: 0.7559 Mean Class Acc: 0.7461
val Loss: 2.2420 Acc:

In [2]:
# run inference
# Inference mode Function (Prediction) for Validation Set- Completely Unseen Data

def run_inference(model):

    running_loss = 0.0
    running_corrects = 0
    running_preds = np.array([])
    running_labels = np.array([])
    log_softmax_layer = nn.LogSoftmax(dim=1)

    for inputs, labels in dataloaders["val"]:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        outputs = log_softmax_layer(outputs)
        
        _, preds = torch.max(outputs, 1)

        running_preds = np.append(running_preds, preds.cpu())
        running_labels = np.append(running_labels, labels.cpu())
    
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / dataset_sizes["val"]
    epoch_acc = running_corrects.double() / dataset_sizes["val"]
    conf_matrix = confusion_matrix(running_labels, running_preds)
    epoch_mean_class_acc = np.array(conf_matrix.diagonal() / conf_matrix.sum(axis=1)).mean()

    print('{} Loss: {:.4f} Acc: {:.4f} mean class acc: {:.4f}'.format(
                "Inference", epoch_loss, epoch_acc, epoch_mean_class_acc))

In [None]:
data_dir = "/data6/SuryaKrishnan/raw_data"

image_datasets = HabDataset(data_dir, data_transforms["val"])

dataloaders = torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
                                             shuffle=True, num_workers=4)
              


# Run inference on the best model
good_model = torch.load("./awesome_model_full")
good_model.eval()

run_inference(good_model)