In [1]:
# Load the autoreload extension
%load_ext autoreload

# Set autoreload mode
%autoreload 2

# 1. Setup

In [2]:
import albumentations as A
import mermaidseg.datasets.dataset
import numpy as np
from mermaidseg.io import setup_config, get_parser, update_config_with_args
import copy
import torch
from matplotlib import pyplot as plt

In [3]:
device_count = torch.cuda.device_count()
for i in range(device_count):
    print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
    
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True

CUDA Device 0: Tesla T4


In [4]:
from torch.utils.data import DataLoader, random_split
from mermaidseg.model.meta import MetaModel
from mermaidseg.model.eval import EvaluatorSemanticSegmentation
from mermaidseg.logger import Logger
from mermaidseg.model.train import train_model

2025-12-18 06:36:37.965271: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 06:36:37.979140: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766039797.997780  128293 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766039798.003532  128293 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-18 06:36:38.021516: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

## Config

In [5]:
# Start off with a configuration file
cfg = setup_config(config_path='../configs/linear-dinov3-concept-bottleneck.yaml', config_base_path='../configs/concept_mermaid.yaml')

# Update the initial configuration file with command line arguments 
# (in the case of a notebook run these can be defined explicitly here)
args_input = "--run-name=dinov3-test-concept-bottleneck-run_2 --batch-size=4 --epochs=5 --log-epochs=1"
args_input = args_input.split(" ")

parser = get_parser()
args = parser.parse_args(args_input)

cfg = update_config_with_args(cfg, args)
cfg_logger = copy.deepcopy(cfg)

# 2. Data

In [7]:
transforms = {}
for split in cfg.augmentation:
    transforms[split] = A.Compose(
        [
            getattr(A, transform_name)(**transform_params) for transform_name, transform_params
                                                                 in cfg.augmentation[split].items()
        ]
    )

In [6]:
dataset_name = cfg.data.pop("name", None)
batch_size = cfg.data.pop("batch_size", 4)
whitelist_sources = cfg.data.pop("whitelist_sources", None)

In [8]:
dataset_dict = {}

In [9]:
dataset_dict["train"] = getattr(mermaidseg.datasets.dataset, dataset_name)(transform = transforms[split], **cfg.data)

In [10]:
len(dataset_dict["train"])

8180

In [11]:
# for split in ["train", "val", "test"]:
#     dataset_dict[split] = getattr(mermaidseg.datasets.dataset, dataset_name)(transform = transforms[split], whitelist_sources=whitelist_sources[split], **cfg.data)

In [12]:
# for split in ["train", "val", "test"]:
#     if split in dataset_dict:
#         print(split, len(dataset_dict[split]), dataset_dict[split].num_classes)

In [13]:
# fig, ax = plt.subplots(figsize= (13,6), ncols = 2, nrows = 2)

# image, mask, annotations = dataset[0]
# print(image.shape, mask.shape)

# ax[0, 0].imshow(image.transpose(1,2,0))
# ax[0, 1].imshow(np.where(mask>0, mask, np.nan), cmap = "tab10", vmin=1, vmax=15)

# image, mask, annotations = dataset[0]

# ax[1, 0].imshow(image.transpose(1,2,0))
# ax[1, 1].imshow(np.where(mask>0, mask, np.nan), cmap = "tab10", vmin=1, vmax=15)

# plt.show()

In [11]:
total_size = len(dataset_dict["train"])
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset_dict["train"], [train_size, val_size, test_size], generator=generator)
train_dataset = torch.utils.data.Subset(train_dataset, range(5000))
val_dataset = torch.utils.data.Subset(val_dataset, range(1000))
test_dataset = torch.utils.data.Subset(test_dataset, range(1000))
# train_dataset = torch.utils.data.Subset(dataset_dict["train"], range(3000))
# val_dataset = torch.utils.data.Subset(dataset_dict["val"], range(500))
# test_dataset = torch.utils.data.Subset(dataset_dict["test"], range(500))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["val"].collate_fn)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["test"].collate_fn)

In [12]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

Number of training batches: 1250
Number of validation batches: 250
Number of test batches: 250


In [13]:
dataset_dict["train"].num_classes, dataset_dict["train"].num_concepts

(16, 20)

# 3. Model

In [14]:
meta_model = MetaModel(run_name = cfg.run_name, 
                       num_classes = dataset_dict["train"].num_classes,
                       num_concepts = dataset_dict["train"].num_concepts,
                       device = device,
                       model_kwargs = cfg.model,
                       training_mode = cfg.training_mode,
                       training_kwargs = cfg.training,
                       concept_matrix = dataset_dict["train"].benthic_concept_matrix,
                       conceptid2labelid = dataset_dict["train"].conceptid2labelid,)

In [15]:
evaluator = EvaluatorSemanticSegmentation(num_classes=dataset_dict["train"].num_classes,
                                            device=device,
                                            calculate_concept_metrics=True
                                            )

# from torchmetrics.classification import F1Score, JaccardIndex

# metric_dict = {
#             "f1_class": F1Score(task="multiclass", average = "none", num_classes=3, ignore_index = 0).to(device),
#             "mean_iou": JaccardIndex(task="multiclass", num_classes=3, ignore_index = 0).to(device),
#             "iou": JaccardIndex(task="multiclass", num_classes=3, ignore_index = 0, average='none').to(device)
#             }

# evaluator = EvaluatorSemanticSegmentation(num_classes=dataset.num_concepts,
#                                             device=device,
#                                             metric_dict = metric_dict
#                                             )

In [19]:
cfg.logger.experiment_name = "mermaid"
cfg_logger.logger.experiment_name = "mermaid"

from mermaidseg.logger import Logger

logger = Logger(
    config = cfg_logger,
    meta_model = meta_model,
    log_epochs = cfg.logger.log_epochs,
    log_checkpoint = 2, #cfg.logger.log_checkpoint
    checkpoint_dir = ".",
    enable_mlflow = False,
    enable_wandb = True
)

[34m[1mwandb[0m: Currently logged in as: [33mviktor-domazetoski[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from mermaidseg.model.train import train_model
train_model(meta_model, evaluator, train_loader, val_loader, logger=logger)

EPOCH: 0


100%|██████████| 1250/1250 [20:18<00:00,  1.03it/s]


LOSS train 0.3576062069684267
TRAIN METRICS: {'accuracy': 0.9245154857635498, 'mean_iou': 0.7320473194122314, 'f1_concept': 0.8983958959579468}


100%|██████████| 250/250 [03:26<00:00,  1.21it/s]


LOSS valid 1.301874003648758
VALID METRICS: {'accuracy': 0.6980696320533752, 'mean_iou': 0.44179531931877136, 'f1_concept': 0.6968477368354797}
EPOCH: 1


 44%|████▍     | 547/1250 [08:54<11:27,  1.02it/s]


KeyboardInterrupt: 

In [75]:
final_val_results = evaluator.evaluate_model(dataloader = val_loader, meta_model=meta_model)

{'accuracy': MulticlassAccuracy(), 'mean_iou': MulticlassJaccardIndex()}
{'f1_concept': MulticlassF1Score()}
concept-bottleneck


100%|██████████| 250/250 [03:21<00:00,  1.24it/s]


In [76]:
final_val_results

{'accuracy': 0.700265645980835,
 'mean_iou': 0.4505966901779175,
 'f1_concept': array([0.       , 0.9828919, 0.6998219], dtype=float32)}