In [1]:
import os
import matplotlib.pyplot as plt
from dataset import dataset
from model.edsr import edsr
from train import EdsrTrainer
from numba import cuda

In [2]:
# set hp depth and scale and get device to clear memory later

device = cuda.get_current_device()
depth = 16
scale = 4

In [3]:
# set directory for saving model weights 

weights_dir = './weights/edsr-16-x4/'
weights_file = os.path.join(weights_dir, 'weightsM.h5')
os.makedirs(weights_dir, exist_ok=True)

In [4]:
# create our training and validation datasets

train_ds = dataset(batch_size=16, random_transform=True, subset='train')
valid_ds = dataset(batch_size=1, random_transform=False, subset='valid')

In [5]:
# initialize model and get weights from model trained by Krasserm

model=edsr(scale=scale, num_res_blocks=depth)
model.load_weights('weights/edsr-16-x4/weights.h5')     

In [7]:
# initialize the trainer

trainer = EdsrTrainer(model=model, checkpoint_dir=f'.ckpt/edsr-{depth}-x{scale}')

In [8]:
# train!

trainer.train(train_ds, valid_ds.take(10), steps=3000, evaluate_every=100, save_best_only=True)

100/3000: loss = 8.318, PSNR = 23.869587 (14.96s)
200/3000: loss = 7.922, PSNR = 24.067587 (11.15s)
300/3000: loss = 7.756, PSNR = 24.188211 (11.42s)
400/3000: loss = 7.650, PSNR = 24.277884 (11.12s)
500/3000: loss = 7.575, PSNR = 24.371241 (11.14s)
600/3000: loss = 7.479, PSNR = 24.462727 (11.14s)
700/3000: loss = 7.394, PSNR = 24.527266 (11.18s)
800/3000: loss = 7.335, PSNR = 24.608912 (11.20s)
900/3000: loss = 7.274, PSNR = 24.656963 (11.24s)
1000/3000: loss = 7.221, PSNR = 24.728397 (11.24s)
1100/3000: loss = 7.148, PSNR = 24.815470 (11.24s)
1200/3000: loss = 7.105, PSNR = 24.832607 (11.19s)
1300/3000: loss = 7.080, PSNR = 24.896782 (11.26s)
1400/3000: loss = 7.042, PSNR = 24.945065 (11.18s)
1500/3000: loss = 6.999, PSNR = 24.996777 (11.23s)
1600/3000: loss = 6.969, PSNR = 25.055237 (11.21s)
1700/3000: loss = 6.926, PSNR = 25.083309 (11.32s)
1800/3000: loss = 6.878, PSNR = 25.121912 (11.30s)
1900/3000: loss = 6.852, PSNR = 25.164642 (11.23s)
2000/3000: loss = 6.807, PSNR = 25.18922

In [9]:
# restore the weights from the best checkpoint

trainer.restore()

Model restored from checkpoint at step 2800.


In [10]:
# validate results

psnrv = trainer.evaluate(valid_ds)
print(f'PSNR = {psnrv.numpy():3f}')

KeyboardInterrupt: 

In [11]:
# save weights

trainer.model.save_weights(weights_file)

In [12]:
# clear GPU memory

device.reset()