In [None]:
from __future__ import print_function, division

import os

import tensorflow as tf

from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical
import keras.backend as K

from torchvision.datasets import STL10

import matplotlib.pyplot as plt

import numpy as np

from PIL import Image

class ContextEncoder():
    def __init__(self, mask_name):
        print("Model initializing..")
        self.img_rows = 96
        self.img_cols = 96
        self.mask_height = 48   
        self.mask_width = 48
        self.channels = 3
        self.mask_name = mask_name
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.missing_shape = (self.mask_height, self.mask_width, self.channels)
        
        #download = True if no dataset is downloaded
        self.X_train = STL10(root='drive/My Drive/ContextEncoder/dataset', split='unlabeled', download=True)
        self.X_train = np.moveaxis(self.X_train.data[:10000], 1, 3)

        self.optimizer = Adam(0.0002, 0.5)
        
        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates the missing
        # part of the image
        masked_img = Input(shape=self.img_shape)
        gen_missing = self.generator(masked_img)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines
        # if it is generated or if it is a real image
        valid = self.discriminator(gen_missing)

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model(masked_img , [gen_missing, valid])
        self.combined.compile(loss=['mse', 'binary_crossentropy'],
            loss_weights=[0.999, 0.001],
            optimizer=self.optimizer)

    def build_generator(self):

        model = Sequential()

        # Encoder
        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(512, kernel_size=1, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.5))

        # Decoder
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(32, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        #model.add(UpSampling2D())
        #model.add(Conv2D(32, kernel_size=3, padding="same"))
        #model.add(Activation('relu'))
        #model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
        model.add(Activation('tanh'))

        model.summary()

        masked_img = Input(shape=self.img_shape)
        gen_missing = model(masked_img)
        return Model(masked_img, gen_missing)   #inputs: masked images
                                                #outputs: generated missings

    def build_discriminator(self):

        model = Sequential()

        model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.missing_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(256, kernel_size=3, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.missing_shape)
        validity = model(img)

        return Model(img, validity)

    def maskn(self, imgs, n):   #mask function without empty areas
        num = len(imgs)
        masked_imgs = np.empty_like(imgs)
        missing_parts = np.empty((num, self.mask_height, self.mask_width, self.channels))

        submask = np.ones((2*n, 2*n, self.channels))
        submask[n:,n:] = 0

        N = self.img_rows//(2*n)    # how many submask is there in a row/column

        mask = np.concatenate(N*[submask], axis=0)
        mask = np.concatenate(N*[mask], axis=1)

        for i, img in enumerate(imgs):
            masked = np.einsum('ijk,ijk->ijk', img, mask)

            missing_part = img[np.where(mask==0)]
            missing_part = missing_part.reshape((self.mask_height, self.mask_width, self.channels))

            masked_imgs[i] = masked
            missing_parts[i] = missing_part

        return masked_imgs, missing_parts
    
    def total_mask(self, imgs, n):      #masked images with white areas
        num = len(imgs)
        masked_imgs = np.empty_like(imgs)
        missing_parts = np.empty((num, self.mask_height, self.mask_width, self.channels))
        
        #Mask for image
        submask = np.ones((2*n, 2*n, self.channels))
        submask[n:,n:] = 0
        N = self.img_rows//(2*n)    # how many submask is there in a row/column

        mask = np.concatenate(N*[submask], axis=0)
        mask = np.concatenate(N*[mask], axis=1)

        #Mask for missing parts
        reverse_mask = 1 - mask

        for i, img in enumerate(imgs):
            masked = np.einsum('ijk,ijk->ijk', img, mask)
            missing_part = np.einsum('ijk, ijk->ijk', img, reverse_mask)

            masked_imgs[i] = masked
            missing_parts[i] = missing_part

        return masked_imgs, missing_parts

    def unmask(self, masked, missing, mask_n):
        for i in range(0, self.mask_height, mask_n):
            for j in range(0, self.mask_width, mask_n):
                masked[2*i+mask_n:2*(i+mask_n), 2*j+mask_n:2*(j+mask_n), :] = missing[i:i+mask_n, j:j+mask_n, :]
        return masked
    
    def total_unmask(self, masked, missing, mask_n):
        submask = np.ones((2*mask_n, 2*mask_n, self.channels))
        submask[mask_n:,mask_n:] = 0
        N = self.img_rows//(2*mask_n)    # how many submask is there in a row/column

        mask = np.concatenate(N*[submask], axis=0)
        mask = np.concatenate(N*[mask], axis=1)

        #Mask for missing parts
        reverse_mask = 1 - mask

        masked = np.einsum('ijk,ijk->ijk', masked, mask)
        missing = np.einsum('ijk, ijk->ijk', missing, reverse_mask)
    
        return np.add(masked, missing)
    
    def train(self, epochs, batch_size=128, sample_interval=50, mask_n=1):    
        # Rescale -1 to 1
        train_set = self.X_train / 127.5 - 1

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        d_losses = []
        g_losses = []
        for epoch in range(epochs):
            
            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, train_set.shape[0], batch_size)
            imgs = train_set[idx]

            masked_imgs, missing_parts = self.maskn(imgs, mask_n)

            # Generate a batch of new images
            gen_missing = self.generator.predict(masked_imgs)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(missing_parts, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_missing, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            g_loss = self.combined.train_on_batch(masked_imgs, [missing_parts, valid])

            if epoch % 100 == 0:
                print("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

            if epoch % 1000 == 0:
                # Plot the progress
                file = open("drive/My Drive/ContextEncoder/%s/metrics.txt" % (self.mask_name), "w")
                file.write("[Discriminative loss: %f]\n[Generative loss: %f]" % (d_loss[0], g_loss[0]))
                file.close()

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                d_losses.append(d_loss[0])
                g_losses.append(g_loss[0])

        for i in [5, 8, 82, 93, 115, 116, 121, 135, 153, 166,
            183, 244, 249, 305, 343, 375, 419, 428, 512, 639,
            691, 745, 746, 825, 832]:
            self.sample_images(train_set, mask_n=mask_n, col=i)
        return d_losses, g_losses

    def sample_images(self, imgs, mask_n, col=0):
        masked_imgs, missing_parts = self.maskn([imgs[col]], mask_n)
        gen_missing = self.generator.predict(masked_imgs)

        filled_in = imgs[col].copy()

        filled_in = 0.5 * filled_in + 0.5
        gen_missing = 0.5 * gen_missing + 0.5

        plt.figure()
        filled_in = self.unmask(filled_in, gen_missing[0], mask_n=mask_n)

        plt.axis("off")
        plt.imshow(filled_in)
        plt.savefig("drive/My Drive/ContextEncoder/%s/images%d/%d.png" % (self.mask_name, mask_n, col), bbox_inches='tight', pad_inches=0)
        plt.close()

    def save_masks(self, upto):
        mask_n = 1
        blank = np.ones((1,self.img_rows, self.img_cols, self.channels))
        while mask_n < upto:
            mask = self.maskn(blank, mask_n)
            plt.imshow(mask[0][0])
            plt.axis("off")
            plt.savefig("drive/My Drive/ContextEncoder/masks/mask_%d.png" % mask_n, bbox_inches='tight', pad_inches=0)
            mask_n *= 2

    def save_original(self, img_list):
        for i in img_list:
            plt.figure()
            plt.axis('off')
            plt.imshow(self.X_train[i])
            plt.savefig("drive/My Drive/ContextEncoder/original/%d.png" % i, bbox_inches='tight', pad_inches=0)
            plt.close()


In [None]:
def model(batch_size=64, sample_interval=1, train_list=[(1, 15001)], mask_name="1x1"):
    d_losses = []
    g_losses = []

    context_encoder = ContextEncoder(mask_name)

    for n, epoch in train_list:
        d_losses_partial, g_losses_partial = context_encoder.train(epochs=epoch, batch_size=batch_size, sample_interval=sample_interval, mask_n=n)
        
        d_losses = np.concatenate((d_losses, d_losses_partial))
        g_losses = np.concatenate((g_losses, g_losses_partial))
    
    plt.plot(np.array(range(0, sample_interval*len(g_losses), sample_interval)), g_losses, label='generative loss')
    plt.legend()
    plt.title('Losses')
    plt.xlabel("Epochs")
    plt.ylabel('Loss Value')
    plt.savefig("drive/My Drive/ContextEncoder/%s/Losses_gen.png" %(mask_name), bbox_inches='tight', pad_inches=0)
    plt.close()
    
    plt.plot(np.array(range(0, sample_interval*len(d_losses), sample_interval)), d_losses, label='discriminative loss')
    plt.legend()
    plt.title('Losses')
    plt.xlabel("Epochs")
    plt.ylabel('Loss Value')
    plt.savefig("drive/My Drive/ContextEncoder/%s/Losses_disc.png" %mask_name, bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
model( train_list=[(1, 1001)], mask_name="1x1")
model( train_list=[(1, 1001), (2, 1001)], mask_name="1x1 + 2x2")
model( train_list=[(1, 1001), (2, 1001), (4, 1001)], mask_name="1x1 + 2x2 + 4x4")
model( train_list=[(1, 1001), (2, 1001), (4, 1001), (8, 1001)], mask_name="1x1 + 2x2 + 4x4 + 8x8")