In [None]:
!git clone https://github.com/corydambach/mmrotate-sandbox/
!pip install -q condacolab
import condacolab
condacolab.install() # expect a kernel restart

!mamba env update -n base -f mmrotate-sandbox/env.yml

fatal: destination path 'mmrotate-sandbox' already exists and is not an empty directory.
✨🍰✨ Everything looks OK!
Channels:
 - pytorch
 - nvidia
 - conda-forge
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - 

In [None]:
#os
import os

#ml
import numpy as np
import torch
import mmcv
import cv2
from mmdet.apis    import init_detector
from mmrotate.apis import inference_detector_by_patches
from dataclasses import asdict
import json

from geo_util import Pnt, Vehicle, VehicleExport, rbbox_to_poly

In [None]:

def process_results( img, result, out_file=None, **kwargs ):
    img = mmcv.imread(img)
    img = img.copy()
    if isinstance(result, tuple):
        bbox_result, segm_result = result
        if isinstance( segm_result, tuple ):
            segm_result = segm_result[0]
    else:
        bbox_result, segm_result = result, None
    bboxes = np.vstack(bbox_result)
    labels = [
        np.full( bbox.shape[0], i, dtype=np.int32 )
        for i, bbox in enumerate(bbox_result)
    ]
    labels = np.concatenate( labels )
    #remove everything that isn't a small vehicle
    # labels = labels[mask]
    # bboxes = bboxes[mask]
    print( f"Detected {len( bboxes )} vehicles." )
    reshaped = labels.reshape( -1, 1 )
    concat = np.concatenate( (bboxes, reshaped), axis=1 )
    np.save( out_file, concat )
    a_entries = [Vehicle( r[0], r[1], r[2], r[3], r[4], r ) for r in concat]
    vas = [VehicleExport( Pnt( a.x, a.y ), a.width, a.height, a.theta, rbbox_to_poly( a.arr ).tolist(), str(i) ) for i, a in enumerate(a_entries)]
    with open( f"{out_file}.json", "w" ) as f:
        json.dump( [asdict(va) for va in vas], f )
    return

In [None]:
class Args:
    def __init__( self, batch_size: int, patch_size: int, patch_step: int ):
        self.img = "./PA-SM-2020-06-12/satellite-sm.png"
        self.config = "mmrotate/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py"
        self.checkpoint = "model/redet_re50_fpn_1x_dota_ms_rr_le90-fc9217b5.pth"
        self.score_thr = 0.0025
        self.merge_iou_thr = 0.85
        self.img_ratios = [1.0]
        self.out_file = "PA-SM-2020-6-12-SM"
        self.batch_size = batch_size
        self.patch_sizes = [patch_size]
        self.patch_steps = [patch_step]
        self.device = 'cuda:0'
        self.palette = 'dota'

torch.cuda.set_per_process_memory_fraction( 0.85 )

In [None]:
### setup args here ###
args = Args( batch_size=24, patch_size=1024, patch_step=768 )

In [None]:
model = init_detector( args.config, args.checkpoint, device=args.device )

In [None]:
img = mmcv.imread( args.img )

In [None]:
result = inference_detector_by_patches( model,
                                        img,
                                        args.patch_sizes,
                                        args.patch_steps,
                                        args.img_ratios,
                                        args.merge_iou_thr,
                                        args.batch_size )

In [None]:
process_results( args.img, result, out_file=args.out_file, score_thr=args.score_thr )