In [None]:
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
#from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data import build_detection_train_loader
from detectron2.engine import DefaultTrainer
from detectron2.engine import SimpleTrainer
from detectron2.engine import HookBase
from typing import Dict, List, Optional
import detectron2.solver as solver
import detectron2.modeling as modeler
import detectron2.data as data
import detectron2.data.transforms as T
import detectron2.checkpoint as checkpointer
from detectron2.data import detection_utils as utils
import weakref
import copy
import torch
import time

from astrodet import astrodet as toolkit
from astrodet import detectron as detectron_addons


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Print the versions to test the imports and so we know what works
print(detectron2.__version__)
print(np.__version__)
print(cv2.__version__)


In [None]:
# Prettify the plotting
from astrodet.astrodet import set_mpl_style
set_mpl_style()

### Register HSC training data

First, format the HSC data using training_data.ipynb.  It will need to be partitioned in "train, test and val" directories

The file metadata for each dataset is specified with the `filesnames_dict`. We will specify the filters first, then populate the filenames in the dataset directory.

For a custom dataset, this dictionary needs to be populated correctly for your data.

You will need to change directory paths!

In [None]:
dirpath = '/home/shared/hsc/HSC/HSC_DR3/data/'
output_dir = './output/hsc'

In [None]:
#this block is for debug purposes, set to -1 to include every sample
sampleNumbers = 20

In [None]:
from detectron2.structures import BoxMode
from astropy.io import fits
import glob
#Yufeng Jun19 add test here
dataset_names = ['train','test'] #, 'val'] # train
filenames_dict_list = [] # List holding filenames_dict for each dataset

for i, d in enumerate(dataset_names):
    data_path = os.path.join(dirpath, d)

    # Get dataset dict info
    filenames_dict = {}
    filenames_dict['filters'] = ['g', 'r', 'i']

    # Get each unqiue tract-patch in the data directory
    #file = full path name
    files = glob.glob(os.path.join(data_path, '*_scarlet_segmask.fits'))
    if sampleNumbers != -1:
        files = files[:sampleNumbers]
    # s = sample name
    s = [os.path.basename(f).split('_scarlet_segmask.fits')[0] for f in files]
    print(f'Tract-patch List: {s}')
    for f in filenames_dict['filters']:
        filenames_dict[f] = {}
        # List of image files in the dataset
        #Yufeng dec/21  [Errno 2] No such file or directory: '/home/shared/hsc/test/G-I-8525-4,5-c5_scarlet_img'
        #filenames_dict[f]['img'] = [os.path.join(data_path, f'{f.upper()}-{tract_patch}_scarlet_img.fits') for tract_patch in s]
        #Yufeng jan 18 f.upper() indicates filter, tract_patch[1:] removes the default I band in the front
        filenames_dict[f]['img'] = [os.path.join(data_path, f.upper() + f'{tract_patch[1:]}_scarlet_img.fits') for tract_patch in s]
        # List of mask files in the dataset
        #Yufeng jan 18 all mask files are in the I band
        filenames_dict[f]['mask'] = [os.path.join(data_path, f'{tract_patch}_scarlet_segmask.fits') for tract_patch in s]
        
    filenames_dict_list.append(filenames_dict)

In [None]:
#number of total samples
print('# of train sample: ', len(filenames_dict_list[0]['g']['img']))
print('# of test sample: ', len(filenames_dict_list[1]['g']['img']))

For detectron2 to read the data, it must be in a dictionary format.  The function get_astro_dicts reads in the FITS files
and formats to a dictionary. 

However, this step can take a few minutes, and so we recommend only running it once and saving the dictionary data as a json file that can be read in at the beginning of your code.  Check out the `format_hsc_data.ipynb` notebook for how to do that


In [None]:
def get_astro_dicts(filename_dict):
    
    """
    This needs to be customized to your traning data format
    
    """
        
    dataset_dicts = []
    filters = list(filename_dict.keys())
    #yufeng april5: why only 1st filter
    f = filename_dict['filters'][0] # Pick the 1st filter for now
    
    # Filename loop
    for idx, (filename_img, filename_mask) in enumerate(zip(filename_dict[f]['img'], filename_dict[f]['mask'])):
        record = {}

        # Open FITS image of first filter (each should have same shape)
        with fits.open(filename_img, memmap=False, lazy_load_hdus=False) as hdul:
            height, width = hdul[0].data.shape
            
        # Open each FITS mask image
        with fits.open(filename_mask, memmap=False, lazy_load_hdus=False) as hdul:
            hdul = hdul[1:]
            sources = len(hdul)
            # Normalize data
            data = [hdu.data for hdu in hdul]
            #category_ids = [hdu.header["NEW_ID"] for hdu in hdul]
            category_ids = [0 for hdu in hdul]

            ellipse_pars = [hdu.header["ELL_PARM"] for hdu in hdul]
            bbox = [list(map(int, hdu.header["BBOX"].split(','))) for hdu in hdul]
            area = [hdu.header["AREA"] for hdu in hdul]

        # Add image metadata to record (should be the same for each filter)
        for f in filename_dict['filters']:
            record[f"filename_{f.upper()}"] = filename_dict[f]['img'][idx]
        # Assign file_name
        record[f"file_name"] = filename_dict[filename_dict['filters'][0]]['img'][idx]
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
        objs = []

        # Generate segmentation masks from model
        for i in range(sources):
            image = data[i]
            # Why do we need this?
            if len(image.shape) != 2:
                continue
            height_mask, width_mask = image.shape
            # Create mask from threshold
            mask = data[i]
            # Smooth mask
            #mask = cv2.GaussianBlur(mask, (9,9), 2)
            x,y,w,h = bbox[i] # (x0, y0, w, h)

            # https://github.com/facebookresearch/Detectron/issues/100
            contours, hierarchy = cv2.findContours((mask).astype(np.uint8), cv2.RETR_TREE,
                                                        cv2.CHAIN_APPROX_SIMPLE)
            segmentation = []
            for contour in contours:
                # contour = [x1, y1, ..., xn, yn]
                contour = contour.flatten()
                if len(contour) > 4:
                    contour[::2] += (x-w//2)
                    contour[1::2] += (y-h//2)
                    segmentation.append(contour.tolist())
            # No valid countors
            if len(segmentation) == 0:
                continue

            # Add to dict
            obj = {
                "bbox": [x-w//2, y-h//2, w, h],
                "area": w*h,
                "bbox_mode": BoxMode.XYWH_ABS,
                "segmentation": segmentation,
                "category_id": category_ids[i],
                "ellipse_pars": ellipse_pars[i]
            }
            objs.append(obj)
        
        record["annotations"] = objs
        dataset_dicts.append(record)
            
    return dataset_dicts

Now, we register the dataset following the detectron2 documention.

In [None]:

#code snippet for unregistering if you want to change something
'''
if "astro_train" in DatasetCatalog.list():
    print('removing astro_train')
    DatasetCatalog.remove("astro_train")
    #MetadataCatalog.remove("astro_train")
    
if "astro_test" in DatasetCatalog.list():
    print('removing astro_test')
    DatasetCatalog.remove("astro_test")
    #MetadataCatalog.remove("astro_test")

if "astro_val" in DatasetCatalog.list():
    print('removing astro_val')
    DatasetCatalog.remove("astro_val")
    #MetadataCatalog.remove("astro_val")

'''

### A note on classes

In this demo, we assume one class for all objects.  To see how we assign classes based on external HSC catalogs, check out the `hsc_class_assign.ipynb` notebook 

In [None]:
DatasetCatalog.register("astro_train", lambda: get_astro_dicts(filenames_dict_list[0]))
astrotrain_metadata = MetadataCatalog.get("astro_train").set(thing_classes=["object"])
DatasetCatalog.register("astro_test", lambda: get_astro_dicts(filenames_dict_list[1]))
astrotest_metadata = MetadataCatalog.get("astro_test").set(thing_classes=["object"])


dataset_dicts = {}
for i, d in enumerate(dataset_names):
    print(f'Loading {d}')
    dataset_dicts[d] = get_astro_dicts(filenames_dict_list[i])

### Visualize Ground Truth Examples

In [None]:
nsample = 1
maxInd = sampleNumbers
if maxInd == -1: maxInd = 20
randInd = np.random.randint(0,maxInd, nsample)
fig = plt.figure(figsize=(15,15*nsample*2))
i = 0
for ind in randInd:
    # Need to increase ceil_percentile if the data are saturating!
    d = dataset_dicts['train'][ind]
    filenames = [d['filename_G'],d['filename_R'],d['filename_I']]
    img = toolkit.read_image_hsc(filenames, normalize="astrolupton", stretch=.5, Q=10)
    visualizer = Visualizer(img, metadata=astrotrain_metadata)
    out = visualizer.draw_dataset_dict(d)
    ax1 = plt.subplot(nsample*2, 1, 2*i+1)
    ax1.imshow(out.get_image(), origin='upper')
    ax1.axis('off')
    ax2 = plt.subplot(nsample*2, 1, 2*i+2)
    ax2.imshow(img)
    ax2.axis('off')
    i += 1
    
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
bins = 25
#91,28, 38 are  bad examples
d = dataset_dicts['train'][1]
filenames=[d['filename_G'],d['filename_R'],d['filename_I']] 

img = toolkit.read_image_hsc(filenames, normalize="astrolupton", stretch=.5, Q=10)
ax.hist(img[:,:,0].flatten(), histtype="step", bins=bins, log=True, color="r", lw=2, zorder=1, label='i');
ax.hist(img[:,:,1].flatten(), histtype="step", bins=bins, log=True, color="g", lw=2, linestyle='-.', zorder=2, label='r');
ax.hist(img[:,:,2].flatten(), histtype="step", bins=bins, log=True, color="b", lw=2, linestyle='dashed', zorder=3, label='g');
ax.set_xlabel('Value', fontsize=20)
ax.set_ylabel('Count', fontsize=20)
ax.legend(fontsize=18)

fig.tight_layout()

### Data Augmentation

In [None]:
from astrodet.detectron import _transform_to_aug
def train_mapper(dataset_dict):

    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    filenames=[dataset_dict['filename_G'],dataset_dict['filename_R'],dataset_dict['filename_I']] 
    image = toolkit.read_image_hsc(filenames, normalize="astrolupton", stretch=0.5, Q=10)
    
    augs = detectron_addons.KRandomAugmentationList([
            # my custom augs
            T.RandomRotation([-90, 90, 180], sample_style='choice'),
            T.RandomFlip(prob=0.5),
            T.RandomFlip(prob=0.5,horizontal=False,vertical=True),
            #detectron_addons.CustomAug(gaussblur,prob=1.0),
            #detectron_addons.CustomAug(addelementwise,prob=1.0)
            #CustomAug(white),
            ],
            k=-1,
            cropaug=_transform_to_aug(T.CropTransform(image.shape[1]//4,image.shape[0]//4,image.shape[1]//2,image.shape[0]//2))
        )

    # Data Augmentation
    auginput = T.AugInput(image)
    # Transformations to model shapes
    transform = augs(auginput)
    image = torch.from_numpy(auginput.image.copy().transpose(2, 0, 1))
    annos = [
        utils.transform_instance_annotations(annotation, [transform], image.shape[1:])
        for annotation in dataset_dict.pop("annotations")
    ]
    
    instances = utils.annotations_to_instances(annos, image.shape[1:])
    instances = utils.filter_empty_instances(instances) 
    
    return {
       # create the format that the model expects
        "image": image,
        "image_shaped": auginput.image,
        "height": 1050,
        "width": 1050,
        "image_id": dataset_dict["image_id"],
        "instances": instances,
    }

In [None]:
from detectron2.structures import BoxMode
fig, axs = plt.subplots(1,2, figsize=(10*2, 10))

dictionary = iter(dataset_dicts['train'])
d = next(dictionary)
filenames=[d['filename_G'],d['filename_R'],d['filename_I']] 

img = toolkit.read_image_hsc(filenames, normalize="astrolupton", stretch=0.5, Q=10)
visualizer = Visualizer(img, metadata=astrotrain_metadata, scale=1)
# Get the ground truth boxes
gt_boxes = np.array([a['bbox'] for a in d['annotations']])
# Convert to the mode visualizer expects
gt_boxes = BoxMode.convert(gt_boxes, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
out = visualizer.overlay_instances(boxes=gt_boxes)
axs[0].imshow(out.get_image())
axs[0].axis('off')

aug_d = train_mapper(d)
img_aug = aug_d["image_shaped"]
visualizer = Visualizer(img_aug, metadata=astrotrain_metadata, scale=1)
print(img_aug.shape)
# Convert to the mode visualizer expects
out = visualizer.overlay_instances(boxes=aug_d['instances'].gt_boxes)
axs[1].imshow(out.get_image())
axs[1].axis('off')
fig.tight_layout()
fig.show()

### Training

We prepare for training by intializing a config object and setting hyperparameters.  The we can take the intial weights from the pre-trained models in the model zoo.  For a full list of available config options, check https://detectron2.readthedocs.io/en/latest/modules/config.html

This setup is for demo purposes, so it does not follow the full training schedule we use for the paper.  You can check the `train_hsc_primary.py` script for the final training configurations 

In [None]:
cfg = get_cfg()
init_coco_weights = True # Start training from MS COCO weights

cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")) # Get model structure
cfg.DATASETS.TRAIN = ("astro_train") 
cfg.DATASETS.TEST = ("astro_test")

cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN=6000
cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.33


cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.PIXEL_MEAN = [13.49794151,  9.11051305,  5.42995532]
    
cfg.INPUT.MIN_SIZE_TRAIN = 500
cfg.INPUT.MAX_SIZE_TRAIN = 525

cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[8, 16, 32, 64, 128]]
cfg.SOLVER.IMS_PER_BATCH = 2   # this is images per iteration. 1 epoch is len(images)/(ims_per_batch iterations)

cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
# Type of gradient clipping, currently 2 values are supported:
# - "value": the absolute values of elements of each gradients are clipped
# - "norm": the norm of the gradient for each parameter is clipped thus
#   affecting all elements in the parameter
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "norm"
# Maximum absolute value used for clipping gradients
# Floating point number p for L-p norm to be used with the "norm"
# gradient clipping type; for L-inf, please specify .inf
cfg.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 5.0



e1=200
cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.STEPS = []         
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
cfg.SOLVER.WARMUP_ITERS = 0
cfg.SOLVER.MAX_ITER = e1     # for DefaultTrainer

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.OUTPUT_DIR = output_dir
cfg.TEST.DETECTIONS_PER_IMAGE = 500

cfg.MODEL.BACKBONE.FREEZE_AT = 4

if init_coco_weights:
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")  # Initialize from MS COCO
else:
    cfg.MODEL.WEIGHTS = os.path.join(output_dir, 'model_temp.pth')  # Initialize from a local weights

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
model = modeler.build_model(cfg)
optimizer = solver.build_optimizer(cfg, model)
loader = data.build_detection_train_loader(cfg, mapper=train_mapper)
schedulerHook = detectron_addons.CustomLRScheduler(optimizer=optimizer)
saveHook = detectron_addons.SaveHook()
saveHook.set_output_name("model_temp")
hookList = [saveHook,schedulerHook]


In [None]:
# Hack if you get SSL certificate error 
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

Ignore warnings due to cropping

In [None]:
import warnings
try:
    # ignore ShapelyDeprecationWarning from fvcore
    # This comes from the cropping
    from shapely.errors import ShapelyDeprecationWarning
    warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)

except:
    pass


In [None]:
trainer = toolkit.NewAstroTrainer(model, loader, optimizer, cfg)
trainer.register_hooks(hookList)
trainer.set_period(10) # print loss every 10 iterations
trainer.train(0,e1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
ax.plot(trainer.lossList, label=r'$L_{\rm{tot}}$')
ax.legend(loc='upper right')
ax.set_xlabel('training epoch', fontsize=20)
ax.set_ylabel('loss', fontsize=20)
fig.tight_layout()

### Inference

In [None]:
roi_thresh=0.5
nms_thresh = 0.3

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")) # Get model structure
cfg.DATASETS.TRAIN = ("astro_train") # Register Metadata # TODO: Should be TRAIN
cfg.DATASETS.TEST = ("astro_test") # Config calls this TEST, but it should be the val dataset
cfg.OUTPUT_DIR = output_dir

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_temp.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2   # todo set a custom testing threshold
  
cfg.TEST.DETECTIONS_PER_IMAGE = 1000

cfg.INPUT.MIN_SIZE_TRAIN = 1025
cfg.INPUT.MAX_SIZE_TRAIN = 1050
cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 6000  
cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000 

cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[8, 16, 32, 64, 128]]


cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = roi_thresh   # set a custom testing threshold
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1


#predictor = DefaultPredictor(cfg)
predictor = toolkit.AstroPredictor(cfg)



In [None]:
from detectron2.utils.visualizer import ColorMode

nsample = 1
fig = plt.figure(figsize=(30,15*nsample))

for i, d in enumerate(random.sample(dataset_dicts['train'], nsample)):
    filenames=[d['filename_G'],d['filename_R'],d['filename_I']] 
    img = toolkit.read_image_hsc(filenames, normalize="astrolupton", stretch=0.5, Q=10)
    print('total instances:', len(d['annotations']))
    v0 = Visualizer(img,
                   metadata=astrotest_metadata, 
                   scale=1, 
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    groundTruth = v0.draw_dataset_dict(d)
    
    ax1 = plt.subplot(nsample, 2, 2*i+1)
    ax1.imshow(groundTruth.get_image())
    ax1.axis('off')
    
    v1 = Visualizer(img,
                   metadata=astrotest_metadata, 
                   scale=1, 
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    outputs = predictor(img)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    out = v1.draw_instance_predictions(outputs["instances"].to("cpu"))
    print('detected instances:', len(outputs['instances'].pred_boxes))
    print('')
    ax1 = plt.subplot(nsample,2,2*i+2)
    ax1.imshow(out.get_image())
    ax1.axis('off')
    

### Evaluate

In [None]:
def test_mapper(dataset_dict, **read_image_args):

    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    filenames=[d['filename_G'],d['filename_R'],d['filename_I']] 
    image = toolkit.read_image_hsc(filenames, normalize="astrolupton", stretch=0.5, Q=100)
    augs = T.AugmentationList([
    ])
    # Data Augmentation
    auginput = T.AugInput(image)
    # Transformations to model shapes
    transform = augs(auginput)
    image = torch.from_numpy(auginput.image.copy().transpose(2, 0, 1))
    annos = [
        utils.transform_instance_annotations(annotation, [transform], image.shape[1:])
        for annotation in dataset_dict.pop("annotations")
    ]
    return {
       # create the format that the model expects
        "image": image,
        "image_shaped": auginput.image,
        "height": 1050,
        "width": 1050,
        "image_id": dataset_dict["image_id"],
        "instances": utils.annotations_to_instances(annos, image.shape[1:]),
        "annotations": annos
    }

In [None]:
from detectron2.evaluation import inference_on_dataset
from detectron2.data import build_detection_test_loader

evaluator = toolkit.COCOEvaluatorRecall("astro_test", use_fast_impl=True, allow_cached_coco=False, output_dir=cfg.OUTPUT_DIR)

test_loader = data.build_detection_test_loader(cfg, "astro_test", mapper=test_mapper)

In [None]:
results = inference_on_dataset(predictor.model, test_loader, evaluator)

In [None]:
print(results['bbox'].keys())


In [None]:
ap_type = 'bbox' 
cls_names = ['star', 'galaxy']
#take star out april
results_per_category = results[ap_type]['results_per_category']

#fig, axs = plt.subplots(1, 2, figsize=(15, 5))
fig = plt.figure(figsize=(7,4))
#axs = plt.subplot(1, 1, figsize=(10, 10))
axs = fig.add_subplot(111)
#axs = axs.flatten()

ious = np.linspace(0.50,0.95,10)
colors = plt.cm.viridis(np.linspace(0,1,len(ious)))

# Plot precision recall
for j, precision_class in enumerate(results_per_category):
    precision_shape = np.shape(precision_class)
    for i in range(precision_shape[0]):
        # precision has dims (iou, recall, cls, area range, max dets)
        # area range index 0: all area ranges
        # max dets index -1: typically 100 per image
        p_dat = precision_class[i, :, j, 0, -1]
        # Hide vanishing precisions
        mask = (p_dat > 0)
        # Only keep first occurance of 0 value in array
        mask[np.cumsum(~mask) == 1] = True
        p = p_dat[mask]
        # Recall points
        r = np.linspace(0, 1, len(p))
        dr = np.diff(np.linspace(0, 1, len(p_dat)))[0] # i think
        # Plot
        iou = np.around(ious[i], 2)
        AP = 100*np.sum(p*dr)
        axs.plot(r, p, label=r'${\rm{AP}}_{%.2f} = %.1f$' % (iou, AP), color=colors[i], lw=2)
        axs.set_xlabel('Recall', fontsize=20)
        axs.set_ylabel('Precision', fontsize=20)
        axs.set_xlim(0, 1.1)
        axs.set_ylim(0, 1.1)
        axs.legend(fontsize=10, title=f'{cls_names[j]}', bbox_to_anchor=(1.35, 1.0))
        
        
        #axs[j].plot(r, p, label=r'${\rm{AP}}_{%.2f} = %.1f$' % (iou, AP), color=colors[i], lw=2)
        #axs[j].set_xlabel('Recall', fontsize=20)
        #axs[j].set_ylabel('Precision', fontsize=20)
        #axs[j].set_xlim(0, 1.1)
        #axs[j].set_ylim(0, 1.1)
        #axs[j].legend(fontsize=10, title=f'{cls_names[j]}', bbox_to_anchor=(1.35, 1.0))
        
fig.tight_layout()

Real data has a lot more variation than simulations and requires more training for the networks to have good evaulation performance.  This demo is just to show how to set up the training.  We encourage you to add object classes, try different contrast scalings, and train for longer!