In [None]:
import torch
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import wandb

# Initialize wandb API
api = wandb.Api()

# Specify the artifact path
artifact_path = 'nadjaflechner/Finetune_segformer_sweep/finetuned_segformer:v44'

# Download the artifact
artifact = api.artifact(artifact_path)
artifact_dir = artifact.download()

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b5-finetuned-ade-640-640",
    num_labels=2,
    ignore_mismatched_sizes=True
).to(device)

# Load the state dict
state_dict = torch.load(f"{artifact_dir}/model.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Initialize the image processor
image_processor = SegformerImageProcessor(
    image_mean=[74.90, 85.26, 80.06],
    image_std=[15.05, 13.88, 12.01],
    do_reduce_labels=False
)

def plot_segmentation(image_path):
    # Load and process the image
    image = Image.open(image_path)
    pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)

    # Generate prediction
    with torch.no_grad():
        outputs = model(pixel_values=pixel_values)

    # Process the output
    logits = outputs.logits.cpu()
    predicted_segmentation_map = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
    predicted_segmentation_map = predicted_segmentation_map.cpu().numpy()
    # Create color segmentation
    color_seg = np.zeros((predicted_segmentation_map.shape[0],
                        predicted_segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3

    color = np.array([4, 250, 7])
    color_seg[predicted_segmentation_map == 0, :] = color
    # Convert to BGR
    color_seg = color_seg[..., ::-1]

    # Overlay segmentation on image
    img = np.array(image) * 0.5 + color_seg * 0.5
    img = img.astype(np.uint8)

    # Plot
    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.title("Segmentation Result")
    plt.axis('off')
    plt.show()

# Example usage
image_path = "path/to/your/image.jpg"
plot_segmentation(image_path)

# Original claude suggestion:

In [None]:
import torch
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import wandb

# Initialize wandb API
api = wandb.Api()

# Specify the artifact path
artifact_path = 'nadjaflechner/Finetune_segformer_sweep/finetuned_segformer:v44'

# Download the artifact
artifact = api.artifact(artifact_path)
artifact_dir = artifact.download()

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b5-finetuned-ade-640-640",
    num_labels=2,
    ignore_mismatched_sizes=True
).to(device)

# Load the state dict
state_dict = torch.load(f"{artifact_dir}/model.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Initialize the image processor
image_processor = SegformerImageProcessor(
    image_mean=[74.90, 85.26, 80.06],
    image_std=[15.05, 13.88, 12.01],
    do_reduce_labels=False
)

def plot_segmentation(image_path):
    # Load and process the image
    image = Image.open(image_path)
    pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)

    # Generate prediction
    with torch.no_grad():
        outputs = model(pixel_values=pixel_values)

    # Process the output
    logits = outputs.logits
    upsampled_logits = torch.nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False
    )
    predicted_segmentation_map = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()

    # Create color segmentation
    color_seg = np.zeros((predicted_segmentation_map.shape[0],
                          predicted_segmentation_map.shape[1], 3), dtype=np.uint8)
    
    color = np.array([4, 250, 7])  # Green color for the segmentation
    color_seg[predicted_segmentation_map == 1] = color  # Assuming 1 is the target class
    
    # Overlay segmentation on image
    img = np.array(image) * 0.5 + color_seg * 0.5
    img = img.astype(np.uint8)

    # Plot
    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.title("Segmentation Result")
    plt.axis('off')
    plt.show()

# Example usage
image_path = "path/to/your/image.jpg"
plot_segmentation(image_path)