In [None]:
#!pip install Tensorflow
#!pip install Keras


In [None]:
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import cifar10
from PIL import Image


class DataLoader():
    def __init__(self, dataset_name, img_res=(128,128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        (x, y), (_, _) = cifar10.load_data()
        x=np.asarray(x)
        y=np.asarray(y)

        data_type = "train" if not is_testing else "test"
        batch_images = np.random.choice(range(x.shape[0]), size=batch_size)
        imgs_hr = []
        imgs_lr = []

        for img_index in batch_images:
            img = x[img_index, :, :, :]
            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)

            img_hr = Image.fromarray(img)
            img_hr = img_hr.resize(self.img_res, Image.BICUBIC)
            img_hr = np.array(img_hr)

            img_lr = Image.fromarray(img)
            img_lr = img_lr.resize((low_w, low_h), Image.BICUBIC)
            img_lr = np.array(img_lr)

            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        imgs_hr = np.array(imgs_hr) / 127.5 - 1
        imgs_lr = np.array(imgs_lr) / 127.5 - 1

        return imgs_hr, imgs_lr





In [None]:
from __future__ import print_function, division
import scipy
from keras.layers import BatchNormalization, Input, Dense, Reshape, Flatten, Dropout, Concatenate, Activation, ZeroPadding2D, Add, Conv2D, UpSampling2D
from tensorflow.keras.applications import VGG19
from keras.models import Sequential, Model
from tensorflow.keras.layers import LeakyReLU, PReLU
from keras.optimizers.legacy import Adam
import datetime
import matplotlib.pyplot as plt
import numpy as np
import sys
import os


import keras.backend as k


In [None]:
class SRGAN():
    def __init__(self):
        self.channels=3
        self.lr_height=64
        self.lr_width=64


        self.lr_shape=(self.lr_height, self.lr_width, self.channels)


        self.hr_height=self.lr_height*4
        self.hr_width=self.lr_width*4
        self.hr_shape=(self.hr_height,self.hr_width,self.channels)

        # Residual blocks

        self.n_residual=16

        optimizer=Adam(0.0002,0.5)

        #pretrained VGG model to extract features from the high resolution and the generated high res images, min(mse)

        self.vgg=self.build_vgg()
        self.vgg.trainable=False
        self.vgg.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

        #config Dataloader

        self.dataset_name="cifar10_dataset"
        self.dataloader=DataLoader(dataset_name=self.dataset_name, img_res=(self.hr_height,self.hr_width))


        patch=int(self.hr_height/2**4)
        self.disc_patch=(patch,patch,1)

        # Filters in gen and discriminator
        self.gf=64
        self.df=64

        #build and compile descriminator

        self.descriminator= self.build_descriminator()
        self.descriminator.compile(loss='mse',optimizer=optimizer, metrics=['accuracy'])

        #builld the generator

        self.generator=self.build_generator()


        #high res and low res images

        img_hr=Input(shape=self.hr_shape)
        img_lr=Input(shape=self.lr_shape)


        #gen hr from lr

        fake_hr=self.generator(img_lr)

        #extract VGG19 features of this image

        fake_features=self.vgg(fake_hr)

        # in combined model gen is trainable and descriminator as non trainable

        self.descriminator.trainable=False

        validity=self.descriminator(fake_hr)
        self.combined=Model([img_lr, img_hr],[validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy','mse'], loss_weights=[1e-3, 1], optimizer=optimizer)

    def build_vgg(self):
        vgg = VGG19(weights="imagenet",include_top=False, input_shape=self.hr_shape)
        return(Model(inputs=vgg.input,outputs=vgg.layers[9].output))


    def build_generator(self):
        def residual_block(layer_input, filters):
            d=Conv2D(filters,kernel_size=3,strides=1, padding='same')(layer_input)
            d=Activation('relu')(d)
            d=BatchNormalization(momentum=.8)(d)
            d=Conv2D(filters,kernel_size=3,strides=1, padding='same')(d)
            d=Add()([d,layer_input])
            return(d)

        def deconv2D(layer_input):
            u=UpSampling2D(size=2)(layer_input)
            u=Conv2D(256,kernel_size=3, strides=1, padding='same')(u)
            u=Activation('relu')(u)
            return(u)

        # lr image input
        img_lr=Input(shape=self.lr_shape)

        #pre residual block

        c1=Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1=Activation('relu')(c1)


        #Residual blocks

        r=residual_block(c1,self.gf)
        for _ in range(self.n_residual-1):
            r=residual_block(r,self.gf)

        #post residual blocks

        c2=Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2=BatchNormalization(momentum=.8)(c2)
        c2=Add()([c2,c1])

        #upsampling
        u1=deconv2D(c2)
        u2=deconv2D(u1)

        #hr o/p

        gen_hr=Conv2D(self.channels, kernel_size=9, strides=1, padding='same',activation='tanh')(u2)

        return(Model(img_lr,gen_hr))

    def build_descriminator(self):
      def d_block(layer_input,filters,strides=1,bn=True):
        #discriminator layer
        d=Conv2D(filters,kernel_size=3,strides=strides,padding='same')(layer_input)
        d=LeakyReLU(alpha=.2)(d)
        if bn:
          d=BatchNormalization(momentum=.8)(d)

        return(d)

      #input image
      d0=Input(shape=self.hr_shape)

      d1=d_block(d0,self.df,bn=False)
      d2=d_block(d1,self.df,strides=2)
      d3=d_block(d2,self.df*2)
      d4=d_block(d3,self.df*2,strides=2)
      d5=d_block(d4,self.df*4)

      d6=d_block(d5,self.df*4,strides=2)
      d7=d_block(d6,self.df*8)
      d8=d_block(d7,self.df*8,strides=2)

      d9=Dense(self.df*16)(d8)
      d10=LeakyReLU(alpha=.2)(d9)
      validity=Dense(1,activation='sigmoid')(d10)

      return(Model(d0,validity))

    def train(self,epochs,batch_size=1,sample_intervals=50):

      start_time=datetime.datetime.now()

      for epoch in range(epochs):
        #train discriminator


        #sample images and their conditioning counterparts
        imgs_hr,imgs_lr=self.dataloader.load_data(batch_size)

        # lr to hr (fake)

        fake_hr=self.generator.predict(imgs_lr)

        valid=np.ones((batch_size,)+self.disc_patch)
        fake=np.zeros((batch_size,)+self.disc_patch)


        #train the discriminator

        d_loss_real=self.descriminator.train_on_batch(imgs_hr,valid)
        d_loss_fake=self.descriminator.train_on_batch(fake_hr,fake)

        d_loss=.5*np.add(d_loss_real,d_loss_fake)

        #Train gen

        imgs_hr, imgs_lr=self.dataloader.load_data(batch_size)

        valid=np.ones((batch_size,)+self.disc_patch)

        image_features=self.vgg.predict(imgs_hr)

        g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
        elapsed_time=datetime.datetime.now()-start_time

        #plotting
        print("%d time: %s"%(epoch,elapsed_time))

        if epoch%sample_intervals==0:
          self.sample_images(epoch)


    def sample_images(self,epoch):
      os.makedirs("images/%s"%self.dataset_name,exist_ok=True)

      r,c=2,2

      imgs_hr,imgs_lr=self.dataloader.load_data(batch_size=2,is_testing=True)

      fake_hr=self.generator.predict(imgs_lr)

      imgs_lr=.5*imgs_lr+.5
      fake_hr=.5*fake_hr+.5
      imgs_hr=.5*imgs_hr+.5


      titles=['Generated','Original']

      fig,axs=plt.subplots(r,c)

      cnt=0

      for row in range(r):
        for col, image in enumerate([fake_hr,imgs_hr]):
          axs[row,col].imshow(image[row])
          axs[row,col].set_title(titles[col])
          axs[row,col].axis("off")
        cnt+=1
      plt.savefig('images/%s/%d.png'%(self.dataset_name,epoch))
      plt.close()

      for i in range(r):
        fig=plt.figure()
        plt.imshow(imgs_lr[i])
        fig.savefig('images/%s/%d_lowres%d.png'%(self.dataset_name,epoch,i))
        plt.close()

    def save_model(self):
      def save(model,model_name):
        model_path="saved_model/%s.json"%model_name
        weights_path="saved_model/%s_weights.hdf5"%model_name
        options={'file_arch':model_path,
                 'file_weight':weights_path}
        json_string=model.to_json()
        open(options['file_arch'],'w').write(json_string)
        model.save_weights(options['file_weight'])

      save(self.generator,"generator")
      save(self.descriminator,"discriminator")













In [None]:
if __name__=="__main__":
  gan=SRGAN()
  gan.train(epochs =100, batch_size=1, sample_intervals=10)
  !mkdir saved_model
  gan.save_model()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
0 time: 0:00:30.098856
1 time: 0:00:36.099395
2 time: 0:00:38.406297
3 time: 0:00:40.457463
4 time: 0:00:42.344154
5 time: 0:00:44.965628
6 time: 0:00:46.965499
7 time: 0:00:48.929651
8 time: 0:00:50.782159
9 time: 0:00:52.737100
10 time: 0:00:54.619690
11 time: 0:00:58.926984
12 time: 0:01:00.793631
13 time: 0:01:02.670620
14 time: 0:01:04.594272
15 time: 0:01:06.534547
16 time: 0:01:09.113640
17 time: 0:01:11.293706
18 time: 0:01:13.330051
19 time: 0:01:15.286194
20 time: 0:01:17.197599
21 time: 0:01:20.959230
22 time: 0:01:23.316634
23 time: 0:01:26.427262
24 time: 0:01:28.374732
25 time: 0:01:30.315283
26 time: 0:01:32.409836
27 time: 0:01:34.979072
28 time: 0:01:36.949657
29 time: 0:01:38.910986
30 time: 0:01:40.788567
31 time: 0:01:44.330184
32 time: 0:01:47.083830
3