In [None]:
# Local paths
dir_pfx = './'
data_dir = dir_pfx + '../../data/Vehicules1024/'
# Add the local_modules directory to the set of paths
# Python uses to look for imports.
import sys
sys.path.append(dir_pfx)
sys.path.append('../')
sys.path.append('../../')

In [None]:
!pip install --user imageio

In [None]:
#!pip install -q git+https://www.github.com/keras-team/keras-contrib.git
#!pip install --user scikit-image
#!pip install --user imageio
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from importlib import reload
import utils
import DenseSRGAN


In [None]:
num_images = 50
files = utils.scan_dataset(data_dir, num_images) #TODO: 50
training_set, testing_set = utils.create_subsets(files, data_dir, use_validation=False)
im_hr, im_lr, batch_idx = utils.load_data(0, training_set, data_dir, True, len(training_set))


In [None]:
ix = 2560
normalized_around_zero = False

if normalized_around_zero:
    plt.figure().suptitle('RGB+Infra', fontsize=20)
    plt.subplot(1,2,1)
    plt.imshow(utils.un_normalize(im_hr[ix,:,:,:]))
    plt.subplot(1,2,2)
    plt.imshow(utils.un_normalize(im_lr[ix,:,:,:]))

    plt.figure().suptitle('RGB', fontsize=20)
    plt.subplot(1,2,1)
    plt.imshow(utils.un_normalize(im_hr[ix,:,:,0:3]))
    plt.subplot(1,2,2)
    plt.imshow(utils.un_normalize(im_lr[ix,:,:,0:3]))
else:
    plt.figure().suptitle('RGB+Infra', fontsize=20)
    plt.subplot(1,2,1)
    plt.imshow(im_hr[ix,:,:,:])
    plt.subplot(1,2,2)
    plt.imshow(im_lr[ix,:,:,:])

    plt.figure().suptitle('RGB', fontsize=20)
    plt.subplot(1,2,1)
    plt.imshow(im_hr[ix,:,:,0:3])
    plt.subplot(1,2,2)
    plt.imshow(im_lr[ix,:,:,0:3])

## Load GAN Model for Training

In [None]:
gan = DenseSRGAN.DenseSRGAN(dir_pfx,                          # Working directory
                            im_hr,im_lr,                      # High Res / Low Res Images
                            proj_pfx="OH",gpu_list=[1,3,5,7], # Prefix for saving
                            dropout_rate=0.3,                 # Droupout Rate
                            num_epochs_trained=0)             # Number of epochs if 
                                                              # continuing training


## Train the model 

In [None]:
gan.train(epochs=1000,       # Number of epochs
          verbose=False,     # Print time for each step
          bench_idx=2560,    # Index of image for benchmark
          batch_size=16,     # Minibatch size
          save_interval=10,  # Interval epochs to save weights/benchmark
          view_interval=2)   # Interval epochs to print loss

## Load a Pretrained Model with weights

In [None]:
gan = DenseSRGAN.DenseSRGAN(dir_pfx,                             # Working directory
                            im_hr,im_lr,                         # High Res / Low Res Images
                            proj_pfx="OH",gpu_list=None,         # Prefix for saving
                            dropout_rate=0.3,                    # Droupout Rate
                            weights_path=dir_pfx + 'weights/OH/' # Load weights to resume training or predict
                            num_epochs_trained=0)                # Number of epochs if 
                                                                 # continuing training

## Show some predictions vs actual patches

In [None]:
ix = 10
ix2 = 11

img = gan.gen.predict(im_lr[ix:ix+1,:,:,:]).squeeze()
img = (img + 1)/2
plt.figure().suptitle('RGB+Infra', fontsize=20)
plt.subplot(1,2,1)
plt.imshow(im_hr[ix,:,:,:])
plt.subplot(1,2,2)
plt.imshow(img)

img = gan.gen.predict(im_lr[ix2:ix2+1,:,:,:]).squeeze()
img = (img + 1)/2
plt.figure().suptitle('RGB+Infra', fontsize=20)
plt.subplot(1,2,1)
plt.imshow(im_hr[ix2,:,:,:])
plt.subplot(1,2,2)
plt.imshow(img)

## Show the loss of the discriminator on some real/generated images

In [None]:
floss = []
tloss = []

for i in np.random.randint(low=0,high=len(im_lr),size=(500)):
  tloss.append(gan.disc.predict(im_hr[i:i+1,:,:,:]).squeeze())
  floss.append(gan.disc.predict(gan.gen.predict(im_lr[i:i+1,:,:,:])).squeeze())

  
print('Mean Loss Real Images: {0}'.format(np.mean(tloss)))
print('Mean Loss False Images: {0}'.format(np.mean(floss)))


plt.figure()
plt.plot(list(range(len(tloss))),tloss)
plt.plot(list(range(len(floss))),floss)
plt.legend(['Loss Real Images','Loss Generated Images'])
plt.xlabel('Random Sample Number')
plt.ylabel('Cross Entropy Loss')
plt.show