In [1]:
## Global params
SURVEY_NAME = 'DHS'
SATELLITE = 's2'

if SATELLITE == 's2':
    IMAGE_SIZE = [224, 224]
    NUM_GROUPS = 5

if SATELLITE == 'l8':
    IMAGE_SIZE = [167, 167]
    NUM_GROUPS = 3

EPOCHS = 10
BATCH_SIZE = 32 # 16
VALIDATION_SPLIT = 0.2

In [2]:
import os, sys, math
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import random
from skimage import exposure
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import tensorflow as tf
print("Tensorflow version " + tf.__version__)
AUTOTUNE = tf.data.AUTOTUNE

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.models import load_model, Model

Tensorflow version 2.4.0


In [3]:
# Get actual values function
# https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset

#### NTL Group
def decode_fn_ntl_group(record_bytes):
    return tf.io.parse_single_example(
      # Data
      record_bytes,

      # Schema
      {"viirs_ntl_group": tf.io.FixedLenFeature([], dtype=tf.int64)}
    )

def extract_ntl_group(TF_FILES):
    actual_values = []
    for batch in tf.data.TFRecordDataset([TF_FILES]).map(decode_fn_ntl_group):
        value = batch['viirs_ntl_group'].numpy()
        actual_values.append(value)

    return actual_values

#### UID
def decode_fn_uid(record_bytes):
    return tf.io.parse_single_example(
      # Data
      record_bytes,

      # Schema
      {"uid": tf.io.FixedLenFeature([], dtype=tf.string)}
    )

def extract_uid(TF_FILES):
    actual_values = []
    for batch in tf.data.TFRecordDataset([TF_FILES]).map(decode_fn_uid):
        value = batch['uid'].numpy()
        actual_values.append(value)

    return actual_values

In [4]:
def dataset_to_numpy_util(dataset, N, process_image = True):
  dataset = dataset.batch(N)
  
  for images, labels in dataset:
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()

    if process_image:
      p2, p98 = np.percentile(numpy_images, (2,98))
      numpy_images = exposure.rescale_intensity(numpy_images, in_range=(p2, p98)) 
    break;

  return numpy_images, numpy_labels

def display_one_image(image, title, subplot, red=False):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image)
    plt.title(title, fontsize=16, color='red' if red else 'black')
    return subplot+1

def display_9_images_from_dataset(dataset):
  subplot=331
  plt.figure(figsize=(13,13))
  images, labels = dataset_to_numpy_util(dataset, 9)
  for i, image in enumerate(images):
    title = labels[i] # CLASSES[labels[i]]
    subplot = display_one_image(image, title, subplot)
    if i >= 8:
      break;
              
  #plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()

def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    #plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])