# Imports

In [None]:
import json
import os
from functools import partial
from pathlib import Path
import datetime
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import ipywidgets as widgets
import pandas
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import AdamW
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import ramses2

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

%matplotlib widget

# Class properties

In [2]:
# min cls index must be 1 (0 is always bg, it is not needed in the dict)
cls_to_idx = {
    "Ra":1,
    "Rc":2,
    "Rb01":3,
    "Rb02": 4,
    "Rcu01":5,
    "Ru01": 6,
    "Ru02": 7,
    "Ru04": 8,
    "Ru05": 9,
    "Ru06": 10,
    "X01": 11,
    "Coin":12,
    "X02":13,
    "X03":14,
    "Rg":15
    }

idx_to_cls = {v: k for k, v in cls_to_idx.items()}
ncls = len(cls_to_idx)
print(f"Number of classes: {len(cls_to_idx)}")
print(f"{cls_to_idx}")

Number of classes: 15
{'Ra': 1, 'Rc': 2, 'Rb01': 3, 'Rb02': 4, 'Rcu01': 5, 'Ru01': 6, 'Ru02': 7, 'Ru04': 8, 'Ru05': 9, 'Ru06': 10, 'X01': 11, 'Coin': 12, 'X02': 13, 'X03': 14, 'Rg': 15}


# Building Model

## Setting config

In [3]:
target_shape = (4096 // 2 , 6144 // 2)  # image are resized to input_shape
input_shape = target_shape + (3,)
STEM_STRIDE = 4
backbone_strides = [4, 8, 16, 32]  # backbone level C2 etc...
FPN_input_strides = [8, 16, 32] # # C3 -> C5
FPN_output_strides = [8, 16, 32, 64] # P3 -> P6
# strides of the FPN levels used in the SOLO head
fpn_fusion_level = 0 # Merge all FPN level to level FPN level 0
first_grid_level = 0 # first FPN level used in the heads
upscale_FPN_output = False  # if True, the FPN output is upsampled one time to produce the unified mask representation.
mask_stride = FPN_output_strides[fpn_fusion_level]//2 if upscale_FPN_output else FPN_output_strides[fpn_fusion_level]
mask_shape = np.array(target_shape) // mask_stride
BACKBONE_LEVELS = 4  # number of levels in backbone
offset_factor = 0.75  # reduce the size of the target box in SOLO class head (1 = full box, 0.5 = half box)

# Grid sizes of the solo heads
grid_sizes = [144, 72, 36, 18] #

H, W = target_shape
maxdim = max(target_shape)
mindim = min(target_shape)
grid_HW = [(int(round(gs * H / maxdim)), (int(round(gs * W/maxdim)))) for gs in grid_sizes]
print("Grid sizes (H, W):", grid_HW)
nloc = sum([nx * ny for nx, ny in grid_HW])
print("Number of locations:", nloc)

scale_ranges = [[1, 128]]   # for 1280 * 1920 and grid size 144

factor = 0.5
for i in range(len(FPN_output_strides) - 1):
    scale_ranges.append([int(factor * scale_ranges[-1][1]), scale_ranges[-1][1] * 2])
scale_ranges[-1][1] = int(np.max(target_shape))

print("Mask shape:", mask_shape, "stride", mask_stride)
print("Backbone strides", backbone_strides)
print("Scale Ranges:", scale_ranges)
print("Grid sizes", grid_sizes)
print("Approximate dmin and dmax in pixels for each level")
for i, (dmin, dmax) in enumerate(scale_ranges):
    level = 1 + np.log2(FPN_output_strides[i])
    print(
        f"FPN Stride {FPN_output_strides[i]} (FPN level {int(level)}) \
          {offset_factor * dmin * grid_sizes[i] / (target_shape[1] ):.2f}, \
          {offset_factor * dmax * grid_sizes[i] / (target_shape[1] ):.2f}"
    )

# Config for Convnext - pretrained model
activation = "GELU"
normalization = "LayerNorm2d"
normalization_kw = {"eps":1e-6}
bb = "convnextv2_femto"
backbone_feature_nodes = {
    "stages.0": "C2",
    "stages.1": "C3",
    "stages.2": "C4",
    "stages.3": "C5",
}
backbone_params = {"build_head":True} # we need the head to load the full state_dict
load_backbone = "../checkpoints/backbones/convnextv2_femto_1k_224_ema.pt"
backbone_source="local"
# dims=[48, 96, 192, 384]
connection_layers = {"C3": 96, "C4": 192, "C5": 384}


Grid sizes (H, W): [(96, 144), (48, 72), (24, 36), (12, 18)]
Number of locations: 18360
Mask shape: [256 384] stride 8
Backbone strides [4, 8, 16, 32]
Scale Ranges: [[1, 128], [64, 256], [128, 512], [256, 3072]]
Grid sizes [144, 72, 36, 18]
Approximate dmin and dmax in pixels for each level
FPN Stride 8 (FPN level 4)           0.04,           4.50
FPN Stride 16 (FPN level 5)           1.12,           4.50
FPN Stride 32 (FPN level 6)           1.12,           4.50
FPN Stride 64 (FPN level 7)           1.12,           13.50


In [4]:
from sympy import true

params = {
    # backbone
    "load_backbone": load_backbone,
    "backbone_source": backbone_source,
    "backbone_feature_nodes": backbone_feature_nodes,
    "backbone_params": backbone_params,
    "backbone": bb,
    # Specific params
    "ncls": ncls,
    "imshape": input_shape,
    "mask_stride": mask_stride,  # It must match with the backbone and the param 'output_level'
    # General layers params
    "activation": activation,
    "normalization": normalization,
    "normalization_kw": normalization_kw,
    # FPN connection_layer smust be a dict with the id of the layer as key and its output channels as value
    "connection_layers": connection_layers,  # backbone connection layers,
    "FPN_filters": 256,
    "extra_FPN_layers": 1,  # layers after P5. Strides must correspond to the number of FPN layers !
    # SOLO head
    "strides": FPN_output_strides,  # strides of FPN levels used in the heads [used to compute targets]
    "head_layers": 4,  # Number of repeats of head conv layers
    "head_filters": 256,
    "kernel_size": 1,
    "grid_sizes": grid_sizes,
    # SOLO MASK head
    "point_nms": False,
    "mask_mid_filters": 128,
    "mask_output_filters": 256,
    "geom_feat_convs": 4,  # number of convs in the geometry factor branch
    "geom_feats_filters": 128,
    "mask_output_level": 0,  # size of the unified mask (in level of the FPN)
    "FPN_output_upscaling": upscale_FPN_output, #if True, upscale the FPN fusion
    # For inference
    "use_binary_masks": True,
    "sigma_nms": 0.5,
    "min_area": 0,
    # target allocation
    "scale_ranges":scale_ranges,
    "offset_factor":offset_factor
}

config = ramses2.Config(**params)

# loss and training parameters
lossweights = [1.0, 1.0, 1.0]
max_pos_samples = 512  # limit the number of positive gt samples when computing loss to limit memory footprint
print(config)

load_backbone:../checkpoints/backbones/convnextv2_femto_1k_224_ema.pt
backbone_params:{'build_head': True}
backbone:convnextv2_femto
backbone_source:local
backbone_feature_nodes:{'stages.0': 'C2', 'stages.1': 'C3', 'stages.2': 'C4', 'stages.3': 'C5'}
ncls:15
imshape:(2048, 3072, 3)
mask_stride:8
activation:GELU
normalization:LayerNorm2d
normalization_kw:{'eps': 1e-06}
connection_layers:{'C3': 96, 'C4': 192, 'C5': 384}
FPN_filters:256
extra_FPN_layers:1
strides:[8, 16, 32, 64]
head_layers:4
head_filters:256
kernel_size:1
grid_sizes:[144, 72, 36, 18]
point_nms:False
mask_mid_filters:128
mask_output_filters:256
geom_feat_convs:4
geom_feats_filters:128
mask_output_level:0
FPN_output_upscaling:False
sigma_nms:0.5
min_area:0
use_binary_masks:True
lossweights:[1.0, 1.0, 1.0]
max_pos_samples:512
scale_ranges:[[1, 128], [64, 256], [128, 512], [256, 3072]]
offset_factor:0.75



## Create network from existing config

In [12]:
#Either using the previous config or loading an existing one
config_path = Path("../checkpoints/2048x3072")
config_name = Path("config.json")
with open(config_path / config_name, 'r') as jsonfile:
    params = json.load(jsonfile)
config = ramses2.Config(**params)
model = ramses2.RAMSESModel(config)
# config.save(os.path.join(config_path, config_name))

Building model with config: {'build_head': True}


# Load Weights

In [None]:
checkpoint_path = Path("../checkpoints/2048x3072/best-val-loss.pt")
state_dict = torch.load(checkpoint_path)
model.load_state_dict(state_dict, strict=False)

# Caution: if you want to load a *.ckpt from lightning use :
# state_dict_cleaned = {}
# state_dict_raw = torch.load(checkpoint_path)["state_dict"]
# for key, value in state_dict_raw.items():
#     # Get the model
#     if key.startswith("model."):
#         new_key = key[6:]  # On retire le pr√©fixe 'model.'
#         state_dict_cleaned[new_key] = value
#     else:
#         state_dict_cleaned[key] = value
# model.load_state_dict(state_dict_cleaned, strict=False)

<All keys matched successfully>

# Load DatasetManager

The DatasetManager is useful to generate training and valid sets. </br>
It has methods to filter the images based on various criterion. </br>
It contains annotations and train/valid filenames to build a torch dataset

In [None]:
ds_path = Path("/PATH/TO/DATASET/FILES")    # csv and json files generated using DatasetManager
ds_name = "15CLS_20250723-173206_MASS_ONLY" # basename o fthe files
# ds_name = "15CLS_20250723-173206"

dataloader = ramses2.DatasetManager.from_file(
    annfile=ds_path / Path(ds_name + ".csv"), filename=ds_path / Path(ds_name + ".json"),
)
if len(dataloader.train_basenames) > 0:
    print("\nTrain dataset stats")
    print("number of images in training set:", len(dataloader.train_basenames))
    print("number of unique images in training set:", np.unique(dataloader.train_basenames).size)
    print("number of instances per class")
    print(dataloader.train_class_counts)

if len(dataloader.valid_basenames) > 0:
    print("Valid dataset stats")
    print("number of images in valid set:", len(dataloader.valid_basenames))
    print("number of unique images in valid set:", np.unique(dataloader.valid_basenames).size)
    print("number of instances per class")
    print(dataloader.valid_class_counts)


Train dataset stats
number of images in training set: 2897
number of unique images in training set: 1750
number of instances per class
{'Coin': 549, 'Pl': 0, 'Ra': 13539, 'Rb01': 14270, 'Rb02': 6348, 'Rc': 19749, 'Rcu01': 1056, 'Rg': 271, 'Ru01': 19843, 'Ru02': 14246, 'Ru04': 10193, 'Ru05': 10876, 'Ru06': 4734, 'SHELLS': 0, 'UNKNOWN': 0, 'X01': 2506, 'X02': 1148, 'X03': 1492, 'X04': 0}
Valid dataset stats
number of images in valid set: 258
number of unique images in valid set: 258
number of instances per class
{'Coin': 0, 'Pl': 0, 'Ra': 500, 'Rb01': 500, 'Rb02': 499, 'Rc': 500, 'Rcu01': 0, 'Rg': 0, 'Ru01': 500, 'Ru02': 500, 'Ru04': 500, 'Ru05': 500, 'Ru06': 499, 'SHELLS': 0, 'UNKNOWN': 0, 'X01': 502, 'X02': 0, 'X03': 0, 'X04': 0}


# Create Torch Dataset
TorchDataset creates a torch train/test Datasets using the DatasetManager annotations and filenames.</br>
Some augmentations are implemented (brightness, noise, rotation) </br>
CutMix is implemented in the collate function used in the Torch Dataloader </br>

In [6]:
print(mask_stride, target_shape)
train_dataset = ramses2.torchDataset(
    dataloader.annotations,
    dataloader.train_basenames,
    input_shape=target_shape,
    mask_stride=mask_stride,
    cls_to_idx=cls_to_idx,
    transform=ramses2.TorchAugmentations(probability=[0.33, 0.33, 0.33], factor=0.3, seed=0),
    crop_to_aspect_ratio=True,
    random_resize_method=True,
    seed=1,
)

valid_dataset = ramses2.torchDataset(
    dataloader.annotations,
    dataloader.valid_basenames,
    input_shape=target_shape,
    mask_stride=mask_stride,
    cls_to_idx=cls_to_idx,
    transform=ramses2.TorchAugmentations(probability=[0.0, 0.0, 0.0], factor=0.0, seed=0),
    crop_to_aspect_ratio=True,
    random_resize_method=True,
    seed=1,
)

8 (2048, 3072)


# Setting training parameters

## Freeze layers

In [12]:
w1 = widgets.Checkbox(value=False, description='Freeze backbone')
w2 = widgets.Checkbox(value=False, description='Freeze FPN')
w3 = widgets.Checkbox(value=False, description='Freeze cls head')
w4 = widgets.Checkbox(value=False, description='Freeze kernel head & masks')
w5 = widgets.Checkbox(value=False, description='Freeze all density layers')
print("")
display(w1, w2, w3, w4, w5)




Checkbox(value=False, description='Freeze backbone')

Checkbox(value=False, description='Freeze FPN')

Checkbox(value=False, description='Freeze cls head')

Checkbox(value=False, description='Freeze kernel head & masks')

Checkbox(value=False, description='Freeze all density layers')

In [13]:
freeze_backbone = w1.value
freeze_FPN = w2.value
freeze_cls_head = w3.value
freeze_kernel_mask_head = w4.value
freeze_density_layers = w5.value

for  param in model.parameters():
    param.requires_grad = True

if freeze_backbone:
    for param in model.backbone.parameters():
        param.requires_grad = False

if freeze_FPN:
    for param in model.FPN.parameters():
        param.requires_grad = False

if freeze_cls_head:
    for param in model.shared_heads.class_head.parameters():
        param.requires_grad = False
    for param in model.shared_heads.class_logits.parameters():
        param.requires_grad = False

if freeze_kernel_mask_head:
    # kernel head
    for param in model.shared_heads.kernel_head.parameters():
        param.requires_grad = False
    for param in model.shared_heads.kernel_out.parameters():
        param.requires_grad = False
    # mask head
    for param in model.mask_head.level_pipelines.parameters():
        param.requires_grad = False
    for param in model.mask_head.mask_out.parameters():
        param.requires_grad = False

if freeze_density_layers:
    # class factor head
    for param in model.shared_heads.class_factor_head.parameters():
        param.requires_grad = False
    for param in model.shared_heads.class_factor_out.parameters():
        param.requires_grad = False
    # geometry factor head
    for param in model.mask_head.geom_convs.parameters():
        param.requires_grad = False
    for param in model.mask_head.geom_final.parameters():
        param.requires_grad = False

for name, param in model.named_parameters():
    if param.requires_grad:
        status = "TRAINABLE"
    else:
        status = "FROZEN"
        print(f"{name:<50} | {status}")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
total_params = trainable_params + frozen_params
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {frozen_params:,}")

shared_heads.class_factor_head.0.ops.1.weight      | FROZEN
shared_heads.class_factor_head.0.ops.2.weight      | FROZEN
shared_heads.class_factor_head.0.ops.2.bias        | FROZEN
shared_heads.class_factor_head.1.ops.1.weight      | FROZEN
shared_heads.class_factor_head.1.ops.2.weight      | FROZEN
shared_heads.class_factor_head.1.ops.2.bias        | FROZEN
shared_heads.class_factor_out.ops.1.weight         | FROZEN
shared_heads.class_factor_out.ops.1.bias           | FROZEN
mask_head.geom_convs.0.ops.1.weight                | FROZEN
mask_head.geom_convs.0.ops.2.weight                | FROZEN
mask_head.geom_convs.0.ops.2.bias                  | FROZEN
mask_head.geom_convs.1.ops.1.weight                | FROZEN
mask_head.geom_convs.1.ops.2.weight                | FROZEN
mask_head.geom_convs.1.ops.2.bias                  | FROZEN
mask_head.geom_convs.2.ops.1.weight                | FROZEN
mask_head.geom_convs.2.ops.2.weight                | FROZEN
mask_head.geom_convs.2.ops.2.bias       

## TRAINING parameters
Set Callbacks, learning rates, dataloader...

In [None]:
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
nx, ny = target_shape
# set your own log directory !
logdir = Path(f"PATH/TO/YOUR_LOGDIR_{now}")

print("Run will be saved in", logdir)

callbacks = []

loggers = [TensorBoardLogger(logdir / Path("tb_logs"), name="RAMSES2"), CSVLogger(logdir, name="RAMSES2")]

callbacks.append(ModelCheckpoint(
    monitor="val_total_loss", filename=os.path.join(logdir, "best-val_loss"), mode="min", save_top_k=2, save_last=False
))

callbacks.append(ModelCheckpoint(
    monitor="val_precision", filename=os.path.join(logdir, "best-val-prec_loss"), mode="max", save_top_k=1, save_last=False
))

callbacks.append(ModelCheckpoint(
    monitor="train_total_loss",
    filename=os.path.join(logdir, "best-train_loss"),
    mode="min",
    save_top_k=2,
    save_last=False,
))

callbacks.append(LearningRateMonitor(logging_interval="epoch"))
# callbacks.append(EarlyStopping(monitor="val_density_loss", min_delta=0.00, patience=10, verbose=False, mode="min"))
callbacks.append(EarlyStopping(monitor="val_precision", min_delta=0.00, patience=10, verbose=False, mode="max"))
callbacks.append(TQDMProgressBar(leave=True))

train_config = ramses2.TrainConfig(
    lr=1e-5,
    mask_quality_weighting=True,
    losses={"cls": True, "seg": True, "mass": False}, seg_loss="dice", label_smoothing=0, cls_threshold=0.5
)

scheduler_config = [
    {
        "scheduler": ReduceLROnPlateau,
        "mode": "min",
        "patience": 3,
        "factor": 0.5,
        "monitor": "train_total_loss",
        "interval": "epoch",
        "frequency": 1,
    },
]

# Set a small lr for kernel and mask head to avoid divergence
kernel_params = set()
for name, param in model.named_parameters():
    if ("kernel_head" in name or "kernel_out" in name or name.startswith("mask_head")) and param.requires_grad:
        kernel_params.add(param)

bb_params = set()
for name, param in model.named_parameters():
    if "backbone" in name and param.requires_grad:
        bb_params.add(param)

other_params = [p for p in model.parameters() if p.requires_grad and ((p not in kernel_params) and (p not in bb_params))]

optimizer = torch.optim.AdamW(
    [
        {"params": other_params, "lr": 1e-4, "weight_decay":5e-3},
        {"params": list(kernel_params), "lr": 1e-5, "weight_decay":5e-3},
        {"params": list(bb_params), "lr": 1e-5, "weight_decay":5e-3},
    ]
)

# optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, weight_decay=5e-4)

model.config.max_pos_samples = 768

lightning_model = ramses2.RAMSESLightning(
    config, train_config=train_config, optimizer=optimizer, scheduler_config=scheduler_config
)

# Here we use a lightning model
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=callbacks,
    accelerator="gpu",
    logger=loggers,
    num_sanity_val_steps=0,
    gradient_clip_val=1,
)
batch_size = 8

# Use collate_fn_cutmix only for instance segmentation training - not for mass prediction
collate_fn = partial(ramses2.collate_fn_cutmix, p=0.5, mask_stride=config.mask_stride, max_patches=3, min_patch_ratio=0.1, max_patch_ratio=0.25)
# just use collate_fn whent training mass prediction

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn, pin_memory=True
)
valid_dataloader = DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=ramses2.collate_fn, pin_memory=True
)

ntrain = len(train_dataset)
nvalid = len(valid_dataset)

Run will be saved in /media/jlux/SSD1/Projets/ARCADE/RAMSES2/runs/1024x1536/SOLO__newgrid_20251003-133303


NameError: name 'model' is not defined

## Training


In [None]:
os.makedirs(logdir, exist_ok=True)
print("save folder", logdir)
torch.set_float32_matmul_precision("high")
# Here we save the database in the config file saved in the logdir
try:
    config.database = os.path.join(ds_path, ds_name)
    config.save(os.path.join(logdir, "config.json"))
except AttributeError:
    with open(Path(logdir) / Path("config.json"), "w") as jsonfile:
        json.dump(config, jsonfile)

print("Losses:", train_config.losses)

trainer.fit(lightning_model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
# trainer.fit(
#     lightning_model,
#     train_dataloaders=train_dataloader,
#     val_dataloaders=valid_dataloader,
#     ckpt_path=Path(
#         "/media/jlux/SSD1/Projets/ARCADE/RAMSES2/runs/logs/run_20250711-221033/best-train_loss-epoch=31.ckpt"
#     ),
# )









# View predictions

In [None]:
iterator = iter(valid_dataset)
# model = model.to(torch.device('cuda:0'))
model.eval()
print("")
results_dataframe = pandas.DataFrame(columns=['Filename', 'PredictedMass', 'GTMass', 'RE'])
# results_dataframe.set_index('Filename', inplace=True)




In [None]:
%matplotlib widget

for i in range(len(valid_dataset)):
#inputs = next(iterator)
    inputs = valid_dataset[i]
    # print(inputs)
    fn = inputs["filename"]
    gt_img = inputs["image"]
    gt_mask_img = inputs["masks"]
    gt_cls_ids = inputs["category_id"]
    gt_labels = inputs["label"]
    gt_mass = inputs["mass"]
    mask_res = inputs["res"][0]
    gt_cls_labels = [idx_to_cls[id] for id in list(gt_cls_ids.numpy())]
    gt_img = inputs["image"]
    nx, ny = gt_img.shape[-2:]

    print("Propcessing Image", fn, gt_img.shape)
    folder = dataloader.annotations.loc[dataloader.annotations["baseimg"] == fn]["folder"].iloc[0]
    print("folder", folder)

    res_ini = dataloader.annotations.loc[dataloader.annotations["baseimg"] == fn]["res"].to_numpy()[0]
    # The image may have been cropped, so we need to adjust the resolution
    im = Image.open(os.path.join(folder, "images", fn + ".jpg"))
    width, height = im.size
    res_input = res_ini * gt_img.shape[1] / height

    print(f"original resolution {res_ini} input resolution: {res_input}  mask resolution: {mask_res} ({res_input / config.mask_stride}) ")
    model.config.sigma_nms = 0.5
    print("number of gt instances", gt_cls_ids.shape[-1])
    # gt_img=gt_img.to(torch.device('cuda:0'))
    with torch.no_grad():
        results = model(gt_img.unsqueeze(0),
                    training=False,
                    cls_threshold=0.5,
                    nms_threshold=0.3, # iou threshold or cls threshold in MatrixNMX
                    mask_threshold=0.5,
                    max_detections=768,
                    scale_by_mask_scores=False,
                    min_area=32,
                    nms_mode="greedy")[0]

    # [{"masks": seg_preds, "scores": scores, "cls_labels": cls_labels_pos, "masses": masses}]

    processed_masks = ramses2.decode_predictions(results['masks'], results["scores"], threshold=0.5, by_mask_scores=False)
    pred_masks = results['masks'].detach().cpu().numpy()
    pred_masses = results['masses'].detach().cpu().numpy()
    pred_cls_ids = results["cls_labels"].detach().cpu().numpy()
    pred_scores = results["scores"].detach().cpu().numpy()

    print("number of detected instances ", pred_masks.shape[0])
    print("scores", pred_scores)

    plt.close()

    # gt_mass is not always defined
    finite_indexes = np.where(np.isfinite(gt_mass.numpy()))
    # unormalize mass predictions
    gt_mass = gt_mass[finite_indexes] / (10*mask_res)**2
    pred_masses = pred_masses / (10*mask_res)**2

    print("GT total mass", gt_mass.sum().item())
    print("PRED total masses", pred_masses.sum().item())
    # print("GT masses", np.sort(gt_mass))
    # print("PRED masses", np.sort(pred_masses))
    pred_cls_labels = [idx_to_cls[id + 1] for id in list(pred_cls_ids)]
    print(f"GT class labels \n{gt_cls_labels}")
    print(f"PRED class labels \n{pred_cls_labels}")

    new_row = [{'Filename': fn,
            'PredictedMass': pred_masses.sum().item(),
            'GTMass': gt_mass.sum().item(),
            'RE': abs(pred_masses.sum().item() - gt_mass.sum().item()) / gt_mass.sum().item() if gt_mass.sum().item() > 0 else np.nan}]

    results_dataframe = pandas.concat([results_dataframe, pandas.DataFrame(new_row)])


    image = gt_img.permute(1,2,0).cpu().numpy()

    fig = ramses2.plot_instances(
        image,
        processed_masks.cpu().detach().numpy(),
        cls_ids=pred_cls_labels,
        cls_scores=pred_scores,
        alpha=0.4,
        fontsize=3,
        fontcolor="black",
        draw_boundaries=True,
        boundary_mode="inner",
        dpi=200,
        show=False,
        x_offset=20,
        y_offset=10,
    )

    plt.show()
    plt.savefig(f"/media/jlux/SSD1/Projets/ARCADE/RAMSES2/segmentations/SEG_{fn}.png", dpi=400)

In [None]:
results_dataframe.to_excel("mass_res_per_img.xlsx")