In [1]:
from data import drone_data

In [2]:
train_loader = drone_data(subset='train')
train_ds = train_loader.dataset(batch_size=6)

In [3]:
valid_loader = drone_data(subset='valid')
valid_ds = valid_loader.dataset(batch_size=1)

# EDSR Model

In [4]:
import tensorflow as tf
from models.edsr import edsr
from train import EdsrTrainer

trainer = EdsrTrainer(model = edsr(scale=4, num_res_blocks=16),
                      checkpoint_dir=f'../.ckpt/edsr-16-x4')
trainer.train(train_ds,
            valid_ds.take(40),
            steps=2000,
            evaluate_every=20,
            save_best_only=True)
trainer.restore()

Model restored from checkpoint at step 2000.
Model restored from checkpoint at step 2000.


In [5]:
import os
psnr = trainer.evaluate(valid_ds.take(2200))
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
os.makedirs('../weights/edsr-16-x4', exist_ok=True)
trainer.model.save_weights('../weights/edsr-16-x4/weights.h5')

PSNR = 31.162270


# WDSR Model

In [6]:
from models.wdsr import wdsr_b
from train import WdsrTrainer

trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'../.ckpt/wdsr-b-8-x4')

trainer.train(train_ds,
              valid_ds.take(40),
              steps=3300, 
              evaluate_every=40, 
              save_best_only=True)

trainer.restore()

Model restored from checkpoint at step 3280.
Model restored from checkpoint at step 3280.


In [7]:
psnr = trainer.evaluate(valid_ds.take(2200))
print(f'PSNR = {psnr.numpy():3f}')

PSNR = 31.551722


In [8]:
import os
# Save weights to separate location.
os.makedirs('../weights/wdsr-b-32-x4', exist_ok=True)
trainer.model.save_weights('../weights/wdsr-b-32-x4/weights.h5')

# SRGAN finetuning EDSR Model

In [9]:
from models.srgan import generator, discriminator
from models.edsr import edsr
from train import SrganTrainer

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('../weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=300)

50/300, perceptual loss = 0.0148, discriminator loss = 2.7900
100/300, perceptual loss = 0.0078, discriminator loss = 1.9429
150/300, perceptual loss = 0.0085, discriminator loss = 3.0903
200/300, perceptual loss = 0.0106, discriminator loss = 2.0236
250/300, perceptual loss = 0.0097, discriminator loss = 1.9306
300/300, perceptual loss = 0.0086, discriminator loss = 2.4216


In [10]:
generator.save_weights('../weights/edsr-16-x4/finetuned_weights.h5')

# SRGAN finetuning WDSR Model

In [4]:
from models.srgan import generator, discriminator
from models.wdsr import wdsr_b
from train import SrganTrainer

# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('../weights/wdsr-b-32-x4/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=300)

50/300, perceptual loss = 0.0119, discriminator loss = 3.2831
100/300, perceptual loss = 0.0065, discriminator loss = 1.7058
150/300, perceptual loss = 0.0075, discriminator loss = 1.8289
200/300, perceptual loss = 0.0124, discriminator loss = 2.3117
250/300, perceptual loss = 0.0113, discriminator loss = 1.8097
300/300, perceptual loss = 0.0097, discriminator loss = 2.7430


In [5]:
generator.save_weights('../weights/wdsr-b-32-x4/finetuned_weights.h5')

# SRGAN Model

In [5]:
from models.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'../.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=4000, evaluate_every=100)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('../weights/srgan/pre_generator.h5')

Model restored from checkpoint at step 200.
300/4000: loss = 237.381, PSNR = 22.157049 (475.10s)
400/4000: loss = 150.067, PSNR = 21.504644 (461.68s)
500/4000: loss = 142.026, PSNR = 22.961258 (447.06s)
600/4000: loss = 127.015, PSNR = 24.027056 (383.37s)
700/4000: loss = 98.534, PSNR = 24.289984 (418.23s)
800/4000: loss = 146.285, PSNR = 24.207560 (397.58s)
900/4000: loss = 108.621, PSNR = 22.497898 (320.21s)
1000/4000: loss = 277.233, PSNR = 26.083469 (321.05s)
1100/4000: loss = 224.403, PSNR = 27.737858 (321.09s)
1200/4000: loss = 191.350, PSNR = 27.578867 (320.76s)
1300/4000: loss = 136.751, PSNR = 26.912033 (320.14s)
1400/4000: loss = 152.394, PSNR = 28.228680 (327.34s)
1500/4000: loss = 194.451, PSNR = 27.594666 (331.47s)
1600/4000: loss = 128.516, PSNR = 26.617075 (330.99s)
1700/4000: loss = 126.056, PSNR = 28.367420 (329.10s)
1800/4000: loss = 159.528, PSNR = 28.749945 (327.69s)
1900/4000: loss = 156.064, PSNR = 22.445560 (328.66s)
2000/4000: loss = 68.865, PSNR = 25.233292 (32

In [8]:
from models.srgan import generator, discriminator
from train import SrganTrainer

gan_generator = generator()
gan_generator.load_weights('../weights/srgan/pre_generator.h5')

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())

gan_trainer.train(train_ds, steps=750)

gan_trainer.generator.save_weights('../weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('../weights/srgan/gan_discriminator.h5')

50/750, perceptual loss = 0.0208, discriminator loss = 3.0945
100/750, perceptual loss = 0.0178, discriminator loss = 2.9196
150/750, perceptual loss = 0.0208, discriminator loss = 3.9222
200/750, perceptual loss = 0.0260, discriminator loss = 6.6371
250/750, perceptual loss = 0.0210, discriminator loss = 4.5799
300/750, perceptual loss = 0.0185, discriminator loss = 4.6304
350/750, perceptual loss = 0.0174, discriminator loss = 4.4746
400/750, perceptual loss = 0.0216, discriminator loss = 2.9348
450/750, perceptual loss = 0.0181, discriminator loss = 2.7856
500/750, perceptual loss = 0.0178, discriminator loss = 3.8018
550/750, perceptual loss = 0.0209, discriminator loss = 4.9150
600/750, perceptual loss = 0.0193, discriminator loss = 2.6541
650/750, perceptual loss = 0.0176, discriminator loss = 3.0440
700/750, perceptual loss = 0.0159, discriminator loss = 2.3343
750/750, perceptual loss = 0.0340, discriminator loss = 1.0428
