In [None]:
import gradio as gr
from PIL import Image
import torch
from transformers import AutoModelForSemanticSegmentation, AutoImageProcessor
import numpy as np
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101
from torchvision import transforms

checkpoint = "nvidia/mit-b0"
id2label = {i: str(i) for i in range(20)}
id2label[255] = "255"
label2id = {str(i): i for i in range(20)}
label2id["255"] = 255

model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)

def model_vit(image):
    inputs = image_processor(images=[image], return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        upsampled_logits = torch.nn.functional.interpolate(
            logits,
            size=image.size[::-1],
            mode="bilinear",
            align_corners=False,
        )
        predicted_segmentation = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
    return Image.fromarray((predicted_segmentation * 255).astype(np.uint8))

def model_cnn1(image):
    model = fcn_resnet50(pretrained=False, num_classes=21)
    model.load_state_dict(torch.load('../data/cnn.pth', map_location=torch.device("cpu")))
    model.eval()
    
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)['out']
        output = torch.nn.functional.interpolate(
            output,
            size=image.size[::-1],
            mode="bilinear",
            align_corners=False,
        )
        predicted_segmentation = output.argmax(dim=1).squeeze().cpu().numpy()
    
    return Image.fromarray((predicted_segmentation * 255).astype(np.uint8))

def model_cnn2(image):
    model = fcn_resnet101(pretrained=False, num_classes=21)
    model.load_state_dict(torch.load('../data/cnn_v2.pth', map_location=torch.device("cpu")))
    model.eval()
    
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)['out']
        output = torch.nn.functional.interpolate(
            output,
            size=image.size[::-1],
            mode="bilinear",
            align_corners=False,
        )
        predicted_segmentation = output.argmax(dim=1).squeeze().cpu().numpy()
    
    return Image.fromarray((predicted_segmentation * 255).astype(np.uint8))

def segment_image(image, model_choice):
    if model_choice == "MiT-B0":
        return model_vit(image)
    elif model_choice == "ResNet-50":
        return model_cnn1(image)
    elif model_choice == "ResNet-101":
        return model_cnn2(image)

demo = gr.Interface(
    fn=segment_image,
    inputs=[gr.Image(type="pil"), gr.Dropdown(choices=["MiT-B0", "ResNet-50", "ResNet-101"], label="Select Model")],
    outputs=gr.Image(type="pil")
)

demo.launch()


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.batch_norm.running_var', 'decode_head.classifier.bias', 'decode_head.batch_norm.weight', 'decode_head.batch_norm.running_mean']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configurat

* Running on local URL:  http://127.0.0.1:7869

To create a public link, set `share=True` in `launch()`.




Traceback (most recent call last):
  File "c:\Users\naang\anaconda3\envs\DL_test\lib\site-packages\gradio\queueing.py", line 624, in process_events
    response = await route_utils.call_process_api(
  File "c:\Users\naang\anaconda3\envs\DL_test\lib\site-packages\gradio\route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
  File "c:\Users\naang\anaconda3\envs\DL_test\lib\site-packages\gradio\blocks.py", line 2043, in process_api
    result = await self.call_function(
  File "c:\Users\naang\anaconda3\envs\DL_test\lib\site-packages\gradio\blocks.py", line 1590, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "c:\Users\naang\anaconda3\envs\DL_test\lib\site-packages\anyio\to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "c:\Users\naang\anaconda3\envs\DL_test\lib\site-packages\anyio\_backends\_asyncio.py", line 2441, in run_sync_in_worker_thread
    re