In [1]:
import os
from detectron2.data.datasets import register_coco_instances
from detectron2.data.catalog import DatasetCatalog
from detectron2.data.catalog import MetadataCatalog
import cv2
import numpy as np
from pathlib import Path
import random
from detectron2.utils.visualizer import Visualizer
import matplotlib.pyplot as plt
from tensorboard.plugins.hparams import api as hp
import tensorflow as tf

from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
import torch
import time
import datetime
import logging
from detectron2.engine import DefaultTrainer
from detectron2.evaluation import COCOEvaluator, inference_on_dataset

from GPUtil import showUtilization as gpu_usage

from detectron2.config import get_cfg
cfg = get_cfg()

#torch.cuda.set_per_process_memory_fraction(0.2)

def train_hail_model(run_dir, base_lr, gamma, batch_size):
    cfg.merge_from_file("/home/appuser/detectron2_repo/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    
    cfg.INPUT.MIN_SIZE_TRAIN = (500,)
    
    cfg.OUTPUT_DIR = run_dir
    
    cfg.DATASETS.TRAIN = ("train_hail",)
    cfg.DATASETS.TEST = ("val_hail",)
    
    cfg.DATALOADER.NUM_WORKERS = 1

    #cfg.INPUT.RANDOM_FLIP = "horizontal"
    #cfg.SOLVER.CHECKPOINT_PERIOD = 2000

    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = base_lr
    cfg.SOLVER.WARMUP_ITERS = 50
    cfg.SOLVER.MAX_ITER = 200 #adjust up if val AP is still rising, adjust down if overfit
    cfg.SOLVER.STEPS = (80, 90)
    cfg.SOLVER.GAMMA = gamma

    # Test
    cfg.TEST.EVAL_PERIOD = 100
    cfg.TEST.DETECTIONS_PER_IMAGE = 50
    
    cfg.MODEL.DEVICE = "cuda"
    cfg.MODEL.WEIGHTS = "model_final_f10217.pkl"  # initialize from model zoo
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = batch_size
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # 1 class (hail)
    cfg.MODEL.MASK_ON = True
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = CocoTrainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()
    
    output_dir = cfg.OUTPUT_DIR + '/eval/'
    os.makedirs(output_dir, exist_ok=True)
    evaluator = COCOEvaluator("val_hail", cfg, False, output_dir=output_dir)
    val_loader = build_detection_test_loader(cfg, "val_hail")
    result = inference_on_dataset(trainer.model, val_loader, evaluator)


def run(run_dir, hparams, base_lr, gamma, batch_size):
    with tf.summary.create_file_writer(run_dir).as_default():
        hp.hparams(hparams)  # record the values used in this trial
        
        cfg.merge_from_file("/home/appuser/detectron2_repo/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    
        cfg.INPUT.MIN_SIZE_TRAIN = (500,)

        cfg.OUTPUT_DIR = run_dir

        cfg.DATASETS.TRAIN = ("train_hail",)
        cfg.DATASETS.TEST = ("val_hail",)

        cfg.DATALOADER.NUM_WORKERS = 1

        #cfg.INPUT.RANDOM_FLIP = "horizontal"
        #cfg.SOLVER.CHECKPOINT_PERIOD = 2000

        cfg.SOLVER.IMS_PER_BATCH = 2
        cfg.SOLVER.BASE_LR = base_lr
        cfg.SOLVER.WARMUP_ITERS = 50
        cfg.SOLVER.MAX_ITER = 200 #adjust up if val AP is still rising, adjust down if overfit
        cfg.SOLVER.STEPS = (80, 90)
        cfg.SOLVER.GAMMA = gamma

        # Test
        cfg.TEST.EVAL_PERIOD = 100
        cfg.TEST.DETECTIONS_PER_IMAGE = 50

        cfg.MODEL.DEVICE = "cuda"
        cfg.MODEL.WEIGHTS = "model_final_f10217.pkl"  # initialize from model zoo
        cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = batch_size
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # 1 class (hail)
        cfg.MODEL.MASK_ON = True
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05

        os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
        trainer = CocoTrainer(cfg)
        trainer.resume_or_load(resume=False)
        trainer.train()

        output_dir = cfg.OUTPUT_DIR + '/eval/'
        os.makedirs(output_dir, exist_ok=True)
        evaluator = COCOEvaluator("val_hail", cfg, False, output_dir=output_dir)
        val_loader = build_detection_test_loader(cfg, "val_hail")
        result = inference_on_dataset(trainer.model, val_loader, evaluator)
        
        AP_bbox = result['bbox']['AP']
        AP50_bbox = result['bbox']['AP50']
        AP75_bbox = result['bbox']['AP75']
        AP_segm = result['segm']['AP']
        AP50_segm = result['segm']['AP50']
        AP75_segm = result['segm']['AP75']
        
        del result
        
        tf.summary.scalar(AP_BBOX, AP_bbox, step=1)
        tf.summary.scalar(AP50_BBOX, AP50_bbox, step=1)
        tf.summary.scalar(AP75_BBOX, AP75_bbox, step=1)
        tf.summary.scalar(AP_SEGM, AP_segm, step=1)
        tf.summary.scalar(AP50_SEGM, AP50_segm, step=1)
        tf.summary.scalar(AP75_SEGM, AP75_segm, step=1)
        
        del AP_bbox, AP50_bbox, AP75_bbox, AP_segm, AP50_segm, AP75_segm
        
        
class CocoTrainer(DefaultTrainer):

  @classmethod
  def build_evaluator(cls, cfg, dataset_name, output_folder=cfg.OUTPUT_DIR):
    
    return COCOEvaluator(dataset_name, cfg, False, output_folder)

  def build_hooks(self):
        
    hooks = super().build_hooks()
    hooks.insert(-1,LossEvalHook(
        cfg.TEST.EVAL_PERIOD,
        self.model,
        build_detection_test_loader(
            self.cfg,
            self.cfg.DATASETS.TEST[0],
            DatasetMapper(self.cfg,True)
        )
    ))
    return hooks

class LossEvalHook(HookBase):
    def __init__(self, eval_period, model, data_loader):
        self._model = model
        self._period = eval_period
        self._data_loader = data_loader
    
    def _do_loss_eval(self):
        # Copying inference_on_dataset from evaluator.py
        total = len(self._data_loader)
        num_warmup = min(5, total - 1)
            
        start_time = time.perf_counter()
        total_compute_time = 0
        losses = []
        for idx, inputs in enumerate(self._data_loader):            
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_compute_time = 0
            start_compute_time = time.perf_counter()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_compute_time += time.perf_counter() - start_compute_time
            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            seconds_per_img = total_compute_time / iters_after_start
            if idx >= num_warmup * 2 or seconds_per_img > 5:
                total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
                eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    "Loss on Validation  done {}/{}. {:.4f} s / img. ETA={}".format(
                        idx + 1, total, seconds_per_img, str(eta)
                    ),
                    n=5,
                )
            loss_batch = self._get_loss(inputs)
            losses.append(loss_batch)
        mean_loss = np.mean(losses)
        self.trainer.storage.put_scalar('validation_loss', mean_loss)
        comm.synchronize()

        return losses
            
    def _get_loss(self, data):
        # How loss is calculated on train_loop 
        metrics_dict = self._model(data)
        metrics_dict = {
            k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
            for k, v in metrics_dict.items()
        }
        total_losses_reduced = sum(loss for loss in metrics_dict.values())
        return total_losses_reduced
             
    def after_step(self):
        next_iter = self.trainer.iter + 1
        is_final = next_iter == self.trainer.max_iter
        if is_final or (self._period > 0 and next_iter % self._period == 0):
            self._do_loss_eval()
        self.trainer.storage.put_scalars(timetest=12)

In [2]:
register = 1
if register == 1:
    #Register train data set
    register_coco_instances("train_hail", {}, "./data/hail_20210620_r1/train/annotations/instances_default.json", "./data/hail_20210620_r1/train/images")
    dataset_dicts_train = DatasetCatalog.get("train_hail")
    hail_metadata_train = MetadataCatalog.get("train_hail")

    #Register validation data set
    register_coco_instances("val_hail", {}, "./data/hail_20210620_r1/val/annotations/instances_default.json", "./data/hail_20210620_r1/val/images")
    dataset_dicts_val = DatasetCatalog.get("val_hail")
    hail_metadata_val = MetadataCatalog.get("val_hail")

    #Register test data set
    register_coco_instances("test_hail", {}, "./data/hail_20210620_r1/test/annotations/instances_default.json", "./data/hail_20210620_r1/test/images")
    dataset_dicts_test = DatasetCatalog.get("test_hail")
    hail_metadata_test = MetadataCatalog.get("test_hail")

HP_BASE_LR = hp.HParam('base_lr', hp.Discrete([0.0001, 0.00025, 0.0005, 0.001]))
HP_GAMMA = hp.HParam('gamma', hp.Discrete([0.1,0.5]))
HP_BATCH_SIZE_PER_IMAGE = hp.HParam('batch_size_per_image', hp.Discrete([128, 256]))

AP_BBOX = 'AP_bbox'
AP50_BBOX = 'AP50_bbox'
AP75_BBOX = 'AP75_bbox'

AP_SEGM = 'AP_segm'
AP50_SEGM = 'AP50_segm'
AP75_SEGM = 'AP75_segm'

with tf.summary.create_file_writer('output/logs/hparam_tuning').as_default():
    hp.hparams_config(
        hparams=[HP_BASE_LR, HP_GAMMA, HP_BATCH_SIZE_PER_IMAGE],
        metrics=[hp.Metric(AP_BBOX, display_name='AP_bbox'), 
                 hp.Metric(AP50_BBOX, display_name='AP50_bbox'), 
                 hp.Metric(AP75_BBOX, display_name='AP75_bbox'),
                 hp.Metric(AP_SEGM, display_name='AP_segm'), 
                 hp.Metric(AP50_SEGM, display_name='AP50_segm'), 
                 hp.Metric(AP75_SEGM, display_name='AP75_segm')
                 ],
    ) 

session_num = 0
for base_lr in HP_BASE_LR.domain.values:
    for gamma in HP_GAMMA.domain.values:
        for batch_size_per_image in HP_BATCH_SIZE_PER_IMAGE.domain.values:
            hparams = {
                HP_BASE_LR: base_lr,
                HP_GAMMA: gamma,
                HP_BATCH_SIZE_PER_IMAGE: batch_size_per_image,
                }
            run_name = "run-%d" % session_num
            print('--- Starting trial: %s' % run_name)
            print({h.name: hparams[h] for h in hparams})
            run('output/logs/hparam_tuning/' + run_name, hparams, base_lr, gamma, batch_size_per_image)
            session_num += 1

--- Starting trial: run-0
{'base_lr': 0.0001, 'gamma': 0.1, 'batch_size_per_image': 128}
[32m[07/22 11:13:45 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=

[32m[07/22 11:13:45 d2.data.datasets.coco]: [0mLoaded 150 images in COCO format from ./data/hail_20210620_r1/train/annotations/instances_default.json
[32m[07/22 11:13:45 d2.data.build]: [0mRemoved 2 images with no usable annotations. 148 images left.
[32m[07/22 11:13:45 d2.data.build]: [0mDistribution of instances among all 1 categories:
[36m|  category  | #instances   |
|:----------:|:-------------|
|    hail    | 937          |
|            |              |[0m
[32m[07/22 11:13:45 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(500,), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[07/22 11:13:45 d2.data.build]: [0mUsing training sampler TrainingSampler
[32m[07/22 11:13:45 d2.data.common]: [0mSerializing 148 elements to byte tensors and concatenating them all ...
[32m[07/22 11:13:45 d2.data.common]: [0mSerialized dataset takes 0.34 MiB
[32m[07/22 11:13:45 d2.data.dataset_mapper]: [0m[Dataset

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[07/22 11:13:45 d2.engine.train_loop]: [0mStarting training from iteration 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[07/22 11:13:48 d2.utils.events]: [0m eta: 0:00:19  iter: 19  total_loss: 5.619  loss_cls: 0.8687  loss_box_reg: 0.3974  loss_mask: 0.6954  loss_rpn_cls: 3.178  loss_rpn_loc: 0.5747  time: 0.1153  data_time: 0.0076  lr: 3.8062e-05  max_mem: 1083M
[32m[07/22 11:13:50 d2.utils.events]: [0m eta: 0:00:18  iter: 39  total_loss: 2.388  loss_cls: 0.4649  loss_box_reg: 0.7612  loss_mask: 0.6167  loss_rpn_cls: 0.25  loss_rpn_loc: 0.2425  time: 0.1150  data_time: 0.0028  lr: 7.8022e-05  max_mem: 1083M
[32m[07/22 11:13:52 d2.utils.events]: [0m eta: 0:00:16  iter: 59  total_loss: 2.007  loss_cls: 0.3656  loss_box_reg: 0.7932  loss_mask: 0.5135  loss_rpn_cls: 0.09244  loss_rpn_loc: 0.2186  time: 0.1157  data_time: 0.0027  lr: 0.0001  max_mem: 1083M
[32m[07/22 11:13:55 d2.utils.events]: [0m eta: 0:00:13  iter: 79  total_loss: 1.874  loss_cls: 0.3041  loss_box_reg: 0.7779  loss_mask: 0.4753  loss_rpn_cls: 0.09315  loss_rpn_loc: 0.2078  time: 0.1162  data_time: 0.0034  lr: 0.0001  max_mem:

[32m[07/22 11:14:16 d2.engine.hooks]: [0mOverall training speed: 198 iterations in 0:00:23 (0.1173 s / it)
[32m[07/22 11:14:16 d2.engine.hooks]: [0mTotal training time: 0:00:30 (0:00:06 on hooks)
[32m[07/22 11:14:16 d2.data.datasets.coco]: [0mLoaded 33 images in COCO format from ./data/hail_20210620_r1/val/annotations/instances_default.json
[32m[07/22 11:14:16 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(800, 800), max_size=1333, sample_style='choice')]
[32m[07/22 11:14:16 d2.data.common]: [0mSerializing 33 elements to byte tensors and concatenating them all ...
[32m[07/22 11:14:16 d2.data.common]: [0mSerialized dataset takes 0.09 MiB
[32m[07/22 11:14:16 d2.evaluation.coco_evaluation]: [0mFast COCO eval is not built. Falling back to official COCO eval.
[32m[07/22 11:14:16 d2.evaluation.evaluator]: [0mStart inference on 33 batches
[32m[07/22 11:14:16 d2.evaluation.evaluator]: [0mInference done 11/33

[32m[07/22 11:14:22 d2.evaluation.coco_evaluation]: [0mSome metrics cannot be computed and is shown as NaN.
Loading and preparing results...
DONE (t=0.01s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *segm*
DONE (t=0.34s).
Accumulating evaluation results...
DONE (t=0.01s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.289
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.756
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.127
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.289
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.062
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.360
 Average Recall     (AR) @[ IoU=0.50

[32m[07/22 11:14:23 d2.data.datasets.coco]: [0mLoaded 150 images in COCO format from ./data/hail_20210620_r1/train/annotations/instances_default.json
[32m[07/22 11:14:23 d2.data.build]: [0mRemoved 2 images with no usable annotations. 148 images left.
[32m[07/22 11:14:23 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(500,), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[07/22 11:14:23 d2.data.build]: [0mUsing training sampler TrainingSampler
[32m[07/22 11:14:23 d2.data.common]: [0mSerializing 148 elements to byte tensors and concatenating them all ...
[32m[07/22 11:14:23 d2.data.common]: [0mSerialized dataset takes 0.34 MiB
[32m[07/22 11:14:23 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(500,), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[07/22 11:14:23 d2.data.datasets.coco]: [0mLoaded 33 images in COCO format 

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[07/22 11:14:23 d2.engine.train_loop]: [0mStarting training from iteration 0
[32m[07/22 11:14:25 d2.utils.events]: [0m eta: 0:00:20  iter: 19  total_loss: 5.356  loss_cls: 0.736  loss_box_reg: 0.2042  loss_mask: 0.688  loss_rpn_cls: 3.283  loss_rpn_loc: 0.5339  time: 0.1176  data_time: 0.0076  lr: 3.8062e-05  max_mem: 1199M
[4m[5m[31mERROR[0m [32m[07/22 11:14:26 d2.engine.train_loop]: [0mException during training:
Traceback (most recent call last):
  File "/container/hail/detectron2/detectron2/engine/train_loop.py", line 149, in train
    self.run_step()
  File "/container/hail/detectron2/detectron2/engine/defaults.py", line 494, in run_step
    self._trainer.run_step()
  File "/container/hail/detectron2/detectron2/engine/train_loop.py", line 285, in run_step
    losses.backward()
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/

RuntimeError: CUDA out of memory. Tried to allocate 96.00 MiB (GPU 0; 7.80 GiB total capacity; 1.17 GiB already allocated; 36.62 MiB free; 1.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [6]:
!pip install numba

Defaulting to user installation because normal site-packages is not writeable
Collecting numba
  Downloading numba-0.53.1-cp36-cp36m-manylinux2014_x86_64.whl (3.4 MB)
     |################################| 3.4 MB 1.8 MB/s            
[?25hCollecting llvmlite<0.37,>=0.36.0rc1
  Downloading llvmlite-0.36.0-cp36-cp36m-manylinux2010_x86_64.whl (25.3 MB)
     |################################| 25.3 MB 6.8 MB/s            
Installing collected packages: llvmlite, numba
Successfully installed llvmlite-0.36.0 numba-0.53.1


In [8]:
torch.cuda.set_device(0)

In [11]:
free_gpu_cache()

Initial GPU Usage
| ID | GPU | MEM |
------------------
|  0 |  0% | 21% |
GPU Usage after emptying the cache
| ID | GPU | MEM |
------------------
|  0 | 17% |  2% |


In [18]:
#Detect hail and save segmentation masks
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.engine import DefaultPredictor
import numpy as np

from detectron2.utils.visualizer import ColorMode
import glob
from pathlib import Path
import pickle
import numpy as np

experiment_folder = './output/hail_20210620_r1_test_4/'

cfg.MODEL.WEIGHTS = os.path.join(experiment_folder, "model_0014999.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.80
predictor = DefaultPredictor(cfg)

images_path = 'data/hail_20210620_r1/all_images/'
mask_array_path = 'products/hail_20210620_r1_test_4/pkl/'

if not os.path.exists(mask_array_path):
    os.makedirs(mask_array_path)  

all_images = glob.glob(images_path+'*.png')
all_images.sort()

for file in all_images:
    im = cv2.imread(file)
    outputs = predictor(im)
    masks = outputs['instances'][outputs['instances'].pred_classes==0].pred_masks.cpu().numpy()

    image_name = Path(file).stem

    mask_array = []
    #hail_count = 1 # For saving all individual masks to a figure
    for i in range(masks.shape[0]):
        mask_int = masks[i,:,:]*1
        mask_array.append(mask_int)
        
        #hail_area = (mask_int == 1).sum() #1 pixel = 1mm^2
        #hail_areas.append(hail_area)
        #title='Hailstone area: '+str(hail_area)+'mm$^2$'
        #display_images([masks[:,:,i]],mask_path, image_name, hail_count, titles=[title])
        #plt.close('all')
        #hail_count = hail_count + 1

    with open(mask_array_path+'mask_array_'+image_name+'.pkl','wb') as f:
        pickle.dump(mask_array, f)