In [1]:
import json
import argparse
from easydict import EasyDict
from importlib import import_module

import gc
import warnings
from tqdm import tqdm
import os
import torch
import numpy as np
import random
import time
from torch.utils.data import DataLoader
import wandb

from mapcalc import calculate_map
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

In [2]:
def get_args(config):
    args = EasyDict()
    with open(f'./config/{config}.json', 'r') as f:
        args.update(json.load(f))
    
    return args
        
def killmemory():
    gc.collect()
    torch.cuda.empty_cache()

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def create_dir(path):
    if not os.path.isdir(path):
        os.mkdir(path)
        
def collate_fn(batch):
    return tuple(zip(*batch))

class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

def set_valid_gt(valid_dataset):
    gt = []

    for idx in range(len(valid_dataset)):
        data = valid_dataset[idx][1]
        gt.append({'boxes': data['boxes'].tolist(), 'labels': data['labels'].tolist()})

    return gt

def validation(valid_dataloader, model, device, resize):
    model.eval()
    outputs = []
    for images, targets, image_ids in valid_dataloader:

        images = list(image.float().to(device) for image in images)
        output = model(images)

        for out in output:
            outputs.append({'boxes': out['boxes'].tolist(), 'labels': out['labels'].tolist(), 'scores': out['scores'].tolist()})
    
    results = []
    cocoGt = COCO(args.val_annotation)
    for i, output in enumerate(outputs):
        file_name = cocoGt.loadImgs(cocoGt.getImgIds(imgIds=i))[0]['file_name']
        for bbox, score, label in zip(output['boxes'], output['scores'], output['labels']):
            results.append(
                {
                    'image_id': int(i),
                    'category_id': label,
                    'bbox': [
                        bbox[0] / (resize/512),
                        bbox[1] / (resize/512),
                        (bbox[2] - bbox[0]) / (resize/512),
                        (bbox[3] - bbox[1]) / (resize/512)
                    ],
                    'score': score
                }
            )
    '''
    ref
    https://www.programcreek.com/python/?code=potterhsu%2Feasy-faster-rcnn.pytorch%2Feasy-faster-rcnn.pytorch-master%2Fdataset%2Fcoco2017.py#
    https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
    '''
    cocoDt = cocoGt.loadRes(results)
    cocoEval = COCOeval(cocoGt, cocoDt, "bbox")
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()
    model.train()
    
    return cocoEval.stats, outputs


def get_logdata(valid_outputs, valid_dataset, idx):
    output = valid_outputs[idx]
    valid_box_data = []
    for i in range(len(output['labels'])):
        box = output['boxes'][i]
        label = output['labels'][i]
        score = output['scores'][i]

        log_data = {'position': {'minX':box[0],'maxX':box[2],'minY':box[1],'maxY':box[3]},
                    'class_id': label,
                    'box_caption': f'{label2class[label]}({score:.4f})',
                    'domain': 'pixel',
                    'scores': {'pred_score':score}}
        valid_box_data.append(log_data)

    gt_img = valid_dataset[idx][0]
    gt_data = valid_dataset[idx][1]
    gt_box_data = []
    for i in range(len(gt_data['labels'])):
        box = gt_data['boxes'][i].tolist()
        label = gt_data['labels'][i].item()

        log_data = {'position': {'minX':box[0],'maxX':box[2],'minY':box[1],'maxY':box[3]},
                    'class_id': label,
                    'box_caption': f'{label2class[label]}',
                    'domain': 'pixel',
                    'scores':{'pred_score':1}}
        gt_box_data.append(log_data)


    image_log = {"predictions":{'box_data':valid_box_data, 'class_labels':label2class}, 
                 "ground_truth":{'box_data':gt_box_data, 'class_labels':label2class}}

    img = wandb.Image(gt_img, boxes = image_log)
    
    return img

In [3]:
config = 'config10'

In [4]:
args = get_args(config)

seed_everything(args.seed)
warnings.filterwarnings(action='ignore')
killmemory()
create_dir('./saved_model')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

wandb.init(project='BC_stage3_ObjectDetection', entity='doooom')
wandb.run.name = f'({config}){args.config_name}'
wandb.config.config = config
wandb.config.update(args)

classes = ("UNKNOWN","General trash","Paper","Paper pack","Metal","Glass","Plastic","Styrofoam","Plastic bag", "Battery", "Clothing")
label2class = {i:c for i, c in enumerate(classes)}

train_augmentation_module = getattr(import_module("augmentation"), args.augmentation)
valid_augmentation_module = getattr(import_module("augmentation"), 'ValidAugmentation')
train_augmentation = train_augmentation_module(augp=args.augp, resize=args.resize)
valid_augmentation = valid_augmentation_module(resize=args.resize) 

dataset_module = getattr(import_module("dataset"), args.dataset)
train_dataset = dataset_module(args.annotation, args.data_dir, train_augmentation)
valid_dataset = dataset_module(args.val_annotation, args.data_dir, valid_augmentation)
valid_gt = set_valid_gt(valid_dataset)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)

model_module = getattr(import_module("model"), args.model)
model = model_module(num_classes = 11, args=args)
model.to(device)

optimizer_module = getattr(import_module("torch.optim"), args.optimizer)
optimizer = optimizer_module(params = model.parameters(), lr=args.learning_rate)

[34m[1mwandb[0m: Currently logged in as: [33mdoooom[0m (use `wandb login --relogin` to force relogin)


In [5]:
best_mAP50 = 0.3
loss_hist = Averager()

for epoch in range(args.epoch):
    print(f'* Epoch {epoch+1}')
    start_time = time.time()
    loss_hist.reset()
    
    model.train()
    for step, (images, targets, image_ids) in enumerate(train_dataloader):

        images = list(image.float().to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        loss_hist.send(loss_value)

        # backward
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        print(f'  Step [{step+1}/{len(train_dataloader)}], Loss : {loss_value:.4f}', end="\r")
        
    train_time = time.time()-start_time
    print(f'  training time : {train_time:.4f}, train loss : {loss_hist.value} \n')
    
    if (epoch+1) % args.val_every == 0:
        print('* Start validation...', end="\r")
        start_time = time.time()
        mAP, valid_outputs = validation(valid_dataloader, model, device, args.resize[0])
        print(f"  validation time : {time.time()-start_time:.4f}")
        wandb.log({"time": train_time, "train_loss": loss_hist.value, "valid_mAP": mAP[0], "valid_mAP50": mAP[1], "valid_mAP75": mAP[2], 
                   "valid_mAP(S)": mAP[3], "valid_mAP(M)": mAP[4], "valid_mAP(L)": mAP[5]})
    
    
    if epoch+1 > 14:
        print("  Best model saved")
        torch.save(model.state_dict(), f'./saved_model/({config}){args.config_name}_{epoch+1}.pth')
        best_mAP50 = mAP[1]
        
        
#     if 0.35 < mAP[1]:
#         print("  Best model saved")
#         torch.save(model.state_dict(), f'./saved_model/({config}){args.config_name}_{epoch+1}.pth')
#         best_mAP50 = mAP[1]
        
#         log = []
#         for i in range(50):
#             img = get_logdata(valid_outputs, valid_dataset, i)
#             log.append(img)
#         wandb.log({f'epoch {epoch+1}':log})
    print()

wandb.finish()

* Epoch 1
  training time : 591.5261, train loss : 0.9435397256848864 

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.070
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.152
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.052
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.023
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.062
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.130
  validation time : 62.4690

* Epoch 2
  training time : 591.8003, train loss : 0.7690902341403122 

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.139
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.240
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.142
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.066
 Average Precision  (AP) @[

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
time,591.007
train_loss,0.3568
valid_mAP,0.57507
valid_mAP50,0.79171
valid_mAP75,0.66365
valid_mAP(S),0.43662
valid_mAP(M),0.56497
valid_mAP(L),0.71552
_runtime,12361.0
_timestamp,1621336027.0


0,1
time,▆▇▅▄▆▆▄▃▁▅█▇▄▅▆▃▅▃▅
train_loss,█▆▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁
valid_mAP,▁▂▃▃▄▄▄▅▆▆▆▇▇▇▇▇███
valid_mAP50,▁▂▃▃▄▅▅▆▆▆▇▇▇▇▇████
valid_mAP75,▁▂▃▃▄▄▄▅▅▆▆▆▇▇▇▇███
valid_mAP(S),▁▂▂▂▃▃▃▄▄▅▅▆▆▇▇▇███
valid_mAP(M),▁▂▃▃▄▅▄▅▆▆▆▇▇▇▇████
valid_mAP(L),▁▂▃▃▅▅▅▆▆▆▇▇▇▇▇████
_runtime,▁▁▂▂▃▃▃▄▄▄▅▅▆▆▆▇▇██
_timestamp,▁▁▂▂▃▃▃▄▄▄▅▅▆▆▆▇▇██


In [6]:
# FasterFCNN 파들어가 soft nms 알아보기
# nms iou threshold나 score threshold 달리해보기