# Imports


In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.models.segmentation as segmodels
import gradio as gr

# PATHS

In [2]:

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 256
THRESH = 0.5


WEIGHTS = {
    "VGG16_UNet":          "/kaggle/input/wht543/VGG16_UNet_BEST.pth",
    "DeepLabV3_R50":       "/kaggle/input/wht543/DeepLabV3_R50_BEST.pth",
    "FCN_R50":             "/kaggle/input/wht543/FCN_R50_fold1_dice0.8988.pth",
    "ResNet34_SE_UNet":    "/kaggle/input/wht543/ResNet34_SE_UNet_LR3e-05_B16_best_val0.9014.pth",
    "SegNet":           "/kaggle/input/wht543/SegNet_fold3_dice0.8915.pth",
}


# HELPERS

In [3]:
def preprocess_image(rgb_img, size=256):
    
    img = cv2.resize(rgb_img, (size, size), interpolation=cv2.INTER_AREA)
    img = img.astype(np.float32) / 255.0
    
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    img = (img - mean) / std
    img = np.transpose(img, (2, 0, 1))  # CHW
    x = torch.from_numpy(img).unsqueeze(0)  # 1CHW
    return x

def postprocess_mask(logits, out_h, out_w, thresh=0.5):
    
    prob = torch.sigmoid(logits)[0, 0].detach().cpu().numpy()
    pred = (prob > thresh).astype(np.uint8) * 255
    pred = cv2.resize(pred, (out_w, out_h), interpolation=cv2.INTER_NEAREST)
    return pred

def overlay_mask(rgb_img, mask_u8, alpha=0.45):
    
    overlay = rgb_img.copy()
    red = np.zeros_like(rgb_img)
    red[..., 0] = 255  # red channel in RGB

    m = (mask_u8 > 0).astype(np.float32)[..., None]
    overlay = (overlay * (1 - alpha * m) + red * (alpha * m)).astype(np.uint8)
    return overlay

def model_forward(model, x, arch_name):
    
    if arch_name in ["DeepLabV3_R50", "FCN_R50"]:
        return model(x)["out"]
    return model(x)

def load_state_safely(model, weight_path):
    if not os.path.isfile(weight_path):
        raise FileNotFoundError(f"Weight file not found: {weight_path}")
    sd = torch.load(weight_path, map_location=DEVICE)
    model.load_state_dict(sd)
    return model

# MODEL DEFINITIONS

In [4]:

class VGG16_UNet(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
        self.enc1 = vgg[:5]
        self.enc2 = vgg[5:10]
        self.enc3 = vgg[10:17]
        self.enc4 = vgg[17:24]
        self.enc5 = vgg[24:31]
        self.up = nn.Upsample(scale_factor=2, mode="nearest")

        def conv(i, o):
            return nn.Sequential(
                nn.Conv2d(i, o, 3, padding=1), nn.ReLU(inplace=True),
                nn.Conv2d(o, o, 3, padding=1), nn.ReLU(inplace=True)
            )

        self.d5 = conv(512 + 512, 512)
        self.d4 = conv(512 + 256, 256)
        self.d3 = conv(256 + 128, 128)
        self.d2 = conv(128 + 64,   64)
        self.d1 = conv(64, 64)
        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)

        d5 = self.d5(torch.cat([self.up(e5), e4], dim=1))
        d4 = self.d4(torch.cat([self.up(d5), e3], dim=1))
        d3 = self.d3(torch.cat([self.up(d4), e2], dim=1))
        d2 = self.d2(torch.cat([self.up(d3), e1], dim=1))
        d1 = self.d1(self.up(d2))
        return self.final(d1)  


In [5]:

class ConvBlock(nn.Module):
    def __init__(self, i, o):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(i, o, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(o, o, 3, 1, 1), nn.ReLU(True),
        )
    def forward(self, x): return self.net(x)

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channel, channel // reduction, bias=False)
        self.fc2 = nn.Linear(channel // reduction, channel, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = F.adaptive_avg_pool2d(x, (1, 1)).view(b, c)
        y = self.fc1(y)
        y = F.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class ResNet34_SE_UNet(nn.Module):
    def __init__(self):
        super().__init__()
        r = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)

        self.e0 = nn.Sequential(r.conv1, r.bn1, r.relu)  
        self.e1 = nn.Sequential(r.maxpool, r.layer1)     
        self.e2 = r.layer2                               
        self.e3 = r.layer3                               
        self.e4 = r.layer4                               

        self.se1 = SEBlock(64)
        self.se2 = SEBlock(128)
        self.se3 = SEBlock(256)
        self.se4 = SEBlock(512)

        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.d3 = ConvBlock(512 + 256, 256)
        self.d2 = ConvBlock(256 + 128, 128)
        self.d1 = ConvBlock(128 + 64,   64)
        self.d0 = ConvBlock(64  + 64,   64)
        self.out = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        a = self.e0(x)     
        b = self.e1(a)     
        b = self.se1(b)
        c = self.e2(b)     
        c = self.se2(c)
        d = self.e3(c)     
        d = self.se3(d)
        e = self.e4(d)     
        e = self.se4(e)

        x = self.d3(torch.cat([self.up(e), d], 1))
        x = self.d2(torch.cat([self.up(x), c], 1))
        x = self.d1(torch.cat([self.up(x), b], 1))
        x = self.d0(torch.cat([self.up(x), a], 1))
        x = self.up(x)
        return self.out(x)  

In [6]:

def DeepLabV3_R50():
    m = segmodels.deeplabv3_resnet50(weights="DEFAULT")
    m.classifier[4] = nn.Conv2d(256, 1, kernel_size=1)
    return m

In [7]:

def FCN_R50():
    m = segmodels.fcn_resnet50(weights="DEFAULT")
    m.classifier[4] = nn.Conv2d(512, 1, kernel_size=1)
    return m


In [8]:

class SegNet(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
        self.enc = nn.ModuleList([
            vgg[:5], vgg[5:10], vgg[10:17], vgg[17:24], vgg[24:31]
        ])
        self.dec = nn.ModuleList([
            nn.ConvTranspose2d(512,512,2,2),
            nn.ConvTranspose2d(512,256,2,2),
            nn.ConvTranspose2d(256,128,2,2),
            nn.ConvTranspose2d(128,64,2,2),
            nn.ConvTranspose2d(64,1,2,2),
        ])

    def forward(self, x):
        for e in self.enc:
            x = e(x)
            x = F.max_pool2d(x, 2)
        for d in self.dec:
            x = d(x)
        return x


# MODEL FACTORY + CACHE

In [9]:
MODEL_BUILDERS = {
    "VGG16_UNet": VGG16_UNet,
    "DeepLabV3_R50": DeepLabV3_R50,
    "FCN_R50": FCN_R50,
    "ResNet34_SE_UNet": ResNet34_SE_UNet,
    "SegNet": SegNet,
}

_loaded_models = {}

def get_model(arch_name):
    
    if arch_name in _loaded_models:
        return _loaded_models[arch_name]

    model = MODEL_BUILDERS[arch_name]()
    model = load_state_safely(model, WEIGHTS[arch_name])
    model = model.to(DEVICE).eval()
    _loaded_models[arch_name] = model
    return model


# INFERENCE FUNCTION (for GUI)

In [10]:
@torch.no_grad()
def predict(image, arch_name, thresh):
    if image is None:
        return None, None, "No image provided."


    rgb = np.array(image).astype(np.uint8)
    h, w = rgb.shape[:2]

    model = get_model(arch_name)

    x = preprocess_image(rgb, size=IMG_SIZE).to(DEVICE)
    logits = model_forward(model, x, arch_name)

    mask = postprocess_mask(logits, out_h=h, out_w=w, thresh=float(thresh))
    over = overlay_mask(rgb, mask, alpha=0.45)

    info = f"Device: {DEVICE} | Model: {arch_name} | Threshold: {thresh}"
    return mask, over, info


# GRADIO UI

In [None]:
with gr.Blocks(title="ISIC Segmentation GUI") as demo:
    gr.Markdown("## Skin Lesion Segmentation (Upload Image → Choose Model → Get Mask)")

    with gr.Row():
        inp = gr.Image(type="numpy", label="Upload Image (RGB)")
        with gr.Column():
            model_dd = gr.Dropdown(
                choices=list(MODEL_BUILDERS.keys()),
                value="VGG16_UNet",
                label="Choose Model"
            )
            thr = gr.Slider(0.1, 0.9, value=THRESH, step=0.05, label="Mask Threshold")
            btn = gr.Button("Run Inference")

    with gr.Row():
        out_mask = gr.Image(type="numpy", label="Predicted Mask (binary)")
        out_overlay = gr.Image(type="numpy", label="Overlay (mask on image)")

    status = gr.Textbox(label="Status", interactive=False)

    btn.click(
        fn=predict,
        inputs=[inp, model_dd, thr],
        outputs=[out_mask, out_overlay, status]
    )

    gr.Markdown(
        "### Notes\n"
        "- Make sure your weight files exist at the paths in the `WEIGHTS` dictionary.\n"
        "- If you trained on different normalization/size, match those in `preprocess_image()`.\n"
    )

demo.launch(debug=True)

* Running on local URL:  http://127.0.0.1:7860
It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

* Running on public URL: https://78db607043283e1579.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 170MB/s] 
