In [1]:
from sargan_models import SARGAN
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from sar_utilities import add_gaussian_noise
from cifar_helper import get_data, chunks
import random
import skimage.measure as ski_me

In [2]:
#get test batches from cifar dataset
test_filename = "test_batch"
test_imgs, test_classes = get_data(test_filename)
test_imgs = test_imgs.astype('float32')
BATCH_SIZE = 50
test_batches = np.array(list(chunks(test_imgs, BATCH_SIZE)))
# test_batches = test_imgs
NOISE_STD = 0.1

# params to recover the trained model
img_size = (32,32,3)
DATA_PATH = "/scratch/hle/data/"
model_trained_epoch = 1
model_path_test = os.path.join(DATA_PATH,
                               "trained_models/cifar_10_gaussian_corrupted/",
                               "cifar_10_gaussian_corrupted_model_{}.ckpt".format(model_trained_epoch))
model = SARGAN(img_size, BATCH_SIZE, img_channel=3)
saver = tf.train.Saver()

# tensorflow config
gpu = "3"
gpu_options = tf.GPUOptions(allow_growth=True, visible_device_list=gpu)
config=tf.ConfigProto(gpu_options=gpu_options)

# some display of original, corrupted and recovered
# fig=plt.figure(figsize=(8, 100))
# columns = 3
# rows = 50

# error calculation
sum_psnr = 0
with tf.Session(config=config) as sess:
    saver.restore(sess, model_path_test)
    for i in range(200):
        #get 1 batch of size BATCH_SIZE and then corrupt it
        batch = test_batches[i]
        print("batch range %s, %s" % (np.amax(batch), np.amin(batch)))
        corrupted_batch = np.array([add_gaussian_noise(image, sd=NOISE_STD) for image in batch])
        print("corrupted batch range %s, %s" % (np.amax(corrupted_batch), np.amin(corrupted_batch)))
        gen_imgs = sess.run(model.gen_img, feed_dict={model.image:batch, model.cond:corrupted_batch})
        print("recovered batch range %s, %s" % (np.amax(gen_imgs), np.amin(gen_imgs)))
#         j = 1
        for k in range(len(batch)):
#             fig.add_subplot(rows, columns, j)
            img = batch[k]
#             plt.imshow(img)
#             j+=1
#             fig.add_subplot(rows, columns, j)
#             corrupted_img = corrupted_batch[k]
#             plt.imshow(corrupted_img)
#             j+=1
#             fig.add_subplot(rows, columns, j)
            recovered_img = gen_imgs[k]
#             plt.imshow(recovered_img)
#             j+=1
            sum_psnr += ski_me.compare_psnr(img, recovered_img)
            
    plt.show()
    print(sum_psnr/10000)

                

Loading data: /scratch/hle/cifar-10-batches-py/test_batch
INFO:tensorflow:Restoring parameters from /scratch/hle/data/trained_models/cifar_10_gaussian_corrupted/cifar_10_gaussian_corrupted_model_1.ckpt
batch range 1.0, 0.0
corrupted batch range 1.31610079622, -0.340876273604
recovered batch range 1.29231, -0.173034
batch range 1.0, 0.0
corrupted batch range 1.36744246215, -0.30796166186
recovered batch range 1.30428, -0.0624188
batch range 1.0, 0.0
corrupted batch range 1.37450754834, -0.351406407982
recovered batch range 1.23362, -0.153712
batch range 1.0, 0.0
corrupted batch range 1.3394379686, -0.375399637162
recovered batch range 1.23425, -0.0991922
batch range 1.0, 0.0
corrupted batch range 1.37437590501, -0.306593283681
recovered batch range 1.21276, -0.0761571
batch range 1.0, 0.0
corrupted batch range 1.38590277736, -0.366631001712
recovered batch range 1.3419, -0.121686
batch range 1.0, 0.0
corrupted batch range 1.29749136138, -0.337660410235
recovered batch range 1.14086, -0.

recovered batch range 1.20463, -0.0899133
batch range 1.0, 0.0
corrupted batch range 1.41114218205, -0.278628347988
recovered batch range 1.18003, -0.0580806
batch range 1.0, 0.0
corrupted batch range 1.3761543331, -0.334248219215
recovered batch range 1.19872, -0.0966075
batch range 1.0, 0.0
corrupted batch range 1.34884797055, -0.329398942328
recovered batch range 1.21492, -0.154888
batch range 1.0, 0.0
corrupted batch range 1.34200211611, -0.278087847701
recovered batch range 1.27202, -0.127794
batch range 1.0, 0.0
corrupted batch range 1.38352083306, -0.390971290348
recovered batch range 1.23998, -0.113122
batch range 1.0, 0.0
corrupted batch range 1.37233380328, -0.323208630365
recovered batch range 1.24186, -0.0571087
batch range 1.0, 0.0
corrupted batch range 1.35514051474, -0.33778471451
recovered batch range 1.19607, -0.101265
batch range 1.0, 0.0
corrupted batch range 1.35889213885, -0.308183784202
recovered batch range 1.21188, -0.0871722
batch range 1.0, 0.0
corrupted batch

recovered batch range 1.14771, -0.124005
batch range 1.0, 0.0
corrupted batch range 1.40903042802, -0.320928619787
recovered batch range 1.21197, -0.182687
batch range 1.0, 0.0
corrupted batch range 1.32076256244, -0.292070401122
recovered batch range 1.24847, -0.12262
batch range 1.0, 0.0
corrupted batch range 1.32641104301, -0.370664769834
recovered batch range 1.31115, -0.135955
batch range 1.0, 0.0
corrupted batch range 1.32180313577, -0.389910583033
recovered batch range 1.29853, -0.125813
batch range 1.0, 0.0
corrupted batch range 1.31421433607, -0.31692957923
recovered batch range 1.31032, -0.107304
batch range 1.0, 0.0
corrupted batch range 1.36523652581, -0.351716162714
recovered batch range 1.24568, -0.217284
batch range 1.0, 0.0
corrupted batch range 1.36926963575, -0.334044301144
recovered batch range 1.20536, -0.138314
batch range 1.0, 0.0
corrupted batch range 1.36212021382, -0.358366213018
recovered batch range 1.37131, -0.181312
batch range 1.0, 0.0
corrupted batch rang