[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/joconnor-ml/osm-ai-tools/blob/master/examples/custom_power_plants/mistag_classification_tpu.ipynb)

In [None]:
#@title Authenticate, Import, Download Data

from google.colab import auth
auth.authenticate_user()

!pip install -q fsspec gcsfs

import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import math

In [None]:
TFRECORD_PREFIX = "gs://osm-object-detector/data/power_plants/tfrecords/shard-"
IMAGE_SIZE=224

In [None]:
try: # detect TPUs
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.experimental.TPUStrategy(tpu)
  print("Found TPU")
except ValueError: # detect GPUs
  strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
  print("Didn't find TPU")

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
AUTO = tf.data.experimental.AUTOTUNE # used in tf.data.Dataset API

def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
        "label": tf.io.FixedLenFeature([], tf.int64),   # shape [] means scalar
        "bbox_id": tf.io.FixedLenFeature([], tf.int64),   # shape [] means scalar
    }
    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    
    # FixedLenFeature fields are now ready to use: exmple['size']
    # VarLenFeature fields require additional sparse_to_dense decoding
    
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
    image = tf.cast(image, tf.float32)
    
    return {"image": image, "label": example['label'], "bbox_id": example["bbox_id"]}
    
# read from TFRecords. For optimal performance, read from multiple
# TFRecord files at once and set the option experimental_deterministic = False
# to allow order-altering optimizations.

option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False

filenames = tf.io.gfile.glob(TFRECORD_PREFIX + "*.tfrec")
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
ds = ds.with_options(option_no_order)
ds = ds.map(read_tfrecord, num_parallel_calls=AUTO)
ds = ds.shuffle(300)

In [None]:
num_positives = len(filenames) * 512 // 2

In [None]:
# we try to balance the dataset, so num_positives ~= half the data
train_ds = ds.take(num_positives)
val_ds = ds.skip(num_positives)

In [None]:
BATCH_SIZE = 512
def get_model():
    module = tf.keras.models.load_model("gs://osm-object-detector/pretrained_models/resisc_224px_rgb_resnet50")
    module.trainable = True
    module.summary()

    images = tf.keras.layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3))
    features = module(images)
    features = tf.keras.layers.GlobalAveragePooling2D()(features)
    features = tf.keras.layers.Dropout(0.5)(features)
    output = tf.keras.layers.Dense(1, activation="sigmoid")(features)
    model = tf.keras.Model(inputs=images, outputs=output)

    lr = 0.003 * BATCH_SIZE / 512

    # Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=[int(50*BATCH_SIZE/512), int(75*BATCH_SIZE/512), int(100*BATCH_SIZE/512)],
        values=[lr, lr*0.1, lr*0.001, lr*0.0001]
    )
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)

    model.compile(
      optimizer=optimizer,
      # use label smoothing since we know quite a few labels will be wrong
      loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.05),
      metrics=['acc']
    )
    return model


In [None]:
def to_keras(row):
  return row["image"], row["label"]

augmentor = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.experimental.preprocessing.RandomContrast(0.1),
    tf.keras.layers.experimental.preprocessing.Resizing(256,256),
    tf.keras.layers.experimental.preprocessing.RandomTranslation(0.2, 0.2),
    tf.keras.layers.experimental.preprocessing.RandomRotation(2*math.pi),
    tf.keras.layers.experimental.preprocessing.RandomZoom(0.25),
    tf.keras.layers.experimental.preprocessing.CenterCrop(224,224),
])

def augment(image_batch, label_batch):
    return augmentor.call(image_batch), label_batch


In [None]:
with strategy.scope():
  model = get_model()
# note: TPU training requires drop_remainder=True to keep batch sizes constant
model.fit(train_ds.shuffle(500).map(to_keras).batch(BATCH_SIZE, drop_remainder=True).map(augment, num_parallel_calls=AUTO).prefetch(-1), validation_data=val_ds.map(to_keras).batch(BATCH_SIZE).prefetch(-1), epochs=3)

In [None]:
pred_dfs = []
for row in val_ds.batch(BATCH_SIZE).take(-1):
  preds = model.predict(row["image"])
  pred_dfs.append(pd.DataFrame({"pred": preds.flatten(), "label": row["label"].numpy(), "osm_id": row["bbox_id"].numpy()}))
pred_df = pd.concat(pred_dfs).reset_index(drop=True)

In [None]:
def plot_one_object(object_id):
    filename = patches.loc[patches.osm_id==object_id, "image_id"].iloc[0]
    img = tf.io.decode_png(tf.io.read_file(f"gs://osm-object-detector/data/custom_power_plants/images/{filename}.png"))
    bboxes = patches.loc[patches.osm_id==object_id, ["y_min", "x_min", "y_max", "x_max"]].values
    crops = tf.image.crop_and_resize(
        tf.expand_dims(img, axis=0), bboxes, box_indices=tf.zeros_like(bboxes[:, 0], dtype=tf.int32),
        crop_size=[IMAGE_SIZE, IMAGE_SIZE], method='bilinear',
        extrapolation_value=0, name=None
    )
    plt.imshow((crops[0].numpy()).astype(np.uint8))

In [None]:
#@title Plot a few images the model disagrees with: if we succeeded, this should be primarily mistagged OSM data
for i, row in pred_df.query("label==1").nsmallest(5, "pred").iterrows():
  plot_one_object(row.osm_id)
  plt.title(f"{row.pred:.3f}, {row.label}, {row.osm_id}")
  plt.show()

In [None]:
#@title Complete the cross-val loop: train on the second half of objects, predict on the first
with strategy.scope():
  model = get_model()
model.fit(val_ds.shuffle(500).map(to_keras).batch(BATCH_SIZE, drop_remainder=True).map(augment, num_parallel_calls=AUTO).prefetch(-1), validation_data=train_ds.map(to_keras).batch(BATCH_SIZE).prefetch(-1), epochs=3)

In [None]:
pred_dfs = []
for row in val_ds.batch(BATCH_SIZE).take(-1):
  preds = model.predict(row["image"])
  pred_dfs.append(pd.DataFrame({"pred": preds.flatten(), "label": row["label"].numpy(), "osm_id": row["bbox_id"].numpy()}))
pred_df2 = pd.concat(pred_dfs).reset_index(drop=True)

In [None]:
df = pd.concat([
  pred_df.query("label==1"),
  pred_df2.query("label==1")
])
df["mislabel_score"] = 1 - df["pred"]
df.to_csv("gs://osm-object-detector/data/custom_power_plants/mislabel_scores.csv")

In [None]:
#@title Hand-label surprising objects

!pip install -q pigeon-jupyter

from pigeon import annotate

def plot_and_show(x):
    plt.figure(figsize=(8,8))
    plot_one_object(x)
    plt.show()

annotations = annotate(
  df.query("pred<0.25").osm_id,
  options=['correct', 'mistagged'],
  display_fn=plot_and_show
)

In [None]:
annotations = pd.DataFrame(annotations, columns=["osm_id", "status"])
annotations.to_csv("gs://osm-object-detector/data/custom_power_plants/hand_labels.csv")