In [21]:
from PIL import Image
import torch
import numpy as np
import hydra
from hydra import compose
from hydra.core.global_hydra import GlobalHydra

# Initialize Hydra and load configuration
GlobalHydra.instance().clear()
hydra.initialize(config_path="configs", job_name="example_prediction")

model_name = "boltzformer_focal-l" # other options: "boltzformer_hiera-s", "boltzformer_hiera-bp"
cfg = compose(config_name=model_name)

# Instantiate the model from the configuration
model = hydra.utils.instantiate(cfg, _convert_="object")
if model_name == "boltzformer_focal-l":
    # initialize the FocalNet backbone with SEEM pretrained weights for easier finetuning
    model.load_pretrained("PretrainedModels/xdecoder_focall_last_oq101.pt")

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path="configs", job_name="example_prediction")


Checkpoint loaded successfully!


### Example Input and Output

In [22]:
# Load an example input image
image = Image.open("examples/01_CT_lung.png").convert("RGB")
image_array = np.array(image, dtype=np.float32)
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1)
print(image_tensor.shape, image_tensor.max(), image_tensor.min())

# ground truth mask
mask = Image.open("examples/01_CT_lung_nodule.png").convert("L")
mask_array = np.array(mask, dtype=np.float32)
mask_tensor = torch.from_numpy(1.0*(mask_array>0)).unsqueeze(0)
print(mask_tensor.shape, mask_tensor.max(), mask_tensor.min())

# example text prompt
text = 'lung nodule in CT scan'

torch.Size([3, 1024, 1024]) tensor(255.) tensor(0.)
torch.Size([1, 1024, 1024]) tensor(1., dtype=torch.float64) tensor(0., dtype=torch.float64)


In [23]:
from utils.loss import BoltzFormerLoss

loss_fn = BoltzFormerLoss()

# example input
input = {
    "image": image_tensor.unsqueeze(0),
    "text": [text]
}

# model forward pass
output = model(input, mode="train")

loss = loss_fn(output['predictions']['pred_gmasks'], mask_tensor.unsqueeze(0))
print("Loss:", loss.item())

Loss: 1.7780694066315534
