<a href="https://colab.research.google.com/github/convergencelab/LSHT-HSLT-MODIS-Landsat-Fusion/blob/master/VGG_19.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The SR-GAN uses the 2nd layer of the VGG-19 to include feature detection
in the perceptual loss function.
--rather than using a model pretrained on image net, it may be more useful to use a pre-trained model, trained on
  data more similar to that of the scenes we are using for landsat-modis super resolution

  -> idea 1) train a binary classifier to differentiate landsat from modis: this does not really achieve the goal
  of deriving meaningful features from the image. The major difference between landsat and modis is the resolution
  so this sort of classifier would likely produce a model that distinguishes high res from low res.
  -> idea 2) explore different landcover/other feature classification approaches on both landsat and modis images:
          a) train both and then average weights
          b) scale up modis and train on same model ( may cause too much variance between scenes )


In [20]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os
from tqdm import tqdm
import tensorflow_datasets as tfds
_CITATION = """
    @misc{helber2017eurosat,
    title={EuroSAT: A Novel Dataset and Deep Learning Benchmark for Land Use and Land Cover Classification},
    author={Patrick Helber and Benjamin Bischke and Andreas Dengel and Damian Borth},
    year={2017},
    eprint={1709.00029},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}"""

In [21]:
### Hyperparameters ###
batch_size = 5
EPOCHS = 1000

In [48]:
### get data ###
"""
using eurosat dataset, this dataset uses the sentenial-2 collected satellite images
"""
# load train
data, info = tfds.load('eurosat', split="train", with_info=True)


ds_size = info.splits["train"].num_examples
num_features = info.features["label"].num_classes
train_data = data.batch(batch_size).repeat(EPOCHS)
train_data = tfds.as_numpy(train_data)




In [49]:

### initialize model ###
vgg = tf.keras.applications.VGG19(
                            include_top=True,
                            weights=None,
                            input_tensor=None,
                            input_shape=[224, 224, 3],
                            pooling=None,
                            classes=1000,
                            classifier_activation="softmax"
                        )

### loss function ###
"""
Use MSE loss:
  
    ref -> "https://towardsdatascience.com/loss-functions-based-on-feature-activation-and-style-loss-2f0b72fd32a9"
"""

m_loss = tf.keras.losses.MSE

### adam optimizer for SGD ###
optimizer = tf.keras.optimizers.Adam()


In [50]:
### intialize metrics ###
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_vgg-19_acc')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_vgg-19_acc')



In [62]:
### train step ###
@tf.function
def train_step(sample, label):
  with tf.GradientTape() as tape:
    # preprocess for vgg-19
    sample = tf.image.resize(sample, (224, 224))
    sample = tf.keras.applications.vgg19.preprocess_input(sample)

    predictions = vgg(sample, training=True)
    # mean squared error in prediction
    loss = tf.keras.losses.MSE(label, predictions)

  # apply gradients
  gradients = tape.gradient(loss, vgg.trainable_variables)
  optimizer.apply_gradients(zip(gradients, vgg.trainable_variables))

  # update metrics
  train_loss(loss)
  train_accuracy(y_pred=predictions, y_true=label)

### generator test step ###
@tf.function
def test_step(idx, sample, label):
  # preprocess for vgg-19
  sample = tf.image.resize(sample, (224, 224))
  sample = tf.keras.applications.vgg19.preprocess_input(sample)
  # feed test sample in
  predictions = vgg.predict(sample, training=False)
  t_loss = tf.keras.losses.MSE(label, predictions)

  # update metrics
  test_loss(t_loss)
  test_accuracy(label, predictions)

### Weights Dir ###
if not os.path.isdir('./checkpoints'):
    os.mkdir('./checkpoints')


In [None]:
### TRAIN ###
NUM_CHECKPOINTS_DIV = int(EPOCHS/4)
save_c = 1

for epoch in range(EPOCHS):

    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    ds_size
    # train step
    for idx in tqdm(range(ds_size // batch_size)):
        batch = next(train_data)
      
        for image, label in zip(batch['image'], batch['label']):
          image = np.array(image)[np.newaxis, ...]
          label = np.array(label)[np.newaxis, ...]
          train_step(image, label)

        # test step
        # for sample, label in zip(batch[0], batch[1]):
         #   sample = np.array(sample)[np.newaxis, ...]
         #   label = np.array(label)[np.newaxis, ...]

         #   test_step(idx, sample, label)
    
    ### save weights ###
    if not epoch % NUM_CHECKPOINTS_DIV:
        vgg.save_weights('./checkpoints/my_checkpoint_{}'.format(save_c))
        save_c += 1
    if not epoch % 100:
        ### outputs every 100 epochs so .out file from slurm is not huge. ###
        template = 'Training VGG-19:\nEpoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
        print(template.format(epoch + 1,
                              train_loss.result(),
                              train_accuracy.result() * 100,
                              test_loss.result(),
                              test_accuracy.result() * 100))


  0%|          | 7/5400 [01:40<21:32:58, 14.39s/it]