In [1]:
!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint models/research/object_detection/test_data/

--2020-11-03 17:14:07--  http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 216.58.206.144, 2a00:1450:4009:811::2010
Connecting to download.tensorflow.org (download.tensorflow.org)|216.58.206.144|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 244817203 (233M) [application/x-tar]
Saving to: ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’


2020-11-03 17:14:14 (42.8 MB/s) - ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’ saved [244817203/244817203]



In [15]:
import gdal
import sys
sys.path.append('/home/dan/utils/dspix/geopix/')
import geopix as gp
import ogr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import tensorflow as tf
import random
import glob
import re
%matplotlib notebook

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder

In [5]:
img = gdal.Open('AOI_5_Khartoum_Train/RGB-PanSharpen/RGB-PanSharpen_AOI_5_Khartoum_img99.tif')
pixels = img.ReadAsArray()
pixels = pixels/pixels.max()
print(pixels.shape)

(3, 650, 650)


In [6]:
geo = ogr.Open('AOI_5_Khartoum_Train/geojson/buildings/buildings_AOI_5_Khartoum_img99.geojson')
geo

<osgeo.ogr.DataSource; proxy of <Swig Object of type 'OGRDataSourceShadow *' at 0x7f1d02acf0f0> >

In [2]:
tifs = glob.glob('AOI_5_Khartoum_Train/RGB-PanSharpen/*.tif')
geojsons = glob.glob('AOI_5_Khartoum_Train/geojson/buildings/*.geojson')

In [3]:
tif_ids = [int(re.findall(r'\d+', tif)[2]) for tif in tifs]
geojson_ids = [int(re.findall(r'\d+', js)[2]) for js in geojsons]

files = [[tifs[i], geojsons[geojson_ids.index(id_)]] for i, id_ in enumerate(tif_ids)]

In [21]:
def parse_boxes(geojson, gt, shape):
    nrows, ncols = shape
    geo = ogr.Open(geojson)
    gref = gp.Georeference(gt)
    
    layer = geo.GetLayer()
    layer.ResetReading()

    rel_boxes = []

    for feat in layer:
        geom = feat.GetGeometryRef()
        bbox = np.array(geom.GetEnvelope(), dtype=np.float32) # minX, maxX, minY, maxY
        bbox = gref.world2pix([bbox[::2], bbox[1::2]]).ravel()
        rel_boxes.append((bbox[[1, 0, 3, 2]] / np.array([nrows, ncols, nrows, ncols])))
    
    if len(rel_boxes)>0:   
        return np.array(rel_boxes)

In [22]:
images_np = []
boxes_np = []

for img_fn, json_fn in files:

    img = gdal.Open(img_fn)
    pixels = img.ReadAsArray()
    pixels = pixels/pixels.max()
    
    gt = img.GetGeoTransform()
    boxes = parse_boxes(json_fn, gt, pixels[0].shape)
    
    if boxes is not None:
        boxes_np.append(boxes)
        images_np.append((pixels.transpose(1, 2, 0)*255).astype(np.uint8))

In [23]:
len(boxes_np), len(images_np)

(924, 924)

In [7]:
def rasterise(rows, cols, gt, dtype, src):
    driver = gdal.GetDriverByName('MEM')
    mem_raster = driver.Create(
        '',
        cols,
        rows,
        dtype
    )
    mem_raster.SetGeoTransform(gt)
    gdal.RasterizeLayer(mem_raster, [1], src.GetLayer(), burn_values=[2])
    
    return mem_raster
    
gt = img.GetGeoTransform()
nrows = img.RasterYSize
ncols = img.RasterXSize

labels = rasterise(nrows, ncols, gt, gdal.GDT_Byte, geo).ReadAsArray()

In [8]:
gref = gp.Georeference(gt)
gref.world2pix(np.array([[32.5016809, 15.54392933], [32.50181189, 15.54398237]]))

array([[  17.59259259, -104.863     ],
       [  66.10740741, -124.50744445]])

In [9]:
layer = geo.GetLayer()
layer.ResetReading()

boxes = []
rel_boxes = []

for feat in geo.GetLayer():
    geom = feat.GetGeometryRef()
    bbox = np.array(geom.GetEnvelope(), dtype=np.float32) # minX, maxX, minY, maxY
    bbox = gref.world2pix([bbox[::2], bbox[1::2]]).ravel()
    boxes.append((bbox[[1, 0, 3, 2]])) # Assume ymin, xmin, ymax, xmax
    rel_boxes.append((bbox[[1, 0, 3, 2]] / np.array([nrows, ncols, nrows, ncols])))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.imshow(images_np[7])
ax1.axis('off');

# ax2.imshow(labels)
# ax2.axis('off');

# for box in boxes:
#     xy = box[[1, 2]]
#     w = box[3] - box[1]
#     h = box[0] - box[2]
#     rect = patches.Rectangle(
#         xy,
#         w,
#         h,
#         linewidth=1,
#         edgecolor='r',
#         facecolor='none'
#     )
#     ax2.add_patch(rect)

for rbox in boxes_np[7]:
    box = rbox*np.array([nrows, ncols, nrows, ncols])
    xy = box[[3, 0]]
    w = box[1] - box[3]
    h = box[2] - box[0]
    rect = patches.Rectangle(
        xy,
        w,
        h,
        linewidth=1,
        edgecolor='r',
        facecolor='none'
    )
    ax1.add_patch(rect)


In [89]:
len(rel_boxes), len(boxes_np[0])

(20, 49)

In [11]:
images = [(pixels.transpose(1, 2, 0)*255).astype(np.uint8)] # uint8 numpy arrays with shape (img_height, img_width, 3)
bboxes = [np.array(rel_boxes)] # List of relative pixel-coord boxes (n, 4)

In [24]:
images = images_np[:10]
bboxes = boxes_np[:10]

In [37]:
images[0].shape

(650, 650, 3)

In [25]:
print(len(images), len(bboxes))

10 10


In [26]:
for box in bboxes:
    print(box.shape)

(49, 4)
(11, 4)
(16, 4)
(10, 4)
(11, 4)
(29, 4)
(9, 4)
(35, 4)
(20, 4)
(27, 4)


In [27]:
num_classes = 1

category_index = {1: {'id': 1, 'name': 'buildings'}} # Only one class
label_id_offset = 1
train_tensors = []
one_hot_tensors = []
box_tensors = []

train_data = zip(images, bboxes)

for rgb, box in train_data:
    train_tensors.append(
      tf.expand_dims(
          tf.convert_to_tensor(
              rgb,
              dtype=tf.float32),
          axis=0
      )
    )
    box_tensors.append(tf.convert_to_tensor(box, dtype=tf.float32))
    
    zero_indexed_classes = tf.convert_to_tensor(
      np.ones(shape=[box.shape[0]], dtype=np.int32) - label_id_offset
    )
    one_hot_tensors.append(
        tf.one_hot(zero_indexed_classes, num_classes)
    )

In [38]:
tf.keras.backend.clear_session()
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'

configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
      model_config=model_config, is_training=True)

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 650, 650, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)

In [39]:
tf.keras.backend.set_learning_phase(True)

# These parameters can be tuned; since our training set has 5 images
# it doesn't make sense to have a much larger batch size, though we could
# fit more examples in memory if we wanted to.
batch_size = 4
learning_rate = 0.01
num_batches = 100

# Select variables in top layers to fine-tune.
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead'
]

for var in trainable_variables:
    if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
        to_fine_tune.append(var)

# Set up forward + backward pass for a single train step.
def get_model_train_step_function(model, optimizer, vars_to_fine_tune):
    """Get a tf.function for training step."""

    # Use tf.function for a bit of speed.
    # Comment out the tf.function decorator if you want the inside of the
    # function to run eagerly.
    @tf.function
    def train_step_fn(image_tensors,
                    groundtruth_boxes_list,
                    groundtruth_classes_list):
        """A single training iteration.

        Args:
          image_tensors: A list of [1, height, width, 3] Tensor of type tf.float32.
            Note that the height and width can vary across images, as they are
            reshaped within this function to be 640x640.
          groundtruth_boxes_list: A list of Tensors of shape [N_i, 4] with type
            tf.float32 representing groundtruth boxes for each image in the batch.
          groundtruth_classes_list: A list of Tensors of shape [N_i, num_classes]
            with type tf.float32 representing groundtruth boxes for each image in
            the batch.

        Returns:
          A scalar tensor representing the total loss for the input batch.
        """
        shapes = tf.constant(batch_size * [[650, 650, 3]], dtype=tf.int32)
        model.provide_groundtruth(
            groundtruth_boxes_list=groundtruth_boxes_list,
            groundtruth_classes_list=groundtruth_classes_list)
        with tf.GradientTape() as tape:
            preprocessed_images = tf.concat(
              [detection_model.preprocess(image_tensor)[0]
               for image_tensor in image_tensors], axis=0)
            prediction_dict = model.predict(preprocessed_images, shapes)
            losses_dict = model.loss(prediction_dict, shapes)
            total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
            gradients = tape.gradient(total_loss, vars_to_fine_tune)
            optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))
        return total_loss

    return train_step_fn

optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
train_step_fn = get_model_train_step_function(
    detection_model, optimizer, to_fine_tune)

# Fine-tune
for idx in range(num_batches):
    # Grab keys for a random subset of examples
    all_keys = list(range(len(images)))
    random.shuffle(all_keys)
    example_keys = all_keys[:batch_size]

    # Note that we do not do data augmentation in this demo.  If you want a
    # a fun exercise, we recommend experimenting with random horizontal flipping
    # and random cropping :)
    boxes_list = [box_tensors[key] for key in example_keys]
    classes_list = [one_hot_tensors[key] for key in example_keys]
    image_tensors = [train_tensors[key] for key in example_keys]

    # Training step (forward pass + backwards pass)
    total_loss = train_step_fn(image_tensors, boxes_list, classes_list)

    if idx % 10 == 0:
        print('batch ' + str(idx) + ' of ' + str(num_batches)
        + ', loss=' +  str(total_loss.numpy()), flush=True)


batch 0 of 100, loss=63.40186


KeyboardInterrupt: 

In [30]:
def detect(input_tensor):
    """Run detection on an input image.

    Args:
    input_tensor: A [1, height, width, 3] Tensor of type tf.float32.
      Note that height and width can be anything since the image will be
      immediately resized according to the needs of the model within this
      function.

    Returns:
    A dict containing 3 Tensors (`detection_boxes`, `detection_classes`,
      and `detection_scores`).
    """
    preprocessed_image, shapes = detection_model.preprocess(input_tensor)
    prediction_dict = detection_model.predict(preprocessed_image, shapes)
    
    return detection_model.postprocess(prediction_dict, shapes)

detections = detect(image_tensors[0])

In [36]:
detections

{'detection_boxes': <tf.Tensor: shape=(1, 100, 4), dtype=float32, numpy=
 array([[[9.54884112e-01, 5.80020130e-01, 9.72209632e-01, 6.26446903e-01],
         [6.64004505e-01, 8.45370114e-01, 1.00000000e+00, 1.00000000e+00],
         [1.92102671e-01, 7.21002519e-01, 1.00000000e+00, 1.00000000e+00],
         [6.97399616e-01, 0.00000000e+00, 1.00000000e+00, 2.96739846e-01],
         [2.30277419e-01, 0.00000000e+00, 1.00000000e+00, 7.16237545e-01],
         [0.00000000e+00, 4.65444028e-01, 1.00000000e+00, 1.00000000e+00],
         [6.20131016e-01, 7.08733261e-01, 1.00000000e+00, 1.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 7.76874542e-01, 3.44459713e-01],
         [0.00000000e+00, 0.00000000e+00, 1.98714554e-01, 7.36167431e-02],
         [4.29542542e-01, 7.36446977e-01, 1.00000000e+00, 1.00000000e+00],
         [4.00954962e-01, 0.00000000e+00, 1.00000000e+00, 9.97426689e-01],
         [0.00000000e+00, 0.00000000e+00, 1.18120864e-01, 3.70646715e-02],
         [0.00000000e+00, 9

In [33]:
fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.imshow(images[0])
ax1.axis('off');

#ax2.imshow(labels)
ax2.axis('off');

# for box in boxes:
#     xy = box[[1, 2]]
#     w = box[3] - box[1]
#     h = box[0] - box[2]
#     rect = patches.Rectangle(
#         xy,
#         w,
#         h,
#         linewidth=1,
#         edgecolor='r',
#         facecolor='none'
#     )
#     ax2.add_patch(rect)
nrows, ncols = images[0].shape[:-1]
for rbox in detections['detection_boxes'][0].numpy():
    box = rbox*np.array([nrows, ncols, nrows, ncols])
    xy = box[[3, 0]]
    w = box[1] - box[3]
    h = box[2] - box[0]
    rect = patches.Rectangle(
        xy,
        w,
        h,
        linewidth=1,
        edgecolor='r',
        facecolor='none'
    )
    ax2.add_patch(rect)
    


<IPython.core.display.Javascript object>