In [None]:
!pip install --upgrade --index-url https://pypi.ngc.nvidia.com nvidia-tensorrt
!sudo apt-get install -y tensorrt

In [None]:
import tensorrt
print(tensorrt.__version__)
assert tensorrt.Builder(tensorrt.Logger())

In [None]:
!sudo apt-get install python3-libnvinfer-dev -y
!python3 -m pip install numpy onnx
!sudo apt-get install onnx-graphsurgeon -y

In [None]:
!dpkg-query -W tensorrt
!dpkg -l | grep nvinfer

In [None]:
!pip install --upgrade pip
!pip install --upgrade sphinx-glpi-theme
!pip install --no-cache-dir --index-url https://pypi.nvidia.com pytorch-quantization
!pip install wget
!pip3 install -U pip && pip3 install onnxsim
!pip install onnxruntime

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from torchvision import models, datasets
from torch.utils.tensorboard import SummaryWriter
import onnxruntime
import onnxsim


import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer
from pytorch_quantization.nn.modules import quant_pooling
from tqdm import tqdm

print(pytorch_quantization.__version__)
# print(torch_tensorrt.__version__)
print(onnxruntime.__version__)
print(onnxsim.__version__)
#print(torchaudio.__version__)

import os
import sys
import warnings
import time
import numpy as np
import wget
import tarfile
import shutil
warnings.simplefilter('ignore')
import collections

# Accelerate parts
from accelerate import Accelerator, notebook_launcher # main interface, distributed launcher
from accelerate.utils import set_seed # reproducability across devices
torch.backends.cudnn.benchmark = True  # Enables auto-tuning for speed optimization

In [None]:
# Initialize the Accelerator
accelerator = Accelerator()

# Set seed for reproducibility
set_seed(42)


In [None]:
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
import random

class CityscapesDataset(Dataset):
    def __init__(self, root_dir, split='train', resize_shape=(512, 1024), is_train = True):
        super().__init__()
        self.inputs, self.targets = [], []
        self.resize_shape = resize_shape
        self.ignore_index = 255
        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
        self.class_map = {k: v for v, k in enumerate(self.valid_classes)}
        self.is_train = is_train

        img_dir = os.path.join(root_dir, 'leftImg8bit', split)
        mask_dir = os.path.join(root_dir, 'gtFine', split)

        for root, _, filenames in os.walk(img_dir):
            for filename in filenames:
                if filename.endswith('.png'):
                    filename_base = '_'.join(filename.split('_')[:-1])
                    input_path = os.path.join(root, filename_base + '_leftImg8bit.png')
                    target_path = os.path.join(mask_dir, os.path.basename(root), filename_base + '_gtFine_labelIds.png')

                    if os.path.exists(input_path) and os.path.exists(target_path):
                        self.inputs.append(input_path)
                        self.targets.append(target_path)

        print(f"Found {len(self.inputs)} images and {len(self.targets)} targets for split '{split}'")

        # Define Albumentations transformation pipeline
        if self.is_train:
            self.transform = A.Compose([
                A.RandomResizedCrop(self.resize_shape[0], self.resize_shape[1], scale=(0.7, 1.0), ratio=(1.0, 1.0), p=0.5),
                A.RandomScale(scale_limit=0.3, p=0.5),  # Scale images to enhance small objects
                A.HorizontalFlip(p=0.5),  # Useful for symmetrical structures
                A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, p=0.5),  # More aggressive color variation
                A.GaussianBlur(blur_limit=(3, 7), p=0.3),
                A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),  # Helps with perspective issues
                A.CoarseDropout(max_holes=12, max_height=75, max_width=75, p=0.6),  # Simulate occlusions
                A.Resize(self.resize_shape[0], self.resize_shape[1], p=1.0),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ])
        else:
            self.transform = A.Compose([
                A.Resize(self.resize_shape[0], self.resize_shape[1]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ])


    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, i):
        input_img = np.array(Image.open(self.inputs[i]).convert('RGB'))
        target_img = np.array(Image.open(self.targets[i]))

        # Apply transformations to both image and mask
        transformed = self.transform(image=input_img, mask=target_img)
        input_img = transformed['image']
        target_tensor = transformed['mask']

        # Remap target labels
        remapped_target = torch.full_like(target_tensor, self.ignore_index)
        for cls, mapped in self.class_map.items():
            remapped_target[target_tensor == cls] = mapped

        return input_img, remapped_target

# DataLoader setup
root_dir = '/kaggle/input/cityscapes/Cityscape'
resize_shape = (512, 1024)

train_dataset = CityscapesDataset(root_dir=root_dir, split='train', resize_shape=resize_shape, is_train=True)
val_dataset = CityscapesDataset(root_dir=root_dir, split='val', resize_shape=resize_shape,  is_train=False)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)
inference_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

# Load a small subset for calibration
# Select a subset of 100 images for calibration
CALIBRATION_SIZE = 500
random.seed(42)
calib_indices = random.sample(range(len(val_dataset)), CALIBRATION_SIZE)
calib_subset = torch.utils.data.Subset(val_dataset, calib_indices)
# calib_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'val'), transform=data_transforms['val'])
calib_dataloader = torch.utils.data.DataLoader(
    calib_subset, batch_size=8, shuffle=False, drop_last=True, pin_memory=True, num_workers=4
)

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet18
from torchvision.models._utils import IntermediateLayerGetter
from collections import OrderedDict

num_classes = 19  # Cityscapes has 19 classes

# Define FCNHead manually (since torchvision does not expose it)
class FCNHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super().__init__(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),  # Prevents overfitting
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

# Load ResNet-18 backbone
try:
    backbone = resnet18(weights="IMAGENET1K_V1")  # Torchvision >= 0.13
except TypeError:
    backbone = resnet18(pretrained=True)  # Older torchvision versions

# Extract only the feature layers (excluding final FC layer)
return_layers = {"layer4": "out"}  # Extract feature maps from layer4
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

# Load pre-trained FCN-ResNet50 to initialize classifier weights
fcn_resnet50 = torchvision.models.segmentation.fcn_resnet50(pretrained=True)

# Create FCN model with ResNet-18 backbone
model = torchvision.models.segmentation.fcn_resnet50(pretrained=False, num_classes=num_classes)

# Replace the backbone with ResNet-18
model.backbone = backbone

# **Fix: Initialize a new FCNHead with correct input channels (512)**
model.classifier = FCNHead(512, num_classes)  # ResNet-18 outputs 512 channels

# **Fix: Do not transfer classifier weights from ResNet-50 (1024 → 512 mismatch)**

# **Optional:** Remove auxiliary classifier if it causes issues
model.aux_classifier = None

# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print("✅ FCN-ResNet18 Model loaded successfully!")


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Decode function to convert mask to RGB (Cityscapes)
def decode_segmap(mask, n_classes=19, label_colours=None):
    if label_colours is None:
        label_colours = [
            (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
            (190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
            (107, 142, 35), (152, 251, 152), (0, 130, 180), (220, 20, 60),
            (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100),
            (0, 80, 100), (0, 0, 230), (119, 11, 32)
        ]  # 19 valid classes

    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()

    # Ensure ignored pixels (255) are set to zero
    mask[mask == 255] = 0

    r, g, b = np.zeros_like(mask, dtype=np.uint8), np.zeros_like(mask, dtype=np.uint8), np.zeros_like(mask, dtype=np.uint8)

    for l in range(n_classes):
        r[mask == l] = label_colours[l][0]
        g[mask == l] = label_colours[l][1]
        b[mask == l] = label_colours[l][2]

    return np.stack([r, g, b], axis=2)  # Convert to RGB image

# Function to visualize input, ground truth, and predicted masks
def visualize_segmentation(input_image, decoded_target=None, decoded_prediction=None):
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    # Convert input tensor to numpy for display
    img = input_image.cpu().numpy().transpose(1, 2, 0)  # Convert [C, H, W] to [H, W, C]
    img = (img - img.min()) / (img.max() - img.min())  # Normalize for display

    axs[0].imshow(img)
    axs[0].set_title("Input Image")
    axs[0].axis('off')

    if decoded_target is not None:
        axs[1].imshow(decoded_target)
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis('off')
    else:
        axs[1].imshow(np.zeros_like(img))
        axs[1].set_title("Ground Truth Unavailable")
        axs[1].axis('off')

    if decoded_prediction is not None:
        axs[2].imshow(decoded_prediction)
        axs[2].set_title("Predicted Mask")
        axs[2].axis('off')
    else:
        axs[2].imshow(np.zeros_like(img))
        axs[2].set_title("Prediction Unavailable")
        axs[2].axis('off')

    plt.show()

# Function to check if the model is trained
def is_model_trained(model, threshold=0.001):
    """Check if classifier weights are different from initialization."""
    with torch.no_grad():
        weights = model.classifier[-1].weight
        return torch.var(weights).item() > threshold  # Check variance instead of mean


# Function to perform inference and visualize predictions
def visualize_predictions(model, data_loader, device, n_classes=19):
    model.eval()
    
    trained = is_model_trained(model)
    print(f"Model trained: {trained}")

    with torch.no_grad():
        for input_img, target_mask in data_loader:
            input_img, target_mask = input_img.to(device), target_mask.to(device)

            if trained:
                # Perform inference
                output = model(input_img)
                if isinstance(output, dict):  # FCN models return a dict with 'out' key
                    output = output["out"]
                _, predicted_mask = torch.max(output, 1)  # Get predicted class indices
            else:
                predicted_mask = None  # No predictions, show random samples

            # Select single image from batch
            input_image = input_img[0]
            target_mask = target_mask[0]
            predicted_mask = predicted_mask[0] if predicted_mask is not None else None

            # Decode segmentation masks
            decoded_target = decode_segmap(target_mask, n_classes=n_classes) if target_mask is not None else None
            decoded_prediction = decode_segmap(predicted_mask, n_classes=n_classes) if predicted_mask is not None else None

            # Visualize the segmentation results
            visualize_segmentation(input_image, decoded_target, decoded_prediction)
            break  # Show only one batch

# Ensure model is on GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trained_model = model.to(device)  # Ensure model is on correct device

# Call the visualization function with the model and data loader
visualize_predictions(trained_model, inference_loader, device, n_classes=19)


In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Loss function with ignore_index for unlabeled pixels
criterion = nn.CrossEntropyLoss(ignore_index=255)

optimizer = optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-4) #2e-5
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
# scheduler = CosineAnnealingLR(optimizer, T_max=10)  # T_max is the number of epochs before resetting


scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=1, verbose=True)


In [None]:
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import CosineAnnealingLR, SequentialLR, LinearLR

# Class weighting to handle class imbalance
class_weights = torch.tensor([
    1.0, 1.5, 1.0, 7.0, 4.0, 6.0, 16.0, 5.0, 1.0, 4.5, 1.0, 
    22.0, 22.0, 1.0, 20.0, 22.0, 24.0, 24.0, 24.0  # Increased 11-18 weights
], device=device).to(device)


# Define loss function
# criterion = FocalLoss(weight=class_weights, gamma=1.5, reduction="mean", ignore_index=255)

criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=255)  # Ignore unlabeled pixels
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=0.3, patience=2, verbose=True)



In [None]:
# Load Best Model Checkpoint
checkpoint = torch.load("/kaggle/input/fcnresnet18/pytorch/default/2/best_fcn_resnet18_ckptv4.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['opt_state_dict'])
start_epoch = checkpoint['epoch']


In [None]:
import matplotlib.pyplot as plt
import torch

# Updated visualize_sample function
def visualize_sample(model, data_loader, device):
    model.eval()  # Set model to evaluation mode

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Get model predictions - access the 'out' key for the main output
            outputs = model(inputs)
            main_output = outputs['out']  # Access the main prediction
            predictions = main_output.argmax(dim=1)  # Get predicted class labels

            # Move data back to CPU for visualization
            inputs = inputs.cpu()
            predictions = predictions.cpu()
            targets = targets.cpu()

            # Visualize each image in the batch
            for i in range(len(inputs)):
                input_img = inputs[i].permute(1, 2, 0).numpy()  # Convert to HWC format for display
                input_img = np.clip(input_img, 0, 1)  # Clip input image to [0, 1] range

                target_mask = decode_segmap(targets[i])  # Decode ground truth mask
                predicted_mask = decode_segmap(predictions[i])  # Decode predicted mask

                # Clip masks to [0, 255] and convert to uint8 if needed
                target_mask = np.clip(target_mask * 255, 0, 255).astype(np.uint8)
                predicted_mask = np.clip(predicted_mask * 255, 0, 255).astype(np.uint8)

                fig, axs = plt.subplots(1, 3, figsize=(15, 5))

                # Display input image
                axs[0].imshow(input_img)
                axs[0].set_title("Input Image")
                axs[0].axis("off")

                # Display ground truth mask
                axs[1].imshow(target_mask)
                axs[1].set_title("Ground Truth Mask")
                axs[1].axis("off")

                # Display predicted mask
                axs[2].imshow(predicted_mask)
                axs[2].set_title("Predicted Mask")
                axs[2].axis("off")

                plt.show()

            break  # Display only the first batch for visualization

# Usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_sample(model, inference_loader, device)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import confusion_matrix
from torch.nn import CrossEntropyLoss

# Helper function: Compute pixel accuracy
def calculate_accuracy(outputs, targets):
    _, preds = torch.max(outputs, 1)
    mask = targets != 255  # Ignore "ignore_index"
    correct = (preds[mask] == targets[mask]).float().sum()
    total = mask.float().sum()  # Count only valid pixels
    return correct / total if total > 0 else 0  # Avoid division by zero

# Helper function: Compute mean IoU
def calculate_mean_iou(preds, targets, n_classes):
    preds = preds.cpu().detach().numpy().flatten()
    targets = targets.cpu().detach().numpy().flatten()
    mask = targets != 255  # Ignore unlabeled pixels

    preds, targets = preds[mask], targets[mask]

    if len(preds) == 0:
        return 0.0  # No valid pixels

    cm = confusion_matrix(targets, preds, labels=np.arange(n_classes))
    intersection = np.diag(cm)
    union = cm.sum(axis=1) + cm.sum(axis=0) - intersection
    iou = intersection / (union + 1e-10)  # Avoid division by zero

    return np.mean(iou[iou > 0])  # Ignore empty classes

# Training function
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    running_accuracy = 0.0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device).long()
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        main_output = outputs['out']
        aux_output = outputs.get('aux', None)  # Handle auxiliary output

        # Compute losses
        main_loss = criterion(main_output, masks)
        aux_loss = criterion(aux_output, masks) * 0.4 if aux_output is not None else 0
        loss = main_loss + aux_loss

        # Backward pass & optimization
        loss.backward()
        optimizer.step()
        # scheduler.step()

        # Compute accuracy
        accuracy = calculate_accuracy(main_output, masks)

        # Logging
        running_loss += loss.item()
        running_accuracy += accuracy.item()

        # Free memory
        del images, masks, outputs, main_output, aux_output, loss
        torch.cuda.empty_cache()

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = running_accuracy / len(train_loader)
    return epoch_loss, epoch_accuracy


# Validation loop with IoU calculation
def evaluate_model_with_metrics(model, val_loader, criterion, device, n_classes):
    model.eval()
    val_loss, val_accuracy, val_iou = 0.0, 0.0, 0.0

    with torch.no_grad():  # No gradient tracking
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device).long()

            # Forward pass
            outputs = model(images)
            main_output = outputs['out'].detach()  # Detach to free memory
            loss = criterion(main_output, masks)

            # Compute metrics
            accuracy = calculate_accuracy(main_output, masks)
            _, preds = torch.max(main_output, 1)
            iou = calculate_mean_iou(preds, masks, n_classes)

            # Logging
            val_loss += loss.item()
            val_accuracy += accuracy.item()
            val_iou += iou

            # Free memory
            del images, masks, outputs, main_output, preds
            torch.cuda.empty_cache()

    avg_val_loss = val_loss / len(val_loader)
    avg_val_accuracy = val_accuracy / len(val_loader)
    avg_iou = val_iou / len(val_loader)
    
    return avg_val_loss, avg_val_accuracy, avg_iou


# EarlyStopping class
class EarlyStopping:
    """Stops training if validation IoU doesn't improve after a set patience."""
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0

# Function to save model checkpoint
def save_checkpoint(state, ckpt_path):
    torch.save(state, ckpt_path)
    print(f"✅ Checkpoint saved: {ckpt_path}")




In [None]:
# Hyperparameters
n_classes = 19
num_epochs = 90
patience = 10  # Stop training if no improvement for `patience` epochs
early_stopping = EarlyStopping(patience=patience, min_delta=0.005)
best_val_iou = 0.0
log_path = '/kaggle/working/log_segmentation.txt'
ckpt_path = "/kaggle/working/best_segmentation.pth"

print(f"📌 Logging to {log_path}")

# Write headers to log file
with open(log_path, 'w') as log_file:
    log_file.write("Epoch, Train Loss, Train Acc, Val Loss, Val Acc, Val IoU\n")

# Training loop with early stopping
for epoch in range(num_epochs):
    # Train the model for one epoch
    train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device)
    val_loss, val_accuracy, val_iou = evaluate_model_with_metrics(model, val_loader, criterion, device, n_classes)
    scheduler.step(val_iou)

    # Print metrics for the current epoch
    print(f"🔹 Epoch {epoch + 1}/{num_epochs}")
    print(f"   📉 Train Loss: {train_loss:.4f}, 🎯 Train Acc: {train_accuracy:.4f}")
    print(f"   📉 Val Loss: {val_loss:.4f}, 🎯 Val Acc: {val_accuracy:.4f}, 🔥 Val IoU: {val_iou:.4f}")

    # Save metrics to log file
    with open(log_path, 'a') as log_file:
        log_file.write(f"{epoch + 1}, {train_loss:.4f}, {train_accuracy:.4f}, {val_loss:.4f}, {val_accuracy:.4f}, {val_iou:.4f}\n")

    # Save the best model checkpoint
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        save_checkpoint({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'acc': val_accuracy,
            'val_iou': val_iou,
            'opt_state_dict': optimizer.state_dict()
        }, ckpt_path)
        
        print(f"✅ Best model saved at epoch {epoch + 1} (IoU: {val_iou:.4f})")

    # Early stopping check
    early_stopping(val_iou)
    if early_stopping.early_stop:
        print("🚨 Early stopping triggered. Training stopped.")
        break

print("🎉 Training complete! Best IoU:", best_val_iou)


In [None]:
val_loss, val_accuracy, val_iou = evaluate_model_with_metrics(model, val_loader, criterion, device, num_classes)
print(f"Validation Loss: {val_loss:.4f}, Validation Acc: {val_accuracy:.4f}, Validation IoU: {val_iou:.4f}")

In [None]:
!/usr/src/tensorrt/bin/trtexec

In [None]:
# Exporting to ONNX
dummy_input = torch.randn(1, 3, 512, 1024, device='cuda')
input_names = ["actual_input_1"]
output_names = ["output1"]
torch.onnx.export(
    model,
    dummy_input,
    "/kaggle/working/deeplab_resnet50_base.onnx",  # Single .onnx extension
    verbose=False,
    opset_version=11,
    do_constant_folding=False,
    input_names=input_names,
    output_names=output_names
)

# Converting ONNX model to TRT
! /usr/src/tensorrt/bin/trtexec --onnx=/kaggle/working/deeplab_resnet50_base.onnx --useSpinWait --avgRuns=100  --saveEngine=/kaggle/working/deeplab_resnet50_base.trt

In [None]:
import torch

# Define dummy input with a fixed size for export, but specify dynamic dimensions
dummy_input = torch.randn(1, 3, 1024, 2048, device='cuda')  # Example size; actual size will be dynamic

# Specify input and output names
input_names = ["actual_input_1"]
output_names = ["output1"]

# Define dynamic axes for batch size, height, and width
dynamic_axes = {
    "actual_input_1": {0: "batch_size", 2: "height", 3: "width"},  # Batch, height, and width are dynamic
    "output1": {0: "batch_size", 2: "height", 3: "width"}
}

# Export the model with dynamic input sizes
torch.onnx.export(
    model,
    dummy_input,
    "/kaggle/working/deeplab_resnet50_dynamic.onnx",  # Save as .onnx file
    verbose=False,
    opset_version=11,
    do_constant_folding=False,
    input_names=input_names,
    output_names=output_names,
    # dynamic_axes=dynamic_axes  # Specify dynamic axes
)

# Convert ONNX model to TensorRT with dynamic shapes
!/usr/src/tensorrt/bin/trtexec --onnx=/kaggle/working/deeplab_resnet50_dynamic.onnx --explicitBatch  --useSpinWait --workspace=512 --avgRuns=100 --saveEngine=/kaggle/working/deeplab_resnet50_dynamic.trt


In [None]:
import torch

# Define dummy input with a fixed size for export, but specify dynamic dimensions
dummy_input = torch.randn(1, 3, 512, 1024, device='cuda')  # Example size; actual size will be dynamic

# Specify input and output names
input_names = ["actual_input_1"]
output_names = ["output1"]

# Define dynamic axes for batch size, height, and width
dynamic_axes = {
    "actual_input_1": {0: "batch_size", 2: "height", 3: "width"},  # Batch, height, and width are dynamic
    "output1": {0: "batch_size", 2: "height", 3: "width"}
}

# Export the model with dynamic input sizes
torch.onnx.export(
    model,
    dummy_input,
    "/kaggle/working/fcn_resnet18_fp16.onnx",  # Save as .onnx file
    verbose=False,
    opset_version=11,
    do_constant_folding=False,
    input_names=input_names,
    output_names=output_names,
    # dynamic_axes=dynamic_axes  # Specify dynamic axes
)

# Convert ONNX model to TensorRT with dynamic shapes
!/usr/src/tensorrt/bin/trtexec --onnx=/kaggle/working/fcn_resnet18_fp16.onnx --explicitBatch  --fp16 --useSpinWait --workspace=512 --avgRuns=100 --saveEngine=/kaggle/working/fcn_resnet18_fp16.trt


In [None]:
from sklearn.metrics import jaccard_score

def compute_class_iou(preds, targets, n_classes=19):
    """
    Computes class-wise IoU for segmentation predictions.

    Args:
        preds (torch.Tensor): Predicted segmentation mask (H, W) or (batch, H, W)
        targets (torch.Tensor): Ground truth segmentation mask (H, W) or (batch, H, W)
        n_classes (int): Number of segmentation classes

    Returns:
        dict: IoU per class
    """
    preds = preds.cpu().numpy().flatten()
    targets = targets.cpu().numpy().flatten()

    # Compute class-wise IoU
    iou_per_class = jaccard_score(targets, preds, average=None, labels=np.arange(n_classes))

    # Return a dictionary of IoU per class
    return {f"Class {i}": round(iou, 4) for i, iou in enumerate(iou_per_class)}

# Ensure model is in evaluation mode
model.eval()

# Run inference on validation set
with torch.no_grad():
    for input_img, target_mask in inference_loader:  # Ensure inference_loader is defined
        input_img, target_mask = input_img.to(device), target_mask.to(device)

        # Perform inference
        output = model(input_img)["out"]
        predicted_mask = torch.argmax(output, dim=1)  # Convert logits to class indices

        # Compute class-wise IoU for this batch
        iou_per_class = compute_class_iou(predicted_mask[0], target_mask[0], n_classes=19)
        
        print("Class-wise IoU:", iou_per_class)
        break  # Process only the first batch

def compute_class_weights(train_loader, n_classes=19):
    """
    Computes class weights based on the frequency of each class in the dataset.
    
    Args:
        train_loader (DataLoader): Training data loader
        n_classes (int): Number of classes
    
    Returns:
        torch.Tensor: Class weights for loss function
    """
    class_counts = torch.zeros(n_classes)
    
    for _, target in train_loader:
        target = target.view(-1)  # Flatten the masks
        for i in range(n_classes):
            class_counts[i] += torch.sum(target == i)

    class_weights = 1.0 / (class_counts + 1e-10)  # Avoid division by zero
    class_weights /= class_weights.sum()  # Normalize

    return class_weights.to(device)

# Compute weights and define weighted loss
class_weights = compute_class_weights(train_loader, n_classes=19)
# criterion = nn.CrossEntropyLoss(weight=class_weights)

In [None]:
import tensorrt as trt
import numpy as np
import os

class Int8Calibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, data_loader, cache_file="calibration_data.cache"):
        super().__init__()
        self.data_loader = data_loader
        self.cache_file = cache_file
        self.batch_size = 8
        self.device_input = None
        self.current_index = 0
        self.data = None

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names):
        if self.current_index >= len(self.data_loader):
            return None

        batch = self.data_loader[self.current_index]
        batch = np.ascontiguousarray(batch.numpy().astype(np.float32))
        self.current_index += 1

        return [batch.ravel()]

    def read_calibration_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()
        return None

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)


In [None]:
calibrator = Int8Calibrator(
    data_loader=calib_dataloader,  # Use full dataset for calibration
    cache_file="\kaggle\working\calibration_data.cache"
)


In [None]:

# Define dummy input with a fixed size for export, but specify dynamic dimensions
dummy_input = torch.randn(1, 3, 512, 1024, device='cuda')  # Example size; actual size will be dynamic

# Specify input and output names
input_names = ["actual_input_1"]
output_names = ["output1"]

# Define dynamic axes for batch size, height, and width
dynamic_axes = {
    "actual_input_1": {0: "batch_size", 2: "height", 3: "width"},  # Batch, height, and width are dynamic
    "output1": {0: "batch_size", 2: "height", 3: "width"}
}

# Export the model with dynamic input sizes
torch.onnx.export(
    model,
    dummy_input,
    "/kaggle/working/fcn_resnet18_int8.onnx",  # Save as .onnx file
    verbose=False,
    opset_version=11,
    do_constant_folding=False,
    input_names=input_names,
    output_names=output_names,
    # dynamic_axes=dynamic_axes  # Specify dynamic axes
)

# Convert ONNX model to TensorRT with dynamic shapes
!/usr/src/tensorrt/bin/trtexec --onnx=/kaggle/working/fcn_resnet18_int8v1.onnx --calib=/kaggle/working/calibration_data.cache --explicitBatch --int8 --useSpinWait --workspace=1024 --avgRuns=100 --saveEngine=/kaggle/working/fcn_resnet18_int8v1.trt


In [None]:
#-----------per channel weight-----------------------------------
quant_desc_input = QuantDescriptor(num_bits=8, calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

# #-----------per tensor weight--------------------------------------
# quant_desc_input = QuantDescriptor(num_bits=4, calib_method='histogram', axis=None)
# quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
# quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
# quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input)
# quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

# quant_desc_weight = QuantDescriptor(num_bits=4, calib_method='histogram', axis=None)
# quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)
# quant_nn.QuantConvTranspose2d.set_default_quant_desc_weight(quant_desc_weight)
# quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)


In [None]:
quant_modules.initialize()

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet18
from torchvision.models._utils import IntermediateLayerGetter
from collections import OrderedDict

num_classes = 19  # Cityscapes has 19 classes

# Define FCNHead manually (since torchvision does not expose it)
class FCNHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super().__init__(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),  # Prevents overfitting
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

# Load ResNet-18 backbone
try:
    backbone = resnet18(weights="IMAGENET1K_V1")  # Torchvision >= 0.13
except TypeError:
    backbone = resnet18(pretrained=True)  # Older torchvision versions

# Extract only the feature layers (excluding final FC layer)
return_layers = {"layer4": "out"}  # Extract feature maps from layer4
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

# Load pre-trained FCN-ResNet50 to initialize classifier weights
fcn_resnet50 = torchvision.models.segmentation.fcn_resnet50(pretrained=True)

# Create FCN model with ResNet-18 backbone
q_model = torchvision.models.segmentation.fcn_resnet50(pretrained=False, num_classes=num_classes)

# Replace the backbone with ResNet-18
q_model.backbone = backbone

# **Fix: Initialize a new FCNHead with correct input channels (512)**
q_model.classifier = FCNHead(512, num_classes)  # ResNet-18 outputs 512 channels

# **Fix: Do not transfer classifier weights from ResNet-50 (1024 → 512 mismatch)**

# **Optional:** Remove auxiliary classifier if it causes issues
q_model.aux_classifier = None

# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
q_model = q_model.to(device)

print("✅ FCN-ResNet18 Model loaded successfully!")


In [None]:
q_model

In [None]:
# Load Best Model Checkpoint
checkpoint = torch.load("/kaggle/working/fcnresnet18_qat.pth")
q_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['opt_state_dict'])
start_epoch = checkpoint['epoch']


In [None]:
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")
    model.cuda()

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistics"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    # Feed data to the network for collecting stats
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir):
    """
        Feed data to the network and calibrate.
        Arguments:
            model: classification model
            model_name: name to use when creating state files
            data_loader: calibration data set
            num_calib_batch: amount of calibration passes to perform
            calibrator: type of calibration to use (max/histogram)
            hist_percentile: percentiles to be used for historgram calibration
            out_dir: dir to save state files in
    """

    if num_calib_batch > 0:
        print("Calibrating model")
        with torch.no_grad():
            collect_stats(model, data_loader, num_calib_batch)

        if not calibrator == "histogram":
            compute_amax(model, method="max")
            calib_output = os.path.join(
                out_dir,
                F"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth")
            torch.save(model.state_dict(), calib_output)
        else:
            for percentile in hist_percentile:
                print(F"{percentile} percentile calibration")
                compute_amax(model, method="percentile")
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)

            for method in ["mse", "entropy"]:
                print(F"{method} calibration")
                compute_amax(model, method=method)
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)

In [None]:
# Set model to evaluation mode if you're using it for inference
q_model.eval()

In [None]:
!mkdir calib_cache

In [None]:
#Calibrate the model using max calibration technique.
with torch.no_grad():
    calibrate_model(
        model=q_model,
        model_name="fcnresnet18_qat",
        data_loader=calib_dataloader,
        num_calib_batch=8,
        calibrator="histogram",
        hist_percentile=[99.9], 
        out_dir="/kaggle/working/calib_cache/")
"""[ 99.99, 99.999, 99.9999]"""

In [None]:
checkpoint = torch.load("/kaggle/working/calib_cache/fcnresnet18_qat-percentile-99.9-64.pth")
q_model.load_state_dict(checkpoint)

In [None]:
val_loss, val_accuracy, val_iou = evaluate_model_with_metrics(q_model, val_loader, criterion, device, num_classes)
print(f"Validation Loss: {val_loss:.4f}, Validation Acc: {val_accuracy:.4f}, Validation IoU: {val_iou:.4f}")

In [None]:
# Hyperparameters
n_classes = 19
num_epochs = 5
patience = 10  # Stop training if no improvement for `patience` epochs
early_stopping = EarlyStopping(patience=patience, min_delta=0.005)
best_val_iou = 0.0
log_path = '/kaggle/working/fcnresnet18_qat.txt'
ckpt_path = "/kaggle/working/fcnresnet18_qat.pth"

print(f"📌 Logging to {log_path}")

# Write headers to log file
with open(log_path, 'w') as log_file:
    log_file.write("Epoch, Train Loss, Train Acc, Val Loss, Val Acc, Val IoU\n")

# Training loop with early stopping
for epoch in range(num_epochs):
    # Train the model for one epoch
    train_loss, train_accuracy = train_one_epoch(q_model, train_loader, criterion, optimizer, scheduler, device)
    val_loss, val_accuracy, val_iou = evaluate_model_with_metrics(q_model, val_loader, criterion, device, n_classes)
    scheduler.step(val_iou)

    # Print metrics for the current epoch
    print(f"🔹 Epoch {epoch + 1}/{num_epochs}")
    print(f"   📉 Train Loss: {train_loss:.4f}, 🎯 Train Acc: {train_accuracy:.4f}")
    print(f"   📉 Val Loss: {val_loss:.4f}, 🎯 Val Acc: {val_accuracy:.4f}, 🔥 Val IoU: {val_iou:.4f}")

    # Save metrics to log file
    with open(log_path, 'a') as log_file:
        log_file.write(f"{epoch + 1}, {train_loss:.4f}, {train_accuracy:.4f}, {val_loss:.4f}, {val_accuracy:.4f}, {val_iou:.4f}\n")

    # Save the best model checkpoint
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        save_checkpoint({
            'epoch': epoch + 1,
            'model_state_dict': q_model.state_dict(),
            'acc': val_accuracy,
            'val_iou': val_iou,
            'opt_state_dict': optimizer.state_dict()
        }, ckpt_path)
        
        print(f"✅ Best model saved at epoch {epoch + 1} (IoU: {val_iou:.4f})")

    # Early stopping check
    early_stopping(val_iou)
    if early_stopping.early_stop:
        print("🚨 Early stopping triggered. Training stopped.")
        break

print("🎉 Training complete! Best IoU:", best_val_iou)


In [None]:
import torch
quant_nn.TensorQuantizer.use_fb_fake_quant = True
# Define dummy input with a fixed size for export, but specify dynamic dimensions
dummy_input = torch.randn(1, 3, 512, 1024, device='cuda')  # Example size; actual size will be dynamic

# Specify input and output names
input_names = ["actual_input_1"]
output_names = ["output1"]

# Define dynamic axes for batch size, height, and width
dynamic_axes = {
    "actual_input_1": {0: "batch_size", 2: "height", 3: "width"},  # Batch, height, and width are dynamic
    "output1": {0: "batch_size", 2: "height", 3: "width"}
}

# Export the model with dynamic input sizes
torch.onnx.export(
    q_model.eval(),
    dummy_input,
    "/kaggle/working/fcn_resnet18_qat.onnx",  # Save as .onnx file
    verbose=False,
    opset_version=13,
    do_constant_folding=False,
    input_names=input_names,
    output_names=output_names,
    # dynamic_axes=dynamic_axes  # Specify dynamic axes
)

# Convert ONNX model to TensorRT with dynamic shapes
!/usr/src/tensorrt/bin/trtexec --onnx=/kaggle/working/fcn_resnet18_qat.onnx  --explicitBatch --int8 --useSpinWait --workspace=1024 --avgRuns=100 --saveEngine=/kaggle/working/fcn_resnet18_qat.trt


In [None]:
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
import torch
import time
from sklearn.metrics import jaccard_score

def allocate_buffers(engine, context, batch_size, height, width):
    inputs, outputs, bindings = [], [], []
    stream = cuda.Stream()
    for binding in range(engine.num_bindings):
        binding_shape = (batch_size, 3, height, width) if engine.binding_is_input(binding) else (batch_size, engine.get_binding_shape(binding)[1], height, width)
        size = trt.volume(binding_shape)
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        bindings.append(int(device_mem))
        if engine.binding_is_input(binding):
            inputs.append({'host': host_mem, 'device': device_mem})
        else:
            outputs.append({'host': host_mem, 'device': device_mem})
    return inputs, outputs, bindings, stream

def infer_and_evaluate(engine, context, val_loader, criterion, device, n_classes):
    total_loss, total_accuracy, total_iou, latencies = 0.0, 0.0, 0.0, []
    total_images = 0
    height = 1024
    width = 2048
    with torch.no_grad():
        for images, masks in val_loader:
            batch_size, channels, height, width = images.shape
            images, masks = images.to(device), masks.to(device).long()
            
            # Set dynamic shape for context
            context.set_binding_shape(0, (batch_size, channels, height, width))
            # print(batch_size, channels, height, width)
            # Allocate buffers based on the current shape
            inputs, outputs, bindings, stream = allocate_buffers(engine, context, batch_size, height, width)
            
            # Copy images to input buffer
            images = images.cpu().numpy().astype(np.float32)
            np.copyto(inputs[0]['host'], images.ravel())
            
            # Measure latency
            start_time = time.time()
            cuda.memcpy_htod_async(inputs[0]['device'], inputs[0]['host'], stream)
            context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
            cuda.memcpy_dtoh_async(outputs[0]['host'], outputs[0]['device'], stream)
            stream.synchronize()
            latency = time.time() - start_time
            latencies.append(latency)
            total_images += batch_size

            # Process model output
            main_output = torch.from_numpy(outputs[0]['host']).view(batch_size, n_classes, height, width).to(device)
            loss = criterion(main_output, masks)
            accuracy = calculate_accuracy(main_output, masks)
            _, preds = torch.max(main_output, 1)
            iou = calculate_mean_iou(preds, masks, n_classes)

            total_loss += loss.item()
            total_accuracy += accuracy.item()
            total_iou += iou

    # Calculate average metrics
    avg_loss = total_loss / len(val_loader)
    avg_accuracy = total_accuracy / len(val_loader)
    avg_iou = total_iou / len(val_loader)
    avg_latency = np.mean(latencies)
    throughput = total_images / sum(latencies)

    return avg_loss, avg_accuracy, avg_iou, avg_latency, throughput


In [None]:
# Define necessary parameters
engine_path = "/kaggle/working/fcn_resnet18_qat.trt"
device = torch.device("cuda")  # or torch.device("cpu")
n_classes = 19  # Set according to your dataset, e.g., 19 for Cityscapes

# Load the TensorRT engine
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def load_engine(engine_path):
    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

engine = load_engine(engine_path)
# Create an execution context from the engine
context = engine.create_execution_context()

# Example call to infer_and_evaluate function
avg_loss, avg_accuracy, avg_iou, avg_latency, throughput = infer_and_evaluate(
    engine, context, inference_loader, criterion, device, n_classes
)

# Print results
print("Average Loss:", avg_loss)
print("Average Accuracy:", avg_accuracy)
print("Average Mean IoU:", avg_iou)
print("Average Latency (seconds):", avg_latency)
print("Throughput (images/second):", throughput)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Decode function to convert mask to RGB (Cityscapes)
def decode_segmap(mask, n_classes=19, label_colours=None):
    if label_colours is None:
        label_colours = [
            (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
            (190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
            (107, 142, 35), (152, 251, 152), (0, 130, 180), (220, 20, 60),
            (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100),
            (0, 80, 100), (0, 0, 230), (119, 11, 32)
        ]  # 19 valid classes

    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()

    # Ensure ignored pixels (255) are set to zero
    mask[mask == 255] = 0

    r, g, b = np.zeros_like(mask, dtype=np.uint8), np.zeros_like(mask, dtype=np.uint8), np.zeros_like(mask, dtype=np.uint8)

    for l in range(n_classes):
        r[mask == l] = label_colours[l][0]
        g[mask == l] = label_colours[l][1]
        b[mask == l] = label_colours[l][2]

    return np.stack([r, g, b], axis=2)  # Convert to RGB image

# Function to visualize input, ground truth, and predicted masks
def visualize_segmentation(input_image, decoded_target=None, decoded_prediction=None):
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    # Convert input tensor to numpy for display
    img = input_image.cpu().numpy().transpose(1, 2, 0)  # Convert [C, H, W] to [H, W, C]
    img = (img - img.min()) / (img.max() - img.min())  # Normalize for display

    axs[0].imshow(img)
    axs[0].set_title("Input Image")
    axs[0].axis('off')

    if decoded_target is not None:
        axs[1].imshow(decoded_target)
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis('off')
    else:
        axs[1].imshow(np.zeros_like(img))
        axs[1].set_title("Ground Truth Unavailable")
        axs[1].axis('off')

    if decoded_prediction is not None:
        axs[2].imshow(decoded_prediction)
        axs[2].set_title("Predicted Mask")
        axs[2].axis('off')
    else:
        axs[2].imshow(np.zeros_like(img))
        axs[2].set_title("Prediction Unavailable")
        axs[2].axis('off')

    plt.show()

# Function to perform inference using TensorRT and visualize predictions
def visualize_trt_predictions(engine, context, inference_loader, device, n_classes=19):
    height, width = 512, 1024

    with torch.no_grad():
        for input_img, target_mask in inference_loader:
            input_img, target_mask = input_img.to(device), target_mask.to(device)
            batch_size, channels, h, w = input_img.shape

            # Set TensorRT input shape dynamically
            context.set_binding_shape(0, (batch_size, channels, height, width))

            # Allocate buffers
            inputs, outputs, bindings, stream = allocate_buffers(engine, context, batch_size, height, width)

            # Copy image to input buffer
            images_np = input_img.cpu().numpy().astype(np.float32)
            np.copyto(inputs[0]['host'], images_np.ravel())

            # Perform inference
            cuda.memcpy_htod_async(inputs[0]['device'], inputs[0]['host'], stream)
            context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
            cuda.memcpy_dtoh_async(outputs[0]['host'], outputs[0]['device'], stream)
            stream.synchronize()

            # Process model output
            main_output = torch.from_numpy(outputs[0]['host']).view(batch_size, n_classes, height, width).to(device)
            _, predicted_mask = torch.max(main_output, 1)  # Get predicted class indices

            # Select single image from batch
            input_image = input_img[0]
            target_mask = target_mask[0]
            predicted_mask = predicted_mask[0]

            # Decode segmentation masks
            decoded_target = decode_segmap(target_mask, n_classes=n_classes) if target_mask is not None else None
            decoded_prediction = decode_segmap(predicted_mask, n_classes=n_classes) if predicted_mask is not None else None

            # Visualize the segmentation results
            visualize_segmentation(input_image, decoded_target, decoded_prediction)
            break  # Show only one batch

# Call the visualization function with the TensorRT engine
visualize_trt_predictions(engine, context, inference_loader, device, n_classes=19)
