In [1]:
from main.paths import ROOT_PATH  # isort:skip

In [2]:
import argparse
import os.path as osp
import sys

In [3]:
import cv2
from loguru import logger

import torch

import random
import os
import numpy as np

In [4]:
from videoanalyst.config.config import cfg as root_cfg
from videoanalyst.config.config import specify_task
from videoanalyst.data import builder as dataloader_builder
from videoanalyst.engine import builder as engine_builder
from videoanalyst.model import builder as model_builder
from videoanalyst.optim import builder as optim_builder
from videoanalyst.utils import Timer, complete_path_wt_root_in_cfg, ensure_dir
from videoanalyst.pipeline import builder as pipeline_builder
from videoanalyst.engine.builder import build as tester_builder
from videoanalyst.engine.monitor.monitor_impl.tensorboard_logger import TensorboardLogger

In [5]:
cv2.setNumThreads(1)


In [6]:
# torch.backends.cudnn.enabled = False

# pytorch reproducibility
# https://pytorch.org/docs/stable/notes/randomness.html#cudnn
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [7]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [8]:
def make_parser():
    parser = argparse.ArgumentParser(description='Test') # Train?
    parser.add_argument('-cfg',
                        '--config',
                        default='experiments/stmtrack/train/got10k/stmtrack-googlenet-trn-edited.yaml',
                        # experiments/stmtrack/train/fulldata/stmtrack-googlenet-trn-fulldata.yaml
                        type=str,
                        help='path to experiment configuration')
    parser.add_argument(
        '-r',
        '--resume',
        default="",
        help=r"completed epoch's number, latest or one model path")
    parser.add_argument(
        '-v',
        '--validation',
        default="15",
        help=r"Epoch's number to start to evaluate the model on the validation set")

    return parser

# main

In [9]:
set_seed(1000000007)

In [10]:
parser = make_parser()

In [11]:
# parsed_args = parser.parse_args()

In [12]:
parsed_args_config = 'experiments/stmtrack/train/got10k/stmtrack-googlenet-trn-edited.yaml'
# experiments/stmtrack/train/fulldata/stmtrack-googlenet-trn-fulldata.yaml

In [13]:
exp_cfg_path = osp.realpath(parsed_args_config)

In [14]:
root_cfg.merge_from_file(exp_cfg_path)

In [15]:
root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH)
# root_cfg = root_cfg.train

In [16]:
task, task_cfg = specify_task(root_cfg.train)

In [17]:
task_cfg.freeze()

In [18]:
log_dir = osp.join(task_cfg.exp_save, task_cfg.exp_name, "logs")

In [19]:
ensure_dir(log_dir)

In [20]:
logger.configure(
        handlers=[
            dict(sink=sys.stderr, level="INFO"),
            dict(sink=osp.join(log_dir, "train_log.txt"),
                 enqueue=True,
                 serialize=True,
                 diagnose=True,
                 backtrace=True,
                 level="INFO")
        ],
        extra={"common_to_all": "default"},
    )

[1, 2]

In [21]:
# backup config
logger.info("Load experiment configuration at: %s" % exp_cfg_path)
logger.info(
    "Merged with root_cfg imported from videoanalyst.config.config.cfg")
cfg_bak_file = osp.join(log_dir, "%s_bak.yaml" % task_cfg.exp_name)

2024-08-12 15:32:58.354 | INFO     | __main__:<module>:2 - Load experiment configuration at: /home/meysam/test-apps/STMTrack/experiments/stmtrack/train/got10k/stmtrack-googlenet-trn-edited.yaml
2024-08-12 15:32:58.356 | INFO     | __main__:<module>:3 - Merged with root_cfg imported from videoanalyst.config.config.cfg


In [22]:
with open(cfg_bak_file, "w") as f:
    f.write(task_cfg.dump())

In [23]:
logger.info("Task configuration backed up at %s" % cfg_bak_file)

2024-08-12 15:33:01.563 | INFO     | __main__:<module>:1 - Task configuration backed up at /home/meysam/test-apps/STMTrack/snapshots/stmtrack-googlenet-got-train/logs/stmtrack-googlenet-got-train_bak.yaml


In [24]:
if task_cfg.device == "cuda":
    world_size = task_cfg.num_processes
    assert torch.cuda.is_available(), "please check your devices"
    assert torch.cuda.device_count(
    ) >= world_size, "cuda device {} is less than {}".format(
        torch.cuda.device_count(), world_size)
    devs = ["cuda:{}".format(i) for i in range(world_size)]
else:
    devs = ["cpu"]

In [25]:
# build model
model = model_builder.build(task, task_cfg.model)
model.set_device(devs[0])

proj_fg_bg_label_map:['weight']

AuxLogits.conv0.conv:['weight']
AuxLogits.conv0.bn:['weight', 'bias', 'running_mean', 'running_var']
AuxLogits.conv1.conv:['weight']
AuxLogits.conv1.bn:['weight', 'bias', 'running_mean', 'running_var']
AuxLogits.fc:['weight', 'bias']
Mixed_7a.branch3x3_1.conv:['weight']
Mixed_7a.branch3x3_1.bn:['weight', 'bias', 'running_mean', 'running_var']
Mixed_7a.branch3x3_2.conv:['weight']
Mixed_7a.branch3x3_2.bn:['weight', 'bias', 'running_mean', 'running_var']
Mixed_7a.branch7x7x3_1.conv:['weight']
Mixed_7a.branch7x7x3_1.bn:['weight', 'bias', 'running_mean', 'running_var']
Mixed_7a.branch7x7x3_2.conv:['weight']
Mixed_7a.branch7x7x3_2.bn:['weight', 'bias', 'running_mean', 'running_var']
Mixed_7a.branch7x7x3_3.conv:['weight']
Mixed_7a.branch7x7x3_3.bn:['weight', 'bias', 'running_mean', 'running_var']
Mixed_7a.branch7x7x3_4.conv:['weight']
Mixed_7a.branch7x7x3_4.bn:['weight', 'bias', 'running_mean', 'running_var']
Mixed_7b.branch1x1.conv:['weight']
Mixed_7b.branch1

In [26]:
 # load data
with Timer(name="Dataloader building", verbose=True):
    dataloader = dataloader_builder.build(task, task_cfg.data)

2024-08-12 15:33:11.917 | INFO     | videoanalyst.data.builder:build:34 - Build dummy AdaptorDataset
2024-08-12 15:33:11.919 | INFO     | videoanalyst.data.builder:build:42 - Read dummy training sample
2024-08-12 15:33:11.968 | INFO     | videoanalyst.evaluation.got_benchmark.datasets.got10k:_get_cache_path:182 - GOT10k: passed cache file None invalid, change to default cache path
2024-08-12 15:33:11.969 | INFO     | videoanalyst.evaluation.got_benchmark.datasets.got10k:_ensure_cache:147 - GOT10k: cache file exists: /media/meysam/hdd2/GOT-10K/GOT-10k/train.pkl 
2024-08-12 15:33:13.042 | INFO     | videoanalyst.evaluation.got_benchmark.datasets.got10k:_load_cache_for_current_subset:200 - GOT10k: loaded cache file /media/meysam/hdd2/GOT-10K/GOT-10k/train.pkl
2024-08-12 15:33:13.045 | INFO     | videoanalyst.evaluation.got_benchmark.datasets.got10k:_ensure_cache:151 - GOT10k: record check has been processed and validity is confirmed for cache file: /media/meysam/hdd2/GOT-10K/GOT-10k/train

In [28]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f93d8d91040>

In [29]:
# build optimizer
optimizer = optim_builder.build(task, task_cfg.optim, model)

In [30]:
trainer = engine_builder.build(task, task_cfg.trainer, "trainer", optimizer,
                                   dataloader)

2024-08-12 15:34:27.266 | INFO     | videoanalyst.data.sampler.sampler_base:__init__:51 - Sampler's underlying datasets: GOT10kDataset, length 9335


In [31]:
trainer.set_device(devs)

In [32]:
parsed_args_resume = 0
trainer.resume(parsed_args_resume)

In [33]:
# Validator initialization
root_cfg.test.track.freeze()
pipeline = pipeline_builder.build("track", root_cfg.test.track.pipeline, model)
testers = tester_builder("track", root_cfg.test.track.tester, "tester", pipeline)
parsed_args_validation = 15
epoch_validation = int(parsed_args.validation)
logger.info("Start to evaluate the model on the validation set after the epoch #{}".format(epoch_validation))


In [34]:
testers

In [None]:
logger.info("Training begins.")
while not trainer.is_completed():
    model.train()
    trainer.train()
    trainer.save_snapshot()
    if trainer._state['epoch'] >= epoch_validation:
        logger.info('Validation begins.')
        model.eval()
        for tester in testers:
            res = tester.test()
            benchmark = '{}/{}/{}'.format(tester.__class__.__name__,
                                          tester._hyper_params['subsets'][0],
                                          'AO')
            logger.info('{}: {}'.format(benchmark, res['main_performance']))
            tb_log = {benchmark: res['main_performance']}
            for mo in trainer._monitors:
                if isinstance(mo, TensorboardLogger):
                    mo.update(tb_log)
            torch.cuda.empty_cache()
        logger.info('Validation ends.')


In [34]:
trainer.save_snapshot(model_param_only=True)
logger.info("Training completed.")