# ImageNet

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import urllib
import sys

import skimage.io as io
import tensorflow as tf

from utils import read_imagenet_data, add_noise, normalize

In [None]:
# reload(sys)
# sys.setdefaultencoding('utf8')

def store_raw_images(link, save_path, im_size = (128, 128)):
    images_link = link
    image_urls = urllib.request.urlopen(images_link).read().decode()
    
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    pic_num = 1
    for i in image_urls.split('\n'):
        try:
            if not os.path.exists(os.path.join(save_path, str(pic_num)+".jpg")):
                print(pic_num, i)
                urllib.request.urlretrieve(i, os.path.join(save_path, str(pic_num)+".jpg"))
                img = cv2.imread(os.path.join(save_path, str(pic_num)+".jpg"), cv2.IMREAD_GRAYSCALE)
                # should be larger than samples / pos pic (so we can place our image on it)
                resized_image = cv2.resize(img, im_size)
                cv2.imwrite(os.path.join(save_path, str(pic_num)+".jpg"), resized_image)
            pic_num += 1

        except Exception as e:
                print(str(e))  
    print("Total", pic_num, "images loaded successfully")

In [None]:
link = r'http://image-net.org/api/text/imagenet.synset.geturls?wnid=n01317541'
save_path = r'./images/animals'
store_raw_images(link, save_path)

In [None]:
link = r'http://image-net.org/api/text/imagenet.synset.geturls?wnid=n00017222'
save_path = r'./images/plants'
store_raw_images(link, save_path)

In [None]:
link = r'http://image-net.org/api/text/imagenet.synset.geturls?wnid=n00021939'
save_path = r'./images/artifacts'
store_raw_images(link, save_path)

In [None]:
root_path = r'../images'
ims = read_imagenet_data(root_path)
ims = normalize(ims[:,:,:,np.newaxis])
print(ims.shape, ims.dtype, ims.max(), ims.min())

In [None]:
ims_noise = add_noise(ims, mean=0, var=1e-3, n_type='gaussian')
#ims_noise = normalize(ims_noise)
print(ims_noise.shape, ims_noise.dtype, ims_noise.max(), ims_noise.min())

In [None]:
def error(x1, x2, mode='mse'):
    if mode == 'mse':
        return np.mean(np.square(x1-x2))
    elif mode == 'mae':
        return np.mean(np.abs(x1-x2))
    return

def psnr(x1, x2):
    return tf.reduce_mean(tf.image.psnr(x1, x2, max_val=1)).numpy()

print(error(ims_noise, ims, 'mae'))
print(error(ims_noise, ims, 'mse'))
print(psnr(ims_noise, ims))

In [None]:
if __name__ == '__main__':
    N_show = 3

    plt.figure(figsize = (10*2,5*N_show))
    for i in range(N_show):
        plt.subplot(N_show,2,2*i+1)
        plt.imshow(ims[i].squeeze(),cmap='gray')
        plt.axis('off')

        plt.subplot(N_show,2,2*i+2)
        plt.imshow(ims_noise[i].squeeze(),cmap='gray')
        plt.axis('off')
    plt.show()

# DnCNN data