In [1]:
import torch
import numpy as np 
import json
import pickle
import cv2
import re
from pprint import pprint

from PIL import Image, ImageDraw

import os

print(torch.__version__,torch.cuda.is_available())

torch.cuda.empty_cache()
torch.cuda.device_count()

import detectron2
from detectron2.utils.logger import setup_logger
from detectron2.data.datasets import register_coco_instances, load_coco_json
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, default_argument_parser, \
                              default_setup, hooks, launch, DefaultTrainer, HookBase, default_writers
from detectron2.config import CfgNode as CN, get_cfg
from detectron2.modeling import build_model
import detectron2.utils.comm as comm
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.data import detection_utils as utils, build_detection_test_loader, DatasetMapper, \
                            build_detection_train_loader, MetadataCatalog, DatasetCatalog
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.utils.events import EventStorage

logger = setup_logger()

import sys
sys.path.append("../")

from modules import CBMCOCOEvaluator, CustomMapper, add_cbm_config, add_uhcc_config, MyVisualizer

  from .autonotebook import tqdm as notebook_tqdm


1.10.1 True


In [2]:
# register sample dataset in Detectron2 style 
# as per Detectron2, the first argument in load_coco_json has to be a fully-qualified path 
DatasetCatalog.register("sample", lambda: load_coco_json("/YOUR_FILE_PATH/sample_data/sample_data.json", "", "sample", ['region_shape', 'region_orientation', 'region_margin', 'region_echo', 'region_posterior', 'region_cancer']))

MetadataCatalog.get("sample").shape_classes = ['oval', 'not oval']
MetadataCatalog.get("sample").orientation_classes = ['parallel', 'not parallel']
MetadataCatalog.get("sample").margin_classes = ['circumscribed', 'not circumscribed']
MetadataCatalog.get("sample").echo_classes = ['anechoic', 'not anechoic']
MetadataCatalog.get("sample").posterior_classes = ['no features', 'features']
MetadataCatalog.get("sample").cancer_classes = ['benign', 'malignant']

MetadataCatalog.get("sample").thing_colors = ["g", "m"]
MetadataCatalog.get("sample").thing_classes = ['lesion']

#### Training Cancer Model(s)
As training the lesion-only and concept prediction models can be handled by the standard Detectron2 training loop (just remember to freeze the backbone when training the concept predictions on top), we only provide an example for training the different cancer head configurations. 

In [4]:
cfg3a = get_cfg()
cfg3a.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
add_uhcc_config(cfg3a)
add_cbm_config(cfg3a)
cfg3a.MODEL.WEIGHTS = None
cfg3a.merge_from_file("configs/stage_3a.yaml")
cfg3a.DATALOADER.NUM_WORKERS = 0
#reduce the number of detections for visualization purposes 
cfg3a.TEST.DETECTIONS_PER_IMAGE = 1

model3a.train()

GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
      (res2): Sequential(
        (0): BottleneckBlock

In [11]:
# the structure of this train loop is provided in the Detectron2 GitHub repository. 
# to implement validation loss computation during the training loop, use the ValidationMapper class

# we use the existing_weights argument when training the side channel on an existing concept-only model 
existing_weights = None
# resuming from previous training or not 
resume = False

optimizer = build_optimizer(cfg3a, model3a)
scheduler = build_lr_scheduler(cfg3a, optimizer)

checkpointer = DetectionCheckpointer(
    model3a, save_dir='output', optimizer=optimizer, scheduler=scheduler
)
start_iter = (
    checkpointer.resume_or_load(cfg3a.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
)
max_iter = 25

periodic_checkpointer = PeriodicCheckpointer(
    checkpointer, cfg3a.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)

writers = default_writers(cfg3a.OUTPUT_DIR, max_iter) if comm.is_main_process() else []

data_loader = build_detection_train_loader(DatasetCatalog.get("sample"), mapper=CustomMapper(cfg3a, is_train=True, augmentations=[]), total_batch_size=8)

non_cancer_losses = ['loss_cls', 'loss_box_reg', 'loss_mask',  'shape_loss',  'margin_loss',  
                        'orientation_loss', 'echo_loss', 'posterior_loss', 'loss_rpn_cls',  'loss_rpn_loc']

logger.info("Starting training from iteration {}".format(start_iter))
with EventStorage(start_iter) as storage:
    for data, iteration in zip(data_loader, range(start_iter, max_iter)):
        storage.iter = iteration

        loss_dict = model3a(data)
        losses = sum(loss_dict.values()) 
        loss_dict_back = loss_dict.copy()

        # popping off all the other losses
        for e in non_cancer_losses: 
            if e in loss_dict_back:
                loss_dict_back.pop(e)

        losses_back = sum(loss_dict_back.values()) 

        assert torch.isfinite(losses).all(), loss_dict

        loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        if comm.is_main_process():
            storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses_back.backward()       
        optimizer.step()

        if cfg3a.MODEL.CBM.SIDE_CHANNEL and (existing_weights is not None):
            # making a copy of the adjusted weight 
            c = copy.deepcopy(model3a.roi_heads.cancer_head.second_model[0].weight)
            # slice out the new weights for the side_channel 
            d = c[:, 5:]
            # combine original values with placeholder 
            weight_val = torch.cat((existing_weights, d), dim=1)
            # assign weight to the combined version of old + new
            model3a.roi_heads.cancer_head.second_model[0].weight = torch.nn.Parameter(weight_val)
        elif not cfg3a.MODEL.CBM.SIDE_CHANNEL and (existing_weights is None):
            pass
        else:
            raise ValueError('Mismatch between side channel and existing weights found')

        storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
        scheduler.step()

        if iteration - start_iter > 5 and (
            (iteration + 1) % 20 == 0 or iteration == max_iter - 1
        ):
            for writer in writers:
                writer.write()
        periodic_checkpointer.step(iteration)

[32m[03/06 21:49:30 d2.data.datasets.coco]: [0mLoaded 5 images in COCO format from /YOUR_FILE_PATH/sample_data/sample_data.json
[32m[03/06 21:49:30 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[03/06 21:49:30 d2.data.common]: [0mSerializing 5 elements to byte tensors and concatenating them all ...
[32m[03/06 21:49:30 d2.data.common]: [0mSerialized dataset takes 0.00 MiB
[32m[03/06 21:49:30 detectron2]: [0mStarting training from iteration 0
[32m[03/06 21:49:30 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[03/06 21:49:30 d2.data.common]: [0mSerializing 5 elements to byte tensors and concatenating them all ...
[32m[03/06 21:49:30 d2.data.common]: [0mSerialized da