To open this notebook in Google Colab, follow this [tutorial](https://www.endtoend.ai/blog/githubtocolab/).

# CNN-CAPS Inference

This notebook uses a trained cnn-caps model to perform inference on test images of CAPS ELS energy spectrograms.

To run this notebook, upload test images into a `test` folder and modify the `filename` parameter in 'Predict single capsplot' section. Then run the cell for prediction. 

In [2]:
!pip install grad-cam
!pip install ttach

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting grad-cam
  Downloading grad-cam-1.4.3.tar.gz (7.8 MB)
[K     |████████████████████████████████| 7.8 MB 4.7 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting ttach
  Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB)
Building wheels for collected packages: grad-cam
  Building wheel for grad-cam (PEP 517) ... [?25l[?25hdone
  Created wheel for grad-cam: filename=grad_cam-1.4.3-py3-none-any.whl size=32263 sha256=9d04441bf532c8368aa9576f0911cfab871ca7bb61f92d942b3dc519976ead9b
  Stored in directory: /root/.cache/pip/wheels/77/05/47/36e06c7cdf46685b5a9e30686a1f93bff3e95a91bf1404c75d
Successfully built grad-cam
Installing collected packages: ttach, grad-cam
Successfully installed grad-cam-1.4.3 ttach-0.0.3
Looking in indexes: https://pypi.org/simple

#Dependencies

In [3]:
# general libraries
import time
import gc
import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline
import scipy
from tqdm import trange
from glob import glob
import h5py
from IPython import display
import pandas as pd
import itertools


# ML metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, auc, classification_report

# tensorflow
import tensorflow as tf

# pytorch deep learning
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.models import resnet18
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F


# Helper functions

In [4]:
def load_model(path, model_filename="best.pt"): 
    """ Load a pretrained model. """
    with open(os.path.join(path, model_filename), 'rb') as f:
        if torch.cuda.is_available():
            model = torch.load(f) 
        else:
            model = torch.load(f, map_location=torch.device('cpu'))
    return model.eval().to(device)


def save_pkl(result, filename):
    with open(filename, "wb") as f:
        pickle.dump(result, f)


def load_pkl(filename):
    """ Load pickle file. """
    with open(filename, "rb") as f:
        result = pickle.load(f)
    return result


def load_img(filename):
    """ Load .jpg or .png image as tensor. """
    img = tf.keras.preprocessing.image.load_img(filename, color_mode='rgb', target_size=(224,224))
    img_arr = tf.keras.preprocessing.image.img_to_array(img)
    img_arr = np.array([img_arr])
    img_arr /= 255
    img_tensor = torch.FloatTensor(img_arr).permute([0,3,1,2])
    return img_tensor

def plot_img_tensor(img_tensor, label=None):
    """ Assume shape of img_tensor is (1,3,224,224) """
    plt.imshow(img_tensor.squeeze().permute(1,2,0))
    if label is not None:
        print('label: ', label.numpy())
    plt.show()

def preprocess_capsplot(filename, transform=None, plot=False):
    """ Preprocess image for single capsplot prediction with model. """
    # load image
    img = load_img(filename)

    # normalise the image
    resize_norm_transform = transforms.Compose([
                                    transforms.Resize(224),
                                    transforms.Normalize(mean=[0.403,0.647,0.577], std=[0.301,0.183,0.312]),
                                  ])
    img_stdscaled = resize_norm_transform(img)

    if transform is not None:
        img_stdscaled = transform(img_stdscaled)
    
    if plot:
        plot_img_tensor(img)

    return img_stdscaled

def predict(filename, model, transform=None, plot=False):
    """ Predict single capsplot with cnn-caps model. """
    img = preprocess_capsplot(filename, transform=transform, plot=plot)
    result = model(torch.FloatTensor(img).to(device))[1].data.cpu().numpy()
    return result

## CONSTANTS

In [5]:
# CONSTANTS
N_CLASSES = 3
BATCH_SIZE = 128
LEARNING_RATE = 0.00016 # 0.00025 for softmax cross entropy, 0.00016 for EDL training
RANDOM_SEED = 42
#CLASS_LABELS = ['0_notCrossing', '1_MP', '2_BS']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ResNet Model

In [6]:
# create layer that returns unchanged input
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

In [7]:
class TransferredNet(nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.pretrained_model = pretrained_model
        self.pretrained_model.fc = Identity()

        self.head = nn.Sequential(
            nn.Linear(512, 16),
            #nn.ReLU(),
            nn.Tanh(),
            nn.Linear(16, N_CLASSES),
            #nn.Sigmoid()
        )
        
    def forward(self, input):
        logits = self.head(self.pretrained_model(input))
        probs = F.softmax(logits, dim=1)
        return logits, probs

In [8]:
# load pretrained model and freeze some of the parameters in the model
pretrained_model = resnet18(pretrained=True)

# COMMENT OUT BELOW IF WANT TO TRAIN THE WHOLE NETWORK
# for i, child in enumerate(pretrained_model.children()):
    
#     # There are 9 blocks in resnet18 (last fc layer replaced with Identity)
#     #print(i, child)

#     # Let's freeze the weights of the first 2/3 of model
#     if i <= 6:
#         for param in child.parameters():
#             param.requires_grad = False

model = TransferredNet(pretrained_model).to(device)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [9]:
# Show the blocks which require gradients (i.e. will be finetunned in training)

print('PRETRAINED MODEL:\n')
for i, child in enumerate(model.pretrained_model.children()):
    for param in child.parameters():
        if param.requires_grad:
            print(i, child)
            break

print('HEAD:\n')
for i, child in enumerate(model.head.children()):
    for param in child.parameters():
        if param.requires_grad:
            print(i, child)
            break

PRETRAINED MODEL:

0 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
1 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
4 Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

# Predict single capsplot:


## Predictions with Cross Entropy Model (Multiclass)

In [14]:
# Upload pretrained model to current runtime
MODEL_NAME = "resnet18_custom_softmaxCrossEntropy_train_full_multiclass.pt"
model = load_model(path="./", model_filename=MODEL_NAME)

In [15]:
%%time
# Upload images from datasets/images/ to current runtime
filename='/content/2007-02-02T220000_2007-02-03T010000.png' # correct prediction
filename='/content/2007-03-16T130300_2007-03-16T150300.png' # incorrect prediction (however the caps data do not look strikingly clear)

# predict
result = predict(filename, model, plot=False)
print('Prediction', result)

Prediction [[0.63766557 0.2857451  0.07658938]]
CPU times: user 15.9 ms, sys: 0 ns, total: 15.9 ms
Wall time: 16.4 ms


For cases where the prediction is uncertain, e.g. 
filename='/content/test/2007-03-16T130300_2007-03-16T150300.png' # incorrect prediction (however the caps data do not look strikingly clear). The predictions were: Prediction [[0.63766736 0.2857435  0.07658914]]. We could set a threshold for if the max predicted probability is less than 0.75, then this is an uncertain class and we should assign them to the 'human-in-the-loop' pile. See 'Define Data and Establish Baseline' page in MLOps course.

