In [1]:
import os
import time
import sherpa
import tfplot
import logging
import warnings
import numpy as np
import tensorflow as tf

from tensorflow import keras as k
from tensorflow import layers as ly

from tqdm import tqdm_notebook as pbar

from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec

DEBUG:matplotlib:CACHEDIR=/root/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /root/.cache/matplotlib/fontList.json
DEBUG:matplotlib.backends:backend agg version v2.2


In [2]:
plt.style.use('seaborn')
warnings.filterwarnings('ignore')

In [3]:
def plot_spectrum(spec):
    fig, ax = tfplot.subplots(figsize=(4, 3))
    im = ax.plot(np.arange(1210, 1280, 0.25), spec)
    return fig

In [4]:
def get_total_parameters():
    total_parameters = 0
    for variable in tf.trainable_variables():
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    return total_parameters

In [None]:
def residual_block(x, kern, filters):
    
    y = ly.conv1d(
        inputs=x,
        filters=filters,
        kernel_size=kern,
        strides=1,
        padding='same',
        kernel_initializer=tf.random_normal_initializer(stddev=0.02),
        activation=tf.identity
        )
    y = k.layers.PReLU()(y)
    y = ly.conv1d(
        inputs=y,
        filters=filters,
        kernel_size=kern,
        strides=1,
        padding='same',
        kernel_initializer=tf.random_normal_initializer(stddev=0.02),
        activation=tf.identity
        )
    
    return tf.add(x, y)

In [None]:
def subpixel_convolution(x, block_size):
    
    y = ly.conv1d(
    inputs=x,
    filters=256,
    kernel_size=3,
    strides=1,
    padding='same',
    kernel_initializer=tf.random_normal_initializer(stddev=0.02),
    activation=tf.identity
    )
    y = tf.depth_to_space(y, block_size=block_size)
    
    return k.layers.PReLU()(y)

In [None]:
def convolution_block(x, filters, stride):
    
    y = ly.conv1d(
    inputs=x,
    filters=filters,
    kernel_size=3,
    strides=stride,
    padding='same',
    kernel_initializer=tf.random_normal_initializer(stddev=0.02),
    activation=tf.identity
    )
    y = ly.batch_normalization(y)
    
    return k.layers.LeakyReLU(alpha=0.2)(y)

In [None]:
class ReInitDataSampler:
    def __init__(self, train_filepath, batch_size, valid_filepath=None, test_filepath=None, shuffle=True):
        self.batch_size = batch_size
        self.shuffle = shuffle

        train_files = [os.path.join(train_filepath, file) for file in os.listdir(train_filepath) if file.endswith('.tfrecords')]
        train_dataset = self.make_dataset(train_files)
        self.iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
        self.train_init_op = self.iter.make_initializer(train_dataset)

        if valid_filepath is not None:
            valid_files = [os.path.join(valid_filepath, file) for file in os.listdir(valid_filepath) if file.endswith('.tfrecords')]
            valid_dataset = self.make_test_dataset(valid_files)
            self.valid_init_op = self.iter.make_initializer(valid_dataset)

        if test_filepath is not None:
            test_files = [os.path.join(test_filepath, file) for file in os.listdir(test_filepath) if file.endswith('.tfrecords')]
            test_dataset = self.make_test_dataset(test_files)
            self.test_init_op = self.iter.make_initializer(test_dataset)
    
    def make_dataset(self, files):
        dataset = tf.data.TFRecordDataset(files).map(self.decoder)
        
        if self.shuffle:
            dataset = dataset.shuffle(buffer_size=10000)
        
        return dataset.repeat().batch(self.batch_size)
    
    def make_test_dataset(self, files):
        dataset = tf.data.TFRecordDataset(files).map(self.decoder)
        return dataset.batch(self.batch_size)
    
    def initialize(self, dataset='train'):
        if dataset == 'train':
            return self.train_init_op
        elif dataset == 'valid':
            return self.valid_init_op
        elif dataset == 'test':
            return self.test_init_op
        else:
            raise ValueError('Dataset unknown or unavailable.')

    def decoder(self, example_proto):
        keys_to_features = {'latent' : tf.FixedLenFeature(4000, tf.float32),
                            'target' : tf.FixedLenFeature(280, tf.float32),
                            'metadata' : tf.FixedLenFeature(2, tf.float32)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return parsed_features['latent'], parsed_features['target'], parsed_features['metadata']

    def get_batch(self):
        x, y, z = self.iter.get_next()
        x = tf.reshape(x, [-1, 4000, 1])
        y = tf.reshape(y, [-1, 280, 1])
        z = tf.reshape(z, [-1, 2])
        return x, y, z

In [None]:
class Discriminator:
    def __init__(self, params, name='discriminator'):
        self.params = params
        self.name = name

    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as vs:
            if reuse:
                vs.reuse_variables()
            
            y = ly.conv1d(
                inputs=x,
                filters=64,
                kernel_size=3,
                strides=1,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            y = k.layers.LeakyReLU(alpha=0.2)(y)
            
            y = convolution_block(y, filters=64, stride=2)
            y = convolution_block(y, filters=128, stride=1)
            y = convolution_block(y, filters=128, stride=2)
            y = convolution_block(y, filters=256, stride=1)
            y = convolution_block(y, filters=256, stride=2)
            y = convolution_block(y, filters=512, stride=1)
            y = convolution_block(y, filters=512, stride=2)
            
            y = ly.Flatten()(y)
    
            y = ly.dense(
                inputs=y,
                units=512,
                activation=tf.identity,
                )
            y = k.layers.LeakyReLU(alpha=0.2)(y)
            y = ly.dense(
                inputs=y,
                units=1,
                activation=tf.identity,
                )
            
            return y
            
    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]

In [None]:
class Generator:
    def __init__(self, params, num_blocks=8, name='generator'):
        self.params = params
        self.num_blocks = num_blocks
        self.name = name

    def __call__(self, x):
        with tf.variable_scope(self.name) as vs:

            y = ly.conv1d(
                inputs=x,
                filters=64,
                kernel_size=self.params['g_l1_kern'],
                strides=1,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            y = k.layers.PReLU()(y)
            y_ = tf.identity(y)

            for i in range(self.num_blocks):
                y = residual_block(y, kern=self.params['res_block_kern'], filters=64)
            
            y = tf.add(y, y_)
            
            y = ly.conv1d(
                inputs=y,
                filters=64,
                kernel_size=self.params['g_l2_kern'],
                strides=10,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            
            y = ly.conv1d(
                inputs=y,
                filters=64,
                kernel_size=self.params['g_l3_kern'],
                strides=4,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )
            
            y = ly.conv1d(
                inputs=y,
                filters=64,
                kernel_size=self.params['g_l4_kern'],
                strides=4,
                padding='same',
                kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                activation=tf.identity
                )

            y = ly.Flatten()(y)
            
            y = ly.dense(
                inputs=y,
                units=512,
                activation=tf.identity,
                )
            y = k.layers.PReLU()(y)
            y = ly.dense(
                inputs=y,
                units=280,
                activation=None,
                )
            
            y = tf.reshape(y, [-1, 280, 1])

            return y

    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]

In [None]:
class GAN:
    def __init__(self, generator, discriminator, data_sampler, logdir, ckptdir, lr=None, verbose=False):
        self.g = generator
        self.d = discriminator
        self.ds = data_sampler
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.verbose = verbose

        self.latent, self.target, self.metadata = data_sampler.get_batch()

        self.target_ = self.g(self.latent)

        logits_real = self.d(self.target)
        logits_fake = self.d(self.target_, reuse=True)

        d_loss_real = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_real), logits_real)
        d_loss_fake = tf.losses.sigmoid_cross_entropy(tf.zeros_like(logits_fake), logits_fake)
        self.d_loss = d_loss_real + d_loss_fake

        adv_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_fake), logits_fake)
        mse_loss = tf.losses.mean_squared_error(self.target_, self.target)
        self.g_loss = mse_loss
        
        if not lr:
            self.lr = tf.Variable(1e-4, trainable=False)
            self.lr_drop = self.lr.assign(1e-5)
        else:
            self.lr = tf.Variable(lr, trainable=False)

        self.d_adam, self.g_adam = None, None
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_adam = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9).minimize(self.d_loss, var_list=self.d.vars)
            self.g_adam = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9).minimize(self.g_loss, var_list=self.g.vars)

        if not os.path.isdir(logdir):
            os.makedirs(logdir)

        if not os.path.isdir(ckptdir):
            os.makedirs(ckptdir)

        tfplot.summary.plot("Generated_Spectrum", plot_spectrum, [tf.reshape(self.target_[0], [280])])
        d_loss_summary = tf.summary.scalar("D_Loss", self.d_loss)
        g_loss_summary = tf.summary.scalar("G_Total_Loss", self.g_loss)
        adv_loss_summary = tf.summary.scalar("G_Adv_Loss", adv_loss)
        mse_loss_summary = tf.summary.scalar("G_MSE_Loss", mse_loss)
        lr_summary = tf.summary.scalar("Learning_Rate", self.lr)

        self.merged_summary = tf.summary.merge_all()

        self.summary_writer = tf.summary.FileWriter(logdir)
        self.saver = tf.train.Saver(max_to_keep=1)

        self.config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
        self.sess = tf.Session(config=self.config)
        
        self.sess.run(tf.global_variables_initializer())
        
        if self.verbose:
            print("{:,} global variables initialized.".format(get_total_parameters()))

    def train(self, batches=int(2e4), save=False, restore=False):  
        
        self.sess.run(self.ds.initialize('train'))
        
        if restore:
            meta_graph = [os.path.join(self.ckptdir, file) for file in \
                          os.listdir(self.ckptdir) if file.endswith('.meta')][0]
            restorer = tf.train.import_meta_graph(meta_graph)
            restorer.restore(self.sess, tf.train.latest_checkpoint(self.ckptdir))

        try:
            for batch in pbar(range(batches), unit='batch'):
                self.sess.run(self.g_adam)

                if batch % 100 == 0:
                    summaries = self.sess.run(self.merged_summary)
                    self.summary_writer.add_summary(summaries, batch)

                if save and (batch % 10000 == 0 or batch + 1 == batches):
                    self.saver.save(self.sess, os.path.join(self.ckptdir, 'ckpt'), global_step=batch+1)

            return self.sess.run(self.g_loss)

        except KeyboardInterrupt:
            print("Saving model before quitting...")
            self.saver.save(self.sess, os.path.join(self.ckptdir, 'ckpt'), global_step=batch+1)
            print("Save complete. Training stopped.")
    
    def evaluate(self):
        
        self.sess.run(self.ds.initialize('valid'))
        
        if self.verbose:
            print("Evaluating...\n")

        mse = []
        while True:
            try:
                target, target_ = self.sess.run([self.target, self.target_])
                mse.append(((target-target_)**2).mean())

            except tf.errors.OutOfRangeError:
                break

        return np.mean(mse)
                    
    def infer(self, n, show_latent=True, savedir='/vol/data/spectralgan/pred-rev1'):

        def plot(flux1, flux2, latent, metadata, i):
            flux1, flux2 = flux1.reshape([-1, 280]), flux2.reshape([-1, 280])
            latent, metadata = latent.reshape([-1, 4000]), metadata.reshape([-1, 2])
            wave = np.arange(1210, 1280, 0.25)
            wave2 = np.arange(1280, 2280, 0.25)
            
            for j in pbar(range(flux1.shape[0])):
                plt.clf()
                a, b = metadata[j]
                f1, f2, l = flux1[j], flux2[j], latent[j]
                
                f1 = b*f1 + a
                f2 = b*f2 + a
                
                l = b*l + a
                
                plt.plot(wave, f1, label='True Emission')
                plt.plot(wave, f2, label='Inferred Emission')
                
                if show_latent:
                    plt.plot(wave2, l, label='Latent Spectrum')
                    
                plt.axvline(1215.67, color='r', linestyle='--')
                plt.legend()
                savepath = os.path.join(savedir, 'pred-batch{}-sample{}.png'.format(i, j))
                plt.savefig(savepath, dpi=300, bbox_inches='tight')
            
        if not os.path.isdir(savedir):
            os.makedirs(savedir)
            
        self.sess.run(self.ds.initialize('test'))
        
        if self.verbose:
            print("Data sampler initialized on test dataset.")

        meta_graph = [os.path.join(self.ckptdir, file) for file in os.listdir(self.ckptdir) if file.endswith('.meta')][0]
        restorer = tf.train.import_meta_graph(meta_graph)
        restorer.restore(self.sess, tf.train.latest_checkpoint(self.ckptdir))
        
        if self.verbose:
            print("Restored {:,} global parameters.".format(get_total_parameters()))

        for i in pbar(range(n)):
            latent, target, metadata, target_ = self.sess.run([self.latent, self.target, self.metadata, self.target_])
            plot(target, target_, latent, metadata, i)
            

In [None]:
parameters = [
    sherpa.Discrete('res_block_kern', [3, 32]),
    sherpa.Discrete('g_l1_kern', [3, 32]),
    sherpa.Discrete('g_l2_kern', [3, 32]),
    sherpa.Discrete('g_l3_kern', [3, 32]),
    sherpa.Discrete('g_l4_kern', [3, 32]),
    sherpa.Continuous('lr', [1e-6, 1e-3], scale='log')
    ]

algorithm = sherpa.algorithms.BayesianOptimization(max_num_trials=50)
study = sherpa.Study(parameters, algorithm, lower_is_better=True)

INFO:sherpa.core:
-------------------------------------------------------
SHERPA Dashboard running on http://10.244.103.17:8880
-------------------------------------------------------


 * Serving Flask app "sherpa.app.app" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: on


In [None]:
for i, trial in enumerate(study):
    
    print("\033[1mBeginning trial {}...\033[0m\n".format(i))
    
    for key in trial.parameters:
        print('{}\t{}'.format(key, trial.parameters[key]))
    print('\n')
    
    tf.reset_default_graph()
    
    generator = Generator(trial.parameters)
    discriminator = Discriminator(trial.parameters)

    sampler = ReInitDataSampler(
        train_filepath='/vol/data/spectralgan/train-rev1',
        valid_filepath='/vol/data/spectralgan/test-rev1',
        batch_size=16)

    logdir = '/vol/projects/spectralgan/logdir/opt/opt_trial-{}'.format(i)
    ckptdir = '/vol/projects/spectralgan/ckptdir/rev2'

    gan = GAN(generator, discriminator, sampler, logdir, ckptdir, lr=trial.parameters['lr'])
    
    for j in range(10):
        train_error = gan.train(1000)
        valid_error = gan.evaluate()
    
        study.add_observation(
            trial=trial,
            iteration=j,
            objective=valid_error,
            context={'training_error': train_error}
        )
    
    study.finalize(trial)

[1mBeginning trial 0...[0m

res_block_kern	12
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0 to DejaVu Sans ('/modules/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 4.050000





HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 1...[0m

res_block_kern	22
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 2...[0m

res_block_kern	12
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 3...[0m

res_block_kern	22
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 4...[0m

res_block_kern	12
g_l3_kern	12
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 5...[0m

res_block_kern	22
g_l3_kern	12
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 6...[0m

res_block_kern	12
g_l3_kern	12
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 7...[0m

res_block_kern	22
g_l3_kern	12
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 8...[0m

res_block_kern	12
g_l3_kern	22
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 9...[0m

res_block_kern	22
g_l3_kern	22
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 10...[0m

res_block_kern	12
g_l3_kern	22
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 11...[0m

res_block_kern	22
g_l3_kern	22
g_l4_kern	12
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 12...[0m

res_block_kern	12
g_l3_kern	22
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 13...[0m

res_block_kern	22
g_l3_kern	22
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 14...[0m

res_block_kern	12
g_l3_kern	22
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 15...[0m

res_block_kern	22
g_l3_kern	22
g_l4_kern	22
g_l1_kern	12
g_l2_kern	12
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 16...[0m

res_block_kern	12
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	22
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 17...[0m

res_block_kern	22
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	22
lr	1e-05




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 18...[0m

res_block_kern	12
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	22
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


[1mBeginning trial 19...[0m

res_block_kern	22
g_l3_kern	12
g_l4_kern	12
g_l1_kern	12
g_l2_kern	22
lr	0.0001




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

In [14]:
study.get_best_result()

{'Iteration': 8,
 'Objective': 2.80417799949646,
 'Trial-ID': 48,
 'g_l1_kern': 22,
 'g_l2_kern': 12,
 'g_l3_kern': 22,
 'g_l4_kern': 22,
 'lr': 0.0001,
 'res_block_kern': 22,
 'training_error': 2.6794958114624023}