In [7]:
import os
import matplotlib.pyplot as plt
from dataset import dataset
from wdsr import wdsr
from train import WdsrTrainer
from common import resolve
from common import resolve_single
from common import evaluate

In [8]:
# set depth for res blocks and scale (always 4)

depth = 16
scale = 4
filters = 128
loss = 'MAE'
norm = 'nn'

In [9]:
# set folder to save weight into 

weights_dir = './weights/wdsr/'
weights_file = os.path.join(weights_dir, 'weightsB{}F{}-x{}-{}{}.h5'.format(depth, filters, scale, loss, norm))
os.makedirs(weights_dir, exist_ok=True)

In [10]:
# create our training and validation datasets 

train_ds = dataset(batch_size=16, random_transform=True, subset='train')
valid_ds = dataset(batch_size=1, repeat_count=1, random_transform=False, subset='valid')

In [11]:
model=wdsr(scale=scale, num_filters=filters, num_res_blocks=depth, res_block_scaling=None, norm=norm)
#model.load_weights('weights/wdsr/weights-pretuned.h5')

In [12]:
# initialize the trainer

trainer = WdsrTrainer(model=model, loss=loss, checkpoint_dir=f'.ckpt/wdsr-{depth}-x{scale}')

In [13]:
# train!

trainer.train(train_ds, valid_ds.take(10), steps=30000, evaluate_every=500, save_best_only=True)

500/30000: loss = 17.615, PSNR = 22.012589
1000/30000: loss = 10.414, PSNR = 22.629023
1500/30000: loss = 9.749, PSNR = 22.942406
2000/30000: loss = 9.320, PSNR = 23.225838
2500/30000: loss = 8.986, PSNR = 23.427240
3000/30000: loss = 8.656, PSNR = 23.596500
3500/30000: loss = 8.380, PSNR = 23.956503
4000/30000: loss = 8.078, PSNR = 24.137165
4500/30000: loss = 7.807, PSNR = 24.584650
5000/30000: loss = 7.530, PSNR = 24.831776
5500/30000: loss = 7.208, PSNR = 25.046782
6000/30000: loss = 6.981, PSNR = 25.393698
6500/30000: loss = 6.706, PSNR = 25.591116
7000/30000: loss = 6.445, PSNR = 25.966806
7500/30000: loss = 6.157, PSNR = 26.463383
8000/30000: loss = 5.898, PSNR = 26.925694
8500/30000: loss = 5.692, PSNR = 27.109777
9000/30000: loss = 5.506, PSNR = 27.402899
9500/30000: loss = 5.352, PSNR = 27.513275
10000/30000: loss = 5.198, PSNR = 28.032263
10500/30000: loss = 5.047, PSNR = 28.243313
11000/30000: loss = 4.908, PSNR = 28.491262
11500/30000: loss = 4.831, PSNR = 28.522650
12000/

In [14]:
# restore the weights from the best checkpoint

trainer.restore()

Model restored from checkpoint at step 25500.


In [15]:
# validate results

psnrv = trainer.evaluate(valid_ds)
print(f'PSNR = {psnrv.numpy():3f}')

PSNR = 30.212297


In [None]:
# save weights

trainer.model.save_weights(weights_file)

In [6]:
print(model.summary())

Model: "wdsr"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, None, 3 0           input_1[0][0]                    
__________________________________________________________________________________________________
weight_normalization (WeightNor (None, None, None, 3 929         lambda[0][0]                     
__________________________________________________________________________________________________
weight_normalization_1 (WeightN (None, None, None, 1 6529        weight_normalization[0][0]       
_______________________________________________________________________________________________