In [3]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path("..").resolve()))

In [4]:
import mlflow
mlflow.set_tracking_uri("file:///media/sdb1/mlflow")

In [5]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
from models.models import build_model     
from datasets.loader import DataModule, DataConfig  
from train.trainer import Trainer, TrainConfig  
from train.eval import Evaluator             
from datasets.base import collate_bb
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_CLASSES = 1 + 24
model_uri = "file:///media/sdb1/mlflow/753485487056022103/2e19afb3d8e34c7fa8b50505a7dd259e/artifacts/model"
model = mlflow.pytorch.load_model(model_uri)


def allow_missing_masks(model):
    orig_forward = model.forward
    def forward(images, targets=None):
        if model.training and targets is not None and not all(("masks" in t) for t in targets):
            rh = model.roi_heads
            saved = (rh.mask_roi_pool, rh.mask_head, rh.mask_predictor)
            try:
                rh.mask_roi_pool, rh.mask_head, rh.mask_predictor = None, None, None
                return orig_forward(images, targets)
            finally:
                rh.mask_roi_pool, rh.mask_head, rh.mask_predictor = saved
        return orig_forward(images, targets)
    model.forward = forward
    return model
model = allow_missing_masks(model)

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


In [None]:
conf = TrainConfig() 
conf.num_epochs = 10
conf.batch_size = 4
conf.num_workers = 4
conf.lr = 5e-4
conf.weight_decay = 1e-4
conf.momentum = 0.9

def infinite(loader): # re-iterate a DataLoader foreve
    while True: 
        for batch in loader:
            yield batch 

class AlternatingLoader:                      # alternates between two loaders
    def __init__(self, loader_a, loader_b, steps=None, start="a"): 
        self.a = loader_a                      # dataset A (box-only)
        self.b = loader_b                      # dataset B (box+mask)
        self.steps = steps if steps is not None else 2 * max(len(loader_a), len(loader_b))
        self.start = start

    def __len__(self):
        return self.steps 

    def __iter__(self):
        ia, ib = infinite(self.a), infinite(self.b)
        for i in range(self.steps): 
            if (i % 2 == 0) == (self.start == "a"):  # even steps from a, odd steps from b
                yield next(ia) 
            else:                              
                yield next(ib) 

dm = DataModule(
    DataConfig(
        val_frac=0.1,                          #val split in A
        batch_size=conf.batch_size, 
        num_workers=conf.num_workers
    ),
    with_masks=False #this disables fake masks for DatasetA; DatasetB still has real masks
)

b_train_loader, b_val_loader = dm.make_loaders_b() 
a_train_loader = DataLoader(dm.ds_a_train,batch_size=conf.batch_size,shuffle=True, num_workers=conf.num_workers,collate_fn=collate_bb )
a_val_loader = DataLoader(dm.ds_a_val,batch_size=conf.batch_size,shuffle=False,num_workers=conf.num_workers,collate_fn=collate_bb)
model = allow_missing_masks(model) # disables mask branch only for those A steps

# A batch (no masks) -> mask head disabled for that step by allow_missing_masks()
# B batch (has masks) -> mask head enabled and trained on that step
mix_train = AlternatingLoader(a_train_loader,b_train_loader) 
mix_val = AlternatingLoader(a_val_loader, b_val_loader)
trainer = Trainer(model, conf) 
hist_ft = trainer.run(mix_train,mix_val,  experiment_name="Att_FT2_Train6" )

[epoch 001/010] step 50/2250 loss 0.5184
[epoch 001/010] step 100/2250 loss 0.4444
[epoch 001/010] step 150/2250 loss 0.7280
[epoch 001/010] step 200/2250 loss 0.4851
[epoch 001/010] step 250/2250 loss 0.5214
[epoch 001/010] step 300/2250 loss 0.4425
[epoch 001/010] step 350/2250 loss 0.5715
[epoch 001/010] step 400/2250 loss 0.4132
[epoch 001/010] step 450/2250 loss 0.4593
[epoch 001/010] step 500/2250 loss 0.4837
[epoch 001/010] step 550/2250 loss 0.4271
[epoch 001/010] step 600/2250 loss 0.4830
[epoch 001/010] step 650/2250 loss 0.4529
[epoch 001/010] step 700/2250 loss 0.4823
[epoch 001/010] step 750/2250 loss 0.5911
[epoch 001/010] step 800/2250 loss 0.4443
[epoch 001/010] step 850/2250 loss 0.4036
[epoch 001/010] step 900/2250 loss 0.5260
[epoch 001/010] step 950/2250 loss 0.4160
[epoch 001/010] step 1000/2250 loss 0.4628
[epoch 001/010] step 1050/2250 loss 0.4035
[epoch 001/010] step 1100/2250 loss 0.4547
[epoch 001/010] step 1150/2250 loss 0.4061
[epoch 001/010] step 1200/2250 

In [2]:
import torch
torch.save(model.state_dict(), "./weights/maskrcnn_attfpn_warmB_ft_Bmasks_Aboxes.pth")

NameError: name 'model' is not defined