In [1]:
# Load the autoreload extension
%load_ext autoreload

# Set autoreload mode
%autoreload 2

In [2]:
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # suppress TF warnings

In [3]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# 1. Setup

In [4]:
import albumentations as A
import mermaidseg.datasets.dataset
import numpy as np
from mermaidseg.io import setup_config, get_parser, update_config_with_args
import copy
import torch
from matplotlib import pyplot as plt

In [5]:
device_count = torch.cuda.device_count()
for i in range(device_count):
    print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
    
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True

CUDA Device 0: Tesla T4


In [6]:
from torch.utils.data import DataLoader, random_split
from mermaidseg.model.meta import MetaModel
from mermaidseg.model.eval import EvaluatorSemanticSegmentation
from mermaidseg.logger import Logger
from mermaidseg.model.train import train_model

2025-12-14 14:47:47.284630: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765723667.302911  137381 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765723667.308558  137381 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Config

In [7]:
# Start off with a configuration file
cfg = setup_config(
    config_path="../configs/linear-dinov3-base.yaml",
    config_base_path="../configs/base_mermaid.yaml",
)


# Update the initial configuration file with command line arguments
# (in the case of a notebook run these can be defined explicitly here)
args_input = "--run-name=mermaid_base_run_dinov3 --log-epochs=1"
args_input = args_input.split(" ")

parser = get_parser()
args = parser.parse_args(args_input)

cfg = update_config_with_args(cfg, args)
cfg_logger = copy.deepcopy(cfg)

# 2. Data

In [8]:
transforms = {}
for split in cfg.augmentation:
    transforms[split] = A.Compose(
        [
            getattr(A, transform_name)(**transform_params) for transform_name, transform_params
                                                                 in cfg.augmentation[split].items()
        ]
    )

In [9]:
dataset_name = cfg.data.pop("name", None)
batch_size = cfg.data.pop("batch_size", 8)
whitelist_sources = cfg.data.pop("whitelist_sources", None)

In [27]:
dataset_dict = {}
dataset_dict["train"] = getattr(mermaidseg.datasets.dataset, dataset_name)(transform = transforms[split], **cfg.data)
print(len(dataset_dict["train"]))

total_size = len(dataset_dict["train"])
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size
train_size, val_size, test_size

7891


(5523, 1183, 1185)

In [11]:
# total_size = len(dataset)
# train_size = int(0.7 * total_size)
# val_size = int(0.1 * total_size)
# test_size = total_size - train_size - val_size

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset_dict["train"], [train_size, val_size, test_size], generator=generator)
train_dataset = torch.utils.data.Subset(train_dataset, range(5000))
val_dataset = torch.utils.data.Subset(val_dataset, range(1000))
test_dataset = torch.utils.data.Subset(test_dataset, range(1000))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)

In [12]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

Number of training batches: 625
Number of validation batches: 125
Number of test batches: 125


In [13]:
dataset_dict["train"].num_classes, dataset_dict["train"].num_concepts

(16, None)

# 3. Model

In [14]:
meta_model = MetaModel(
    run_name=cfg.run_name,
    num_classes=dataset_dict["train"].num_classes,
    device=device,
    model_kwargs=cfg.model,
    training_kwargs=cfg.training,
)

evaluator = EvaluatorSemanticSegmentation(
    num_classes=dataset_dict["train"].num_classes,
    device=device,
)

In [15]:
cfg.logger.experiment_name = "mermaid"
cfg_logger.logger.experiment_name = "mermaid"

In [16]:
from mermaidseg.logger import Logger

logger = Logger(
    config = cfg_logger,
    meta_model = meta_model,
    log_epochs = cfg.logger.log_epochs,
    log_checkpoint = 2, #cfg.logger.log_checkpoint
    checkpoint_dir = ".",
    enable_mlflow = False,
    enable_wandb = True
)

[34m[1mwandb[0m: Currently logged in as: [33mviktor-domazetoski[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [1]:
from mermaidseg.model.train import train_model
train_model(meta_model, evaluator, train_loader, val_loader, logger=logger)