## Imports

In [1]:
import itertools
from typing import List

import torch
from torch.utils import data
import torchvision
import torchmetrics
from torchvision import transforms, models
import torchvision.transforms.functional as F


import numpy as np
from sklearn import metrics
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
import matplotlib.pyplot as plt

import pytorch_lightning as pl

## Set up MLflow variables

In [2]:
import mlflow
import os
os.environ['AWS_ACCESS_KEY_ID'] = "1234"
os.environ['AWS_SECRET_ACCESS_KEY'] ="123441212344321"
os.environ['MLFLOW_S3_ENDPOINT_URL']="http://localhost:9000"
mlflow.set_tracking_uri('http://localhost:5000/')
mlflow.set_experiment('Cats vs Possums')

<Experiment: artifact_location='s3://mlflow-dwh/4', creation_time=1693158451661, experiment_id='4', last_update_time=1693158451661, lifecycle_stage='active', name='Cats vs Possums', tags={}>

## Define data preparation 

In [3]:
image_transforms ={
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomHorizontalFlip(p=0.9),
        transforms.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225], ),
    ]),
    
    'validation': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225], ),
    ]),
    
    'test':transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225], ),
    ]),
    
}

dataset = {
    key: torchvision.datasets.ImageFolder(f'./images/{key}', transform=transform)
    for key, transform in image_transforms.items()
}

loader = {
    key: data.DataLoader(ds, batch_size=10, shuffle=key=='train')
    for key, ds in dataset.items()
}


## Visualisation utilities


In [4]:
def show(images: torch.Tensor, labels: List[str], dataset: torchvision.datasets.DatasetFolder):
    transform = dataset.transform
    classes = dataset.classes
    mean = transform.transforms[-1].mean
    std = transform.transforms[-1].std
    inverse_normalize = transforms.Compose(
        [
           transforms.Normalize(
                mean=tuple(-m / s for m, s in zip(mean, std)),
                std=tuple(1.0 / s for s in std),
            ),
        ]
    )
    
    fig, axs = plt.subplots(ncols=len(images), squeeze=False, figsize=(18, 2))
    for i, img in enumerate(images):
        img = img.detach()
        img = inverse_normalize(img)
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, i].set_title(classes[labels[i]])
    return fig

def plot_10_predictions(test_loader: data.DataLoader, classifier: torch.nn.Module):
    test_data = next(iter(test_loader))
    predictions = classifier(test_data[0])
    labels = (predictions > 0.5).to(torch.int)
    fig = show(test_data[0], labels, dataset['test'])
    return fig


def plot_insights(classifier: torch.nn.Module):
    test_results = pd.DataFrame()
    with torch.no_grad():
        for data, labels in iter(loader['test']):
            predictions = classifier(data)
            batch = pd.DataFrame({
                'y_hat': predictions.numpy().ravel(),
                'y': labels.numpy().ravel()
            })
            test_results = test_results.append(batch)
    fpr, tpr, thresholds = metrics.roc_curve(
        test_results['y'],
        test_results['y_hat']
    )
    area = px.area(
        x=fpr, y=tpr
    )
    area.add_shape(
        type='line', line=dict(dash='dash'),
        x0=0, x1=1, y0=0, y1=1
    )
    df = pd.DataFrame({
        'False Positive Rate': fpr,
        'True Positive Rate': tpr
    }, index=thresholds)
    fig_thresh = px.line(
        df, title='TPR and FPR at every threshold'
    )
    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=[
            f'ROC Curve (AUC={metrics.auc(fpr, tpr):.4f})',
            'TPR and FPR at every threshold'
        ]
    )
    fig.add_trace(next(area.select_traces()), row=1, col=1)
    fig.update_yaxes(scaleratio=1, row=1, col=1, title_text="True Positive Rate")
    fig.update_xaxes(constrain='domain', row=1, col=1, title_text="False Positive Rate")
    fig.add_trace(next(fig_thresh.select_traces()), row=1, col=2)
    fig.update_yaxes(scaleanchor="x", scaleratio=1, row=1, col=2, title_text="Thresholds")
    fig.update_xaxes(range=[0, 1], constrain='domain', row=1, col=2, title_text="Rate")
    fig.update_layout(height=400, width=900,
                      title_text="ROC AUC and Decision threshold")
    return fig


## Model class

In [5]:
class PossumOrCat(pl.LightningModule):
    def __init__(self, classifier_layers=None):
        super().__init__()
        self.pretrain = models.resnet18(pretrained=True)
        self.output_features = self.pretrain.fc.in_features   
        new_head = self.create_new_output(classifier_layers or [])
        self.pretrain.fc = new_head

        self.train_accuracy = torchmetrics.Accuracy('binary')
        self.validation_accuracy = torchmetrics.Accuracy('binary')
        
    def create_new_output(self, layers: List[int]):
        new_head = []
        current_output_size = self.output_features
        for layer in layers:
            if layer is None:
                continue
            new_head.append(torch.nn.Dropout(0.5))
            new_head.append(torch.nn.Linear(current_output_size, layer))
            new_head.append(torch.nn.ReLU())
            current_output_size = layer
            
        # Remove first dropout and last activation
        new_head = new_head[1:-1]

        # Append new activation function
        new_head.append(torch.nn.Linear(current_output_size, 1))
        new_head.append(torch.nn.Sigmoid())
        return torch.nn.Sequential(*new_head)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def forward(self, x):
        output = self.pretrain(x)
        return output
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)

        # Log (log) loss 
        loss = torch.nn.functional.binary_cross_entropy(y_hat.ravel(), y.to(torch.float))
        self.log("train_loss", loss)

        # Log custom metric
        self.train_accuracy(y_hat.ravel(), y.to(torch.float))
        self.log("train_accuracy", self.train_accuracy)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = torch.nn.functional.binary_cross_entropy(y_hat.ravel(), y.to(torch.float))
        self.log("validation_loss", loss)

         # Log custom metric
        self.validation_accuracy(y_hat.ravel(), y.to(torch.float))
        self.log("validation_accuracy", self.validation_accuracy)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = torch.nn.functional.binary_cross_entropy(y_hat.ravel(), y.to(torch.float))
        self.log("test_loss", loss)
        return loss
    

In [None]:
mlflow.pytorch.autolog()

def run_modeling(layers):
    with mlflow.start_run():
        # Create and train classifier
        clf = PossumOrCat(classifier_layers=layers) 
        trainer = pl.Trainer(max_epochs=10)
        trainer.fit(clf, train_dataloaders=loader['train'], val_dataloaders=loader['validation'])

        # Log hyperparameters
        mlflow.log_param("layers", layers)
        
        # Log custom metric
        test_result = trainer.test(clf, dataloaders=loader['test'], )
        mlflow.log_metric('Total test loss', test_result[0]['test_loss'])
        
        # Log images with predictions
        predictions_vs_actuals = plot_10_predictions(loader['test'], classifier=clf)
        mlflow.log_figure(predictions_vs_actuals, '10test.png')
    
        # Log binary charts
        insights = plot_insights(classifier=clf)
        mlflow.log_figure(insights, 'roc_threshold.html')
        

candidate_layers = [16, 64, None]
for candidate in itertools.product(candidate_layers, repeat=3):
    run_modeling(candidate)
    
    

In [32]:
best_run = mlflow.search_runs(order_by=['metrics.validation_loss'], output_format='list')[0]
uri = f"runs:/{best_run.info.run_id}/model"
best_model = mlflow.pytorch.load_model(uri)
best_model

PossumOrCat(
  (pretrain): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, trac