In [1]:
import torch
import numpy as np
import numba as nb
import cv2
import pandas as pd
import os
import requests
from bs4 import BeautifulSoup
import tensorflow.keras as keras
import tensorflow as tf
import json
import random
from tensorflow.python.data.experimental import AUTOTUNE

In [2]:
train_dir = '../../Data/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '../../Data/DIV2K_valid_HR/DIV2K_valid_HR/'
def prepare_data(data_dir=train_dir):
    sizes = []
    for fl in os.listdir(data_dir):
        img = cv2.imread(os.path.join(train_dir, fl))
        if img.shape not in sizes:
            sizes.append(img.shape)
    return sizes

def download_from_wallhaven(url = r'https://wallhaven.cc/api/v1/search?categories=111&purity=111&atleast=1920x1080&sorting=views&order=desc&apikey=VIkPSqUeNeL2Q5PZh6FAWWW4aSz3IMYD'
, from_page=1, output_dir='../../Data/Wallhaven/', total_images=1000):
    if os.path.exists(output_dir) == False:
        os.makedirs(output_dir)
    dloaded = 0
    while dloaded < total_images:
        _url = url+r'&page='+str(from_page)
        from_page += 1
        page = requests.get(_url)
        b = json.loads(page.content)
        if b == None or len(b) == 0:
            raise Exception("found page len to be 0")
        
        for i in range(len(b['data'])):
            if dloaded >= total_images:
                break
            url_to_dload = b['data'][i]['path'] 
            print(url_to_dload)
            resp = requests.get(url_to_dload, allow_redirects=True, stream=True).raw
            img = np.asarray(bytearray(resp.read()), dtype="uint8")
            image_name = url_to_dload.split('/')[-1]
            img = cv2.imdecode(img, cv2.IMREAD_COLOR)
            cv2.imwrite(output_dir+image_name, img)
            dloaded += 1
    
def prepare_csv(csv_name='train.csv', folders=[], seed=42):
    fs = []
    for folder in folders:
        for file in os.listdir(folder):
            fl = os.path.join(folder, file)
            if fl not in fs:
                fs.append(fl)
    random.seed(seed)
    random.shuffle(fs)
    data = {'path': fs}
    df = pd.DataFrame(data=data)
    df.to_csv(csv_name)
# download_from_wallhaven(from_page=1, output_dir='../../Data/Wallhaven/', total_images=1000)    
# download_from_wallhaven(from_page=1000, output_dir='../../Data/Wallhaven_test/', total_images=100)
# prepare_csv('train.csv', folders=['../../Data/Wallhaven/', '../../Data/DIV2K_train_HR/DIV2K_train_HR/'])
# prepare_csv('test.csv', folders=['../../Data/Wallhaven_test/', '../../Data/DIV2K_valid_HR/DIV2K_valid_HR/'])

In [3]:
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
  print(tf.config.experimental.get_memory_usage('GPU:0'))


Instructions for updating:
Use tf.config.experimental.get_memory_info(device)['current'] instead.
0


In [4]:
def pixel_shuffle(scale):
    return lambda x: tf.nn.depth_to_space(x, scale)

def resnet_block(x_in, filters=64, kernel=3, scaling=.1):
    x = keras.layers.Conv2D(filters, kernel, padding='same')(x_in)
    x = keras.layers.Conv2D(filters, kernel, padding='same')(x)
    if scaling != 0:
        x = keras.layers.Lambda(lambda x: x * scaling)(x)
    x = keras.layers.Add()([x_in, x])
    
def up_sample(x, filters=64, kernel=3, scaling=1):
    factor = 2
    name='scale2'
    if scaling==3:
        factor = 3
        name='scale3'
    x = keras.layers.Conv2D(filters * (factor ** 2), kernel, padidng='same', name=name)(x)
    return keras.layers.Lambda(pixel_shuffle(scaling))(x)

# EDSR

In [5]:
class ResnetBlock(tf.keras.Model):
    def __init__(self, filters, kernel, scaling=0):
        super(ResnetBlock, self).__init__()
        self.conv1 = keras.layers.Conv2D(filters, kernel, padding='same', activation='selu')
        self.conv2 = keras.layers.Conv2D(filters, kernel, padding='same')
        self.scaler = None
        self.scaling = tf.constant(scaling)
        if scaling == None or scaling == 0:
            self.scaling = 1
        self.scaler = keras.layers.Lambda(lambda x: x * scaling)
        self.c = keras.layers.Add()

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = tf.cond(self.scaling > 0,lambda: self.scaler(x),lambda: tf.identity(x))
        x = self.c([x, inputs])
        return x
        
class UpSampler(tf.keras.Model):
    def __init__(self, filters=64, kernel=3, scaling=0, *args, **kwargs):
        super(UpSampler, self).__init__(*args, **kwargs)
        self.layers_ = []
        factor = 2
        if scaling > 1:
            if scaling == 3:
                factor = 3
            self.layers_.append(keras.layers.Conv2D(filters * (factor ** 2), kernel, padding='same'))
            self.layers_.append(keras.layers.Lambda(pixel_shuffle(factor)))
            
            if scaling == 4:
                self.layers_.append(keras.layers.Conv2D(filters * (factor ** 2), kernel, padding='same'))
                self.layers_.append(keras.layers.Lambda(pixel_shuffle(factor)))
    def call(self, inputs):
        x = inputs
        for layer in self.layers_:
            x = layer(x)
        return x


class EDSR(tf.keras.Model):
    def __init__(self, filters, res_blocks, kernel_size=3, scaling=2, res_block_scaling=0, rgb_mean=np.array([0.4488, 0.4371, 0.4040]) * 255) -> None:
        super(EDSR, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.num_res_blocks = res_blocks
        self.res_block_scaling = res_block_scaling
        self.rgb_mean = rgb_mean
        self.scaling = scaling
        
        # self.input = keras.layers.Input((None, None, 3))
        self.normalizer = keras.layers.Lambda(lambda x: (x - rgb_mean) / 127.5)
        self.conv1 = keras.layers.Conv2D(self.filters, self.kernel_size, padding='same')
        self.res_blocks = []
        for i in range (self.num_res_blocks):
            self.res_blocks.append(ResnetBlock(self.filters, self.kernel_size, self.res_block_scaling))
        self.conv2 = keras.layers.Conv2D(self.filters, self.kernel_size, padding='same')    
        self.add1 = keras.layers.Add()
        
        self.upsampler1 = UpSampler(self.filters, self.kernel_size, self.scaling)
        self.conv3 = keras.layers.Conv2D(3, 3, padding='same')
        
        self.denormalizer = keras.layers.Lambda(lambda x: (x * rgb_mean) + 127.5)
        
    def call(self, inputs):
        # print(inputs.shape)
        x = self.normalizer(inputs)
        # print(x.shape)
        x = b = self.conv1(x)
        # print(x.shape, b.shape)
        for block in self.res_blocks:
            b = block(b)
        b = self.conv2(b)
        x = self.add1([x, b])
        x = self.upsampler1(x)
        x = self.conv3(x)
        x = self.denormalizer(x)
        
        # i = tf.constant(0)
        # cond = lambda i, _: i < self.num_res_blocks
        # body = lambda i, x: (self.res_blocks[i](x), tf.add(i, 1))
        # _, x = tf.while_loop(cond,lambda i, x: tf.identity(x), [i, x], swap_memory=True, parallel_iterations=1)
        
        return x
    

In [6]:
import time 


def psnr(x1, x2):
    return tf.image.psnr(x1, x2, max_val=255)

def resolve(model, lr_batch):
    lr_batch = tf.cast(lr_batch, tf.float32)
    sr_batch = model(lr_batch)
    sr_batch = tf.clip_by_value(sr_batch, 0, 255)
    sr_batch = tf.round(sr_batch)
    sr_batch = tf.cast(sr_batch, tf.uint8)
    return sr_batch

def evaluate(model, dataset):
    psnr_values = []
    for lr, hr in dataset:
        sr = resolve(model, lr)
        psnr_value = psnr(hr, sr)[0]
        psnr_values.append(psnr_value)
    return tf.reduce_mean(psnr_values)

class Trainer:
    def __init__(self, model, loss, learning_rate, checkpoint_dr='./checkpoints/', model_name = "") -> None:
        super().__init__()
        self.now = time.time()
        self.loss = loss
        self.checkpoint = tf.train.Checkpoint
        self.model = model
        
        self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                              psnr=tf.Variable(-1.0),
                                              optimizer=keras.optimizers.Adam(learning_rate),
                                              model=model)
        if model_name != "":
            try:
                os.makedirs(checkpoint_dr+model_name)
            except Exception as e:
                print(e)
            checkpoint_dr = os.path.join(checkpoint_dr, model_name)
        self.checkpoint_manager = tf.train.CheckpointManager(checkpoint=self.checkpoint,
                                                             directory=checkpoint_dr,
                                                             max_to_keep=3)

        self.restore()
    # @property
    # def model(self):
    #     return self.checkpoint.model
    
    def train(self, train_ds, valid_ds, epochs, steps, eval_every=1000, save_best_only=False):
        loss_metric = keras.metrics.Mean()
        self.now = time.perf_counter()
        for lr, hr in train_ds.take(steps - self.checkpoint.step.numpy()):
            # print(lr.shape, hr.shape, steps - self.checkpoint.step.numpy())
            self.checkpoint.step.assign_add(1)
            step = self.checkpoint.step.numpy()
            loss = self.train_step(lr, hr)
            loss_metric(loss)
            # print(step + 1, eval_every, step + 1 % eval_every)
            if (step + 1) % eval_every == 0:
                print('validation:: ')
                loss_val = loss_metric.result()
                loss_metric.reset_states()
                
                psnr_val = self.evaluate(valid_ds.take(1))
                
                duration = time.perf_counter() - self.now
                print(f'{step}/{steps}: loss = {loss_val.numpy():.3f}, PSNR = {psnr_val.numpy():3f} ({duration:.2f}s)')

                if save_best_only and psnr_val <= self.checkpoint.psnr:
                    self.now = time.perf_counter()
                    continue

                self.checkpoint.psnr = psnr_val
                self.checkpoint_manager.save()

                self.now = time.perf_counter()
        
        
    
    def train_step(self, lr, hr):
        with tf.GradientTape() as tape:
            lr = tf.cast(lr, tf.float32)
            hr = tf.cast(hr, tf.float32)
            
            sr = self.checkpoint.model(lr, training=True)
            loss = self.loss(hr, sr)
        grads = tape.gradient(loss, self.checkpoint.model.trainable_variables)
        self.checkpoint.optimizer.apply_gradients(zip(grads, self.checkpoint.model.trainable_variables))
        return loss
        
    def evaluate(self, dataset):
        return evaluate(self.checkpoint.model, dataset)
    
    def restore(self):
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(f'Model restored from checkpoint at step {self.checkpoint.step.numpy()}.')
            

In [7]:
import os
import tensorflow as tf
import pandas as pd

from tensorflow.python.data.experimental import AUTOTUNE


class Dataset:
    def __init__(self, csv='train.csv', scale=2, subset='train', downgrade='bicubic', dr='', batch_size=8, split=.8, epochs=50, crop_size=96) -> None:
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.epochs = epochs
        self.downgrades = ['bicubic', 'unknown', 'mild', 'difficult']
        self.df = pd.read_csv(csv)
        self._scales = [2,3,4,8]
        if scale in self._scales:
            self.scale=scale
        else:
            raise ValueError(f'scale must be in ${self.scales}')
                
        
        len_df = len(self.df)
        if subset == 'train':
            self.image_ids = range(1, int(len_df * split))
        elif subset == 'valid':
            self.image_ids = range(int(len_df * split), len_df)
        else:
            raise ValueError("subset must be train or valid")
        
        self.subset = subset
        self.downgrade = downgrade
        self.images_dr = dr
        
        
        
    def __len__(self):
        return len(self.image_ids)
    
    def dataset(self, random_transform):
        ds = tf.data.Dataset.zip((self.create_dataset(self.df.path.to_list(), self.get_path(True), True), self.create_dataset(self.df.path.to_list(), self.get_path(False), False)))
        if random_transform == True:
            ds = ds.map(lambda lr, hr: random_crop(lr, hr, self.crop_size, self.scale), num_parallel_calls=AUTOTUNE)
            ds = ds.map(random_rotate, num_parallel_calls=AUTOTUNE)
            ds = ds.map(random_flip, num_parallel_calls=AUTOTUNE)
            ds = ds.map(random_brightness, num_parallel_calls=AUTOTUNE)
        ds = ds.shuffle(self.batch_size, reshuffle_each_iteration=True)

        ds = ds.batch(self.batch_size)
        ds = ds.repeat(self.epochs)
        ds = ds.prefetch(buffer_size=AUTOTUNE)
        return ds
        
    def get_low_res(self, file):
            splt = file.split('/')
            if 'HR' in splt[-2]:
                splt[-2] = splt[-2][:-2]
            splt[-2] += f'LR_{self.subset}_{self.scale}x'
            return '/'.join(splt)
    
    def create_dataset(self, files, cache, lr=False):
        if lr == True:
            ##
            # self.prepare_low_res()
            ##
            pass
            
            files = [(lambda x: self.get_low_res(x))(x) for x in files]
        ds = tf.data.Dataset.from_tensor_slices(files).map(tf.io.read_file)
        ds = ds.map(lambda x: tf.image.decode_png(x, 3), num_parallel_calls=AUTOTUNE)
        ds.cache(filename=tf.convert_to_tensor(os.getcwd()+'\\'+self.subset+'_'+str(self.scale)+'.cache', dtype=tf.string))
        for _ in ds:
            pass
        return ds
    def get_path(self, is_lr=False):
        r = 'LR' if is_lr == True else 'HR'
        return f'{r}_{self.subset}_{self.scale}x'
    def prepare_low_res(self): 
        
        fls = self.df.path.to_list()
        
        for file in fls:
            # print(file)
            img = tf.convert_to_tensor(cv2.imread(file), dtype=tf.uint8)
            img = tf.image.resize(img, (img.shape[0]//self.scale + (1 if img.shape[0] % 2 != 0 else 0), img.shape[1]//self.scale + (1 if img.shape[1] % 2 != 0 else 0)), tf.image.ResizeMethod.BICUBIC).numpy()
            #similar to get lr func
            splt = file.split('/')
            if 'HR' in splt[-2]:
                splt[-2] = splt[-2][:-2]
            splt[-2] += f'LR_{self.subset}_{self.scale}x'
            spl = splt[:-1]
            if not os.path.exists('/'.join(spl)):
                os.mkdir('/'.join(spl))
            fl = '/'.join(splt)
            
            # print(fl)
            cv2.imwrite(fl, img)
    # def 
        
    
def random_flip(lr_img, hr_img):
    rn = tf.random.uniform(shape=(), maxval=1)
    return tf.cond(rn < .5, 
                   lambda: (lr_img, hr_img), 
                   lambda: (tf.image.flip_left_right(lr_img), tf.image.flip_left_right(hr_img))
                   )
    
def random_brightness(lr_img, hr_img):
    rn = tf.random.uniform(shape=(), maxval=1)
    shift = tf.random.uniform(shape=(), minval=-.2, maxval=.2)
    
    return tf.cond(rn < .5, 
                   lambda: (lr_img, hr_img), 
                   lambda: (tf.image.adjust_brightness(lr_img, shift), tf.image.adjust_brightness(hr_img, shift))
                   )
    

def random_rotate(lr_img, hr_img):
    rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
    return tf.image.rot90(lr_img, rn), tf.image.rot90(hr_img, rn)

def random_crop(lr_img, hr_img, hr_crop_size=96, scale=2):
    lr_crop_size = hr_crop_size // scale
    lr_img_shp = tf.shape(lr_img)[:2]
    
    lr_w = tf.random.uniform(shape=(), maxval=lr_img_shp[1] - lr_crop_size + 1, dtype=tf.int32)
    lr_h = tf.random.uniform(shape=(), maxval=lr_img_shp[0] - lr_crop_size + 1, dtype=tf.int32)
    if lr_w > 3:
        lr_w -=3
    if lr_h > 3:
        lr_h -= 3
    hr_w = lr_w * scale
    hr_h = lr_h * scale
    
    # print(lr_h, lr_h + lr_crop_size, lr_w, lr_w+ lr_crop_size)
    # print(hr_h,hr_h + hr_crop_size, hr_w, hr_w+ hr_crop_size)
    lr_img_cropped = lr_img[lr_h:lr_h + lr_crop_size, lr_w: lr_w+ lr_crop_size]
    hr_img_cropped = hr_img[hr_h:hr_h + hr_crop_size, hr_w: hr_w+ hr_crop_size]
    
    return lr_img_cropped, hr_img_cropped    

In [8]:
tr_ds = Dataset(subset='train', scale=4, epochs=-1, crop_size=96*2)
train_ds = tr_ds.dataset(True)
val_ds = Dataset(subset='valid', scale=4, batch_size=1, crop_size=256, epochs=-1)
valid_ds = val_ds.dataset(True)

In [9]:
# os.makedirs('checkpoints4x')
K = keras.backend
def root_mean_squared_error(y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true))) 
    
model = EDSR(128, 24, scaling=4)
edsr = Trainer(model, root_mean_squared_error, learning_rate=keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[200000], values=[1e-3, 1e-5]), checkpoint_dr='./checkpoints4x/', model_name="res24_rmse")
edsr.train(train_ds, valid_ds.take(100), 40, 1000000, 5000, True) 

[WinError 183] Cannot create a file when that file already exists: './checkpoints4x/res24_rmse'
Model restored from checkpoint at step 4999.
validation:: 
9999/1000000: loss = 13.751, PSNR = 28.835882 (982.17s)
validation:: 
14999/1000000: loss = 17.294, PSNR = 33.131020 (968.08s)
validation:: 
19999/1000000: loss = 13.459, PSNR = 27.726410 (935.30s)
validation:: 
24999/1000000: loss = 18.242, PSNR = 28.966898 (982.68s)
validation:: 
29999/1000000: loss = 16.793, PSNR = 24.669607 (981.53s)
validation:: 
34999/1000000: loss = 13.420, PSNR = 22.958891 (934.22s)
validation:: 
39999/1000000: loss = 15.807, PSNR = 29.061581 (947.92s)
validation:: 
44999/1000000: loss = 13.427, PSNR = 31.028515 (981.09s)
validation:: 
49999/1000000: loss = 16.006, PSNR = 28.281641 (954.91s)
validation:: 
54999/1000000: loss = 17.365, PSNR = 28.995903 (954.54s)
validation:: 
59999/1000000: loss = 13.332, PSNR = 29.909019 (923.90s)


KeyboardInterrupt: 

In [10]:
edsr.model.save_weights('weights24_4', save_format='tf')

In [11]:
# model = EDSR(128, 24, scaling=4)
# model.load_weights('weights24_4')



<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x19f170e86a0>

In [16]:
# import cv2
# import matplotlib.pyplot as plt

full = '../../Data/DIV2K_train_HR/DIV2K_train_HR/0180.png'
pt = '../../Data/DIV2K_train_HR/DIV2K_train_LR_train_4x/0180.png'
lr = cv2.imread(pt)
fimg = cv2.imread(full)
with tf.device('/cpu:0'):
    sr =  resolve(model, tf.expand_dims(lr, axis=0))[0]
plt.figure(figsize=(20, 10))
images = [lr, sr, fimg]
titles = ['LR', f'SR (x{sr.shape[0] // lr.shape[0]})']
for i, (img, title) in enumerate(zip(images, titles)):
    plt.subplot(1, 2, i+1)
    plt.imshow(img)
    plt.title(title)
    plt.xticks([])
    plt.yticks([])
cv2.imwrite('lr.png', lr)
cv2.imwrite('sr2.png', sr.numpy())
cv2.imwrite('fr.png', fimg)
cv2.imwrite('up.png', cv2.resize(lr, (fimg.shape[1], fimg.shape[0])))


True

In [16]:
# cv2.imshow("a", lr)
# cv2.waitKey(0)
# cv2.destroyAllWindows()

In [20]:
cv2.imwrite('lr.png', lr)
cv2.imwrite('sr.png', sr.numpy())
cv2.imwrite('fr.png', fimg)

True

In [19]:
sr.numpy()

array([[[82, 71, 68],
        [83, 69, 68],
        [84, 70, 68],
        ...,
        [85, 70, 67],
        [85, 72, 68],
        [84, 72, 69]],

       [[83, 69, 67],
        [83, 70, 68],
        [84, 69, 68],
        ...,
        [84, 71, 68],
        [84, 72, 69],
        [85, 72, 68]],

       [[83, 68, 67],
        [83, 68, 68],
        [83, 68, 68],
        ...,
        [85, 72, 68],
        [85, 72, 69],
        [86, 74, 69]],

       ...,

       [[37, 37, 57],
        [37, 39, 59],
        [37, 38, 60],
        ...,
        [29, 12, 55],
        [30, 14, 54],
        [31, 15, 54]],

       [[37, 39, 57],
        [38, 40, 59],
        [39, 41, 61],
        ...,
        [30, 15, 55],
        [31, 17, 54],
        [31, 19, 53]],

       [[39, 40, 57],
        [38, 40, 59],
        [39, 40, 61],
        ...,
        [30, 14, 55],
        [30, 17, 54],
        [34, 22, 52]]], dtype=uint8)