In [3]:
NUM_CLASSES = 1 + 24 

In [4]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
from models.models import build_model     
from datasets.loader import DataModule, DataConfig  
from train.trainer import Trainer, TrainConfig  
from train.eval import Evaluator             
from datasets.base import collate_bb
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 25
B_CKPT = "./weights/maskrcnn_B_ep40.pth"

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


In [5]:
WARMUP_EPOCHS = 3
FINETUNE_EPOCHS = 10
BATCH_SIZE = 4
NUM_WORKERS = 4
LR_WARMUP = 1e-4
LR_FINETUNE = 5e-4
WEIGHT_DECAY = 1e-4
MOMENTUM = 0.9
FREEZE_MASK_HEAD_IN_FINETUNE = True

dm = DataModule(DataConfig(val_frac=0.1, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS), with_masks = False)
b_train_loader, b_val_loader = dm.make_loaders_b()
a_train_box = dm.ds_a_train
a_val_box   = dm.ds_a_val
ab_train_ds = ConcatDataset([a_train_box, dm.ds_b_train])
ab_val_ds   = ConcatDataset([a_val_box, dm.ds_b_val])

ab_train_loader = DataLoader(ab_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=collate_bb)
ab_val_loader = DataLoader(ab_val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_bb)

b_test_loader = dm.make_loader_b_test()
c_test_loader = dm.make_loader_c_test()
d_test_loader = dm.make_loader_d_test()

import mlflow
mlflow.set_tracking_uri("file:///media/sdb1/mlflow")

In [None]:
model = build_model("maskrcnn_attfpn", NUM_CLASSES).to(DEVICE)
sd = torch.load(B_CKPT, map_location="cpu")
model.load_state_dict(sd, strict=False)


def allow_missing_masks(model):
    orig_forward = model.forward
    def forward(images, targets=None):
        if model.training and targets is not None and not all(("masks" in t) for t in targets):
            rh = model.roi_heads
            saved = (rh.mask_roi_pool, rh.mask_head, rh.mask_predictor)
            try:
                rh.mask_roi_pool, rh.mask_head, rh.mask_predictor = None, None, None
                return orig_forward(images, targets)
            finally:
                rh.mask_roi_pool, rh.mask_head, rh.mask_predictor = saved
        return orig_forward(images, targets)
    model.forward = forward
    return model
allow_missing_masks(model)

In [7]:
def req(mod, flag: bool):
    for p in mod.parameters():
        p.requires_grad = flag

req(model, False)
req(model.backbone.ca, True)
req(model.backbone.sa, True)

conf = TrainConfig()
conf.num_epochs = WARMUP_EPOCHS
conf.batch_size = BATCH_SIZE
conf.num_workers = NUM_WORKERS
conf.lr = LR_WARMUP
conf.weight_decay = WEIGHT_DECAY
conf.momentum = MOMENTUM

In [11]:
trainer = Trainer(model, conf)
hist_warm = trainer.run(b_train_loader, b_val_loader, experiment_name="Att_Train6")

req(model, True)
if FREEZE_MASK_HEAD_IN_FINETUNE:
    req(model.roi_heads.mask_head, False)
    req(model.roi_heads.mask_predictor, False)

[epoch 001/003] step 50/119 loss 1.0387
[epoch 001/003] step 100/119 loss 0.9133
[epoch 001/003] step 119/119 loss 1.1002
epoch 001/003  train=1.1486  val=1.1030
[epoch 002/003] step 50/119 loss 1.0594
[epoch 002/003] step 100/119 loss 0.9256
[epoch 002/003] step 119/119 loss 1.0038
epoch 002/003  train=0.9435  val=1.0371
[epoch 003/003] step 50/119 loss 0.8514
[epoch 003/003] step 100/119 loss 0.9803
[epoch 003/003] step 119/119 loss 0.9091
epoch 003/003  train=0.8965  val=1.0077




In [12]:
mlflow.end_run()

In [13]:
conf = TrainConfig()
conf.num_epochs = FINETUNE_EPOCHS
conf.batch_size = BATCH_SIZE
conf.num_workers = NUM_WORKERS
conf.lr = LR_FINETUNE
conf.weight_decay = WEIGHT_DECAY
conf.momentum = MOMENTUM

trainer = Trainer(model, conf)
hist_ft = trainer.run(ab_train_loader, ab_val_loader, experiment_name="Att_FT2_Train6")

[epoch 001/010] step 50/1244 loss 1.0411
[epoch 001/010] step 100/1244 loss 1.1805
[epoch 001/010] step 150/1244 loss 0.8446
[epoch 001/010] step 200/1244 loss 0.7666
[epoch 001/010] step 250/1244 loss 0.8738
[epoch 001/010] step 300/1244 loss 1.0213
[epoch 001/010] step 350/1244 loss 0.7507
[epoch 001/010] step 400/1244 loss 0.9404
[epoch 001/010] step 450/1244 loss 0.7451
[epoch 001/010] step 500/1244 loss 0.8671
[epoch 001/010] step 550/1244 loss 0.8172
[epoch 001/010] step 600/1244 loss 0.7147
[epoch 001/010] step 650/1244 loss 0.9449
[epoch 001/010] step 700/1244 loss 0.7148
[epoch 001/010] step 750/1244 loss 0.6044
[epoch 001/010] step 800/1244 loss 0.6489
[epoch 001/010] step 850/1244 loss 0.5952
[epoch 001/010] step 900/1244 loss 0.6872
[epoch 001/010] step 950/1244 loss 0.5851
[epoch 001/010] step 1000/1244 loss 0.8295
[epoch 001/010] step 1050/1244 loss 0.7705
[epoch 001/010] step 1100/1244 loss 0.7502
[epoch 001/010] step 1150/1244 loss 0.7296
[epoch 001/010] step 1200/1244 



In [2]:
import torch
torch.save(model.state_dict(), "./weights/maskrcnn_attfpn_warmB_ft_Bmasks_Aboxes.pth")

NameError: name 'model' is not defined