In [None]:
#Import required dependencies
!mkdir images
!mkdir saved_model
from __future__ import print_function, division
from keras.datasets import cifar10
from keras.layers import Flatten, Dropout,BatchNormalization,Activation, Dense, Input,Reshape, Multiply, GaussianNoise, Embedding, ZeroPadding2D, MaxPooling2D
from tensorflow.keras.layers import LeakyReLU, PReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from keras.optimizers.legacy import Adam
from keras.models import Sequential, Model
from keras import losses
from keras.utils import to_categorical
import keras.backend as K


import matplotlib.pyplot as plt
import numpy as np


In [None]:
class contextEncoder():
  def __init__(self):
    self.img_rows=32
    self.img_cols=32
    self.mask_height=8
    self.mask_width=8
    self.channels=3
    self.num_classes=2
    self.img_shape=(self.img_rows,self.img_cols,self.channels)
    self.missing_image=(self.mask_height,self.mask_width,self.channels)


    optimizer =Adam(0.0002, 0.5)

    #Build and compile discriminator
    self.descriminator=self.build_descriminator()
    self.descriminator.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=['accuracy'])

    #Build and compile Generator

    self.generator=self.build_generator()

    #generator will take noise as inp and generate mising part of the img

    masked_img=Input(shape=self.img_shape)
    gen_missing=self.generator(masked_img)

    #for combined model train just the generator

    self.descriminator.trainable=False

    # Real or fake
    valid=self.descriminator(gen_missing)


    #combined model

    self.combined=Model(masked_img, [gen_missing, valid])
    self.combined.compile(loss=["mse","binary_crossentropy"], loss_weights=[.999,.001], optimizer=optimizer)

  def build_generator(self):
    model=Sequential()


    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))

    model.add(Conv2D(64, kernel_size=3, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))

    model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))

    model.add(Conv2D(512, kernel_size=1, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(Dropout(.5))

    # Decoder (Upsampling)

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding='same'))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=.8))

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding='same'))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=.8))

    model.add(Conv2D(self.channels, kernel_size=3, padding='same'))
    model.add(Activation('tanh'))

    model.summary()

    masked_img=Input(self.img_shape) #i/p
    gen_missing=model(masked_img) # o/p

    return (Model(masked_img,gen_missing))

  def build_descriminator(self):
    model=Sequential()

    model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.missing_image, padding='same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))

    model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))

    model.add(Conv2D(256, kernel_size=1, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))

    model.add(Flatten())
    model.add(Dense(1,activation='sigmoid'))

    model.summary()

    img=Input(shape=self.missing_image)
    validity=model(img)

    return(Model(img, validity))

  def mask_random(self, imgs):
    y1=np.random.randint(0,self.img_rows-self.mask_height, imgs.shape[0])
    y2=y1+self.mask_height

    x1=np.random.randint(0,self.img_rows-self.mask_width, imgs.shape[0])
    x2=x1+self.mask_width

    masked_imgs=np.empty_like(imgs)
    missing_parts=np.empty(shape=(imgs.shape[0], self.mask_height, self.mask_width,self.channels))


    for i, img in enumerate(imgs):
      masked_img=img.copy()
      _y1, _y2, _x1, _x2=y1[i], y2[i], x1[i], x2[i]
      missing_parts[i]=masked_img[_y1:_y2,_x1:_x2].copy()
      masked_img[_y1:_y2,_x1:_x2]=0
      masked_imgs[i]=masked_img

    return (masked_imgs, missing_parts, (y1, y2, x1, x2))


  def train(self, epochs, batch_size=128, sample_interval=50):
    (X_train, y_train), (__,__)=cifar10.load_data()

    X_cats=X_train[(y_train==3).flatten()]
    X_dogs=X_train[(y_train==5).flatten()]

    X_train=np.vstack((X_cats,X_dogs))

    #Rescale between -1 and 1
    X_train=(X_train/127.5) -1
    y_train=y_train.reshape(-1,1)

    #Adversarial GT
    valid=np.ones((batch_size,1))
    fake=np.zeros((batch_size,1))

    for epoch in range(epochs):
      #Train Descriminator

      idx=np.random.randint(0,X_train.shape[0], batch_size)
      imgs=X_train[idx]

      masked_imgs, missing_parts,_=self.mask_random(imgs)

      gen_missing=self.generator.predict(masked_imgs)


      #Train the Descriminator

      d_loss_real=self.descriminator.train_on_batch(missing_parts, valid)

      d_loss_fake=self.descriminator.train_on_batch(gen_missing, fake)

      d_loss=.5*np.add(d_loss_fake,d_loss_real)


      #Train the Generator

      g_loss=self.combined.train_on_batch(masked_imgs,[missing_parts, valid])

      #Plotting Losses

      print("%d [D_loss: %f , acc= %2f%%] [G_loss: %f, mse: %f ]"%(epoch, d_loss[0],100*d_loss[1],g_loss[0],g_loss[1]))


      if epoch%sample_interval==0:
        idx=np.random.randint(0,X_train.shape[0],6)
        imgs=X_train[idx]
        self.sample_images(epoch, imgs)

      if epoch==29900:
        self.save_model()

  def sample_images(self, epoch, imgs):
    r,c=3,6
    masked_imgs,missing_parts, (y1,y2,x1,x2)=self.mask_random(imgs)
    gen_missing=self.generator.predict(masked_imgs)


    imgs=0.5*imgs+0.5
    masked_imgs=masked_imgs*0.5+0.5
    gen_missing=0.5*gen_missing+0.5

    fig, axs=plt.subplots(r,c)
    for i in range(c):
      axs[0,i].imshow(imgs[i,:,:])
      axs[0,i].axis('off')
      axs[1,i].imshow(masked_imgs[i,:,:])
      axs[1,i].axis('off')
      filled_im=imgs[i].copy()
      filled_im[y1[i]:y2[i],x1[i]:x2[i],:]=gen_missing[i]
      axs[2,i].imshow(filled_im)
      axs[2,i].axis('off')

    plt.savefig('images/%d.png'%epoch)
    plt.close()


  def save_model(self):
    def save(model, model_name):
      model_path="saved_model/%s.json"%model_name
      weights_path="saved_model/%s_weights.hdf5"%model_name
      options={"file_arc":model_path, "file_weight":weights_path}
      json_string=model.to_json()
      open(options['file_arc'],'w').write(json_string)
      model.save_weights(options['file_weight'])

    save(self.generator, "generator")
    save(self.descriminator,"descriminator")









In [None]:
if __name__=="__main__":
  context_encoder=contextEncoder()
  context_encoder.train(epochs=30000, batch_size=128, sample_interval=1000)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
16843 [D_loss: 0.693270 , acc= 50.000000%] [G_loss: 0.093007, mse: 0.092406 ]
16844 [D_loss: 0.693269 , acc= 41.015625%] [G_loss: 0.089595, mse: 0.088991 ]
16845 [D_loss: 0.693269 , acc= 35.546875%] [G_loss: 0.090116, mse: 0.089512 ]
16846 [D_loss: 0.693269 , acc= 46.875000%] [G_loss: 0.091609, mse: 0.091007 ]
16847 [D_loss: 0.693269 , acc= 36.718750%] [G_loss: 0.091305, mse: 0.090702 ]
16848 [D_loss: 0.693269 , acc= 33.984375%] [G_loss: 0.100971, mse: 0.100378 ]
16849 [D_loss: 0.693269 , acc= 48.437500%] [G_loss: 0.092236, mse: 0.091635 ]
16850 [D_loss: 0.693269 , acc= 35.937500%] [G_loss: 0.089441, mse: 0.088837 ]
16851 [D_loss: 0.693269 , acc= 47.265625%] [G_loss: 0.091809, mse: 0.091207 ]
16852 [D_loss: 0.693269 , acc= 43.750000%] [G_loss: 0.095793, mse: 0.095195 ]
16853 [D_loss: 0.693269 , acc= 35.937500%] [G_loss: 0.090213, mse: 0.089609 ]
16854 [D_loss: 0.693269 , acc= 42.578125%] [G_loss: 0.098637, mse: 0.098042 ]