# Read the README [in the repository](https://github.com/jzlotek/cs583-final) before continuing


In [7]:
#
# download_dataset.py
#

import requests
import urllib
import os
import sys
from zipfile import ZipFile

OUTDIR='./dataset'

def download_file_from_google_drive(id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True, timeout=None)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True, timeout=None)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 10000
    downloaded = 0
    total = 26926678016.00 * 1.1 / 100.00
    sys.stdout.write("\r0.00% downloaded")
    sys.stdout.flush()

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)
                downloaded+=len(chunk)
                if downloaded % (CHUNK_SIZE*50) == 0:
                    sys.stdout.write('\r%0.2f%% downloaded' % (downloaded/(total)))
    sys.stdout.write("\r100.00% downloaded")

# Ensure directory exists
if not os.path.exists(OUTDIR):
  os.mkdir(OUTDIR)

filepath = OUTDIR+'/Sony.zip'
# Too much data, download the smaller dataset for now
print('Downloading Sony data... (25GB)')
download_file_from_google_drive('10kpAcvldtcb9G2ze5hTcF1odzu4V_Zvh', filepath)

fileout = OUTDIR+'/Sony'
# Always unzip to reset directory
print('\nUnzipping, this will take a while...')
print('Usage of storage should be around 117GB')
ZipFile(filepath).extractall(path=OUTDIR)

    
print("Data ready in %s" % (fileout))

Downloading Sony data... (25GB)
100.00% downloaded
Unzipping, this will take a while...
Usage of storage should be around 117GB
Data ready in ./dataset/Sony




```
# This is formatted as code
```

Data is downloaded, time to train the model.

In [10]:
!pip install rawpy
!pip install imageio



In [13]:
#
# train_Sony.py
#

from __future__ import division
import os, time, scipy.io
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import imageio
import rawpy
import glob
import sys

# Constant
input_dir = './dataset/Sony/short/'
gt_dir = './dataset/Sony/long/'
checkpoint_dir = './result_Sony/'
result_dir = './result_Sony/'

# get train IDs
#
# Gets all the filenames that match the pattern 
# Chops off the first 5 characters to get the file's "ID"
train_fns = glob.glob(gt_dir + '0*.ARW') 
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

ps = 512  # patch size for training
save_freq = 500

DEBUG = 0
if DEBUG == 1:
    save_freq = 2
    train_ids = train_ids[0:5]

# Activation function
def lrelu(x):
    return tf.maximum(x * 0.2, x)


def upsample_and_concat(x1, x2, output_channels, in_channels):
    pool_size = 2
    deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
    deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])

    deconv_output = tf.concat([deconv, x2], 3)
    deconv_output.set_shape([None, None, None, output_channels * 2])

    return deconv_output

# Defines the structure of the model and handles running and input through it
def network(input):
    conv1 = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_1')
    conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_2')
    pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')

    conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_1')
    conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_2')
    pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')

    conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_1')
    conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_2')
    pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')

    conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_1')
    conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_2')
    pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')

    conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv5_1')
    conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv5_2')

    up6 = upsample_and_concat(conv5, conv4, 256, 512)
    conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_1')
    conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_2')

    up7 = upsample_and_concat(conv6, conv3, 128, 256)
    conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_1')
    conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_2')

    up8 = upsample_and_concat(conv7, conv2, 64, 128)
    conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_1')
    conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_2')

    up9 = upsample_and_concat(conv8, conv1, 32, 64)
    conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_1')
    conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_2')

    conv10 = slim.conv2d(conv9, 12, [1, 1], rate=1, activation_fn=None, scope='g_conv10')
    out = tf.depth_to_space(conv10, 2)
    return out


def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

# Reset current tf globals
tf.reset_default_graph()

sess = tf.Session()
in_image = tf.placeholder(tf.float32, [None, None, None, 4])
gt_image = tf.placeholder(tf.float32, [None, None, None, 3])
out_image = network(in_image)

G_loss = tf.reduce_mean(tf.abs(out_image - gt_image))

t_vars = tf.trainable_variables()
lr = tf.placeholder(tf.float32)
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)

saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
    print('loaded ' + ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

g_loss = np.zeros((5000, 1))

allfolders = glob.glob('./result/*0')
lastepoch = 0
for folder in allfolders:
    lastepoch = np.maximum(lastepoch, int(folder[-4:]))

# Batching
total_epochs = 11 # 4001 was used in the paper, but takes a while
current_batch = 0
batch_size = 25 # Uses < 8gb of RAM
shuffled_ids = np.random.permutation(len(train_ids))

while (current_batch+1) * batch_size - 1 < len(shuffled_ids) :
    batch_ids = shuffled_ids[current_batch*batch_size:(current_batch+1)*batch_size-1]
    current_batch += 1
    learning_rate = 1e-4
   
    # Hold batch in memory
    gt_images = [None] * 6000
    input_images = {}
    input_images['300'] = [None] * len(train_ids)
    input_images['250'] = [None] * len(train_ids)
    input_images['100'] = [None] * len(train_ids)

    for epoch in range(lastepoch, total_epochs): 
        if os.path.isdir("result/%04d" % epoch):
            continue
        cnt = 0
        if epoch > (total_epochs // 2): # Decrease learning rate halfway through
            learning_rate = 1e-5      

        for ind in np.random.permutation(batch_ids):
            # get the path from image id
            train_id = train_ids[ind]
            in_files = glob.glob(input_dir + '%05d_00*.ARW' % train_id)
            in_path = None
            if len(in_files) <= 1:
                in_path = in_files[0]
            else :
                in_path = in_files[np.random.randint(0, len(in_files) - 1)]
            in_fn = os.path.basename(in_path)

            gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % train_id)
            gt_path = gt_files[0]
            gt_fn = os.path.basename(gt_path)
            in_exposure = float(in_fn[9:-5])
            gt_exposure = float(gt_fn[9:-5])
            ratio = min(gt_exposure / in_exposure, 300)

            st = time.time()
            cnt += 1

            if input_images[str(ratio)[0:3]][ind] is None:
                raw = rawpy.imread(in_path)
                input_images[str(ratio)[0:3]][ind] = \
                  np.expand_dims(pack_raw(raw), axis=0) * ratio

                gt_raw = rawpy.imread(gt_path)
                im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
                gt_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)

            # crop
            H = input_images[str(ratio)[0:3]][ind].shape[1]
            W = input_images[str(ratio)[0:3]][ind].shape[2]

            xx = np.random.randint(0, W - ps)
            yy = np.random.randint(0, H - ps)
            input_patch = input_images[str(ratio)[0:3]][ind][:, yy:yy + ps, xx:xx + ps, :]
            gt_patch = gt_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]
        
            if np.random.randint(2, size=1)[0] == 1:  # random flip
                input_patch = np.flip(input_patch, axis=1)
                gt_patch = np.flip(gt_patch, axis=1)
            if np.random.randint(2, size=1)[0] == 1:
                input_patch = np.flip(input_patch, axis=2)
                gt_patch = np.flip(gt_patch, axis=2)
            if np.random.randint(2, size=1)[0] == 1:  # random transpose
                input_patch = np.transpose(input_patch, (0, 2, 1, 3))
                gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))

            input_patch = np.minimum(input_patch, 1.0)

            _, G_current, output = sess.run([G_opt, G_loss, out_image],
                                        feed_dict={in_image: input_patch, gt_image: gt_patch, lr: learning_rate})
            output = np.minimum(np.maximum(output, 0), 1)
            g_loss[ind] = G_current

            sys.stdout.write("\rBatch=%-4d Epoch=%-4d Count=%-4d Loss=%-6.3f Time=%-6.3f" % (current_batch-1, epoch, cnt, np.mean(g_loss[np.where(g_loss)]), time.time() - st))

            if epoch % save_freq == 0:
                if not os.path.isdir(result_dir + '%04d' % epoch):
                    os.makedirs(result_dir + '%04d' % epoch)

                temp = np.concatenate((gt_patch[0, :, :, :], output[0, :, :, :]), axis=1)
                temp *= 255
                temp = scipy.misc.bytescale(temp, cmin=0, cmax=255, high=255, low=0)

                imageio.imwrite(result_dir + '%04d/%05d_00_train_%d.jpg' % (epoch, train_id, ratio), np.array(temp, dtype=np.uint8))

    saver.save(sess, checkpoint_dir + 'model.ckpt')




Batch=0    Epoch=0    Count=1    Loss=0.272  Time=3.071 



Batch=0    Epoch=0    Count=2    Loss=0.325  Time=2.372 



Batch=0    Epoch=0    Count=3    Loss=0.277  Time=2.863 



Batch=0    Epoch=0    Count=4    Loss=0.273  Time=2.772 



Batch=0    Epoch=0    Count=5    Loss=0.321  Time=2.756 



Batch=0    Epoch=0    Count=6    Loss=0.314  Time=2.467 



Batch=0    Epoch=0    Count=7    Loss=0.292  Time=2.430 



Batch=0    Epoch=0    Count=8    Loss=0.270  Time=2.424 



Batch=0    Epoch=0    Count=9    Loss=0.262  Time=2.505 



Batch=0    Epoch=0    Count=10   Loss=0.252  Time=2.390 



Batch=0    Epoch=0    Count=11   Loss=0.239  Time=2.482 



Batch=0    Epoch=0    Count=12   Loss=0.237  Time=2.464 



Batch=0    Epoch=0    Count=13   Loss=0.252  Time=2.484 



Batch=0    Epoch=0    Count=14   Loss=0.241  Time=2.592 



Batch=0    Epoch=0    Count=15   Loss=0.236  Time=2.469 



Batch=0    Epoch=0    Count=16   Loss=0.237  Time=2.551 



Batch=0    Epoch=0    Count=17   Loss=0.230  Time=2.435 



Batch=0    Epoch=0    Count=18   Loss=0.231  Time=2.469 



Batch=0    Epoch=0    Count=19   Loss=0.228  Time=2.406 



Batch=0    Epoch=0    Count=20   Loss=0.221  Time=2.525 



Batch=0    Epoch=0    Count=21   Loss=0.220  Time=2.455 



Batch=0    Epoch=0    Count=22   Loss=0.230  Time=2.373 



Batch=0    Epoch=0    Count=23   Loss=0.225  Time=2.422 



Batch=0    Epoch=10   Count=24   Loss=0.076  Time=0.137 



Batch=1    Epoch=0    Count=1    Loss=0.078  Time=2.418 



Batch=1    Epoch=0    Count=2    Loss=0.077  Time=2.364 



Batch=1    Epoch=0    Count=3    Loss=0.079  Time=2.559 



Batch=1    Epoch=0    Count=4    Loss=0.078  Time=2.388 



Batch=1    Epoch=0    Count=5    Loss=0.078  Time=2.365 



Batch=1    Epoch=0    Count=6    Loss=0.079  Time=2.354 



Batch=1    Epoch=0    Count=7    Loss=0.079  Time=2.422 



Batch=1    Epoch=0    Count=8    Loss=0.080  Time=2.765 



Batch=1    Epoch=0    Count=9    Loss=0.079  Time=2.678 



Batch=1    Epoch=0    Count=10   Loss=0.082  Time=2.735 



Batch=1    Epoch=0    Count=11   Loss=0.082  Time=2.377 



Batch=1    Epoch=0    Count=12   Loss=0.082  Time=2.373 



Batch=1    Epoch=0    Count=13   Loss=0.082  Time=2.420 



Batch=1    Epoch=0    Count=14   Loss=0.082  Time=2.340 



Batch=1    Epoch=0    Count=15   Loss=0.081  Time=2.470 



Batch=1    Epoch=0    Count=16   Loss=0.081  Time=2.373 



Batch=1    Epoch=0    Count=17   Loss=0.080  Time=2.894 



Batch=1    Epoch=0    Count=18   Loss=0.080  Time=2.702 



Batch=1    Epoch=0    Count=19   Loss=0.080  Time=2.535 



Batch=1    Epoch=0    Count=20   Loss=0.079  Time=2.380 



Batch=1    Epoch=0    Count=21   Loss=0.079  Time=2.417 



Batch=1    Epoch=0    Count=22   Loss=0.081  Time=2.397 



Batch=1    Epoch=0    Count=23   Loss=0.081  Time=2.652 



Batch=1    Epoch=10   Count=24   Loss=0.073  Time=0.140 



Batch=2    Epoch=0    Count=1    Loss=0.072  Time=2.489 



Batch=2    Epoch=0    Count=2    Loss=0.074  Time=2.430 



Batch=2    Epoch=0    Count=3    Loss=0.073  Time=2.392 



Batch=2    Epoch=0    Count=4    Loss=0.073  Time=2.343 



Batch=2    Epoch=0    Count=5    Loss=0.073  Time=2.347 



Batch=2    Epoch=0    Count=6    Loss=0.073  Time=2.418 



Batch=2    Epoch=0    Count=7    Loss=0.072  Time=2.408 



Batch=2    Epoch=0    Count=8    Loss=0.072  Time=2.373 



Batch=2    Epoch=0    Count=9    Loss=0.072  Time=2.342 



Batch=2    Epoch=0    Count=10   Loss=0.072  Time=2.367 



Batch=2    Epoch=0    Count=11   Loss=0.072  Time=2.382 



Batch=2    Epoch=0    Count=12   Loss=0.072  Time=2.376 



Batch=2    Epoch=0    Count=13   Loss=0.072  Time=2.502 



Batch=2    Epoch=0    Count=14   Loss=0.072  Time=2.707 



Batch=2    Epoch=0    Count=15   Loss=0.073  Time=2.753 



Batch=2    Epoch=0    Count=16   Loss=0.073  Time=2.722 



Batch=2    Epoch=0    Count=17   Loss=0.073  Time=2.384 



Batch=2    Epoch=0    Count=18   Loss=0.073  Time=2.344 



Batch=2    Epoch=0    Count=19   Loss=0.074  Time=2.424 



Batch=2    Epoch=0    Count=20   Loss=0.075  Time=2.324 



Batch=2    Epoch=0    Count=21   Loss=0.075  Time=2.420 



Batch=2    Epoch=0    Count=22   Loss=0.075  Time=2.392 



Batch=2    Epoch=0    Count=23   Loss=0.075  Time=2.481 



Batch=2    Epoch=10   Count=24   Loss=0.070  Time=0.133 



Batch=3    Epoch=0    Count=1    Loss=0.069  Time=2.409 



Batch=3    Epoch=0    Count=2    Loss=0.069  Time=2.543 



Batch=3    Epoch=0    Count=3    Loss=0.069  Time=2.725 



Batch=3    Epoch=0    Count=4    Loss=0.069  Time=2.720 



Batch=3    Epoch=0    Count=5    Loss=0.069  Time=2.690 



Batch=3    Epoch=0    Count=6    Loss=0.069  Time=2.367 



Batch=3    Epoch=0    Count=7    Loss=0.068  Time=2.330 



Batch=3    Epoch=0    Count=8    Loss=0.068  Time=2.353 



Batch=3    Epoch=0    Count=9    Loss=0.068  Time=2.353 



Batch=3    Epoch=0    Count=10   Loss=0.067  Time=2.271 



Batch=3    Epoch=0    Count=11   Loss=0.067  Time=2.359 



Batch=3    Epoch=0    Count=12   Loss=0.067  Time=2.349 



Batch=3    Epoch=0    Count=13   Loss=0.066  Time=2.409 



Batch=3    Epoch=0    Count=14   Loss=0.066  Time=2.340 



Batch=3    Epoch=0    Count=15   Loss=0.066  Time=2.428 



Batch=3    Epoch=0    Count=16   Loss=0.066  Time=2.316 



Batch=3    Epoch=0    Count=17   Loss=0.066  Time=2.476 



Batch=3    Epoch=0    Count=18   Loss=0.066  Time=2.369 



Batch=3    Epoch=0    Count=19   Loss=0.067  Time=2.446 



Batch=3    Epoch=0    Count=20   Loss=0.067  Time=2.322 



Batch=3    Epoch=0    Count=21   Loss=0.067  Time=2.430 



Batch=3    Epoch=0    Count=22   Loss=0.067  Time=2.315 



Batch=3    Epoch=0    Count=23   Loss=0.067  Time=2.438 



Batch=3    Epoch=10   Count=24   Loss=0.066  Time=0.131 



Batch=4    Epoch=0    Count=1    Loss=0.066  Time=2.387 



Batch=4    Epoch=0    Count=2    Loss=0.066  Time=2.344 



Batch=4    Epoch=0    Count=3    Loss=0.066  Time=2.676 



Batch=4    Epoch=0    Count=4    Loss=0.066  Time=2.726 



Batch=4    Epoch=0    Count=5    Loss=0.066  Time=2.719 



Batch=4    Epoch=0    Count=6    Loss=0.066  Time=2.483 



Batch=4    Epoch=0    Count=7    Loss=0.066  Time=2.321 



Batch=4    Epoch=0    Count=8    Loss=0.066  Time=2.369 



Batch=4    Epoch=0    Count=9    Loss=0.066  Time=2.414 



Batch=4    Epoch=0    Count=10   Loss=0.066  Time=2.602 



Batch=4    Epoch=0    Count=11   Loss=0.066  Time=2.717 



Batch=4    Epoch=0    Count=12   Loss=0.066  Time=2.310 



Batch=4    Epoch=0    Count=13   Loss=0.066  Time=2.398 



Batch=4    Epoch=0    Count=14   Loss=0.065  Time=2.452 



Batch=4    Epoch=0    Count=15   Loss=0.065  Time=2.347 



Batch=4    Epoch=0    Count=16   Loss=0.065  Time=2.446 



Batch=4    Epoch=0    Count=17   Loss=0.065  Time=2.241 



Batch=4    Epoch=0    Count=18   Loss=0.065  Time=2.375 



Batch=4    Epoch=0    Count=19   Loss=0.065  Time=2.434 



Batch=4    Epoch=0    Count=20   Loss=0.066  Time=2.221 



Batch=4    Epoch=0    Count=21   Loss=0.066  Time=2.441 



Batch=4    Epoch=0    Count=22   Loss=0.066  Time=2.296 



Batch=4    Epoch=0    Count=23   Loss=0.066  Time=2.774 



Batch=4    Epoch=10   Count=24   Loss=0.063  Time=0.123 



Batch=5    Epoch=0    Count=1    Loss=0.063  Time=2.381 



Batch=5    Epoch=0    Count=2    Loss=0.063  Time=2.327 



Batch=5    Epoch=0    Count=3    Loss=0.063  Time=2.431 



Batch=5    Epoch=0    Count=4    Loss=0.064  Time=2.311 



Batch=5    Epoch=0    Count=5    Loss=0.064  Time=2.378 



Batch=5    Epoch=0    Count=6    Loss=0.065  Time=2.363 



Batch=5    Epoch=0    Count=7    Loss=0.064  Time=2.293 



Batch=5    Epoch=0    Count=8    Loss=0.065  Time=2.348 



Batch=5    Epoch=0    Count=9    Loss=0.065  Time=2.331 



Batch=5    Epoch=0    Count=10   Loss=0.065  Time=2.325 



Batch=5    Epoch=0    Count=11   Loss=0.065  Time=2.376 



Batch=5    Epoch=0    Count=12   Loss=0.066  Time=2.475 



Batch=5    Epoch=0    Count=13   Loss=0.065  Time=2.692 



Batch=5    Epoch=0    Count=14   Loss=0.065  Time=2.685 



Batch=5    Epoch=0    Count=15   Loss=0.065  Time=2.749 



Batch=5    Epoch=0    Count=16   Loss=0.065  Time=2.247 



Batch=5    Epoch=0    Count=17   Loss=0.065  Time=2.422 



Batch=5    Epoch=0    Count=18   Loss=0.065  Time=2.346 



Batch=5    Epoch=0    Count=19   Loss=0.065  Time=2.386 



Batch=5    Epoch=0    Count=20   Loss=0.065  Time=2.322 



Batch=5    Epoch=0    Count=21   Loss=0.065  Time=2.371 



Batch=5    Epoch=0    Count=22   Loss=0.066  Time=2.465 



Batch=5    Epoch=0    Count=23   Loss=0.066  Time=2.439 



Batch=5    Epoch=10   Count=24   Loss=0.065  Time=0.133 

Now that the model is built into result_Sony, we can test it.

In [14]:
#
# test_Sony.py
#

from __future__ import division
import os, scipy.io
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import imageio
import rawpy
import glob

input_dir = './dataset/Sony/short/'
gt_dir = './dataset/Sony/long/'
checkpoint_dir = './result_Sony/'
result_dir = './result_Sony/'

# get test IDs
test_fns = glob.glob(gt_dir + '/1*.ARW')
test_ids = [int(os.path.basename(test_fn)[0:5]) for test_fn in test_fns]

DEBUG = 0
if DEBUG == 1:
    save_freq = 2
    test_ids = test_ids[0:5]


def lrelu(x):
    return tf.maximum(x * 0.2, x)


def upsample_and_concat(x1, x2, output_channels, in_channels):
    pool_size = 2
    deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
    deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])

    deconv_output = tf.concat([deconv, x2], 3)
    deconv_output.set_shape([None, None, None, output_channels * 2])

    return deconv_output


def network(input):
    conv1 = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_1')
    conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_2')
    pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')

    conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_1')
    conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_2')
    pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')

    conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_1')
    conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_2')
    pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')

    conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_1')
    conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_2')
    pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')

    conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv5_1')
    conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv5_2')

    up6 = upsample_and_concat(conv5, conv4, 256, 512)
    conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_1')
    conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_2')

    up7 = upsample_and_concat(conv6, conv3, 128, 256)
    conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_1')
    conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_2')

    up8 = upsample_and_concat(conv7, conv2, 64, 128)
    conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_1')
    conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_2')

    up9 = upsample_and_concat(conv8, conv1, 32, 64)
    conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_1')
    conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_2')

    conv10 = slim.conv2d(conv9, 12, [1, 1], rate=1, activation_fn=None, scope='g_conv10')
    out = tf.depth_to_space(conv10, 2)
    return out


def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

tf.reset_default_graph()
sess = tf.Session()
in_image = tf.placeholder(tf.float32, [None, None, None, 4])
gt_image = tf.placeholder(tf.float32, [None, None, None, 3])
out_image = network(in_image)

saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
    print('loaded ' + ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

if not os.path.isdir(result_dir + 'final/'):
    os.makedirs(result_dir + 'final/')

for test_id in test_ids:
    # test the first image in each sequence
    in_files = glob.glob(input_dir + '%05d_00*.ARW' % test_id)
    for k in range(len(in_files)):
        in_path = in_files[k]
        in_fn = os.path.basename(in_path)
        print(in_fn)
        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % test_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        raw = rawpy.imread(in_path)
        input_full = np.expand_dims(pack_raw(raw), axis=0) * ratio

        im = raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        # scale_full = np.expand_dims(np.float32(im/65535.0),axis = 0)*ratio
        scale_full = np.expand_dims(np.float32(im / 65535.0), axis=0)

        gt_raw = rawpy.imread(gt_path)
        im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        gt_full = np.expand_dims(np.float32(im / 65535.0), axis=0)

        input_full = np.minimum(input_full, 1.0)

        output = sess.run(out_image, feed_dict={in_image: input_full})
        output = np.minimum(np.maximum(output, 0), 1)

        output = output[0, :, :, :]
        gt_full = gt_full[0, :, :, :]
        scale_full = scale_full[0, :, :, :]
        scale_full = scale_full * np.mean(gt_full) / np.mean(
            scale_full)  # scale the low-light image to the same mean of the groundtruth
        
        output *= 255
        output = scipy.misc.bytescale(output, cmin=0, cmax=255, high=255, low=0)
        scale_full *= 255
        scale_full = scipy.misc.bytescale(scale_full, cmin=0, cmax=255, high=255, low=0)
        gt_full *= 255
        gt_full = scipy.misc.bytescale(gt_full, cmin=0, cmax=255, high=255, low=0)
        
        imageio.imwrite(result_dir + 'final/%5d_00_%d_out.png' % (test_id, ratio), np.array(output, dtype=np.uint8))
        imageio.imwrite(result_dir + 'final/%5d_00_%d_scale.png' % (test_id, ratio), np.array(scale_full, dtype=np.uint8))
        imageio.imwrite(result_dir + 'final/%5d_00_%d_gt.png' % (test_id, ratio), np.array(gt_full, dtype=np.uint8))


loaded ./result_Sony/model.ckpt
Instructions for updating:
Use standard file APIs to check for files with this prefix.


Instructions for updating:
Use standard file APIs to check for files with this prefix.


INFO:tensorflow:Restoring parameters from ./result_Sony/model.ckpt


INFO:tensorflow:Restoring parameters from ./result_Sony/model.ckpt


10077_00_0.1s.ARW
10217_00_0.033s.ARW
10217_00_0.04s.ARW
10217_00_0.1s.ARW
10034_00_0.04s.ARW
10034_00_0.1s.ARW
10163_00_0.1s.ARW
10170_00_0.1s.ARW
10185_00_0.04s.ARW
10185_00_0.1s.ARW
10185_00_0.033s.ARW
10192_00_0.1s.ARW
10192_00_0.04s.ARW
10192_00_0.033s.ARW
10030_00_0.04s.ARW
10030_00_0.1s.ARW
10193_00_0.1s.ARW
10193_00_0.04s.ARW
10193_00_0.033s.ARW
10003_00_0.04s.ARW
10003_00_0.1s.ARW
10093_00_0.1s.ARW
10069_00_0.04s.ARW
10069_00_0.1s.ARW
10191_00_0.033s.ARW
10191_00_0.1s.ARW
10191_00_0.04s.ARW
10213_00_0.04s.ARW
10213_00_0.1s.ARW
10213_00_0.033s.ARW
10032_00_0.04s.ARW
10032_00_0.1s.ARW
10106_00_0.1s.ARW
10040_00_0.1s.ARW
10040_00_0.04s.ARW
10068_00_0.04s.ARW
10068_00_0.1s.ARW
10011_00_0.04s.ARW
10011_00_0.1s.ARW
10035_00_0.1s.ARW
10035_00_0.04s.ARW
10140_00_0.1s.ARW
10187_00_0.1s.ARW
10187_00_0.04s.ARW
10187_00_0.033s.ARW
10105_00_0.1s.ARW
10172_00_0.1s.ARW
10101_00_0.1s.ARW
10139_00_0.1s.ARW
10198_00_0.04s.ARW
10198_00_0.1s.ARW
10198_00_0.033s.ARW
10176_00_0.1s.ARW
10087_00_0.1s

The tests have run and output images into result_Sony/final as well as the model into result_Sony. Lets zip and download a few to check are results. Fair warning, the file download is buggy and will not play nice if you are not using chrome. 

If you are having issues, restart the runtime and refresh the page.

In [15]:
!zip model.zip result_Sony/*
!zip results_10191.zip result_Sony/final/10191_00_3*
!zip results_10187.zip result_Sony/final/10187_00_3*
!zip results_00001.zip result_Sony/final/10016_00_3*

  adding: result_Sony/0000/ (stored 0%)
  adding: result_Sony/checkpoint (deflated 42%)
  adding: result_Sony/final/ (stored 0%)
  adding: result_Sony/model.ckpt.data-00000-of-00001 (deflated 7%)
  adding: result_Sony/model.ckpt.index (deflated 63%)
  adding: result_Sony/model.ckpt.meta (deflated 91%)
  adding: result_Sony/final/10191_00_300_gt.png (deflated 0%)
  adding: result_Sony/final/10191_00_300_out.png (deflated 0%)
  adding: result_Sony/final/10191_00_300_scale.png (deflated 0%)
  adding: result_Sony/final/10187_00_300_gt.png (deflated 0%)
  adding: result_Sony/final/10187_00_300_out.png (deflated 0%)
  adding: result_Sony/final/10187_00_300_scale.png (deflated 0%)

zip error: Nothing to do! (results_00001.zip)


In [24]:
!ls -al dataset/Sony/short/00001_0*

-rw-r--r-- 1 root root 24707072 Jun  3 02:47 dataset/Sony/short/00001_00_0.04s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_00_0.1s.ARW
-rw-r--r-- 1 root root 24707072 Jun  3 02:47 dataset/Sony/short/00001_01_0.04s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_01_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_02_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_03_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_04_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_05_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_06_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_07_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_08_0.1s.ARW
-rw-r--r-- 1 root root 24772608 Jun  3 02:47 dataset/Sony/short/00001_09_0.1s.ARW


In [0]:
from google.colab import files

files.download('model.zip')

In [0]:
files.download('results_10191.zip')

In [0]:
files.download('results_10187.zip')

In [0]:
files.download('results_10016.zip')

In [0]:
files.download('dataset/Sony/short/00001_01_0.04s.ARW')