In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

%load_ext autoreload
%autoreload 2

import os
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import json
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

import training
from training import load_config, generate_dataloader, get_model, train_loop, cal_MeanIoU_score
from utils.preprocessing import load_image, apply_img_preprocessing

from inference import load_saved_model, pred_segmentation_mask

  warn(f"Failed to load image Python extension: {e}")


In [2]:
saved_weight_path = "experiment_results/checkpoints/unet_checkpoint_epoch_1.pth"

model_name = "unet"
model_config = {'in_channels': 3, 'out_channels': 1}
inference_config = {'foreground_threshold': 0.5}

In [3]:
# update preprocessing according to training
resize_height, resize_width = 360, 640

# Define the image transformations
img_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((resize_height, resize_width)),   # ensure resize is same as used during training for loaded model 
    transforms.ToTensor()
])

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using DEVICE: {DEVICE}")

Using DEVICE: cuda


In [10]:
# initialize and load saved model
model1 = load_saved_model(model_name=model_name, saved_weight_path="experiment_results/checkpoints/unet_checkpoint_epoch_1.pth", **model_config)
model2 = load_saved_model(model_name=model_name, saved_weight_path="experiment_results/checkpoints/unet_checkpoint_epoch_2.pth", **model_config)
model3 = load_saved_model(model_name=model_name, saved_weight_path="experiment_results/checkpoints/unet_checkpoint_epoch_3.pth", **model_config)
model4 = load_saved_model(model_name=model_name, saved_weight_path="experiment_results/checkpoints/unet_checkpoint_epoch_4.pth", **model_config)
model5 = load_saved_model(model_name=model_name, saved_weight_path="experiment_results/checkpoints/unet_checkpoint_epoch_5.pth", **model_config)
model6 = load_saved_model(model_name=model_name, saved_weight_path="experiment_results/checkpoints/unet_checkpoint_epoch_6.pth", **model_config)
model1 = model.to(DEVICE)
model2 = model.to(DEVICE)
model3 = model.to(DEVICE)
model4 = model.to(DEVICE)
model5 = model.to(DEVICE)
model6 = model.to(DEVICE)

Saved weights loaded
Saved weights loaded
Saved weights loaded
Saved weights loaded
Saved weights loaded
Saved weights loaded


In [6]:
img_dir = "/cs6945share/retro_project/bdd100k/images/val/"
img_paths = glob(img_dir + "*.jpg")

In [14]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image
import io

img_indice=30

# Dropdown for file selection
file_dropdown = widgets.Dropdown(
    options=img_paths,
    value=img_paths[img_indice],
    description='File:',
    layout=widgets.Layout(width='600px')  # Adjust width to suit your preference
)


# Two images side by side
# Output widget to hold the matplotlib figure
image_output = widgets.Output()

# === Image Update Function ===
def update_images(change=None):
    selected_file = file_dropdown.value
    # full_path = os.path.join(image_folder, selected_file)
    img = Image.open(selected_file)
    mask = Image.open(selected_file.replace("images", "generated_masks_v0_2").replace(".jpg", ".png"))
    loaded_img = load_image(selected_file)
    # Model 1:
    pred_mask1 = pred_segmentation_mask(
        model=model1, test_img=loaded_img, img_transform=img_transform, add_batch_dim=True, device=DEVICE, 
        pos_threshold=inference_config["foreground_threshold"])
    # Convert the model output to a PIL Image (if it's float/binary, map to [0..255] as needed)
    pred_mask_pil1 = Image.fromarray((np.squeeze(pred_mask1) * 255).astype(np.uint8))
    # # Model 2:
    # pred_mask2 = pred_segmentation_mask(
    #     model=model2, test_img=loaded_img, img_transform=img_transform, add_batch_dim=True, device=DEVICE, 
    #     pos_threshold=inference_config["foreground_threshold"])
    # # Convert the model output to a PIL Image (if it's float/binary, map to [0..255] as needed)
    # pred_mask_pil2 = Image.fromarray((np.squeeze(pred_mask2) * 255).astype(np.uint8))
    # # Model 3:
    # pred_mask3 = pred_segmentation_mask(
    #     model=model3, test_img=loaded_img, img_transform=img_transform, add_batch_dim=True, device=DEVICE, 
    #     pos_threshold=inference_config["foreground_threshold"])
    # # Convert the model output to a PIL Image (if it's float/binary, map to [0..255] as needed)
    # pred_mask_pil3 = Image.fromarray((np.squeeze(pred_mask3) * 255).astype(np.uint8))
    # # Model 4:
    # pred_mask4 = pred_segmentation_mask(
    #     model=model4, test_img=loaded_img, img_transform=img_transform, add_batch_dim=True, device=DEVICE, 
    #     pos_threshold=inference_config["foreground_threshold"])
    # # Convert the model output to a PIL Image (if it's float/binary, map to [0..255] as needed)
    # pred_mask_pil4 = Image.fromarray((np.squeeze(pred_mask4) * 255).astype(np.uint8))
    # # Model 5:
    # pred_mask5 = pred_segmentation_mask(
    #     model=model5, test_img=loaded_img, img_transform=img_transform, add_batch_dim=True, device=DEVICE, 
    #     pos_threshold=inference_config["foreground_threshold"])
    # # Convert the model output to a PIL Image (if it's float/binary, map to [0..255] as needed)
    # pred_mask_pil5 = Image.fromarray((np.squeeze(pred_mask5) * 255).astype(np.uint8))
    # # Model 6:
    # pred_mask6 = pred_segmentation_mask(
    #     model6=model, test_img=loaded_img, img_transform=img_transform, add_batch_dim=True, device=DEVICE, 
    #     pos_threshold=inference_config["foreground_threshold"])
    # # Convert the model output to a PIL Image (if it's float/binary, map to [0..255] as needed)
    # pred_mask_pil6 = Image.fromarray((np.squeeze(pred_mask6) * 255).astype(np.uint8))

    # Resize to match original image dimensions
    pred_mask_resized1 = pred_mask_pil1.resize(img.size, resample=Image.NEAREST)
    # pred_mask_resized2 = pred_mask_pil2.resize(img.size, resample=Image.NEAREST)
    # pred_mask_resized3 = pred_mask_pil3.resize(img.size, resample=Image.NEAREST)
    # pred_mask_resized4 = pred_mask_pil4.resize(img.size, resample=Image.NEAREST)
    # pred_mask_resized5 = pred_mask_pil5.resize(img.size, resample=Image.NEAREST)

    with image_output:
        clear_output(wait=True)
        scale = 2
        fig, axs = plt.subplots(1, 2, figsize=(16*scale,8*scale))  # side-by-side
        axs[0].imshow(img)
        axs[0].imshow(mask, cmap="Reds", alpha=0.5)
        axs[0].set_title('Annotations')
        axs[0].axis('off')
        
        axs[1].imshow(img)
        axs[1].imshow(pred_mask_resized1, cmap="Reds", alpha=0.5)
        axs[1].set_title('Segmentation Model ChkPt #1 Predictions')
        axs[1].axis('off')
        
        # axs[1, 0].imshow(img)
        # axs[1, 0].imshow(pred_mask_resized2, cmap="Reds", alpha=0.5)
        # axs[1, 0].set_title('Segmentation Model ChkPt #2 Predictions')
        # axs[1, 0].axis('off')
        
        # axs[1, 1].imshow(img)
        # axs[1, 1].imshow(pred_mask_resized3, cmap="Reds", alpha=0.5)
        # axs[1, 1].set_title('Segmentation Model ChkPt #3 Predictions')
        # axs[1, 1].axis('off')
        
        # axs[2, 0].imshow(img)
        # axs[2, 0].imshow(pred_mask_resized4, cmap="Reds", alpha=0.5)
        # axs[2, 0].set_title('Segmentation Model ChkPt #4 Predictions')
        # axs[2, 0].axis('off')
        
        # axs[2, 1].imshow(img)
        # axs[2, 1].imshow(pred_mask_resized5, cmap="Reds", alpha=0.5)
        # axs[2, 1].set_title('Segmentation Model ChkPt #5 Predictions')
        # axs[2, 1].axis('off')
        plt.show()

update_images()

file_dropdown.observe(update_images, names="value")

# Navigation buttons
prev_button = widgets.Button(description='Prev', button_style='')
next_button = widgets.Button(description='Next', button_style='')

def on_prev_clicked(b):
    current_index = img_paths.index(file_dropdown.value)
    new_index = (current_index - 1) % len(img_paths)  # wrap around
    file_dropdown.value = img_paths[new_index]

def on_next_clicked(b):
    current_index = img_paths.index(file_dropdown.value)
    new_index = (current_index + 1) % len(img_paths)  # wrap around
    file_dropdown.value = img_paths[new_index]

prev_button.on_click(on_prev_clicked)
next_button.on_click(on_next_clicked)

# Lay out the widgets
# Top bar (file dropdown)
top_bar = widgets.HBox([file_dropdown],
                        layout=widgets.Layout(justify_content='center',margin='10px 0'))

# Bottom bar (prev and next buttons)
bottom_bar = widgets.HBox([prev_button, next_button],
                          layout=widgets.Layout(justify_content='center',margin='10px 0'))

# Combine everything in a vertical box
gui = widgets.VBox([top_bar, image_output, bottom_bar])

display(gui)

VBox(children=(HBox(children=(Dropdown(description='File:', index=30, layout=Layout(width='600px'), options=('…