In [None]:
import json

import torch
from tqdm import tqdm

import modeling
from coco_eval import coco_eval
from data_process import CocoDataLoader
from evaluate import Pipeline, preds2coco_eval_result
from train import Trainer

In [None]:
class Args:
    num_epochs = 10
    device = 'cuda:0'
    model_name = 'm24'
    best_epoch = 0
    lr = 0.005
    momentum = 0.9
    weight_decay = 0.01

    model_out = f'../models/faster_rcnn/{model_name}'
    log_dir = f'../runs/faster_rcnn/{model_name}'

    best_model_path = f'{model_out}/e{best_epoch}.pt'
    outfile_name = f'faster_rcnn_{model_name}e{best_epoch}.json'
    eval_coco_path = f'../outputs/eval/coco_results/{outfile_name}'
    coco_result_path = f'../outputs/coco_results/{outfile_name}'

In [None]:
args = Args()

cdl = CocoDataLoader()
train_data = cdl.train_95
dev_data = cdl.dev_05

model = modeling.get_fasterrcnn_resnet50_model(num_classes=6)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

trainer = Trainer(args=args,
                  model=model,
                  optimizer=optimizer,
                  train_data=train_data,
                  dev_data=dev_data)
trainer.train()

In [None]:
pipeline = Pipeline(args.best_model_path, args.device)

cdl = CocoDataLoader()
test_data = cdl.test_all
idx2str = cdl.idx2str
idx2label = cdl.idx2label

preds = []
image_ids = []

for batched_imgs, batched_targets in tqdm(test_data):
    batched_preds = pipeline(batched_imgs)

    preds += batched_preds
    image_ids += [_[0]['image_id'] for _ in batched_targets]

coco_result = preds2coco_eval_result(preds, image_ids, idx2label)

with open(args.eval_coco_path, 'w') as fout:
    json.dump(coco_result, fout)

ce = coco_eval(args.eval_coco_path)