# Imports

In [1]:
import kagglehub
import kagglehub.config
import json
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import gradio as gr

from RoadDataLoader import RoadDataLoader
from RoadDataset import RoadDataset

from baseline_models.DeepLabV3Model import DeepLabV3Model
from baseline_models.UNET2D import UNET2D
from Swin_UNET.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys

from wrapper_modules.RoadSegmentationModule import RoadSegmentationModule

from loss_and_eval_functions import dice_score, combined_loss, iou_score

  from .autonotebook import tqdm as notebook_tqdm


# Download

In [2]:
# get kaggle credentials file from ./kaggle.json
with open("./kaggle.json", "r") as f:
    kaggle_json = json.load(f)
kaggel_username = kaggle_json["username"]
kaggel_key = kaggle_json["key"]

In [3]:
kagglehub.config.set_kaggle_credentials(kaggel_username, kaggel_key)

Kaggle credentials set.


In [4]:
# Download latest version
path = kagglehub.dataset_download("payne18/road-detection-dataset-with-masks")
# make data folder if it does not exist
if not os.path.exists("./data"):
    os.mkdir("./data")
# Move data folder to ./data
os.system(f"mv {path} ./data/road-detection-dataset-with-masks")
# remove empty folder
folder_to_remove = path.split("payne18/road-detection-dataset-with-masks")[0] 
os.system((f"rm -r {folder_to_remove}"))



1

In [5]:
data_path = "./data/road-detection-dataset-with-masks/deepglobe-road-extraction-dataset"
metadata_path = "./data/road-detection-dataset-with-masks/deepglobe-road-extraction-dataset/metadata.csv"

In [6]:
# open metadata
metadata = pd.read_csv(metadata_path)
metadata = metadata[metadata["split"] == "train"]
metadata["sat_image_path"] = metadata["sat_image_path"].apply(lambda x: os.path.join(data_path, x))
metadata["mask_path"] = metadata["mask_path"].apply(lambda x: os.path.join(data_path, x))


In [7]:
# hyperparameters
batch_size = 8
optimizer = "Adam"
lr = 1e-3
weight_decay = 0.01
epochs = 40
loss_fn = combined_loss
accelerator = "auto"
pretrained = False
image_size = 512
num_workers = 8

In [8]:
deeplabv3_model_path = "./models/DeepLabV3_best_model.cpkt"
unet2d_model_path = "./models/UNET2D_best_model.cpkt"
swin_model_path = "./models/Swin_UNET.cpkt"

if not os.path.exists("./models"):
    os.mkdir("./models")
#check if file exists
if not os.path.exists(deeplabv3_model_path):
    !curl -L -o ./models/DeepLabV3_best_model.cpkt https://huggingface.co/beboi0122/Vision_transformers_for_image_segmentation_HF/resolve/main/DeepLabV3_best_model.cpkt
if not os.path.exists(unet2d_model_path):
    !curl -L -o ./models/UNET2D_best_model.cpkt https://huggingface.co/beboi0122/Vision_transformers_for_image_segmentation_HF/resolve/main/UNET2D_best_model.cpkt
if not os.path.exists(swin_model_path):
    !curl -L -o ./models/Swin_UNET.cpkt https://huggingface.co/beboi0122/Vision_transformers_for_image_segmentation_HF/resolve/main/Swin_UNET.cpkt

# Load Models

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
deeplabv3_state_dict = torch.load(deeplabv3_model_path)["state_dict"]
deeplabv3_state_dict = {k.replace("model.model.", "model."): v for k, v in deeplabv3_state_dict.items()}
deeplabv3_model = DeepLabV3Model(pretrained=False)
deeplabv3_model.load_state_dict(deeplabv3_state_dict, strict=False)
deeplabv3_module = RoadSegmentationModule(deeplabv3_model, combined_loss, optimizer)
deeplabv3_model.to(device)
_ = deeplabv3_model.eval()


  deeplabv3_state_dict = torch.load(deeplabv3_model_path)["state_dict"]


In [None]:
unet2d_state_dict = torch.load(unet2d_model_path)["state_dict"]
unet2d_state_dict = {k.replace("model.", ""): v for k, v in unet2d_state_dict.items()}
unet2d_model = UNET2D(3, 1, chanel_list=[8, 16, 32, 64])
unet2d_model.load_state_dict(unet2d_state_dict, strict=False)
unet2d_model.to(device)
_ = unet2d_model.eval()


  unet2d_state_dict = torch.load(unet2d_model_path)["state_dict"]


UNET2D(
  (down_blocks): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_run

In [12]:
swin_model = SwinTransformerSys(
    img_size=image_size,           # Input méret
    patch_size=4,           # Patch méret
    in_chans=3,             # RGB képek
    num_classes=1,          # Bináris szegmentáció
    embed_dim=8,
    num_heads = (1, 2, 4, 8),
    depths = (2, 2, 6, 2),
    window_size=8,          # Ablak méret
    mlp_ratio=4.0,          # MLP arány
    qkv_bias=True,          # QKV bias
    drop_rate=0.1,          # Dropout ráta
    attn_drop_rate=0.1,     # Attention dropout
    drop_path_rate=0.1,     # Drop path
    norm_layer=nn.LayerNorm,# Rétegnormálás
    ape=False,              # Absolute positional embedding
    patch_norm=True,        # Patch normálás
    use_checkpoint=False    # Checkpoint
)
_ = swin_model.to(device)
module = RoadSegmentationModule(swin_model, combined_loss, optimizer)
module.load_state_dict(torch.load(swin_model_path)["state_dict"])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


SwinTransformerSys expand initial----depths:(2, 2, 6, 2);depths_decoder:[1, 2, 2, 2];drop_path_rate:0.1;num_classes:1
---final upsample expand_first---


  module.load_state_dict(torch.load(swin_model_path)["state_dict"])


<All keys matched successfully>

# Start Frontend

In [21]:
dataSet = RoadDataset(metadata=metadata, train=False)
max_images = metadata.shape[0]

#image evaluation based on model
def predict_image(input_int, model):
    img, mask = dataSet.__getitem__(input_int-1)
    print(img.shape)
    if model == "UNet2D":
        with torch.no_grad():
            output = unet2d_model(img.unsqueeze(0).to(device))
            iou = iou_score(output.to(device),mask.to(device)).cpu()
            pred = torch.sigmoid(output).cpu()
    elif model == "DeeplabV3":
        with torch.no_grad():
            output = deeplabv3_model(img.unsqueeze(0).to(device))
            print('bbbb')
            iou = iou_score(output.to(device),mask.to(device)).cpu()
            pred = torch.sigmoid(output).cpu()
    else:
        with torch.no_grad():
            output = swin_model(img.unsqueeze(0).to(device))
            iou = iou_score(output.to(device),mask.to(device)).cpu()
            pred = torch.sigmoid(output).cpu()
        
    pred = pred.squeeze(0)
    pred = pred.squeeze(0)
    pred_numpy = pred.numpy()
    mask = mask.cpu().squeeze(0).numpy()
    
    intersection = np.minimum(mask, pred_numpy)
    union = np.maximum(mask, pred_numpy)
    error = union - intersection
    return (img*0.5+0.5).cpu().permute(1,2,0).numpy(), mask, pred_numpy, intersection, union, error, iou.item()*100




In [22]:
# Gradio frontend
with gr.Blocks() as demo:
# Inputs
    with gr.Row():
        slider = gr.Slider(label="Select an Integer", minimum=1, maximum=max_images, step=1)
        method = gr.Radio(
            choices=["UNet2D", "DeeplabV3", "Swin-Unet"],
            value="UNet2D",
            label="Select the method"
        )
        exec = gr.Button("See")
# Outputs
    with gr.Row():
        img = gr.Image(type="pil", label="Original")
        img2 = gr.Image(type="pil", label="Mask")
        img3 = gr.Image(type="pil", label="Prediction")

    with gr.Row():
        intersection = gr.Image(type="pil", label="Intersection")
        union =  gr.Image(type="pil", label="Union")
        error =  gr.Image(type="pil", label="Error")
    iou = gr.Textbox(label="IoU Percentage")
    exec.click(fn=predict_image, inputs=[slider, method], outputs=[img, img2, img3, intersection, union, error,  iou])
demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.




torch.Size([3, 512, 512])


Traceback (most recent call last):
  File "c:\Python3.12\Lib\site-packages\gradio\queueing.py", line 624, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Python3.12\Lib\site-packages\gradio\route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Python3.12\Lib\site-packages\gradio\blocks.py", line 2043, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Python3.12\Lib\site-packages\gradio\blocks.py", line 1590, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Python3.12\Lib\site-packages\anyio\to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^