In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!pip install pydrive
import pydrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [0]:
## To import Training Data
def get_file_from_drive(file_name):
  file_list = drive.ListFile({'q': "'1vsl4F2NQoUN5ov0vee39mLnALX01DKU6' in parents and trashed=false"}).GetList()
  for file in file_list:
    if(file['title'] == file_name):
      return file['id']

In [0]:
## To import Validation Data
def get_file_from_drive(file_name):
  file_list = drive.ListFile({'q': "'17nYS9E1qRCrJBpS00KMmCHRyQwK84zc1' in parents and trashed=false"}).GetList()
  for file in file_list:
    if(file['title'] == file_name):
      return file['id']

In [0]:
# ## To import Models
# def get_file_from_drive(file_name):
#   file_list = drive.ListFile({'q': "'1K_POvFkuMLSVN4Qrcpr7dMfP2syslRdH' in parents and trashed=false"}).GetList()
#   for file in file_list:
#     if(file['title'] == file_name):
#       return file['id']

In [0]:
def upload_data_system(file_name):
  downloaded = drive.CreateFile({'id': get_file_from_drive(file_name)})
  downloaded.GetContentFile(file_name)

In [0]:
## Importing Training Data

file_name = 'X_train.npy'

print("Importing data from drive...")
print('Importing X_train...')
upload_data_system(file_name)

In [0]:
## Importing Validation data

file_name_1 = 'X_test.npy'
# file_name_2 = 'Y_test.npy'

print('Importing data from drive...')
print('Importing X_test')
upload_data_system(file_name_1)
# print('Importing Y_test')
# upload_data_system(file_name_2)

In [0]:
## Importing Model

file_name = 'model_046.hdf5'

print('Importing Model ' + file_name)
upload_data_system(file_name)

In [0]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import re
from keras.layers import Conv2D, Activation, BatchNormalization, Input, Subtract
from keras.models import Model, load_model
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
import keras.backend as K

%matplotlib inline

In [0]:
X_train = np.load('X_train.npy')
X_test = np.load('X_test.npy')

In [0]:
Y_test = []
for image in X_test:
  noise = np.random.randint(0, 55)
  noisy_image = image + noise
  Y_test.append(noisy_image)
Y_test = np.array(Y_test)

In [0]:
def DnCNN_S(depth):

    inpt = Input(shape = (None, None, 1))
    x = Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'same', kernel_initializer = 'orthogonal')(inpt)
    x = Activation(activation = 'relu')(x)

    for i in range(depth - 2):
        x = Conv2D(filters = 64, kernel_size = (3, 3), strides=(1, 1), padding='same', kernel_initializer='orthogonal')(x)
        x = BatchNormalization()(x)
        x = Activation(activation='relu')(x)

    x = Conv2D(filters = 1, kernel_size = (3, 3), strides = (1, 1), padding = 'same', kernel_initializer = 'orthogonal')(x)
    x = Subtract()([inpt, x])

    model = Model(inputs = inpt, outputs = x)

    print(model.summary())
    return model

In [0]:
def data_generator(epochs, batch_size):

  while True:
    for i in range(0, epochs):
        indices = list(range(X_train.shape[0]))
        np.random.shuffle(indices)
        for j in range(0, X_train.shape[0] // batch_size, batch_size):
            sigma = np.random.randint(0, 55)
            X_train_batch = X_train[indices[j:j+batch_size]]
            X_train_batch = X_train_batch.astype('float32') / 255.0
            noise = np.random.normal(0, sigma / 255.0, X_train_batch.shape)
            Y_train_batch = X_train_batch + noise

            yield Y_train_batch, X_train_batch

def lr_schedule(epoch):

    initial_lr = 0.001

    if epoch <= 30:
        lr = initial_lr
    elif epoch <= 60:
        lr = initial_lr / 10
    elif epoch <= 80:
        lr = initial_lr/20
    else:
        lr = initial_lr/20

    print('current learning rate is %2.8f' %lr)

    return lr

def find_initial_epoch():
    file_list = glob.glob('model_*.hdf5')
    initial_epoch = 0

    if file_list:
        epochs_finished = []
        for file_ in file_list:
            result = re.findall("model_(.*).hdf5.*",file_)
            epochs_finished.append(int(result[0]))
        initial_epoch = max(epochs_finished)

    return initial_epoch

def sum_squared_error(y_true, y_pred):
    #return K.mean(K.square(y_pred - y_true), axis=-1)
    #return K.sum(K.square(y_pred - y_true), axis=-1)/2
    return K.sum(K.square(y_pred - y_true))/2

In [0]:
print("Setting up Model...")
model = DnCNN_S(20)

checkpoint = ModelCheckpoint("model_{epoch:03d}.hdf5", save_weights_only = False, period = 1, verbose = 1)
lr_scheduler = LearningRateScheduler(lr_schedule)

initial_epoch = find_initial_epoch()

if(initial_epoch > 0):
    print('Resuming by loading epoch %03d' % initial_epoch)
    model = load_model('model_%03d.hdf5' % initial_epoch, compile = False)

model.compile(loss = sum_squared_error, optimizer = Adam(0.001))

print("Training Model...")
History = model.fit_generator(data_generator(epochs = 50, batch_size = 64), steps_per_epoch = 2520, epochs = 50,
                              callbacks = [checkpoint, lr_scheduler], initial_epoch = initial_epoch, validation_data = (X_test, Y_test))

plt.plot(History.history['mean_squared_error'])
plt.title('LOSS VS EPOCH')
plt.ylabel('Mean_Squared_Error')
plt.xlabel('Epoch')
plt.show()

Setting up Model...
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, None, None, 1 0                                            
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, None, None, 6 640         input_2[0][0]                    
__________________________________________________________________________________________________
activation_20 (Activation)      (None, None, None, 6 0           conv2d_21[0][0]                  
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, None, None, 6 36928       activation_20[0][0]              
_________________________________________________________________________________________

In [0]:
print(History.history.keys())

In [0]:
plt.plot(History.history['loss'])
plt.plot(History.history['val_loss'])
plt.title('LOSS VS EPOCH')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()