In [1]:
%load_ext autoreload
%autoreload 2
%load_ext yamlmagic

In [2]:
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Helpful plotting functions
import os
os.environ['NEURITE_BACKEND'] = 'pytorch'
import neurite as ne

In [3]:
import sys, os
sys.path.append(os.path.abspath('..'))

In [4]:
import universeg
import scribbleprompt
import multiverseg

# Data Loader

In [5]:
from multiverseg.datasets.wbc import WBCDataset
import itertools

d_support = WBCDataset('JTSC', split='support', label='cytoplasm')
d_test = WBCDataset('JTSC', split='test', label='cytoplasm')

n_support = 10

support_images, support_labels = zip(*itertools.islice(d_support, n_support))
support_images = torch.stack(support_images).to(device)
support_labels = torch.stack(support_labels).to(device)
print(support_images.shape, support_labels.shape)

torch.Size([10, 1, 128, 128]) torch.Size([10, 1, 128, 128])


## Train

In [6]:
# --- Setup ---
%load_ext autoreload
%autoreload 2

import os
import torch
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

os.environ['NEURITE_BACKEND'] = 'pytorch'

import universeg
import scribbleprompt
import multiverseg
from multiverseg.datasets.wbc import WBCDataset
from multiverseg.models.sp_mvs import MultiverSeg
from scribbleprompt.interactions.prompt_generator import FlexiblePromptEmbed
from pylot.experiment.util import eval_config
import neurite as ne

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cuda


In [7]:
# Split datasets
d_support = WBCDataset('JTSC', split='support', label='cytoplasm')
# Load once from the available split
d_full = WBCDataset('JTSC', split='test', label='cytoplasm')

# Split manually (e.g., 80% train, 20% val)
n_total = len(d_full)
n_train = int(0.8 * n_total)
indices = np.random.permutation(n_total)

train_indices = indices[:n_train]
val_indices = indices[n_train:]

# Build subsets
d_train = [d_full[i] for i in train_indices]
d_val = [d_full[i] for i in val_indices]

print(f"Train set: {len(d_train)} | Val set: {len(d_val)}")


n_support = 10
support_images, support_labels = zip(*[d_support[i] for i in range(n_support)])
support_images = torch.stack(support_images).to(device)
support_labels = torch.stack(support_labels).to(device)
print("Support:", support_images.shape, support_labels.shape)

Train set: 72 | Val set: 18
Support: torch.Size([10, 1, 128, 128]) torch.Size([10, 1, 128, 128])


In [8]:
model = MultiverSeg(version="v1",use_ijepa=True,device=device
                    )
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.BCEWithLogitsLoss()


[INFO] Using IJepa backbone for MultiverSegNet
[Override] Using I-JEPA backbone for target encoder


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [9]:
%load_ext yamlmagic

The yamlmagic extension is already loaded. To reload it, use:
  %reload_ext yamlmagic


In [10]:
%%yaml random_warm_start
_class: scribbleprompt.interactions.prompt_generator.FlexiblePromptEmbed
click_embed:
  _fn: scribbleprompt.interactions.embed.click_onehot
init_pos_click_generators:
    - _class: scribbleprompt.interactions.clicks.RandomClick
      train: False
init_neg_click_generators:
    - _class: scribbleprompt.interactions.clicks.RandomClick
      train: False
correction_click_generators:
  - _class: scribbleprompt.interactions.clicks.ComponentCenterClick
    train: False
init_pos_click: 3
init_neg_click: 3
correction_clicks: 1
prob_bbox: 0.0
prob_click: 1.0
from_logits: True

<IPython.core.display.Javascript object>

In [11]:
prompt_generator = eval_config(random_warm_start)

In [12]:
from torch.utils.data import DataLoader

batch_size = 1  # start small; can increase if memory allows
train_loader = DataLoader(d_train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
from scribbleprompt.models.unet import prepare_inputs

num_epochs = 100
train_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        # Generate clicks (one prompt per image in batch)
        prompts = prompt_generator(images, labels)
        clicks = {k: prompts.get(k) for k in ['point_coords', 'point_labels']}

        # Build 5-channel input tensor
        prepared = prepare_inputs({
            'img': images,
            'point_coords': clicks['point_coords'],
            'point_labels': clicks['point_labels'],
            'scribbles': None,
            'box': None,
            'mask_input': None
        }).float().to(device)

        # Forward pass
        logits = model.forward(prepared,
                               context_images=support_images[None],
                               context_labels=support_labels[None])

        # Compute loss
        target = labels.float()  # [B, 1, H, W]
        # if target.ndim == 5:
        #     target = target.squeeze(1)  # now [B, 1, H, W]

        loss = criterion(logits, target)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Free memory after each batch
        del images, labels, prepared, logits, target, loss
        torch.cuda.empty_cache()

    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch+1}: avg loss = {avg_loss:.4f}")

    # Optional visualization
    if (epoch+1) % 5 == 0:
        with torch.no_grad():
            pred = torch.sigmoid(logits).cpu()
            ne.plot.slices([images[0].cpu(), labels[0].cpu(), pred[0] > 0.5],
                           width=10, titles=['Image','Label','Prediction'])

Epoch 1/100:   0%|          | 0/72 [00:00<?, ?it/s]

ValueError: Target size (torch.Size([1, 1, 1, 128, 128])) must be the same as input size (torch.Size([1, 1, 128, 128]))