In [1]:
import os
import matplotlib.pyplot as plt
from dataset import dataset
from edsr import edsr
from train import EdsrTrainer
from common import resolve
from common import resolve_single
from common import evaluate

In [7]:
# set hp depth and scale and get device to clear memory later

'''
author's settled on 32 res blocks and 256 filters and a scaling
multiple for the residual blocks of 0.1
'''
depth = 16
scale = 4
filters = 64
loss = 'MAE'

In [8]:
# set directory for saving model weights 

weights_dir = './weights/edsr/'
weights_file = os.path.join(weights_dir, 'weightsB{}F{}-{}.h5'.format(depth, filters, loss))
os.makedirs(weights_dir, exist_ok=True)

In [9]:
# create our training and validation datasets

train_ds = dataset(batch_size=10, random_transform=True, subset='train')
valid_ds = dataset(batch_size=1, repeat_count=1, random_transform=False, subset='valid')

In [10]:
# initialize model and get weights from model trained by Krasserm

model=edsr(scale=scale, num_filters=filters, num_res_blocks=depth, res_block_scaling=0.1)
#model.load_weights('weights/edsr/weights.h5')     

In [None]:
# initialize the trainer

trainer = EdsrTrainer(model=model, loss=loss, checkpoint_dir=f'.ckpt/edsr-{depth}-x{scale}')

In [None]:
# train!

trainer.train(train_ds, valid_ds.take(10), steps=45000, evaluate_every=500, save_best_only=True)

In [None]:
# restore the weights from the best checkpoint

trainer.restore()

In [None]:
# validate results

psnrv = trainer.evaluate(valid_ds)
print(f'PSNR = {psnrv.numpy():3f}')

In [None]:
# save weights

trainer.model.save_weights(weights_file)

In [11]:
print(model.summary())

Model: "edsr"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
lambda_36 (Lambda)              (None, None, None, 3 0           input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None, None, None, 6 1792        lambda_36[0][0]                  
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None, None, None, 6 36928       conv2d_67[0][0]                  
_______________________________________________________________________________________________