In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
from pathlib import Path
from began import build_discriminator, build_generator, build_adversarial_model, training_schedule

from IPython import display

In [2]:
data_dir = Path("../data/preprocessed")
training_data = np.load(data_dir / "GNILC_dust_map.npy")
MODEL_DIR = Path("../models")

In [3]:
training_data.shape

(1033, 256, 256, 1)

In [4]:
TRAIN_BUFF = 1033
BATCH_SIZE = 32

In [5]:
train_images = tf.data.Dataset.from_tensor_slices(training_data).shuffle(TRAIN_BUFF).batch(BATCH_SIZE)

In [6]:
# Network architecture
DEPTH = 32
IMG_DIM = 256
CHANNELS = 1
KERNELS = [5, 5, 5]
STRIDES = [2, 2, 2]
FILTERS = [DEPTH * 2 ** i for i in range(len(KERNELS))]
LATENT_DIM = 64

In [7]:
# Derived parameters
SHAPE = (IMG_DIM, IMG_DIM, CHANNELS)

In [8]:
# Training parameters
BATCH_SIZE = 32

In [9]:
# Build inidividual and joint models.
DIS = build_discriminator(FILTERS, KERNELS, STRIDES, SHAPE)
GEN = build_generator(DIS, FILTERS, KERNELS, STRIDES, LATENT_DIM, SHAPE)
ADV = build_adversarial_model(DIS, GEN)
print(GEN.summary())

Model: "Generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Dense_G (Dense)              (None, 131072)            8519680   
_________________________________________________________________
Reshape (Reshape)            (None, 32, 32, 128)       0         
_________________________________________________________________
BNorm_G1 (BatchNormalization (None, 32, 32, 128)       512       
_________________________________________________________________
LRelu_G1 (LeakyReLU)         (None, 32, 32, 128)       0         
_________________________________________________________________
UpSample_1 (UpSampling2D)    (None, 64, 64, 128)       0         
_________________________________________________________________
Conv2D_G1 (Conv2D)           (None, 64, 64, 64)        204864    
_________________________________________________________________
BN_G2 (BatchNormalization)   (None, 64, 64, 64)        25

In [10]:
trained_model = training_schedule(DIS, GEN, ADV, train_images,
        LATENT_DIM)

32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
9


In [11]:
trained_model.save(str(MODEL_DIR / "dust_dcgan.h5"))