In [None]:
import torch
from eva.core.models.wrappers import _utils
from eva.vision.callbacks.loggers.batch.segmentation import _draw_semantic_mask, _overlay_mask
import matplotlib.pyplot as plt
from torch.nn import functional
from eva.core import metrics
from lightning_fabric.utilities import cloud_io
from eva.vision.models.networks.decoders.segmentation import ConvDecoderMS, PVTFormerDecoder
from eva.vision.models import modules, wrappers
from eva.vision.models.networks import adapters

## Load Model

In [2]:
hf_token = ""

In [3]:
checkpoint_path = "/mnt/localdisk/data/eva/models/consep/uni_convdecoder.ckpt"
model_1 = modules.SemanticSegmentationModule(
    encoder=None,
    decoder=ConvDecoderMS(in_features=1024, num_classes=5),
    criterion=None
)

checkpoint = torch.load(checkpoint_path)

model_1.load_state_dict(checkpoint['state_dict'])

model_1.encoder = wrappers.ModelFromRegistry(
    model_name="pathology/mahmood_uni",
    model_kwargs={"out_indices": 1, "hf_token": hf_token}
)

_ = model_1.eval()

In [None]:
checkpoint_path = "/mnt/localdisk/data/eva/models/consep/uni_adapter_pvtformer_last.ckpt"

model_2 = modules.SemanticSegmentationModule(
    encoder=adapters.ViTAdapter(
        vit_backbone=wrappers.ModelFromRegistry(
                model_name="pathology/mahmood_uni",
                model_kwargs={"hf_token": hf_token}
        ),
        deform_num_heads=8,
        freeze_vit=True
    ),
    decoder=PVTFormerDecoder(in_features=[1024]*3, num_classes=5),
    criterion=None
)

checkpoint = torch.load(checkpoint_path, map_location="cpu")

model_2.load_state_dict(checkpoint['state_dict'])

_ = model_2.eval()

## Load Data

In [None]:
from eva.vision import datasets
from eva.vision.data.wsi.patching import samplers
from eva.vision.data import transforms

dataset = datasets.CoNSeP(
    root="/mnt/localdisk/data/datasets/consep",
    split="val",
    sampler=samplers.ForegroundGridSampler(max_samples=25),
    transforms=transforms.ResizeAndCrop(size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
dataset.configure()

image, target, metadata = dataset[0]
image.shape

In [None]:
N_IMAGES = 10
MODELS = [model_1, model_2]

dice_metric = metrics.GeneralizedDiceScore(num_classes=5, input_format="index", weight_type="linear", per_class=False, include_background=False)

def _preprocess_image(image: torch.Tensor):
    image = image - image.min()
    image = image / image.max()
    return image

def _one_hot(tensor: torch.Tensor, num_classes: int=5):
    return functional.one_hot(tensor, num_classes=5).permute(2, 0, 1)

for i in range(N_IMAGES):
    image, target, metadata = dataset[i]
    image = _preprocess_image(image)

    fig, ax = plt.subplots(1, 2 + len(MODELS), figsize=(10, 3))
    for j, model in enumerate(MODELS):
        prediction = model(image.unsqueeze(0), to_size=(224, 224))
        prediction = torch.argmax(prediction, dim=1)

        image_with_mask = _overlay_mask(image, prediction.squeeze())

        fig.axes[0].imshow(image.permute(1, 2, 0))
        fig.axes[1].imshow(_draw_semantic_mask(target).permute(1, 2, 0))
        fig.axes[2+j].imshow(_draw_semantic_mask(prediction).permute(1, 2, 0))

        dice_value = dice_metric(_one_hot(prediction.squeeze()), _one_hot(target))

        fig.axes[2+j].set_title(f"Dice: {dice_value:.2f}")

        for ax in fig.axes:
            ax.axis('off')