# Cyber2A Workshop: Toy Model for RTS Training and Inference
This repository demonstrates a simplified example of training and running inference on a toy model using the Retrogressive Thaw Slumps (RTS) dataset. It is adapted from the official [PyTorch Vision Tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html).

The copyright for the tutorial content belongs to PyTorch. © Copyright 2024, PyTorch.
 

In [None]:
from dotenv import load_dotenv
_ = load_dotenv("env.txt")

## Imports
We have some code in a local module to keep this notebook focused on
the training concepts

In [None]:
import sys

sys.path.append("./toy_model")

import torch
from dataset import RTSDataset
from transforms import get_transform
from utils import collate_fn
from model import get_model_instance_segmentation
from engine import evaluate, train_one_epoch

In [None]:
# Set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
assert device.type == 'cuda'

## Hyperparameters
These are the model hyperparameters. They are pulled out of the code to easily
be set here. They will be recorded along with the run in the tracking server

In [None]:
params = {
    "lr": 0.005,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "step_size": 3,
    "gamma": 0.1,
    "epochs": 1
}

## Set up for Training
Find the training data and set up the optomizer

In [None]:
dataset = RTSDataset("data/coco_rts_train.json", get_transform(train=True))
dataset_test = RTSDataset("data/coco_rts_valtest.json", get_transform(train=False))

# Create data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, collate_fn=collate_fn
)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn
)

# Initialize model
model = get_model_instance_segmentation(num_classes=2)
model.to(device)

# Set up optimizer and learning rate scheduler
opt_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(opt_params, lr=params["lr"], 
                            momentum=params["momentum"], 
                            weight_decay=params["weight_decay"])
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                               step_size=params["step_size"], 
                                               gamma=params["gamma"])




## Model Input and Output Schemas
MLFlow requires published models to include schemas for model input and inference
output. 

In [None]:
from mlflow.types import Schema, TensorSpec
import mlflow
import numpy as np

# Define the input schema
input_schema = Schema([
    TensorSpec(
        np.dtype(np.float32),
        shape=(1, 3, 226, 288)
    )
])

# Define the output schema as a dictionary
output_schema = Schema([
    TensorSpec(
        np.dtype(np.float32),
        shape=(71, 4),  # Assuming boxes are in (x1, y1, x2, y2) format
        name="boxes"
    ),
    TensorSpec(
        np.dtype(np.int64),
        shape=([71]),  # One label per detection
        name="labels"
    ),
    TensorSpec(
        np.dtype(np.float32),
        shape=([71]),  # One confidence score per detection
        name="scores"
    ),
    TensorSpec(
        np.dtype(np.float32),
        shape=(71, 1, 226, 288), 
        name="masks"
    )
])

signature = mlflow.models.signature.ModelSignature(
    inputs=input_schema,
    outputs=output_schema
)



## Start the Training Run
This will start the run, log the hyperparameters and then log metrics from the model
as the training progresses.

When training is complete, the model is uploaded to Tracking Server and associated with the training run

In [None]:
import mlflow

with mlflow.start_run() as run:
    # Log the hyperparameters
    mlflow.log_params(params)
    mlflow.set_tag("scientist", "Ben")

    # Training loop
    for epoch in range(params["epochs"],):
        metrics = train_one_epoch(model, 
                                  optimizer, 
                                  data_loader, 
                                  device, 
                                  epoch, 
                                  print_freq=10)
        lr_scheduler.step()
        evaluate(model, data_loader_test, device=device)
    print(metrics)
    
    model_info = mlflow.pytorch.log_model(
        pytorch_model=model,
        artifact_path="model",
        signature=signature
    )

    model_info

#  Inference

## Prepare an image to Perform Inference on

In [None]:
from toy_model.transforms import get_transform
from torchvision.io import read_image

def preprocess_image(image_path, transform, device):
    image = read_image(image_path)
    image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(
        torch.uint8
    )
    image = image[:3, ...]
    transformed_image = transform(image)
    return image, transformed_image[:3, ...].to(device)


image, transformed_image = preprocess_image(
    "data/images/valtest_yg_055.jpg", get_transform(train=False), device
)



In [None]:
model_info.model_uri

In [None]:
import mlflow
model_uri = model_info.model_uri

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(model_uri, model_config={'device': 'cuda'})


In [None]:
loaded_model.metadata.signature

In [None]:
import numpy as np
import pandas as pd

# The pyfunc flavor doesn't handle our input data, 
# we could write a custom flavor wrapper, for now
# just get our hands on the raw torch model
pytorch_model = loaded_model._model_impl.pytorch_model

x = torch.unsqueeze(transformed_image, 0)

pred = pytorch_model(x)[0]

In [None]:
def filter_predictions(p, score_threshold=0.5):
    keep = p["scores"] > score_threshold
    return {k: v[keep] for k, v in p.items()}

filtered = filter_predictions(pred)

## Draw the Predicted Boxes on Top of Image

In [None]:
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
import cv2
def draw_predictions(image, pred):
    pred_labels = [f"RTS: {score:.3f}" for score in pred["scores"]]
    masks = (pred["masks"] > 0.7).squeeze(1)
    output_image = draw_segmentation_masks(image, masks, alpha=0.5, colors="red")
    pred_boxes = pred["boxes"].long()
    output_image = draw_bounding_boxes(
        output_image, pred_boxes, pred_labels, colors="black", width=0
    )
    return output_image
    
output_image = draw_predictions(image, filtered)

i1=output_image.permute(1, 2, 0).numpy()
i2=cv2.cvtColor(i1, cv2.COLOR_RGB2BGR)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(i2)
plt.axis('off')  # Optional: Hides the axis ticks and labels
plt.show()


In [None]:
import matplotlib.pyplot as plt
plt.imshow(image.permute(1, 2, 0).numpy())
plt.axis('off')  # Optional: Hides the axis ticks and labels
plt.show()
