In [None]:
%%capture
# Remember to set runtime to GPU acceleration

# Mount files
from google.colab import drive
drive.mount('/content/drive')

# Set up Kaggle
!pip uninstall -y kaggle
!pip install --upgrade pip
!pip install kaggle==1.5.6
!mkdir ~/.kaggle

import json
token = {"username":"neilgoecknerwald","key":"82411b328e32a9330e81f96a6eefe6ac"}
with open('/root/.kaggle/kaggle.json', 'w') as file:
    json.dump(token, file)

!chmod 600 /root/.kaggle/kaggle.json

In [None]:
# Download files
!kaggle competitions download -c tensorflow-great-barrier-reef

In [None]:
%%capture
!unzip tensorflow-great-barrier-reef.zip
!rm tensorflow-great-barrier-reef.zip

In [None]:
# Clone and pull in python files
!git config --global user.email "ngoecknerwald@gmail.com"
!git config --global user.name "Neil Goeckner-Wald"

!git clone https://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@github.com/ngoecknerwald/tensorflow-experiment.git
!rsync tensorflow-experiment/great-barrier-reef/*.py .
!ls

In [None]:
# Boilerplate
import sys
import numpy as np
import tensorflow as tf
from matplotlib.pyplot import figure, imshow, gca, tight_layout, show
from matplotlib import patches
from importlib import reload

if 'roi_utils' not in sys.modules:
    import backbone
    import classification
    import data_utils
    import faster_rcnn
    import rpn
    import roi_utils
    import geometry
else:
    reload(backbone)
    reload(classification)
    reload(data_utils)
    reload(faster_rcnn)
    reload(rpn)
    reload(roi_utils)
    reload(geometry)

In [None]:
# List physical devices
is_colab = len(tf.config.list_physical_devices('GPU')) > 0

# Data locations
if is_colab:
    datapath='/content'
    backbone_weights='drive/MyDrive/trained_inception.ckpt'
    rpn_weights='drive/MyDrive/trained_rpn.ckpt'
else:
    datapath='tensorflow-great-barrier-reef'
    backbone_weights='trained_inception.ckpt'
    rpn_weights='trained_rpn.ckpt'

In [None]:
# Instantiate the high-level wrapper
frcnn = faster_rcnn.FasterRCNNWrapper(
    input_shape=(720, 1280, 3),
    datapath=datapath,
    backbone_type='InceptionResNet-V2',
    backbone_weights=backbone_weights,
    rpn_weights=rpn_weights,
    rpn_kwargs={},
)

In [None]:
# Save network states if requested

frcnn.backbone.save_backbone('drive/MyDrive/trained_inception.ckpt')
frcnn.rpnwrapper.save_rpn_state('drive/MyDrive/trained_rpn.ckpt')

In [None]:
# Test that the RPN components work as expected

number_boxes = 30
number_prune = 10
all_decoded = []
validation = frcnn.data_loader_full.get_validation().__iter__()

# Find minibatch with a positive example
while all([len(decoded) == 0 for decoded in all_decoded]):
    images, labels = validation.next()
    all_decoded = [frcnn.data_loader_full.decode_label(label) for label in labels]

# RPN in forward mode
coords = frcnn.rpnwrapper.propose_regions(images, top=number_boxes, image_coords=True)
coords_clip = roi_utils.clip_RoI(coords, (720, 1280), (112, 112))
coords_prune = roi_utils.IoU_supression(coords_clip, n_regions=number_prune)

# Plot everything up
for i, decoded in enumerate(all_decoded):

    figure(figsize=(16, 9))
    imshow(images[i, :, :, :].numpy() / 255.0)

    # Draw the ground truth
    for annotation in decoded:
        rect = patches.Rectangle(
            (annotation["x"], annotation["y"]),
            annotation["width"],
            annotation["height"],
            linewidth=4,
            edgecolor="y",
            facecolor="none",
        )
        gca().add_patch(rect)

    # Draw the proposals
    for j in range(number_boxes):
        rect = patches.Rectangle(
            (coords[i, j, 0], coords[i, j, 1]),
            coords[i, j, 2],
            coords[i, j, 3],
            linewidth=np.maximum((number_boxes - j) / 6, 1),
            edgecolor="r",
            facecolor="none",
        )
        gca().add_patch(rect)

    # Draw the proposals
    for j in range(number_prune):
        rect = patches.Rectangle(
            (coords_prune[i, j, 0], coords_prune[i, j, 1]),
            coords_prune[i, j, 2],
            coords_prune[i, j, 3],
            linewidth=np.maximum((number_prune - j) / 2, 1),
            edgecolor="g",
            facecolor="none",
        )
        gca().add_patch(rect)
        
    # Plot it up
    gca().grid("True")
    tight_layout()

show()