In [1]:
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
TridentNet Training Script.

This script is a simplified version of the training script in detectron2/tools.
"""

import os

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import COCOEvaluator

from tridentnet import add_tridentnet_config


%matplotlib inline
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab
dataDir = './datasets/coco'
dataType='val2017'
annFile='{}/annotations/instances_{}.json'.format(dataDir,dataType)
coco=COCO(annFile)

loading annotations into memory...
Done (t=0.50s)
creating index...
index created!


In [2]:
from skimage import img_as_ubyte
def get_coco_image_by_random():
    # get all images containing given categories, select one at random
    catIds = coco.getCatIds(catNms=['person', 'dog', 'skateboard']);
    print('category IDs:{}'.format(catIds))
    imgIds = coco.getImgIds(catIds=catIds);
    print('image Id length:{}'.format(len(imgIds)))
    imgIds = coco.getImgIds(imgIds = [imgIds[1]])
    print('image Ids:{}'.format(imgIds))
    img = coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]
    print('image info:{}'.format(img))
    im = io.imread(img['coco_url'])
    return img_as_ubyte(im)
    
# get_coco_image_by_random()

In [3]:
class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_tridentnet_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg

def get_coco_image():
    dataset_dicts = get_balloon_dicts("balloon/val")
    for d in random.sample(dataset_dicts, 3):    
        im = get_coco_image_by_random()
    outputs = predictor(im)


def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        img = get_coco_image_by_random()
        print(img.shape)
        model(img)
#         res = Trainer.test(cfg, model)
        return 

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    return trainer.train()

In [4]:
# if __name__ == "__main__":
args = default_argument_parser().parse_args(args=[
    '--config-file', './configs/tridentnet_fast_R_50_C4_1x.yaml',
    '--num-gpus', '1'])
print("Command Line Args:", args)
launch(
    main,
    args.num_gpus,
    num_machines=args.num_machines,
    machine_rank=args.machine_rank,
    dist_url=args.dist_url,
    args=(args,),
)

Command Line Args: Namespace(config_file='./configs/tridentnet_fast_R_50_C4_1x.yaml', dist_url='tcp://127.0.0.1:50152', eval_only=False, machine_rank=0, num_gpus=1, num_machines=1, opts=[], resume=False)
[32m[04/15 14:59:37 detectron2]: [0mRank of current process: 0. World size: 1
[32m[04/15 14:59:37 detectron2]: [0mEnvironment info:
------------------------  -------------------------------------------------------------------------------------
sys.platform              linux
Python                    3.7.6 (default, Jan  8 2020, 19:59:22) [GCC 7.3.0]
numpy                     1.18.1
detectron2                0.1.1 @/home/takuma/Documents/1_SemanticSegmentation/detectron2/detectron2
detectron2 compiler       GCC 7.5
detectron2 CUDA compiler  10.1
detectron2 arch flags     sm_75
DETECTRON2_ENV_MODULE     <not set>
PyTorch                   1.4.0 @/home/takuma/anaconda3/envs/detectron2/lib/python3.7/site-packages/torch
PyTorch debug build       False
CUDA available            True
GPU

KeyboardInterrupt: 