In [3]:
NUM_CLASSES = 1 + 24 

In [4]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
from models.models import build_model     
from datasets.loader import DataModule, DataConfig  
from train.trainer import Trainer, TrainConfig  
from train.eval import Evaluator             
from datasets.base import collate_bb
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 25
B_CKPT = "./weights/maskrcnn_B_ep40.pth"

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


In [5]:
WARMUP_EPOCHS = 3
FINETUNE_EPOCHS = 10
BATCH_SIZE = 4
NUM_WORKERS = 4
LR_WARMUP = 1e-4
LR_FINETUNE = 5e-4
WEIGHT_DECAY = 1e-4
MOMENTUM = 0.9
FREEZE_MASK_HEAD_IN_FINETUNE = True

In [None]:
model = build_model("maskrcnn_attfpn", NUM_CLASSES).to(DEVICE)
sd = torch.load(B_CKPT, map_location="cpu")
model.load_state_dict(sd, strict=False)


def allow_missing_masks(model):
    orig_forward = model.forward
    def forward(images, targets=None):
        if model.training and targets is not None and not all(("masks" in t) for t in targets):
            rh = model.roi_heads
            saved = (rh.mask_roi_pool, rh.mask_head, rh.mask_predictor)
            try:
                rh.mask_roi_pool, rh.mask_head, rh.mask_predictor = None, None, None
                return orig_forward(images, targets)
            finally:
                rh.mask_roi_pool, rh.mask_head, rh.mask_predictor = saved
        return orig_forward(images, targets)
    model.forward = forward
    return model
allow_missing_masks(model)

In [7]:
def req(mod, flag: bool):
    for p in mod.parameters():
        p.requires_grad = flag

In [8]:
dm = DataModule(DataConfig(val_frac=0.1, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS), with_masks = False)
b_train_loader, b_val_loader = dm.make_loaders_b()
a_train_box = dm.ds_a_train
a_val_box   = dm.ds_a_val
ab_train_ds = ConcatDataset([a_train_box, dm.ds_b_train])
ab_val_ds   = ConcatDataset([a_val_box, dm.ds_b_val])

ab_train_loader = DataLoader(ab_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=collate_bb)
ab_val_loader = DataLoader(ab_val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_bb)

b_test_loader = dm.make_loader_b_test()
c_test_loader = dm.make_loader_c_test()
d_test_loader = dm.make_loader_d_test()

In [9]:
import mlflow
mlflow.set_tracking_uri("file:///media/sdb1/mlflow")

In [10]:
req(model, False)
req(model.backbone.ca, True)
req(model.backbone.sa, True)

conf = TrainConfig()
conf.num_epochs = WARMUP_EPOCHS
conf.batch_size = BATCH_SIZE
conf.num_workers = NUM_WORKERS
conf.lr = LR_WARMUP
conf.weight_decay = WEIGHT_DECAY
conf.momentum = MOMENTUM

In [11]:
trainer = Trainer(model, conf)
hist_warm = trainer.run(b_train_loader, b_val_loader, experiment_name="Att_Train6")

req(model, True)
if FREEZE_MASK_HEAD_IN_FINETUNE:
    req(model.roi_heads.mask_head, False)
    req(model.roi_heads.mask_predictor, False)

[epoch 001/003] step 50/119 loss 1.0387
[epoch 001/003] step 100/119 loss 0.9133
[epoch 001/003] step 119/119 loss 1.1002
epoch 001/003  train=1.1486  val=1.1030
[epoch 002/003] step 50/119 loss 1.0594
[epoch 002/003] step 100/119 loss 0.9256
[epoch 002/003] step 119/119 loss 1.0038
epoch 002/003  train=0.9435  val=1.0371
[epoch 003/003] step 50/119 loss 0.8514
[epoch 003/003] step 100/119 loss 0.9803
[epoch 003/003] step 119/119 loss 0.9091
epoch 003/003  train=0.8965  val=1.0077




In [12]:
mlflow.end_run()

In [None]:
conf = TrainConfig()
conf.num_epochs = FINETUNE_EPOCHS
conf.batch_size = BATCH_SIZE
conf.num_workers = NUM_WORKERS
conf.lr = LR_FINETUNE
conf.weight_decay = WEIGHT_DECAY
conf.momentum = MOMENTUM

trainer = Trainer(model, conf)
hist_ft = trainer.run(ab_train_loader, ab_val_loader, experiment_name="Att_FT2_Train6")

[epoch 001/010] step 50/1244 loss 1.0411
[epoch 001/010] step 100/1244 loss 1.1805


In [None]:




ev = Evaluator(DEVICE)

print("B test map50:", ev.map50(model, b_test_loader))
print("B test mask metrics:", ev.metrics_masks(model, b_test_loader, num_classes=NUM_CLASSES))
print("C sanity:", ev.sanity(model, c_test_loader))
print("D sanity:", ev.sanity(model, d_test_loader))

ev.show_examples(dm.ds_b_test, model, n=3, score_thresh=0.5, title="B test", show_random=True)

# %% [9] save
torch.save(model.state_dict(), "./weights/maskrcnn_attfpn_warmB_ft_Bmasks_Aboxes.pth")

In [4]:
B_CKPT = "./weights/maskrcnn_B_ep40.pth"  
model = build_model("maskrcnn_attfpn", NUM_CLASSES).to(DEVICE)
sd = torch.load(B_CKPT, map_location="cpu") 
model.load_state_dict(sd, strict=False)


  sd = torch.load(B_CKPT, map_location="cpu")


_IncompatibleKeys(missing_keys=['backbone.ca.mlp.0.weight', 'backbone.ca.mlp.0.bias', 'backbone.ca.mlp.2.weight', 'backbone.ca.mlp.2.bias', 'backbone.sa.conv.weight', 'backbone.sa.conv.bias'], unexpected_keys=[])

In [None]:
data_conf = DataConfig()
data = DataModule(data_conf)
train_conf = TrainConfig()

In [None]:

def collate_fn(batch):
    imgs, targs = zip(*batch)
    return list(imgs), list(targs)

class BoxOnly(torch.utils.data.Dataset):
    def __init__(self, ds):
        self.ds = ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, i):
        img, t = self.ds[i]
        t = dict(t)
        t.pop("masks", None)
        return img, t

def set_requires_grad(module, flag: bool):
    for p in module.parameters():
        p.requires_grad = flag

def train_one_epoch(model, loader, optim):
    model.train()
    s = 0.0
    n = 0
    for imgs, targs in loader:
        imgs = [x.to(DEVICE) for x in imgs]
        targs = [{k: v.to(DEVICE) for k, v in t.items()} for t in targs]

        loss_dict = model(imgs, targs)
        loss = sum(loss_dict.values())

        optim.zero_grad()
        loss.backward()
        optim.step()

        s += float(loss.item())
        n += 1
    return s / max(n, 1)

@torch.no_grad()
def eval_one_epoch(model, loader):
    model.eval()
    s = 0.0
    n = 0
    for imgs, targs in loader:
        imgs = [x.to(DEVICE) for x in imgs]
        targs = [{k: v.to(DEVICE) for k, v in t.items()} for t in targs]
        loss_dict = model(imgs, targs)
        loss = sum(loss_dict.values())
        s += float(loss.item())
        n += 1
    return s / max(n, 1)

# %% [4] datasets
ds_b_train = DatasetB(B_ROOT, "train", LABEL_MAP, MAX_SIZE, NUM_CHANNELS)
ds_b_val   = DatasetB(B_ROOT, "test",  LABEL_MAP, MAX_SIZE, NUM_CHANNELS)

ds_a_train = DatasetA(A_XML, A_IMG, LABEL_MAP, MAX_SIZE, NUM_CHANNELS)
ds_a_box   = BoxOnly(ds_a_train)

dl_b_train = DataLoader(ds_b_train, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=NUM_WORKERS, collate_fn=collate_fn)
dl_b_val   = DataLoader(ds_b_val, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, collate_fn=collate_fn)

ds_mix = ConcatDataset([ds_b_train, ds_a_box])
dl_mix = DataLoader(ds_mix, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS, collate_fn=collate_fn)

# %% [5] model: attention + load B checkpoint


# %% [6] warmup: B-only, train attention blocks (optional but fast)
set_requires_grad(model, False)
set_requires_grad(model.backbone.ca, True)
set_requires_grad(model.backbone.sa, True)

opt = torch.optim.SGD(
    [p for p in model.parameters() if p.requires_grad],
    lr=LR_WARMUP, momentum=0.9, weight_decay=WEIGHT_DECAY
)

for e in range(WARMUP_EPOCHS):
    tr = train_one_epoch(model, dl_b_train, opt)
    va = eval_one_epoch(model, dl_b_val)
    print("warmup", e, tr, va)

# %% [7] finetune: B(masks) + A(boxes), masks dropped for A
set_requires_grad(model, True)

if FREEZE_MASK_HEAD_IN_FINETUNE:
    set_requires_grad(model.roi_heads.mask_head, False)
    set_requires_grad(model.roi_heads.mask_predictor, False)

opt = torch.optim.SGD(
    [p for p in model.parameters() if p.requires_grad],
    lr=LR_FINETUNE, momentum=0.9, weight_decay=WEIGHT_DECAY
)

for e in range(FINETUNE_EPOCHS):
    tr = train_one_epoch(model, dl_mix, opt)
    va = eval_one_epoch(model, dl_b_val)
    print("finetune", e, tr, va)

# %% [8] save
torch.save(model.state_dict(), "maskrcnn_attfpn_Bwarmup_Aboxes_finetune.pt")
