In [None]:
%%capture
# Remember to set runtime to GPU acceleration

# Mount files
from google.colab import drive
drive.mount('/content/drive')

# Set up Kaggle
!pip uninstall -y kaggle
!pip install --upgrade pip
!pip install kaggle==1.5.6
!mkdir ~/.kaggle

# tfa
!pip install tensorflow-addons

import json
token = {"username":"neilgoecknerwald","key":"KEY"}
with open('/root/.kaggle/kaggle.json', 'w') as file:
    json.dump(token, file)

!chmod 600 /root/.kaggle/kaggle.json

In [None]:
# Download files

!kaggle competitions download -c tensorflow-great-barrier-reef

In [None]:
%%capture
!unzip tensorflow-great-barrier-reef.zip
!rm tensorflow-great-barrier-reef.zip

In [None]:
# Clone and pull in python files

!git config --global user.email "ngoecknerwald@gmail.com"
!git config --global user.name "Neil Goeckner-Wald"

!git clone https://git@github.com/ngoecknerwald/starfish-perception-telescope.git

In [None]:
###########
#
# File and environment setup
#
###########

!rsync starfish-perception-telescope/great-barrier-reef/*.py .
!ls

In [None]:
# Global behavior
is_colab = True

# Output version
version='production_v0'

# Data locations
if is_colab:
    datapath='/content'
else:
    datapath='tensorflow-great-barrier-reef'

# File locations

# Baseline trained models
backbone_weights='trained_backbone.ckpt'
rpn_weights='trained_rpn.ckpt'
class_weights='trained_classifier.ckpt'

# Fine tuned models
backbone_tuned='tuned_backbone.ckpt'
rpn_tuned='tuned_rpn.ckpt'
class_tuned='tuned_classifier.ckpt'

In [None]:
# Debug status
# 0 : Run with ~ no validation set, maximal data for training
# 1 : Run with 0.2 validation split for diagnostics, default
# 2 : Run with 1% of data to debug python

debug = 1

In [None]:
# Usual imports

import sys, os
import numpy as np
from matplotlib.pyplot import figure, imshow, gca, tight_layout, show
from matplotlib import patches
from importlib import reload
import tensorflow as tf

if 'roi_pooling' not in sys.modules:
    import backbone
    import classifier
    import data_utils
    import faster_rcnn
    import rpn
    import roi_pooling
    import geometry
    import evaluation
    import callback
    import jointmodel
else:
    reload(backbone)
    reload(classifier)
    reload(data_utils)
    reload(faster_rcnn)
    reload(rpn)
    reload(roi_pooling)
    reload(geometry)
    reload(evaluation)
    reload(callback)
    reload(jointmodel)

In [None]:
###########
#
# Base training
#
###########

# Train only the backbone

frcnn = faster_rcnn.FasterRCNNWrapper(
    datapath=datapath,
    backbone_type='ResNet50',
    backbone_weights='finetune',
    rpn_weights='skip',
    classifier_weights='skip',
    debug=debug
)

In [None]:
# Save and update the baseline backbone

frcnn.backbone.save_backbone(backbone_weights)

if is_colab:
    os.system('rsync -rv trained_backbone.ckpt drive/MyDrive/%s/'%version)

In [None]:
# Train only the RPN

frcnn = faster_rcnn.FasterRCNNWrapper(
    datapath=datapath,
    backbone_type='ResNet50',
    backbone_weights= backbone_weights,
    rpn_weights= 'train',
    classifier_weights='skip',
    debug=debug
)

In [None]:
# Save and update the baseline RPN

frcnn.rpnwrapper.save_rpn_state(rpn_weights)

if is_colab:
    os.system('rsync -rv trained_rpn.ckpt drive/MyDrive/%s/'%version)

In [None]:
# Train only the classifier

frcnn = faster_rcnn.FasterRCNNWrapper(
    datapath=datapath,
    backbone_type='ResNet50',
    backbone_weights=backbone_weights,
    rpn_weights=rpn_weights,
    classifier_weights= 'train',
    debug=debug
)

In [None]:
# Save the classifier

frcnn.classmodel.save_classifier_state(class_weights)

if is_colab:
    os.system('rsync -rv trained_classifier.ckpt drive/MyDrive/%s/'%version)

In [None]:
###########
#
# Fine tuning loop
#
###########

frcnn = faster_rcnn.FasterRCNNWrapper(
    datapath=datapath,
    backbone_type='ResNet50',
    backbone_weights=backbone_weights,
    rpn_weights=rpn_weights,
    classifier_weights= class_weights,
    debug=debug
)

In [None]:
# Do additional passes of fine tuning if requested

fine_tuning_passes = 4
epochs_per_pass = 2

for _ in range(fine_tuning_passes):

    # Run fine tuning
    frcnn.do_fine_tuning(epochs_per_pass)

    # Save weights
    frcnn.backbone.save_backbone(backbone_tuned)
    frcnn.rpnwrapper.save_rpn_state(rpn_tuned)
    frcnn.classmodel.save_classifier_state(class_tuned)

    if is_colab:
        os.system('rsync -rv %s drive/MyDrive/%s/'%(backbone_tuned, version))
        os.system('rsync -rv %s drive/MyDrive/%s/'%(rpn_tuned, version))
        os.system('rsync -rv %s drive/MyDrive/%s/'%(class_tuned, version))

In [None]:
###########
#
# Demonstration phase
#
###########

# Get a pointer to the validation set

validation = frcnn.data_loader_full.get_validation().__iter__()

In [None]:
# Find minibatch with a positive example

all_decoded = [[0.],]
while all([tf.reduce_sum(decoded) < 1. for decoded in all_decoded]):
    images, labels = validation.next()
    all_decoded = [frcnn.data_loader_full.decode_label(label) for label in labels]

In [None]:
# Generate intermediate data products for diagnostics.

# This cell makes the RPN outputs as well as the pooled RoI
# in image coordinates to be added to the plot.

roi_unpool = frcnn.rpnwrapper.propose_regions(images, input_images=True, output_images=True)

# Next show the pooled RoI. Note that we have to do this in feature space
# because that is what the pooling class understands
features=frcnn.backbone.extractor(images) 
regions = frcnn.rpnwrapper.propose_regions(features, input_images=False, output_images=False)
_, roi_pool = frcnn.RoI_pool((features, regions))

# Convert these pooled RoI back to image space for plotting and diagnostic purposes
roi_numpy= roi_pool.numpy().astype('float32')
roi_numpy[:,:,1], roi_numpy[:,:,0]  = frcnn.backbone.feature_coords_to_image_coords(
    roi_numpy[:,:,1], roi_numpy[:,:,0]
)
roi_numpy[:,:,3], roi_numpy[:,:,2]  = frcnn.backbone.feature_coords_to_image_coords(
    roi_numpy[:,:,3], roi_numpy[:,:,2]
)
roi_pool_image = tf.convert_to_tensor(roi_numpy)

# Convert the labels to tensor
all_decoded = tf.convert_to_tensor(all_decoded)

In [None]:
# Now run the RCNNN in prediction mode

predictions = frcnn.predict(images, return_mode='dict')
for i, prediction in enumerate(predictions):
    print('Predictions for image %d'%i)
    for annotation in prediction:
        print(annotation)

In [None]:
# Plot everything up

for i in range(all_decoded.shape[0]):

    figure(figsize=(16, 9))
    imshow(images[i, :, :, :].numpy() / 255.0)

    # Draw the ground truth
    for annotation in all_decoded[i]:
        rect = patches.Rectangle(
            (annotation[0], annotation[1]),
            annotation[2],
            annotation[3],
            linewidth=4,
            edgecolor="y",
            facecolor="none",
        )
        gca().add_patch(rect)

    # Draw the RPN outputs
    for j in range(roi_unpool.shape[1]): #roi_unpool.shape[1]):
        rect = patches.Rectangle(
            (roi_unpool[i, j, 0], roi_unpool[i, j, 1]),
            roi_unpool[i, j, 2],
            roi_unpool[i, j, 3],
            linewidth=4 * ((roi_unpool.shape[1] -j) / roi_unpool.shape[1])+1,
            edgecolor="g",
            facecolor="none",
            linestyle=':'
        )
        gca().add_patch(rect)
    
    # Draw the IoU suppressed and pooled areas
    for j in range(roi_pool.shape[1]):
        rect = patches.Rectangle(
            (roi_pool_image[i, j, 0], roi_pool_image[i, j, 1]),
            roi_pool_image[i, j, 2],
            roi_pool_image[i, j, 3],
            linewidth=4 * ((roi_pool.shape[1] - j) /roi_pool.shape[1]) + 1,
            edgecolor="r",
            facecolor="none",
            linestyle=':'
        )
        gca().add_patch(rect)   
    
    # Draw the final classifier outputs
    for annotation in predictions[i]:
        rect = patches.Rectangle(
            (annotation['x'],
            annotation['y']),
            annotation['width'],
            annotation['height'],
            linewidth = 4 * annotation['score'] + 1,
            edgecolor='k',
            facecolor='none',
            linestyle='--',
        )
        gca().add_patch(rect)   

    
    # Plot it up
    gca().grid("True")
    tight_layout()