In [None]:
import torch, torchvision
import mmdet
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
import mmengine
import os
from mmengine import Config
from mmengine.runner import Runner


In [None]:
#read configuration file from mmdetection
cfg = Config.fromfile(r'./configs/soft_teach_pytorch_mask_rcnn.py') #change path to mmdetection folder

# model training parameters
### samples
trainDataFolder=r""
trainDataJson=rf"{trainDataFolder}/train.json" # change to path to COCO annotation json
testDataFolder=r""
testDataJson=rf"{trainDataFolder}/test.json"
valDataFolder=r""
valDataJson=rf"{trainDataFolder}/val.json"

unlabelDataFolder=r"" #unlabelled images
unlabelDataJson=rf"{unlabelDataFolder}/unlabelled.json" #this annotation only contain 'images':[...],'annotations':[empty]

### training parameters
batchSize=8 #change to fit GPU memory
numWorkers=2 #worker cores for loading data
learningRate=0.0001
epoch=24 
valInterval=1 #calculate validation after # epoch
modelSaveInterval=2 #save weights every # epoch
modelSaveMax=10 #maxmum number of weights saved during training
workDir=r"" #output folder for trained model and logs.
pretrainedWeights=r"" #.pth file for pretrained weights
logPrintInterval=100 #iteration interval (not epoch interval) for printing logs during training
parameterScheduler=[dict(
        type='MultiStepLR',
        begin=0,
        end=24,
        by_epoch=True,
        milestones=[8, 11],
        gamma=0.1)]
#parameter scheduler for training, refer http:// for more details

### Soft-Teacher training parameters
pseudoLabelIntialScore=0.5
rpnPseudoThr=0.95
clsPseudoThr=0.95
sampleRatio=[1,1] # ratio between unlabelled and labelled images per training batch

In [None]:
cfg.model.semi_train_cfg.pseudo_label_initial_score_thr=0.5
cfg.model.semi_train_cfg.rpn_pseudo_thr=rpnPseudoThr
cfg.model.semi_train_cfg.cls_pseudo_thr=clsPseudoThr
cfg.default_hooks.logger.interval=logPrintInterval
cfg.default_hooks.checkpoint.interval=modelSaveInterval
cfg.default_hooks.checkpoint.by_epoch=True
cfg.default_hooks.checkpoint.max_keep_ckpts=modelSaveMax
cfg.batch_size=batchSize
cfg.num_workers=numWorkers

classes = ('tree')

cfg.labeled_dataset.data_root=trainDataFolder
cfg.labeled_dataset.ann_file=trainDataJson
cfg.labeled_dataset.data_prefix=dict(img=trainDataFolder)
cfg.labeled_dataset.metainfo=dict(classes=classes)

cfg.unlabeled_dataset.data_root=unlabelDataFolder
cfg.unlabeled_dataset.ann_file=unlabelDataJson
cfg.unlabeled_dataset.data_prefix=dict(img=unlabelDataFolder)
cfg.unlabeled_dataset.metainfo=dict(classes=classes)

cfg.train_dataloader.dataset.datasets=[cfg.labeled_dataset,cfg.unlabeled_dataset]
cfg.train_dataloader.batch_size=cfg.batch_size
cfg.train_dataloader.num_workers=cfg.num_workers
cfg.train_dataloader.sampler.source_ratio=[1,1]

cfg.val_dataloader.dataset.data_root=valDataFolder
cfg.val_dataloader.dataset.ann_file=valDataJson
cfg.val_dataloader.dataset.data_prefix=dict(img=valDataFolder)
cfg.val_dataloader.batch_size=1
cfg.val_dataloader.num_workers=1

cfg.test_dataloader.dataset.data_root=testDataFolder
cfg.test_dataloader.dataset.ann_file=testDataJson
cfg.test_dataloader.dataset.data_prefix=dict(img=testDataFolder)
cfg.test_dataloader.batch_size=1
cfg.test_dataloader.num_workers=1
cfg.val_evaluator.ann_file=cfg.val_dataloader.dataset.ann_file
cfg.test_evaluator.ann_file=cfg.test_dataloader.dataset.ann_file

cfg.train_cfg=dict(type='EpochBasedTrainLoop', max_epochs=epoch, val_interval=valInterval)
#cfg.train_cfg=dict(type='IterBasedTrainLoop', max_iters=36000, val_interval=2000)

cfg.optim_wrapper.optimizer.lr=learningRate
cfg.work_dir=workDir
cfg.log_processor.by_epoch=False
cfg.randomness=dict(seed=1,
                deterministic=False,
                diff_rank_seed=False)
cfg.param_scheduler=parameterScheduler
#cfg.param_scheduler=[
    #dict(
    #    type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    #dict(
    #    type='MultiStepLR',
    #    begin=0,
    #    end=18000,
    #    by_epoch=False,
    #    milestones=[12000, 16000],
    #    gamma=0.1)
    #dict(
    #    type='MultiStepLR',
    #    begin=0,
    #    end=24,
    #    by_epoch=True,
    #    milestones=[8, 11],
    #    gamma=0.1)
#]
cfg.load_from=pretrainedWeights

In [None]:
configFile=cfg.work_dir+"/config_file.py"
if not os.path.isdir(cfg.work_dir):
    os.mkdir(cfg.work_dir)
with open(configFile,"w") as oj:
    oj.write(cfg.pretty_text)
logFile=cfg.work_dir+"/None.log.json"
if os.path.exists(logFile):
    os.remove(logFile)

In [None]:
runner = Runner.from_cfg(cfg)
runner.train()