In [3]:
import zenml
from zenml.steps import step
import gradio as gr
from pathlib import Path
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchvision.models import resnet18
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader

import pandas as pd
import ast

from typing import Union

print(f'zenml=={zenml.__version__}')
print(f'gradio=={gr.__version__}')
print(f'wandb=={wandb.__version__}')

INFO:numexpr.utils:Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmptcqbmco1
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmptcqbmco1/_remote_module_non_scriptable.py


zenml==0.20.5
gradio==3.6
wandb==0.13.4


In [4]:
PROJECT_PATH = Path.cwd().parent
WEIGHT_DIR = PROJECT_PATH / 'weights'
CHECKPOINT_PATH = WEIGHT_DIR / 'epoch=1-step=320.ckpt'
STAGED_MODEL_FILENAME = 'staged_mri.pt'
LOG_DIR = PROJECT_PATH / 'logs'
STAGED_MODEL_TYPE = 'deployment-demo'
STAGE_MODEL_NAME = 'staged_mri_demo'
MRI_DATA_DIR = PROJECT_PATH/'data/processed/mri/test'
CSV_PATH = PROJECT_PATH/'data/processed/csv/test_mri_patients.csv'


In [5]:
LABELS = ['heart_failure', 'coronary_heart', 'myocardial_infarction', 'stroke', 'cardiac_arrest']

The MRI model code (copied from `../../mri/src/train.py`)

In [6]:
class MRIModel(pl.LightningModule):

    def __init__(self, lr=0.001):
        super().__init__()
        self.lr = lr
        self.resnet = resnet18()
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7),  # change input channel to be 1 instead of 3 
                                      stride=(2, 2), padding=(3, 3), bias=False)
        # add a linear layer at the end for transfer learning
        self.linear = nn.Linear(in_features=self.resnet.fc.out_features,
                                out_features=5)
        self.save_hyperparameters()  # log hyperparameters

    # optionally, define a forward method
    def forward(self, xs):
        y_hats = self.resnet(xs)
        y_hats = self.linear(y_hats)
        return y_hats  # we like to just call the model's forward method
    
    def training_step(self, batch, batch_idx):
        xs, ys = batch
        y_hats = self.forward(xs)
        loss = F.binary_cross_entropy_with_logits(y_hats, ys)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        xs, ys = batch
        y_hats = self.forward(xs)
        loss = F.binary_cross_entropy_with_logits(y_hats, ys)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
    
    # def test_step(self, xs, batch_idx):
    #     y_hats = self.resnet(xs)
    #     y_hats = self.linear(y_hats)
    #     return y_hats

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

Load the trained model from checkpoint ((copied from `../../mri/weights/epoch=*.ckpt`))

In [7]:
model = MRIModel.load_from_checkpoint(CHECKPOINT_PATH)

Save the model to torchscript in the staging directory

In [20]:
def save_model_to_torchscript(model, directory):
    scripted_model = model.to_torchscript(method="script", file_path=None)
    path = Path(directory) / STAGED_MODEL_FILENAME
    torch.jit.save(scripted_model, path)

save_model_to_torchscript(model, WEIGHT_DIR)

Upload the model to `weights & biases`

In [21]:
def upload_staged_model(staged_at, from_directory):
    staged_at.add_file(Path(from_directory) / STAGED_MODEL_FILENAME)
    wandb.log_artifact(staged_at)

In [22]:
# with wandb.init(
#     job_type='stage', entity="multi-modal-fsdl2022", project='deployment', dir=LOG_DIR, 
# ):
#     staged_at = wandb.Artifact(STAGE_MODEL_NAME, type=STAGED_MODEL_TYPE)
#     upload_staged_model(staged_at, from_directory=WEIGHT_DIR)

Running our more portable model via a CLI (following `lab7`)

In [8]:
class TestMRIDataset(Dataset):

    def __init__(self, mri_path: Path, csv_path: Path) -> None:
        self.data_dir = mri_path
        self.df = pd.read_csv(csv_path)
        self.mri_png_paths = list(self.data_dir.glob('*.png'))  # list of mri slices as png images
        self.img_size = 128

    def find_mri_png_paths(self, patient_id: str) -> list[Path]:
        """
        For each patient, find the corresponding paths to the MRI .png images
        and return them in a list
        """
        res = []
        for path in self.mri_png_paths:
            if patient_id in path.name:
                res.append(path)
        return res

    def __getitem__(self, patient_id):
        """
        Return all torch Tensors of MRI slices of a patient and his/her condition labels
        """
        png_paths = self.find_mri_png_paths(patient_id)
        data = []
        for path in png_paths:
            img = read_image(str(path))
            img = img.type(torch.FloatTensor) 
            img = transforms.Resize((self.img_size, self.img_size))(img)
            data.append(img)
        data = torch.stack(data)

        label = self.df[self.df['patient']==patient_id]['label'].values[0]
        label = ast.literal_eval(label)
        label = torch.Tensor(label)

        return data, label

    def __len__(self):
        return len(self.mri_png_paths)

test_dataset = TestMRIDataset(MRI_DATA_DIR, CSV_PATH)

In [16]:
model = torch.jit.load(WEIGHT_DIR/STAGED_MODEL_FILENAME)

def print_predictions(labels: list, predictions: list):
    result = []
    for i, k in enumerate(predictions):
        if k == 1:
            result.append(labels[i])
    print(f'The patient is having risk of {result}')

def predict(model: torch.nn.Module, patient_id: str):
    mri_data, _ = test_dataset[patient_id]
    y_hats = model(mri_data)
    y_hats = torch.sigmoid(y_hats)
    mean_y_hats = torch.mean(y_hats, dim=0)
    pred = [1 if i > 0.5 else 0 for i in mean_y_hats]
    print_predictions(LABELS, pred)

predict(model, '6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c')

The patient is having risk of ['stroke']
