In [1]:
import os
import time
import tfplot
import warnings
import numpy as np
import tensorflow as tf

from tensorflow import keras as k
from tensorflow import layers as ly

from tqdm import tqdm_notebook as pbar

from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec

In [2]:
plt.style.use('seaborn')
warnings.filterwarnings('ignore')

In [3]:
def plot_spectrum(spec):
    fig, ax = tfplot.subplots(figsize=(4, 3))
    im = ax.plot(np.arange(1180, 1280, 0.25), spec)
    return fig

In [4]:
def get_total_parameters():
    total_parameters = 0
    for variable in tf.trainable_variables():
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    return total_parameters

In [5]:
def residual_block(x):
    
    y = ly.conv1d(
        inputs=x,
        filters=64,
        kernel_size=3,
        strides=1,
        padding='same',
        kernel_initializer=tf.random_normal_initializer(stddev=0.02),
        activation=tf.identity
        )
    y = k.layers.PReLU()(y)
    y = ly.conv1d(
        inputs=y,
        filters=64,
        kernel_size=3,
        strides=1,
        padding='same',
        kernel_initializer=tf.random_normal_initializer(stddev=0.02),
        activation=tf.identity
        )
    
    return tf.add(x, y)

In [6]:
def subpixel_convolution(x, block_size):
    
    y = ly.conv1d(
    inputs=x,
    filters=256,
    kernel_size=3,
    strides=1,
    padding='same',
    kernel_initializer=tf.random_normal_initializer(stddev=0.02),
    activation=tf.identity
    )
    y = tf.depth_to_space(y, block_size=block_size)
    
    return k.layers.PReLU()(y)

In [7]:
def convolution_block(x, filters, stride):
    
    y = ly.conv1d(
    inputs=x,
    filters=filters,
    kernel_size=3,
    strides=stride,
    padding='same',
    kernel_initializer=tf.random_normal_initializer(stddev=0.02),
    activation=tf.identity
    )
    y = ly.batch_normalization(y)
    
    return k.layers.LeakyReLU(alpha=0.2)(y)

In [8]:
class ReInitDataSampler:
    def __init__(self, train_filepath, batch_size, valid_filepath=None, test_filepath=None, shuffle=True):
        self.batch_size = batch_size
        self.shuffle = shuffle

        train_files = [os.path.join(train_filepath, file) for file in os.listdir(train_filepath) if file.endswith('.tfrecords')]
        train_dataset = self.make_dataset(train_files)
        self.iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
        self.train_init_op = self.iter.make_initializer(train_dataset)

        if valid_filepath is not None:
            valid_files = [os.path.join(valid_filepath, file) for file in os.listdir(valid_filepath) if file.endswith('.tfrecords')]
            valid_dataset = self.make_dataset(valid_files)
            self.valid_init_op = self.iter.make_initializer(valid_dataset)

        if test_filepath is not None:
            test_files = [os.path.join(test_filepath, file) for file in os.listdir(test_filepath) if file.endswith('.tfrecords')]
            test_dataset = self.make_test_dataset(test_files)
            self.test_init_op = self.iter.make_initializer(test_dataset)
    
    def make_dataset(self, files):
        dataset = tf.data.TFRecordDataset(files).map(self.decoder)
        
        if self.shuffle:
            dataset = dataset.shuffle(buffer_size=10000)
        
        return dataset.repeat().batch(self.batch_size)
    
    def make_test_dataset(self, files):
        dataset = tf.data.TFRecordDataset(files).map(self.decoder)
        return dataset.batch(self.batch_size)
    
    def initialize(self, dataset='train'):
        if dataset == 'train':
            return self.train_init_op
        elif dataset == 'valid':
            return self.valid_init_op
        elif dataset == 'test':
            return self.test_init_op
        else:
            raise ValueError('Dataset unknown or unavailable.')

    def decoder(self, example_proto):
        keys_to_features = {'latent' : tf.FixedLenFeature(4000, tf.float32),
                            'target' : tf.FixedLenFeature(400, tf.float32),
                            'metadata' : tf.FixedLenFeature(4, tf.float32)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return parsed_features['latent'], parsed_features['target'], parsed_features['metadata']

    def get_batch(self):
        x, y, z = self.iter.get_next()
        x = tf.reshape(x, [-1, 4000, 1])
        y = tf.reshape(y, [-1, 400, 1])
        z = tf.reshape(z, [-1, 4])
        return x, y, z

In [9]:
class Discriminator:
    def __init__(self, name='discriminator'):
        self.name = name

    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as vs:
            if reuse:
                vs.reuse_variables()
            
            y = ly.conv1d(
                inputs=x,
                filters=64,
                kernel_size=3,
                strides=1,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            y = k.layers.LeakyReLU(alpha=0.2)(y)
            
            y = convolution_block(y, filters=64, stride=2)
            y = convolution_block(y, filters=128, stride=1)
            y = convolution_block(y, filters=128, stride=2)
            y = convolution_block(y, filters=256, stride=1)
            y = convolution_block(y, filters=256, stride=2)
            y = convolution_block(y, filters=512, stride=1)
            y = convolution_block(y, filters=512, stride=2)
            
            y = ly.Flatten()(y)
    
            y = ly.dense(
                inputs=y,
                units=512,
                activation=tf.identity,
                )
            y = k.layers.LeakyReLU(alpha=0.2)(y)
            y = ly.dense(
                inputs=y,
                units=1,
                activation=tf.identity,
                )
            
            return y
            
    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]

In [10]:
class Generator:
    def __init__(self, num_blocks=8, name='generator'):
        self.num_blocks = num_blocks
        self.name = name

    def __call__(self, x):
        with tf.variable_scope(self.name) as vs:

            y = ly.conv1d(
                inputs=x,
                filters=64,
                kernel_size=9,
                strides=1,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            y = k.layers.PReLU()(y)
            y_ = tf.identity(y)

            for i in range(self.num_blocks):
                y = residual_block(y)
            
            y = tf.add(y, y_)
            
            y = ly.conv1d(
                inputs=y,
                filters=64,
                kernel_size=5,
                strides=4,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            
            y = ly.conv1d(
                inputs=y,
                filters=64,
                kernel_size=5,
                strides=4,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            
            y = ly.conv1d(
                inputs=y,
                filters=64,
                kernel_size=3,
                strides=2,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )

            y = ly.Flatten()(y)
            
            y = ly.dense(
                inputs=y,
                units=512,
                activation=tf.identity,
                )
            y = k.layers.PReLU()(y)
            y = ly.dense(
                inputs=y,
                units=400,
                activation=tf.tanh,
                )
            
            y = tf.reshape(y, [-1, 400, 1])

            return y

    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]

In [11]:
class GAN:
    def __init__(self, generator, discriminator, data_sampler, logdir, ckptdir):
        self.g = generator
        self.d = discriminator
        self.ds = data_sampler
        self.logdir = logdir
        self.ckptdir = ckptdir

        self.latent, self.target, self.metadata = data_sampler.get_batch()

        self.target_ = self.g(self.latent)

        logits_real = self.d(self.target)
        logits_fake = self.d(self.target_, reuse=True)

        d_loss_real = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_real), logits_real)
        d_loss_fake = tf.losses.sigmoid_cross_entropy(tf.zeros_like(logits_fake), logits_fake)
        self.d_loss = d_loss_real + d_loss_fake

        self.adv_loss = 1e-2*tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_fake), logits_fake)
        self.mse_loss = tf.losses.mean_squared_error(self.target_, self.target)
        self.g_loss = self.mse_loss + self.adv_loss

        self.lr = tf.Variable(1e-4, trainable=False)

        self.d_adam, self.g_adam = None, None
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_adam = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9).minimize(self.d_loss, var_list=self.d.vars)
            self.g_adam = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9).minimize(self.g_loss, var_list=self.g.vars)

        if not os.path.isdir(logdir):
            os.makedirs(logdir)

        if not os.path.isdir(ckptdir):
            os.makedirs(ckptdir)

        tfplot.summary.plot("Generated_Spectrum", plot_spectrum, [tf.reshape(self.target_[0], [400])])
        d_loss_summary = tf.summary.scalar("D_Loss", self.d_loss)
        g_loss_summary = tf.summary.scalar("G_Total_Loss", self.g_loss)
        adv_loss_summary = tf.summary.scalar("G_Adv_Loss", self.adv_loss)
        mse_loss_summary = tf.summary.scalar("G_MSE_Loss", self.mse_loss)
        lr_summary = tf.summary.scalar("Learning_Rate", self.lr)

        self.merged_summary = tf.summary.merge_all()

        self.summary_writer = tf.summary.FileWriter(logdir)
        self.saver = tf.train.Saver(max_to_keep=1)

        self.config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))

    def train(self, batches=int(1e6), restore=False):
        
        with tf.Session(config=self.config) as sess:
            if restore:
                meta_graph = [os.path.join(self.ckptdir, file) for file in \
                              os.listdir(self.ckptdir) if file.endswith('.meta')][0]
                restorer = tf.train.import_meta_graph(meta_graph)
                restorer.restore(sess, tf.train.latest_checkpoint(self.ckptdir))
            
            else:
                sess.run(tf.global_variables_initializer())
                print("{:,} global variables initialized.".format(get_total_parameters()))
                
            sess.run(self.ds.initialize('train'))
            print("Data sampler initialized on train dataset.")
            
            try:
                for batch in pbar(range(batches), unit='batch'):
                    sess.run(self.d_adam)
                    sess.run(self.g_adam)
                    
                    if batch % 100 == 0:
                        summaries = sess.run(self.merged_summary)
                        self.summary_writer.add_summary(summaries, batch)

                    if batch % 10000 == 0 or batch + 1 == batches:
                        self.saver.save(sess, os.path.join(self.ckptdir, 'ckpt'), global_step=batch+1)
            
            except KeyboardInterrupt:
                print("Saving model before quitting...")
                self.saver.save(sess, os.path.join(self.ckptdir, 'ckpt'), global_step=batch+1)
                print("Save complete. Training stopped.")
                    
    def infer(self, n, savedir='/vol/data/spectralgan/predictions-latent'):

        def plot(flux1, flux2, latent, metadata, i):
            flux1, flux2 = flux1.reshape([-1, 400]), flux2.reshape([-1, 400])
            latent, metadata = latent.reshape([-1, 4000]), metadata.reshape([-1, 4])
            wave = np.arange(1180, 1280, 0.25)
            wave2 = np.arange(1280, 2280, 0.25)
            
            for j in range(flux1.shape[0]):
                plt.clf()
                a, b, c, d = metadata[j]
                f1, f2, l = flux1[j], flux2[j], latent[j]
                
                f1 = (d-c)*(f1 + 1)/2 + c
                f2 = (d-c)*(f2 + 1)/2 + c
                
                l = b*l + a
                
                plt.plot(wave, f1, label='True Emission')
                plt.plot(wave, f2, label='Inferred Emission')
                plt.plot(wave2, l, label='Latent Spectrum')
                plt.axvline(1215.67, color='r', linestyle='--')
                plt.legend()
                savepath = os.path.join(savedir, 'pred-batch{}-sample{}.png'.format(i, j))
                plt.savefig(savepath, dpi=300, bbox_inches='tight')
            
        if not os.path.isdir(savedir):
            os.makedirs(savedir)
            
        with tf.Session(config=self.config) as sess:
            sess.run(self.ds.initialize('test'))
            print("Data sampler initialized on test dataset.")

            meta_graph = [os.path.join(self.ckptdir, file) for file in os.listdir(self.ckptdir) if file.endswith('.meta')][0]
            restorer = tf.train.import_meta_graph(meta_graph)
            restorer.restore(sess, tf.train.latest_checkpoint(self.ckptdir))
            print("Restored {:,} global parameters.".format(get_total_parameters()))
            
            for i in pbar(range(n)):
                latent, target, metadata, target_ = sess.run([self.latent, self.target, self.metadata, self.target_])
                plot(target, target_, latent, metadata, i)
                

In [12]:
generator = Generator()
discriminator = Discriminator()

In [13]:
sampler = ReInitDataSampler(
    train_filepath='/vol/data/spectralgan/train',
    test_filepath='/vol/data/spectralgan/test',
    batch_size=16)

In [14]:
logdir = '/vol/projects/spectralgan/logdir'
ckptdir = '/vol/projects/spectralgan/ckptdir'

In [15]:
gan = GAN(generator, discriminator, sampler, logdir, ckptdir)

In [None]:
gan.train(restore=True)

INFO:tensorflow:Restoring parameters from /vol/projects/spectralgan/ckptdir/ckpt-40001
Data sampler initialized on train dataset.


HBox(children=(IntProgress(value=0, max=1000000), HTML(value='')))

In [16]:
gan.infer(n=5)

Data sampler initialized on test dataset.
INFO:tensorflow:Restoring parameters from /vol/projects/spectralgan/ckptdir/ckpt-30001
Restored 44,936,883 global parameters.


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))


