In [None]:
#!/usr/bin/env python3

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import os
import random
import cv2
import json
import torch
import copy
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torchio as tio

# common detectron2 utilities
from detectron2 import model_zoo
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.evaluation import COCOEvaluator
from detectron2.config import get_cfg
from detectron2.structures import BoxMode

import detectron2.data.transforms as T
from detectron2.data import detection_utils as utils

In [None]:
#data_path = Path("../data/detectron2/axial")
data_path = Path("../data/synthetic")

In [None]:
def get_board_dicts(imgdir):
    json_file = imgdir / "dataset.json"
    
    with open(json_file) as f:
        dataset_dicts = json.load(f)

    for i in dataset_dicts:
        filename = i["file_name"] 
        i["file_name"] = imgdir / filename 
        i["width"] = int(i["width"])
        i["height"] = int(i["height"])

        for j in i["annotations"]:
            j["bbox"] = [float(num) for num in j["bbox"]]
            j["bbox_mode"] = int(j["bbox_mode"]) # BoxMode.XYWH_ABS
            j["category_id"] = int(j["category_id"])

    return dataset_dicts

def mapper(dataset_dict):
    dataset_dict = copy.deepcopy(dataset_dict)

    image = np.load(dataset_dict["file_name"])

    auginput = T.AugInput(image)
    image = torch.from_numpy(auginput.image.transpose(2, 0, 1))
    annos = [
        utils.transform_instance_annotations(annotation, [], image.shape[1:])
        for annotation in dataset_dict.pop("annotations")
    ]

    return {
       "image": image,
       "image_id": dataset_dict["image_id"],
       "width": dataset_dict["width"],
       "height": dataset_dict["height"],
       "instances": utils.annotations_to_instances(annos, image.shape[1:])
    }

In [None]:
# https://github.com/facebookresearch/detectron2/blob/main/projects/DeepLab/train_net.py

class SyntheticTrainer(DefaultTrainer):
    @classmethod
    def build_train_loader(cls, cfg):
        return detectron2.data.build_detection_train_loader(cfg, mapper=mapper)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        return detectron2.data.build_detection_test_loader(cfg, dataset_name, mapper=mapper)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        return COCOEvaluator(dataset_name, output_dir=cfg.OUTPUT_DIR)

In [None]:
# Registering the data sets

def register_data(set_name, set_path):
    DatasetCatalog.register("brain_metastasis_" + set_name, lambda d=set_path: get_board_dicts(d))
    MetadataCatalog.get("brain_metastasis_" + set_name).set(thing_classes=["TUMOR"])

register_data('train', data_path / 'train')
register_data('test', data_path / 'test')

In [None]:
# Training model

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))

cfg.DATASETS.TRAIN = ("brain_metastasis_train",)
cfg.DATASETS.TEST = ("brain_metastasis_test",)

# Number of data loading threads
cfg.DATALOADER.NUM_WORKERS = 4
#cfg.MODEL.WEIGHTS = os.path.join('./output_first_model', "model_final.pth")  # weights from model we just trained
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")  # weights from detectron2 model zoo

# Number of images per batch across all machines.
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.01
cfg.SOLVER.MAX_ITER = 1000
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.TEST.EVAL_PERIOD = 1000

cfg.OUTPUT_DIR = './output_second_model'

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

trainer = SyntheticTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

In [None]:
# visualizing predictions

cfg.MODEL.WEIGHTS = "./output_synthetic_bnu/model_final.pth"
cfg.DATASETS.TEST = ("brain_metastasis_test", )
predictor = DefaultPredictor(cfg)
test_metadata = MetadataCatalog.get("brain_metastasis_test")

limit = 1

for imageName in data_path.glob('test/*'):
  if limit == 0:
    break
  limit -= 1
  im = np.load(imageName)
  outputs = predictor(im)
  v = Visualizer(im[:, :, ::-1],
                metadata=test_metadata, 
                scale=0.8
                 )
  out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
  cv2.imwrite(f"./inference{limit}.png", out.get_image()[:, :, ::-1])