### pre

In [None]:
import numpy as np
import tifffile
import os
import cv2
import random
import tensorflow as tf
from tqdm import tqdm

In [None]:
t = 0
for i in os.listdir("raw"):
    img = tifffile.imread("raw/"+i)
    h = img.shape[0]/4
    w = img.shape[1]/4
    img = img[int(h):int(3*h),int(w):int(3*w)]
    for j in range(img.shape[0]//128):
        for k in range(img.shape[1]//128):
            tif = img[j*128:(j+1)*128,k*128:(k+1)*128]
            tifffile.imwrite("input/hr/"+str(t)+".tif",tif)
            r = random.randint(1,5)
            if r == 1:
                tif = cv2.resize(tif, dsize=(32,32), interpolation=cv2.INTER_NEAREST)
            elif r == 2:
                tif = cv2.resize(tif, dsize=(32,32), interpolation=cv2.INTER_LINEAR)
            elif r == 3:
                tif = cv2.resize(tif, dsize=(32,32), interpolation=cv2.INTER_AREA)
            elif r == 4:
                tif = cv2.resize(tif, dsize=(32,32), interpolation=cv2.INTER_CUBIC)
            else:
                tif = cv2.resize(tif, dsize=(32,32), interpolation=cv2.INTER_LANCZOS4)
            tifffile.imwrite("input/lr/"+str(t)+".tif",tif)
            t += 1

### train

In [None]:
import numpy as np
import tifffile
import os
import cv2
import random
import tensorflow as tf
from tqdm import tqdm

In [None]:
num = len(os.listdir("input/hr"))
l = os.listdir("input/hr")
hr = np.full((num,128,128,1),np.nan)
for i in tqdm(range(num)):
    img = tifffile.imread("input/hr/"+l[i])
    img = (img-img.min())/(img.max()-img.min())
    hr[i,:,:,0] = img
num = len(os.listdir("input/hr"))
l = os.listdir("input/lr")
lr = np.full((num,32,32,1),np.nan)
for i in tqdm(range(num)):
    img = tifffile.imread("input/lr/"+l[i])
    img = (img-img.min())/(img.max()-img.min())
    lr[i,:,:,0] = img

In [None]:
data = tf.data.Dataset.from_tensor_slices((lr,hr))

In [None]:
def residual_block_gen(ch=64,k_s=3,st=1):
    model=tf.keras.Sequential([
        tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding="same"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding="same"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
    ])
    return model

def Upsample_block(x, ch=256, k_s=3, st=1):
    x = tf.keras.layers.Conv2D(ch,k_s, strides=(st,st),padding="same")(x)
    x = tf.nn.depth_to_space(x, 2)
    x = tf.keras.layers.LeakyReLU()(x)
    return x

input_lr=tf.keras.layers.Input(shape=(None,None,1))
input_conv=tf.keras.layers.Conv2D(64,9,padding="same")(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)
SRRes=input_conv
for x in range(5):
    res_output=residual_block_gen()(SRRes)
    SRRes=tf.keras.layers.Add()([SRRes,res_output])
SRRes=tf.keras.layers.Conv2D(64,9,padding="same")(SRRes)
SRRes=tf.keras.layers.BatchNormalization()(SRRes)
SRRes=tf.keras.layers.Add()([SRRes,input_conv])
SRRes=Upsample_block(SRRes)
SRRes=Upsample_block(SRRes)
output_sr=tf.keras.layers.Conv2D(1,9,activation="tanh",padding="same")(SRRes)
SRResnet=tf.keras.models.Model(input_lr,output_sr)

In [None]:
def residual_block_disc(ch=64,k_s=3,st=1):
    model=tf.keras.Sequential([
        tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding="same"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
    ])
    return model

input_lr=tf.keras.layers.Input(shape=(128,128,1))
input_conv=tf.keras.layers.Conv2D(64,1,padding="same")(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)
channel_nums=[64,128,128,256,256,512,512]
stride_sizes=[2,1,2,1,2,1,2]
disc=input_conv
for x in range(7):
    disc=residual_block_disc(ch=channel_nums[x],st=stride_sizes[x])(disc)
disc=tf.keras.layers.Flatten()(disc)
disc=tf.keras.layers.Dense(1024)(disc)
disc=tf.keras.layers.LeakyReLU()(disc)
disc_output=tf.keras.layers.Dense(1,activation="sigmoid")(disc)
discriminator=tf.keras.models.Model(input_lr,disc_output)

In [None]:
def PSNR(y_true,y_pred):
    y_pred = tf.cast(y_pred,tf.float64)
    mse=tf.reduce_mean( (y_true - y_pred) ** 2 )
    return 20 * log10(1 / (mse ** 0.5))

def log10(x):
    numerator = tf.math.log(x)
    denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
    return numerator / denominator

def pixel_MSE(y_true,y_pred):
    y_pred = tf.cast(y_pred,tf.float64)
    return tf.reduce_mean( (y_true - y_pred) ** 2 )

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy()

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
@tf.function()
def train_step(data,loss_func=pixel_MSE,adv_learning=True,evaluate=["PSNR"],adv_ratio=0.001):
    logs={}
    gen_loss,disc_loss=0,0
    low_resolution,high_resolution=data
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        super_resolution = SRResnet(low_resolution, training=True)
        gen_loss=loss_func(high_resolution,super_resolution)
        logs["reconstruction"]=gen_loss
        if adv_learning:
            real_output = discriminator(high_resolution, training=True)
            fake_output = discriminator(super_resolution, training=True)
            adv_loss_g = generator_loss(fake_output) * adv_ratio
            gen_loss += adv_loss_g
            disc_loss = discriminator_loss(real_output, fake_output)
            logs["adv_g"]=adv_loss_g
            logs["adv_d"]=disc_loss
    gradients_of_generator = gen_tape.gradient(gen_loss, SRResnet.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, SRResnet.trainable_variables))
    if adv_learning:
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    for x in evaluate:
        if x=="PSNR":
            logs[x]=PSNR(high_resolution,super_resolution)
    return logs

In [None]:
generator_optimizer=tf.keras.optimizers.Adam(0.001)
discriminator_optimizer=tf.keras.optimizers.Adam(0.001)
adv_ratio=0.001
evaluate=["PSNR"]
loss_func,adv_learning = pixel_MSE,False

In [None]:
for x in range(500):
    print("epoch:",x)
    for image_batch in tqdm(data.batch(32), position=0, leave=True):
        logs=train_step(image_batch,loss_func,adv_learning,evaluate,adv_ratio)
    print("reconstruction:",logs["reconstruction"],"  PSNR:",logs["PSNR"])

In [None]:
SRResnet.save("output/model.h5")