## Training progress

In [None]:
import sys
import os
import pandas as pd

working_dir = '../output'
history_file = f'{working_dir}/history.csv'
if not os.path.exists(history_file):
    print(f'{history_file} not found')
    print('must train a model before running this')
    sys.exit()
    
df = pd.read_csv(history_file)

train_loss = []
val_loss = []
for _, group in df.groupby('epoch'):
    train_loss.append(group['train_loss'].mean())
    val_loss.append(group['epoch_val_loss'].mean())

loss_df = pd.DataFrame({'Training loss': train_loss, 'Validation loss': val_loss})
min_val_loss = df['epoch_val_loss'].min()
plots = loss_df.plot(xlabel='Epoch', grid=True, title=f'Minimum validation loss {min_val_loss:.4f}')

## Sample predictions versus ground truth

In [None]:
import warnings
warnings.filterwarnings("ignore")

import random
import cv2
from matplotlib import pyplot as plt

from detectron2.config import CfgNode as CN
from detectron2.engine import DefaultPredictor
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.visualizer import Visualizer


input_dir = '../input'
img_dir = f'{input_dir}/train'
set_name = 'val'
json_file = f'{working_dir}/{set_name}-dicts-coco.json'
assert os.path.exists(json_file)

try:
    register_coco_instances(set_name, {}, json_file, img_dir)   
except:
    pass

cfg = CN.load_cfg(open(f'{working_dir}/cfg.yaml'))
cfg.MODEL.WEIGHTS = f'{working_dir}/model.pth'
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
predictor = DefaultPredictor(cfg)
dataset_dicts = DatasetCatalog.get(set_name)

# display this many samples
count = 3

fig = plt.figure(figsize=(20, 25))
ax = fig.subplots(count, 2)
ax[0, 0].set_title('Prediction')
ax[0, 1].set_title('Ground Truth')
for i, dd in enumerate(random.sample(dataset_dicts, count)):    
    im = cv2.imread(dd['file_name'])
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    outputs = predictor(im)
    v = Visualizer(im, metadata=MetadataCatalog.get(set_name))
    pred = v.draw_instance_predictions(outputs['instances'].to('cpu'))
    visualizer = Visualizer(im, metadata=MetadataCatalog.get(set_name))
    target = visualizer.draw_dataset_dict(dd)
    pred_img = cv2.cvtColor(pred.get_image(), cv2.COLOR_BGR2RGB)
    target_img = cv2.cvtColor(target.get_image(), cv2.COLOR_BGR2RGB)
    ax[i, 0].imshow(pred_img)
    ax[i, 1].imshow(target_img)