In [1]:
from data import drone_data

In [2]:
train_loader = drone_data(subset='train')
train_ds = train_loader.dataset(batch_size=6)

In [3]:
valid_loader = drone_data(subset='valid')
valid_ds = valid_loader.dataset(batch_size=1)

# EDSR Model

In [4]:
import tensorflow as tf
from models.edsr import edsr
from train import EdsrTrainer

trainer = EdsrTrainer(model = edsr(scale=4, num_res_blocks=16),
                      checkpoint_dir=f'.ckpt/edsr-16-x4')
trainer.train(train_ds,
            valid_ds.take(40),
            steps=2000,
            evaluate_every=20,
            save_best_only=True)
trainer.restore()

Model restored from checkpoint at step 2000.
Model restored from checkpoint at step 2000.


In [5]:
import os
psnr = trainer.evaluate(valid_ds.take(2200))
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
os.makedirs('weights/edsr-16-x4', exist_ok=True)
trainer.model.save_weights('weights/edsr-16-x4/weights.h5')

PSNR = 31.162270


# WDSR Model

In [6]:
from models.wdsr import wdsr_b
from train import WdsrTrainer

trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'.ckpt/wdsr-b-8-x4')

trainer.train(train_ds,
              valid_ds.take(40),
              steps=3300, 
              evaluate_every=40, 
              save_best_only=True)

trainer.restore()

Model restored from checkpoint at step 3280.
Model restored from checkpoint at step 3280.


In [7]:
psnr = trainer.evaluate(valid_ds.take(2200))
print(f'PSNR = {psnr.numpy():3f}')

PSNR = 31.551722


In [8]:
import os
# Save weights to separate location.
os.makedirs('weights/wdsr-b-32-x4', exist_ok=True)
trainer.model.save_weights('weights/wdsr-b-32-x4/weights.h5')

# SRGAN finetuning EDSR Model

In [9]:
from models.srgan import generator, discriminator
from models.edsr import edsr
from train import SrganTrainer

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=300)

50/300, perceptual loss = 0.0148, discriminator loss = 2.7900
100/300, perceptual loss = 0.0078, discriminator loss = 1.9429
150/300, perceptual loss = 0.0085, discriminator loss = 3.0903
200/300, perceptual loss = 0.0106, discriminator loss = 2.0236
250/300, perceptual loss = 0.0097, discriminator loss = 1.9306
300/300, perceptual loss = 0.0086, discriminator loss = 2.4216


In [10]:
generator.save_weights('weights/edsr-16-x4/finetuned_weights.h5')

# SRGAN finetuning WDSR Model

In [4]:
from models.srgan import generator, discriminator
from models.wdsr import wdsr_b
from train import SrganTrainer

# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('weights/wdsr-b-32-x4/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=300)

50/300, perceptual loss = 0.0119, discriminator loss = 3.2831
100/300, perceptual loss = 0.0065, discriminator loss = 1.7058
150/300, perceptual loss = 0.0075, discriminator loss = 1.8289
200/300, perceptual loss = 0.0124, discriminator loss = 2.3117
250/300, perceptual loss = 0.0113, discriminator loss = 1.8097
300/300, perceptual loss = 0.0097, discriminator loss = 2.7430


In [5]:
generator.save_weights('weights/wdsr-b-32-x4/finetuned_weights.h5')