In [1]:
import os
import copy
import torch
from tqdm import tqdm
import pandas as pd
import detectron2
from detectron2.data import detection_utils as utils
from detectron2.utils.logger import setup_logger
setup_logger()

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader

In [2]:
# Register Dataset
try:
    register_coco_instances('trash_test', {}, '../dataset/test.json', '../dataset/')
except AssertionError:
    pass

In [3]:
# config 불러오기
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file('faster_rcnn_R_101_FPN_3x.yaml'))

In [4]:
# config 수정하기
cfg.DATASETS.TEST = ('trash_test',)

cfg.DATALOADER.NUM_WOREKRS = 2

cfg.OUTPUT_DIR = './output'

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, 'model_0009999.pth')

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 10
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3

In [5]:
# model
predictor = DefaultPredictor(cfg)

# mapper - input data를 어떤 형식으로 return할지
def MyMapper(dataset_dict):
    
    dataset_dict = copy.deepcopy(dataset_dict)
    image = utils.read_image(dataset_dict['file_name'], format='BGR')
    
    dataset_dict['image'] = image
    
    return dataset_dict

# test loader
test_loader = build_detection_test_loader(cfg, 'trash_test', MyMapper)

Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.

[32m[10/09 15:37:18 d2.data.datasets.coco]: [0mLoaded 4871 images in COCO format from ../dataset/test.json
[32m[10/09 15:37:18 d2.data.build]: [0mDistribution of instances among all 10 categories:
[36m|   category    | #instances   |  category   | #instances   |  category  | #instances   |
|:-------------:|:-------------|:-----------:|:-------------|:----------:|:-------------|
| General trash | 0            |    Paper    | 0            | Paper pack | 0            |
|     Metal     | 0            |    Glass    | 0            |  Plastic   | 0            |
|   Styrofoam   | 0            | Plastic bag | 0            |  Battery   | 0            |
|   Clothing    | 0            |             |              |            |              |
|     total     | 0            |             |              |            |              |[0m
[32m[10/09 15:37:18 d2.data.common]: [0mSerializing 4871 elements to 

In [6]:
# output 뽑은 후 sumbmission 양식에 맞게 후처리 
prediction_strings = []
file_names = []

class_num = 10

for data in tqdm(test_loader):
    
    prediction_string = ''
    
    data = data[0]
    
    outputs = predictor(data['image'])['instances']
    
    targets = outputs.pred_classes.cpu().tolist()
    boxes = [i.cpu().detach().numpy() for i in outputs.pred_boxes]
    scores = outputs.scores.cpu().tolist()
    
    for target, box, score in zip(targets,boxes,scores):
        prediction_string += (str(target) + ' ' + str(score) + ' ' + str(box[0]) + ' ' 
        + str(box[1]) + ' ' + str(box[2]) + ' ' + str(box[3]) + ' ')
        
    prediction_strings.append(prediction_string)
    file_names.append(data['file_name'].replace('../dataset/',''))

submission = pd.DataFrame()
submission['PredictionString'] = prediction_strings
submission['image_id'] = file_names
submission.to_csv(os.path.join(cfg.OUTPUT_DIR, f'submission_det2.csv'), index=None)
submission.head()

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370156314/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  filter_inds = filter_mask.nonzero()
100%|██████████| 4871/4871 [05:44<00:00, 14.13it/s]


Unnamed: 0,PredictionString,image_id
0,7 0.9992503523826599 221.61992 60.28158 455.74...,test/0000.jpg
1,5 0.9692215919494629 123.993355 6.259248 501.1...,test/0001.jpg
2,1 0.9994056224822998 424.00546 267.73508 646.5...,test/0002.jpg
3,9 0.9963656663894653 131.89699 288.0879 892.64...,test/0003.jpg
4,0 0.9969171285629272 213.44302 333.98788 867.0...,test/0004.jpg
