In [4]:
!pip install safetensors




In [1]:
from torchvision.models.segmentation import FCN_ResNet50_Weights
import gradio as gr
from PIL import Image
import torch
from transformers import AutoModelForSemanticSegmentation, AutoImageProcessor
import numpy as np
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101
from torchvision import transforms
from safetensors.torch import load_file

checkpoint = "nvidia/mit-b0"
id2label = {i: str(i) for i in range(20)}
id2label[255] = "255"
label2id = {str(i): i for i in range(20)}
label2id["255"] = 255

# Load the model architecture
model = AutoModelForSemanticSegmentation.from_pretrained(
    checkpoint, id2label=id2label, label2id=label2id
)

# Load custom trained weights using Safetensors
from safetensors.torch import load_file
custom_checkpoint_path = "../data/vit/model.safetensors"  # Path to your trained weights
state_dict = load_file(custom_checkpoint_path)  # Load the Safetensors file

# Load the state dict into the model
model.load_state_dict(state_dict, strict=False)
model.eval()

# Define the image processor
image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)

def model_vit(image):
    inputs = image_processor(images=[image], return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        upsampled_logits = torch.nn.functional.interpolate(
            logits,
            size=image.size[::-1],  # (height, width)
            mode="bilinear",
            align_corners=False,
        )
        predicted_segmentation = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
    return Image.fromarray((predicted_segmentation * 255).astype(np.uint8))

def model_cnn1(image):
    # Load the FCN-ResNet50 model
    model = fcn_resnet50(weights=None, num_classes=21)  # Use `weights=None`
    
    # Adjust the keys of the checkpoint to match the model
    checkpoint = torch.load('../data/cnn.pth', map_location=torch.device("cpu"))
    state_dict = {key.replace("model.", ""): value for key, value in checkpoint.items()}
    model.load_state_dict(state_dict)
    model.eval()
    
    # Define preprocessing steps
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Preprocess the input image
    input_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        # Forward pass through the model
        output = model(input_tensor)['out']
        output = torch.nn.functional.interpolate(
            output,
            size=image.size[::-1],  # (height, width)
            mode="bilinear",
            align_corners=False,
        )
        predicted_segmentation = output.argmax(dim=1).squeeze().cpu().numpy()
    
    return Image.fromarray((predicted_segmentation * 255).astype(np.uint8))


def model_cnn2(image):
    model = fcn_resnet101(weights=None, num_classes=21)
    # Adjust the keys of the checkpoint to match the model
    checkpoint = torch.load('../data/cnn_v2.pth', map_location=torch.device("cpu"))
    state_dict = {key.replace("model.", ""): value for key, value in checkpoint.items()}
    model.load_state_dict(state_dict)
    model.eval()
    
    # Define preprocessing steps
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Preprocess the input image
    input_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        # Forward pass through the model
        output = model(input_tensor)['out']
        output = torch.nn.functional.interpolate(
            output,
            size=image.size[::-1],  # (height, width)
            mode="bilinear",
            align_corners=False,
        )
        predicted_segmentation = output.argmax(dim=1).squeeze().cpu().numpy()
    
    return Image.fromarray((predicted_segmentation * 255).astype(np.uint8))

def segment_image(image, model_choice):
    if model_choice == "MiT-B0":
        return model_vit(image)
    elif model_choice == "ResNet-50":
        return model_cnn1(image)
    elif model_choice == "ResNet-101":
        return model_cnn2(image)

demo = gr.Interface(
    fn=segment_image,
    inputs=[gr.Image(type="pil"), gr.Dropdown(choices=["MiT-B0", "ResNet-50", "ResNet-101"], label="Select Model")],
    outputs=gr.Image(type="pil")
)

demo.launch()


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




Using existing dataset file at: .gradio\flagged\dataset2.csv


  checkpoint = torch.load('../data/cnn_v2.pth', map_location=torch.device("cpu"))
