In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import fiftyone as fo

from PIL import Image
from torchvision import transforms

from datetime import datetime

import torch
import lightning as L

from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateFinder

from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.tuner import Tuner

from torch.utils.data import DataLoader

from model.model_light import FishSeg
from model.dataset import SimpleFishialFishDataset
from model.utils import *
from argparse import ArgumentParser

In [None]:
MODEL_PATH = '.ckpt'

encoder_type = "FPN"
backbone = "resnet18"
in_channels = 3
out_classes = 1

In [None]:
data = fo.load_dataset('classification-v0.8')

In [None]:
model = FishSeg(encoder_type, backbone, in_channels=3, out_classes=1, load_checkpoint = MODEL_PATH)
model.eval()
model.cpu()

In [None]:
IMAGE_SIZE = 416

loader = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])


In [None]:
count_to_visualize = 10
view = data.take(count_to_visualize)

for sample_id, sample in enumerate(view):
    
    filepath = sample.filepath
    polygon = sample.polyline.points[0]
    
    pil_image = Image.open(filepath)
    width, height = pil_image.size
    
    gt_mask = create_mask(polygon, height, width, color = (100,123,234))
    
    x_tensor = loader(pil_image).unsqueeze(0)
    with torch.no_grad():
        logits = model(x_tensor)
    pr_mask = logits.sigmoid()[0][0].numpy()
    pr_mask = resize_logits_mask_pil(pr_mask, width, height)
        
    visualize(
        image=pil_image, 
        ground_truth_mask=gt_mask, 
        predicted=pr_mask
    )


In [None]:
croped_model = model.model
croped_model.eval()
croped_model.cpu()

example_forward_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(croped_model.forward, example_forward_input)

In [None]:
import torch.onnx

# Define the path to save the ONNX model
onnx_model_path = f"saved_models/segmentator/model_resnet18_{IMAGE_SIZE}.onnx"

# Load the pre-trained model (in this case, ResNet18)
# model = models.resnet18(pretrained=True)
# model.eval()  # Set the model to evaluation mode

# Create an input tensor with the corresponding shape (batch_size, channels, height, width)
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)

# Export the model to ONNX format
torch.onnx.export(croped_model,          # The model to export
                  dummy_input,           # The input tensor
                  onnx_model_path,       # The path to save the model
                  export_params=True,    # Export the parameters as well
                  opset_version=11,      # The ONNX version
                  do_constant_folding=False, # Enable constant folding optimization
                  input_names = ['input'],   # The names of the input layers
                  output_names = ['output'], # The names of the output layers
                  dynamic_axes={'input' : {0 : 'batch_size'},    # Dynamic batch size
                                'output' : {0 : 'batch_size'}})

print(f"Model successfully exported to {onnx_model_path}")

In [None]:
SAVE_PATH = f"saved_models/segmentator/model_resnet18_{IMAGE_SIZE}.ts"
module.save(SAVE_PATH)

In [None]:
new_model = torch.jit.load(SAVE_PATH)

In [None]:
import onnxruntime
ort_session = onnxruntime.InferenceSession(onnx_model_path)

In [None]:
batch_size = 1
x = torch.randn(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE, requires_grad=False)

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: x.numpy()}
ort_outs = ort_session.run(None, ort_inputs)

# with torch.no_grad():
torch_out = croped_model(x)

np_onnx_oputput = np.array(ort_outs)
full_model_output = torch_out.detach().numpy()

print(np.sum(np_onnx_oputput - full_model_output))
# # compare ONNX Runtime and PyTorch results
# np.testing.assert_allclose(full_model_output, np_onnx_oputput, rtol=1e-03, atol=1e-05)

In [None]:
# model.eval()
new_model.eval()

In [None]:
loader = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])


In [None]:
x_tensor

In [None]:
count_to_visualize = 10
view = data.take(count_to_visualize)

for sample_id, sample in enumerate(view):
    
    filepath = sample.filepath
    polygon = sample.polyline.points[0]
    
    pil_image = Image.open(filepath)
    width, height = pil_image.size
    
    gt_mask = create_mask(polygon, height, width, color = (100,123,234))
    
    x_tensor = loader(pil_image).unsqueeze(0)
    
    with torch.no_grad():
        logits = new_model(x_tensor)
        logits_std = model(x_tensor)
        logits_croped = croped_model(x_tensor)

        ort_inputs = {ort_session.get_inputs()[0].name: x_tensor.numpy()}
        ort_outs = torch.tensor(ort_session.run(None, ort_inputs)[0])
        
    pr_mask_onnx = ort_outs.sigmoid()[0][0].numpy()
    pr_mask_onnx = resize_logits_mask_pil(pr_mask_onnx, width, height)
    
    pr_mask_croped = logits_croped.sigmoid()[0][0].numpy()
    pr_mask_croped = resize_logits_mask_pil(pr_mask_croped, width, height)
    
    pr_mask_std = logits_std.sigmoid()[0][0].numpy()
    pr_mask_std = resize_logits_mask_pil(pr_mask_std, width, height)
    
    pr_mask = logits.sigmoid()[0][0].numpy()
    pr_mask = resize_logits_mask_pil(pr_mask, width, height)
    
    visualize(
        image=pil_image, 
        ground_truth_mask=gt_mask,
        predicted_full=pr_mask_std,
        pr_mask_croped = pr_mask_croped,
        predicted_ts=pr_mask,
        pr_mask_onnx=pr_mask_onnx
    )

In [None]:
type(ort_outs[0])

In [None]:
with torch.no_grad():
#         %timeit logits = new_model(x_tensor)
        %timeit logits_std = model(x_tensor)
#         logits_croped = croped_model(x_tensor)