In [27]:
import numpy as np

import torch
from torch.nn import functional

from transformers import AutoFeatureExtractor

from PIL import Image
import matplotlib.pyplot as plt

## Load model

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torch.load(input("trained model path"), map_location=device)
model.eval()

model_name = model.config._name_or_path
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

FileNotFoundError: [Errno 2] No such file or directory: ''

## Set color palette

In [None]:
PALETTE = [
    [120, 120, 120], # background
    [180, 120, 120], # container_truck
    [6, 230, 230], # forklift
    [80, 50, 50], # reach_stacker
    [4, 200, 3] # ship
]
palette = np.array(PALETTE)

## Visualize

In [None]:
img_path = input("image path to visualize")
img = Image.open(img_path)

id2label = model.config.id2label

inputs = feature_extractor(images=img, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits  # (batch_size, num_labels, height/4, width/4)
upsampled_logits = functional.interpolate(
    logits,
    size=img.size[::-1],  # (height, width)
    mode="bilinear",
    align_corners=False
)
seg = upsampled_logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)  # (height, width, 3)

for label, color in enumerate(palette):
    color_seg[seg == label, :] = color

masked_img = np.array(img) * 0.2 + color_seg * 0.8
masked_img = masked_img.astype(np.uint8)

merged = np.concatenate(
    (
        np.concatenate((img, color_seg), axis=1),
        np.concatenate((np.zeros_like(img), masked_img), axis=1)
    ),
    axis=0
)

plt.figure(figsize=(15, 15))
plt.imshow(merged)