# Tutorial for Pancake

This notebook provides inference examples for _Pancakes: Consistent Multi-Protocol Image Segmentation Across Biomedical Domains_, accepted at Neurips 2025.

## Imports

In [None]:
import math
import itertools

from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt

import einops as E
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Imports for the interactive session
import neurite as ne
import ipywidgets as widgets
from ipywidgets import interact, interactive, VBox, HBox
import matplotlib.pyplot as plt
import numpy as np

## Load Model

In [None]:
import pancakes
model = pancakes.model.pancakesmodel(pretrained=True).to(device)
model.eval()

## Load data

In [None]:
from pancakes.data.wbc import WBCDataset
from torch.utils.data import DataLoader

num_sets = 3
M = 8  # number of label maps to predict
K = 8 # maximum number of labels per map
label_cmap = 'turbo'

dataset = WBCDataset(dataset="CV", label=None)
dataloader = DataLoader(dataset, batch_size=num_sets, shuffle=True)

## Predictions

In [None]:
x, y = next(iter(dataloader))
x = x.to(device)
with torch.no_grad():
    logits = model(x[None], 
                   M=M,
                   K=K)

In [None]:
predictions = torch.argmax(
            input=logits,
            dim=3,
            keepdims=True
        ).squeeze().cpu()

## Visualize predictions

In [None]:
from pancakes.vis.vis import plotpreds

plotpreds(x=x, predictions=predictions, y=y, M=M, num_sets=num_sets, label_cmap=label_cmap)

## Interactive session to play with M and K
It is up to the user to determine ideal values of M and K. The user can see the impact of M and K by running the cell below and changing the parameter values with the slider bars. 

In [None]:
# Define the interactive update function
def update(M=8, K=5):
    num_sets = x.shape[0]
    with torch.no_grad():
        logits = model(x[None], 
                    M=M,
                    K=K)
    predictions = torch.argmax(
            input=logits,
            dim=3,
            keepdims=True
        ).squeeze().cpu()
    

    img2plot = []
    cmaps = []
    titles = []
    for i in range(num_sets):
        img2plot.extend([x[i, 0].cpu().numpy(), *predictions[i].numpy()])
        cmaps.extend(['gray', *[label_cmap for _ in range(M)]])
        titles.extend([f'Image {i+1}', *[f'Map {m+1}' for m in range(M)]])

    ne.plot.slices(img2plot,
            cmaps=cmaps,
            grid=(num_sets, M+1),
            titles=titles,
            )



# Create sliders for M and K
M_slider = widgets.IntSlider(value=8, min=1, max=16, step=1, description='M')
K_slider = widgets.IntSlider(value=20, min=3, max=30, step=1, description='K')

# Combine everything into a UI
ui = VBox([M_slider, K_slider])
out = widgets.interactive_output(update, {'M': M_slider, 'K': K_slider})

display(ui, out)

## Interactive session Location specific
If the user is interested in a specific region, it can specify that by changing the c1 and c2 values using the slider bars

In [None]:
# Define the interactive update function
def update(M=8, K=5, c1=64, c2=64):
    num_sets = x.shape[0]
    with torch.no_grad():
        logits = model(x[None], 
                    M=M,
                    K=K)
    predictions = torch.argmax(
            input=logits,
            dim=3,
            keepdims=True
        ).squeeze().cpu()
    

    # Plot only labels at location (c1, c2)
    pred_plot = predictions.numpy()
    for m in range(M):
        selected_labels = np.unique(pred_plot[:, m, c1, c2])
        if 0 in selected_labels:
            pred_plot[pred_plot == 0] = np.max(pred_plot)+1

        # Set all the values of pred_plot that are not in selected_labels to zero
        pred_plot[:, m, :, :] = pred_plot[:, m, :, :] * np.isin(pred_plot[:, m, :, :], selected_labels).astype(pred_plot.dtype)

    pred_plot[:, :, c1-2:c1+2, c2-2:c2+2] = np.max(pred_plot)+10

    img2plot = []
    cmaps = []
    titles = []
    for i in range(num_sets):
        img2plot.extend([x[i, 0].cpu().numpy(), *pred_plot[i]])
        cmaps.extend(['gray', *[label_cmap for _ in range(M)]])
        titles.extend([f'Image {i+1}', *[f'Map {m+1}' for m in range(M)]])

    ne.plot.slices(img2plot,
            cmaps=cmaps,
            grid=(num_sets, M+1),
            titles=titles,
            )


# Create sliders for M and K
M_slider = widgets.IntSlider(value=8, min=1, max=16, step=1, description='M')
K_slider = widgets.IntSlider(value=20, min=3, max=30, step=1, description='K')
c1_slider = widgets.IntSlider(value=64, min=0, max=127, step=1, description='c1')
c2_slider = widgets.IntSlider(value=64, min=0, max=127, step=1, description='c2')

# Combine everything into a UI
ui = VBox([M_slider, K_slider, c1_slider, c2_slider])
out = widgets.interactive_output(update, {'M': M_slider, 'K': K_slider, 'c1': c1_slider, 'c2': c2_slider})

display(ui, out)