In [1]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path("..").resolve()))


# Train A → B (classic 2-stage)

Stage 1: pretrain on **Dataset A (boxes)**.  
Stage 2: fine-tune on **Dataset B (masks)**.

In [2]:
import torch

from models.models import build_model2
from datasets import cfg
from datasets.loader import DataModule, DataConfig
from train.trainer_v2 import Trainer, TrainConfig
from train.eval import Evaluator


  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "maskrcnn_attfpn"  
NUM_CLASSES = cfg.num_classes    # 1 + 24

# If True, DS A will include rectangle masks (weak masks derived from boxes).
PRETRAIN_WITH_WEAK_MASKS = True #let's just follow the paper whatever

TRACKING_URI = "file:///media/sdb1/mlflow"
EXPERIMENT_PRE_A = "AB_classic_preA"
EXPERIMENT_FT_B  = "AB_classic_ftB"

WEIGHTS_DIR = Path("../weights")

In [4]:
dm = DataModule(DataConfig(val_frac=0.1, batch_size=4, num_workers=4),num_channels=3,with_masks=bool(PRETRAIN_WITH_WEAK_MASKS),) #B always has masks, A produces rectangle masks if we ask for them 

a_train, a_val = dm.make_loaders_a()
b_train, b_val = dm.make_loaders_b()

print("A:", len(dm.ds_a_train), "train |", len(dm.ds_a_val), "val") #we split train part into train & val. there is also test part. 
print("B:", len(dm.ds_b_train), "train |", len(dm.ds_b_val), "val")


A: 3370 train | 374 val
B: 475 train | 52 val


In [5]:
model = build_model2(MODEL_NAME, NUM_CLASSES,weights_backbone=False,trainable_backbone_layers=5).to(DEVICE) #resnet 50 imagenet



## Stage 1 — pretrain on DS A (boxes)

If `PRETRAIN_WITH_WEAK_MASKS=True`, DS A provides rectangle masks so the mask head also sees a weak signal.
Otherwise, the trainer automatically skips the mask head when masks are missing (detection-only).

In [6]:
conf_a = TrainConfig(
    num_epochs=20,
    batch_size=4,
    num_workers=4,
    lr=0.005,
    weight_decay=1e-4,
    momentum=0.9,
    print_every=50,
    tracking_uri=TRACKING_URI,
    amp=True,
    grad_clip=1.0,
    ema_decay=0.999,     
    warmup_iters=1000,
    scheduler="cosine",
    min_lr=1e-6,
    freeze_bn=True,
) #augmenation is always present


trainer_a = Trainer(model, conf_a)
hist_a = trainer_a.run(a_train, a_val, experiment_name=EXPERIMENT_PRE_A)



  self.scaler = torch.cuda.amp.GradScaler(enabled=bool(train_conf.amp and self.device.type == "cuda"))


[epoch 001/020] step 50/843 loss 15.6025
[epoch 001/020] step 100/843 loss 2.8024
[epoch 001/020] step 150/843 loss 2.6432
[epoch 001/020] step 200/843 loss 3.3272
[epoch 001/020] step 250/843 loss 3.0136
[epoch 001/020] step 300/843 loss 2.7734
[epoch 001/020] step 350/843 loss 3.2292
[epoch 001/020] step 400/843 loss 3.1817
[epoch 001/020] step 450/843 loss 3.1564
[epoch 001/020] step 500/843 loss 3.0853
[epoch 001/020] step 550/843 loss 3.1200
[epoch 001/020] step 600/843 loss 3.1480
[epoch 001/020] step 650/843 loss 3.1927
[epoch 001/020] step 700/843 loss 3.0340
[epoch 001/020] step 750/843 loss 3.2251
[epoch 001/020] step 800/843 loss 3.0389
[epoch 001/020] step 843/843 loss 3.6312
[epoch 001/020] train=5.5617  val=6.7116  lr=0.004215
[epoch 002/020] step 50/843 loss 3.1306
[epoch 002/020] step 100/843 loss 2.9910
[epoch 002/020] step 150/843 loss 2.9539
[epoch 002/020] step 200/843 loss 3.0815
[epoch 002/020] step 250/843 loss 2.9441
[epoch 002/020] step 300/843 loss 3.0405
[epo



In [7]:
ckpt_a = WEIGHTS_DIR / f"{MODEL_NAME}_preA_random.pth"
torch.save(model.state_dict(), ckpt_a)
print("saved:", ckpt_a) #mlflow saves checkpoints also


saved: ../weights/maskrcnn_attfpn_preA_random.pth


## Stage 2 — fine-tune on DS B (masks)

In [8]:
# reset LR schedule I'm just training everything with the identical setup
conf_b = TrainConfig(
    num_epochs=40,
    batch_size=4,
    num_workers=4,
    lr=0.0004,
    weight_decay=1e-4,
    momentum=0.9,
    print_every=50,
    tracking_uri=TRACKING_URI,
    amp=True,
    grad_clip=1.0,
    ema_decay=0.999,
    warmup_iters=1000,
    scheduler="cosine",
    min_lr=1e-6,
    freeze_bn=True,
)

trainer_b = Trainer(model, conf_b)
hist_b = trainer_b.run(b_train, b_val, experiment_name=EXPERIMENT_FT_B)

ckpt_ab = WEIGHTS_DIR / f"{MODEL_NAME}_A2B_random.pth"
torch.save(model.state_dict(), ckpt_ab)
print("saved:", ckpt_ab)


  self.scaler = torch.cuda.amp.GradScaler(enabled=bool(train_conf.amp and self.device.type == "cuda"))


[epoch 001/040] step 50/119 loss 2.8353
[epoch 001/040] step 100/119 loss 2.4214
[epoch 001/040] step 119/119 loss 2.0971
[epoch 001/040] train=2.9200  val=3.3090  lr=4.76e-05
[epoch 002/040] step 50/119 loss 1.7504
[epoch 002/040] step 100/119 loss 1.8357
[epoch 002/040] step 119/119 loss 1.9666
[epoch 002/040] train=2.0260  val=3.0342  lr=9.52e-05
[epoch 003/040] step 50/119 loss 1.7681
[epoch 003/040] step 100/119 loss 1.8856
[epoch 003/040] step 119/119 loss 1.8172
[epoch 003/040] train=1.8662  val=2.7958  lr=0.0001428
[epoch 004/040] step 50/119 loss 1.9004
[epoch 004/040] step 100/119 loss 1.8543
[epoch 004/040] step 119/119 loss 1.9673
[epoch 004/040] train=1.7968  val=2.6043  lr=0.0001904
[epoch 005/040] step 50/119 loss 1.6957
[epoch 005/040] step 100/119 loss 1.7102
[epoch 005/040] step 119/119 loss 1.6801
[epoch 005/040] train=1.7511  val=2.4434  lr=0.000238
[epoch 006/040] step 50/119 loss 1.7732
[epoch 006/040] step 100/119 loss 1.5382
[epoch 006/040] step 119/119 loss 1.7



saved: ../weights/maskrcnn_attfpn_A2B_random.pth


## Quick sanity eval (optional)

In [9]:
ev = Evaluator(device=DEVICE)

a_test = dm.make_loader_a_test()
b_test = dm.make_loader_b_test()

map50_a = ev.map50(model, a_test)
masks_b = ev.metrics_masks(model, b_test, num_classes=NUM_CLASSES)

print("A test mAP@50:", map50_a)
print("B test metrics keys:", list(masks_b.keys())[:10])


A test mAP@50: 0.060342345386743546
B test metrics keys: ['mAP50', 'PQ_all', 'mPQ', 'PQ_per_class', 'AJI']


In [11]:

print(masks_b)

{'mAP50': 0.2434900552034378, 'PQ_all': 0.2821934302688574, 'mPQ': 0.2664130457957197, 'PQ_per_class': array([0.35175117, 0.37283354, 0.27804926, 0.28696305, 0.25987666,
       0.28920128, 0.29996131, 0.26858043, 0.18557743, 0.19228072,
       0.32752773, 0.20500502, 0.33150099, 0.27721892, 0.24712891,
       0.29926734, 0.3093991 , 0.28093804, 0.25772101, 0.30437598,
       0.35673964, 0.26613559, 0.05417384, 0.09170611,        nan]), 'AJI': 0.3199680921519608}


In [14]:


import mlflow.pytorch
model_uri = "file:///media/sdb1/mlflow/195229938318171777/0bf1a57b3e0748fe805a5747839be81f/artifacts/model"
m2 = mlflow.pytorch.load_model(model_uri).to(DEVICE)

In [15]:

map50_a = ev.map50(m2, a_test)
masks_b = ev.metrics_masks(m2, b_test, num_classes=NUM_CLASSES)

print("A test mAP@50:", map50_a)
print(masks_b)

A test mAP@50: 0.7488314509391785
{'mAP50': 0.10590457171201706, 'PQ_all': 0.02988920014479454, 'mPQ': 0.025426631659907026, 'PQ_per_class': array([0.00650751, 0.        , 0.00564667, 0.01814717, 0.0168132 ,
       0.01738305, 0.0064599 , 0.00307618, 0.02126989, 0.0130077 ,
       0.03837836, 0.00755604, 0.00957492, 0.01030654, 0.02581946,
       0.01400129, 0.03382844, 0.02724894, 0.09051408, 0.03974869,
       0.08590034, 0.07388909, 0.00604735, 0.03911434,        nan]), 'AJI': 0.16528589918892436}


In [20]:
import torch
from torch.utils.data import DataLoader

from datasets.base import collate_bb
from datasets import cfg
from train.metrics import get_metrics, compute_map50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_CONF = DataConfig(batch_size=1, num_workers=2)
DATA = DataModule(DATA_CONF, num_channels=3, with_masks=True)

ds_custom = DATA.ds_custom
ds_a_test = DATA.ds_a_test
ds_b_test = DATA.ds_b_test

def make_loader(ds, nw=2):
    return DataLoader(ds, batch_size=1, shuffle=False, num_workers=nw, collate_fn=collate_bb)

@torch.no_grad()
def collect(model, loader, need_masks: bool):
    model.eval()
    preds, targs = [], []
    for images, targets in loader:
        images = [im.to(device) for im in images]
        outputs = model(images)

        for out, gt in zip(outputs, targets):
            pred = {
                "boxes": out["boxes"].detach().cpu(),
                "scores": out["scores"].detach().cpu(),
                "labels": out["labels"].detach().cpu(),
            }
            if need_masks:
                pred["masks"] = out["masks"][:, 0].detach().cpu().numpy()  # (N,H,W) numpy

            targ = {"boxes": gt["boxes"].cpu()}
            if "labels" in gt and torch.is_tensor(gt["labels"]):
                targ["labels"] = gt["labels"].cpu()
            if need_masks and "masks" in gt and torch.is_tensor(gt["masks"]):
                targ["masks"] = gt["masks"].cpu().numpy()  # numpy

            preds.append(pred)
            targs.append(targ)

    return preds, targs

def eval_custom_class_agnostic(model):
    preds, targs = collect(model, make_loader(ds_custom), need_masks=True)
    for p in preds:
        p["labels"] = torch.ones((len(p["boxes"]),), dtype=torch.int64)
    for t in targs:
        t["labels"] = torch.ones((len(t["boxes"]),), dtype=torch.int64)
    return get_metrics(preds, targs, num_classes=2)

def eval_a_boxes(model):
    preds, targs = collect(model, make_loader(ds_a_test), need_masks=False)
    return {"mAP50": compute_map50(preds, targs)}

def eval_b_full(model):
    preds, targs = collect(model, make_loader(ds_b_test), need_masks=True)
    return get_metrics(preds, targs, num_classes=cfg.num_classes)

def run(model, name):
    model = model.to(device)
    print(f"\n== {name} ==")
    print("custom (class-agnostic):", eval_custom_class_agnostic(model))
    print("A test (boxes only):    ", eval_a_boxes(model))
    print("B test (full):          ", eval_b_full(model))

run(model, "m1")






== m1 ==
custom (class-agnostic): {'mAP50': 0.5262898206710815, 'PQ_all': 0.03672602001273298, 'mPQ': 0.03672602001273298, 'PQ_per_class': array([0.03672602,        nan]), 'AJI': 0.11215135616084955}
A test (boxes only):     {'mAP50': 0.18084953725337982}
B test (full):           {'mAP50': 0.7555909752845764, 'PQ_all': 0.210108580101161, 'mPQ': 0.211053354715712, 'PQ_per_class': array([0.07364656, 0.08869973, 0.09680295, 0.11778711, 0.12786023,
       0.10462205, 0.17621212, 0.16929189, 0.17586031, 0.1862569 ,
       0.20593897, 0.20849398, 0.27137837, 0.22846105, 0.26479355,
       0.32364428, 0.3005126 , 0.32247135, 0.32722444, 0.28884022,
       0.33605031, 0.28595396, 0.16713227, 0.21734532,        nan]), 'AJI': 0.2086153127503059}


In [None]:
m1 = build_model("maskrcnn_r50_fpn", int(cfg.num_classes)).to(device)
m1.load_state_dict(torch.load("../weights/maskrcnn_B_ep40.pth", map_location=device))
run(m1, "m0")


== m0 ==


  m1.load_state_dict(torch.load("../weights/maskrcnn_B_ep40.pth", map_location=device))


custom (class-agnostic): {'mAP50': 0.3757989704608917, 'PQ_all': 0.025895231791482922, 'mPQ': 0.025895231791482922, 'PQ_per_class': array([0.02589523,        nan]), 'AJI': 0.10967215214188494}
