In [20]:
import pickle as pk
import matplotlib.pyplot as plt
import numpy as np
import time
from tensorflow.examples.tutorials.mnist import input_data

from keras.models import Sequential
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, Activation, LeakyReLU, Dropout,BatchNormalization
from keras.optimizers import Adam, RMSprop

In [21]:
fake_sims=pk.load(open('Fake_sims_for_separation_test.pk','rb'))

In [22]:
class ElapsedTimer(object):
    def __init__(self):
        self.start_time = time.time()
    def elapsed(self,sec):
        if sec < 60:
            return str(sec) + " sec"
        elif sec < (60 * 60):
            return str(sec / 60) + " min"
        else:
            return str(sec / (60 * 60)) + " hr"
    def elapsed_time(self):
        print("Elapsed: %s " % self.elapsed(time.time() - self.start_time) )

In [32]:
class Autoencoder(object):
    
    def __init__(self, img_rows=64, img_cols=64, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.E = None   # endcoder
        self.D = None   # decoder
        self.AE = None  # Auto encoder model
        
    def encoder(self):
        if self.E:
            return self.E
        self.E = Sequential()
        depth = 64
        dropout = 0.6
        # In: 64 x 64 x 1, depth = 1
        # Out: 64 x 64 x 1, depth=64
        input_shape = (self.img_rows, self.img_cols, self.channel)
        self.E.add(Conv2D(depth*1, 3, strides=1, input_shape=input_shape,\
            padding='same'))
        self.E.add(LeakyReLU(alpha=0.2))
        self.E.add(Dropout(dropout))

        self.E.add(Conv2D(depth*2, 5, strides=2, padding='same'))
        self.E.add(BatchNormalization(momentum=0.9))
        self.E.add(LeakyReLU(alpha=0.2))
        self.E.add(Dropout(dropout))

        self.E.add(Conv2D(depth*4, 5, strides=2, padding='same'))
        self.E.add(BatchNormalization(momentum=0.9))
        self.E.add(LeakyReLU(alpha=0.2))
        self.E.add(Dropout(dropout))

        self.E.summary()
        return self.E
    
    def decoder(self):
        if self.D:
            return self.D
        self.D = Sequential()
        dropout = 0.6
        depth = 64*4
        dim = 16
        # In: 100
        # Out: dim x dim x depth
    
        # In: dim x dim x depth
        # Out: 2*dim x 2*dim x depth/2
        self.D.add(UpSampling2D(input_shape=(16, 16, 256)))
        self.D.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
        self.D.add(BatchNormalization(momentum=0.9))
        self.D.add(Activation('relu'))

        self.D.add(UpSampling2D(size=(2,2)))
        self.D.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
        self.D.add(BatchNormalization(momentum=0.9))
        self.D.add(Activation('relu'))
        
        # Out: 256 x 256 x 1 grayscale image [0.0,1.0] per pix
        self.D.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
        self.D.add(Activation('tanh'))
        
        self.D.add(Conv2DTranspose(1, 5, padding='same'))
        self.D.add(Activation('linear'))
        self.D.summary()
        return self.D
    
    def autoencoder_model(self):
        if self.AE:
            return self.AE
        optimizer = RMSprop(lr=0.0001, decay=3e-8)
        self.AE = Sequential()
        self.AE.add(self.encoder())
        self.AE.add(self.decoder())
        self.AE.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        return self.AE
    

In [66]:
class FKSIMS_AE(object):
    def __init__(self):
        self.img_rows = 64
        self.img_cols = 64
        self.channel = 1

        _,self.train_fgs,self.train_mixed = zip(*pk.load(open('Fake_sims_for_separation_test.pk','rb')))
        self.train_mixed=np.array(self.train_mixed).reshape(-1, self.img_rows,\
            self.img_cols, 1).astype(np.float32)
        self.train_fgs=np.array(self.train_fgs).reshape(-1, self.img_rows,\
            self.img_cols, 1).astype(np.float32)
        
        self.AE = Autoencoder()
        self.autoencoder = self.AE.autoencoder_model()

    def train(self, train_steps=2000, batch_size=25, save_interval=0):
        for i in range(train_steps):
            slice = np.random.randint(0,self.train_mixed.shape[0], size=batch_size)
            tmix = self.train_mixed[slice, :, :, :]
            tfgs = self.train_fgs[slice, :, :, :]
            ae_loss = self.autoencoder.train_on_batch(tmix, tfgs)
            log_mesg = "%d  [AE loss: %f, acc: %f]" % (i, ae_loss[0], ae_loss[1])
            print(log_mesg)
            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(tmix, tfgs, save2file=True, samples=8,  step=(i+1))

    def plot_images(self, inputs, truths, save2file=False, samples=8, step=0):
        filename = "fake_sims_%d.png" % step
        inputs = inputs[:samples*2]
        truths = truths[:samples*2]
        fakes = self.autoencoder.predict(inputs)
        

        plt.figure(figsize=(10,10))
        for i in range(samples*2):
            if (i+1)%2!=0:
                plt.subplot(4, 4, i+1)
                truth = truths[i, :, :, :]
                truth = np.reshape(truth, [self.img_rows, self.img_cols])
                plt.imshow(truth, cmap='gray')
                plt.axis('off')
            else:
                plt.subplot(4, 4, i+1)
                fake = fakes[i-1, :, :, :]
                fake = np.reshape(fake, [self.img_rows, self.img_cols])
                plt.imshow(fake, cmap='gray')
                plt.axis('off')
        plt.tight_layout()
        if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()

In [67]:
fksims_ae = FKSIMS_AE()
timer = ElapsedTimer()
fksims_ae.train(train_steps=10000, batch_size=200, save_interval=50)
timer.elapsed_time()
fksims_ae.plot_images(fake=True)
fksims_ae.plot_images(fake=False, save2file=True)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_70 (Conv2D)           (None, 64, 64, 64)        640       
_________________________________________________________________
leaky_re_lu_70 (LeakyReLU)   (None, 64, 64, 64)        0         
_________________________________________________________________
dropout_70 (Dropout)         (None, 64, 64, 64)        0         
_________________________________________________________________
conv2d_71 (Conv2D)           (None, 32, 32, 128)       204928    
_________________________________________________________________
batch_normalization_88 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
leaky_re_lu_71 (LeakyReLU)   (None, 32, 32, 128)       0         
_________________________________________________________________
dropout_71 (Dropout)         (None, 32, 32, 128)       0         
__________

KeyboardInterrupt: 