In [10]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

%load_ext autoreload
%autoreload 2

import os
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import json
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

import training
from training import load_config, generate_dataloader, get_model, train_loop, cal_MeanIoU_score
from utils.preprocessing import load_image, apply_img_preprocessing

from inference import load_saved_model, pred_segmentation_mask

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
saved_weight_path = "experiment_results/checkpoints/unet_final_2025-03-23_14-17-47.pth"

model_name = "unet"
model_config = {'in_channels': 3, 'out_channels': 1}
inference_config = {'foreground_threshold': 0.5}

In [3]:
# update preprocessing according to training
resize_height, resize_width = 360, 640

# Define the image transformations
img_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((resize_height, resize_width)),   # ensure resize is same as used during training for loaded model 
    transforms.ToTensor()
])

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using DEVICE: {DEVICE}")

Using DEVICE: cuda


In [9]:
# initialize and load saved model
model = load_saved_model(model_name=model_name, saved_weight_path=saved_weight_path, **model_config)
model = model.to(DEVICE)

Saved weights loaded


In [14]:
img_dir = "/cs6945share/retro_project/bdd100k/images/val/"
img_paths = glob(img_dir + "*.jpg")

In [32]:
import ipywidgets as widgets
from IPython.display import display
from PIL import Image
import io

def load_image(filename, max_width=800, max_height=800):
    # Open image and preserve aspect ratio
    img = Image.open(filename)
    img.thumbnail((max_width, max_height))  # Resizes in-place with aspect ratio preserved

    # Convert to bytes
    with io.BytesIO() as output:
        img.save(output, format='JPEG')
        return output.getvalue()

img_indice=30

# Dropdown for file selection
file_dropdown = widgets.Dropdown(
    options=img_paths,
    value=img_paths[img_indice],
    description='File:',
    layout=widgets.Layout(width='600px')  # Adjust width to suit your preference
)

# Two images side by side
img_bytes = load_image(file_dropdown.value)

left_image = widgets.Image(
    # value=img_bytes,
    format='jpg',
)
right_image = widgets.Image(
    # value=img_bytes,
    format='jpg',
)

def update_images(change=None):
    selected_file = file_dropdown.value
    img_bytes = load_image(selected_file)
    left_image.value = img_bytes
    right_image.value = img_bytes

update_images()

file_dropdown.observe(update_images, names="value")

# Navigation buttons
prev_button = widgets.Button(description='Prev', button_style='')
next_button = widgets.Button(description='Next', button_style='')

def on_prev_clicked(b):
    current_index = img_paths.index(file_dropdown.value)
    new_index = (current_index - 1) % len(img_paths)  # wrap around
    file_dropdown.value = img_paths[new_index]

def on_next_clicked(b):
    current_index = img_paths.index(file_dropdown.value)
    new_index = (current_index + 1) % len(img_paths)  # wrap around
    file_dropdown.value = img_paths[new_index]

prev_button.on_click(on_prev_clicked)
next_button.on_click(on_next_clicked)

# Lay out the widgets
# Top bar (file dropdown)
top_bar = widgets.HBox([file_dropdown],
                        layout=widgets.Layout(justify_content='center',margin='10px 0'))

# Middle section (two images side by side)
images_box = widgets.HBox([left_image, right_image],
                          layout=widgets.Layout(justify_content='center'))

# Bottom bar (prev and next buttons)
bottom_bar = widgets.HBox([prev_button, next_button],
                          layout=widgets.Layout(justify_content='center',margin='10px 0'))

# Combine everything in a vertical box
gui = widgets.VBox([top_bar, images_box, bottom_bar])

display(gui)

VBox(children=(HBox(children=(Dropdown(description='File:', index=30, layout=Layout(width='600px'), options=('…