In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Reshape, LeakyReLU, Dropout, UpSampling2D,AveragePooling2D,Conv2DTranspose, Input, Concatenate, Add, BatchNormalization, Activation, MultiHeadAttention
import tensorflow_hub as hub
import tensorflow_text as text
from ipywidgets import IntProgress
from IPython.display import display
import cv2

In [None]:
#путь к тестовому изображению
image_path = "C:/users/user/desktop/8k_128/1313693129_71d0b21c63.jpg"
original_image = cv2.imread(image_path)

original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
original_image = tf.image.resize(tf.cast(original_image, tf.float32)/256, (128, 128))
small_image = tf.image.resize(original_image, (32, 32), method = "bicubic")
bicubic_result = tf.image.resize(small_image, (128, 128), method = "bicubic")
plt.imshow(bicubic_result)

In [None]:
#esrgan от tensorflow
SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
model_upscaler = hub.load(SAVED_MODEL_PATH)

In [None]:
#bert
tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1'
tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
bert_model = hub.KerasLayer(tfhub_handle_encoder)

def process_text(text_batch):
    text_preprocessed = bert_preprocess_model(text_batch)
    bert_results = bert_model(text_preprocessed)
    return bert_results["pooled_output"]

In [None]:
esrgan_result = model_upscaler(tf.expand_dims(small_image*256, axis = 0))[0]/256
plt.imshow(esrgan_result)

In [None]:
regression_sr = tf.keras.models.load_model('regression_rs.h5')

In [None]:
regression_result = regression_sr(tf.expand_dims(small_image, axis = 0))[0]
plt.imshow(regression_result)

In [None]:
#подробнее о том, что тут происходит: https://keras.io/examples/generative/ddpm/
class GaussianDiffusion:
    """Утилита для гауссовского диффузии.

    Args:
        beta_start: Начальное значение дисперсии
        beta_end: Конечное значение дисперсии
        timesteps: Количество временных шагов в процессе прямой, а затем обратной диффузии
    """

    def __init__(
        self, beta_start=1e-4, beta_end=0.02, timesteps=1000, clip_min=-1.0, clip_max=1.0):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.timesteps = timesteps
        self.clip_min = clip_min
        self.clip_max = clip_max

        # Определение линейного пространства дисперсии
        self.betas = betas = np.linspace(
            beta_start,
            beta_end,
            timesteps,
            dtype=np.float64,  # Тут используется float64 для лучшей точности
        )
        self.num_timesteps = int(timesteps)

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        self.betas = tf.constant(betas, dtype=tf.float32)
        self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
        self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)

        # Расчеты для диффузии q(x_t | x_{t-1}) и других
        self.sqrt_alphas_cumprod = tf.constant(np.sqrt(alphas_cumprod), dtype=tf.float32)

        self.sqrt_one_minus_alphas_cumprod = tf.constant(np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32)

        self.log_one_minus_alphas_cumprod = tf.constant(np.log(1.0 - alphas_cumprod), dtype=tf.float32)

        self.sqrt_recip_alphas_cumprod = tf.constant(np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32)
        self.sqrt_recipm1_alphas_cumprod = tf.constant(np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32)

        # Расчеты для апостериорной q(x_{t-1} | x_t, x_0)
        posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod))
        self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)

        # Обрезка расчета логарифма, так как апостериорная дисперсия равна 0 в начале цепочки диффузии
        self.posterior_log_variance_clipped = tf.constant(np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32)

        self.posterior_mean_coef1 = tf.constant(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),dtype=tf.float32,)

        self.posterior_mean_coef2 = tf.constant(
            (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),dtype=tf.float32)

    def _extract(self, a, t, x_shape):
        """Извлекает некоторые коэффициенты в указанных временных шагах,
        затем изменяет форму на [batch_size, 1, 1, 1, 1, ...] совпадения форм.

        Args:
            a: Тензор для извлечения
            t: Временной шаг, для которого коэффициенты должны быть извлечены
            x_shape: Форма текущих выборок в батче
        """
        batch_size = x_shape[0]
        out = tf.gather(a, t)
        return tf.reshape(out, [batch_size, 1, 1, 1])

    def q_mean_variance(self, x_start, t):
        """Извлекает среднее значение и дисперсию на текущем временном шаге.

        Args:
            x_start: Начальный образец (перед первым шагом диффузии)
            t: Текущий временной шаг
        """
        x_start_shape = tf.shape(x_start)
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
        variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
        log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start_shape)
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise):
        """Диффузия данных.

        Args:
            x_start: Начальный образец (перед первым шагом диффузии)
            t: Текущий временной шаг
            noise: Добавляемый гауссовский шум на текущем временном шаге
        Returns:
            Диффузионные образцы на временном шаге `t`
        """
        x_start_shape = tf.shape(x_start)
        
        return (
            self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
            + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
            * noise
        )

    def predict_start_from_noise(self, x_t, t, noise):
        x_t_shape = tf.shape(x_t)
        
        return (
            self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
            - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):
        """Вычисляет среднее значение и дисперсию диффузии апостериорной q(x_{t-1} | x_t, x_0).

        Args:
            x_start: Точка начала (образец) для вычисления апостериори
            x_t: Образец на временном шаге `t`
            t: Текущий временной шаг
        Returns:
            Апостериорное среднее значение и дисперсия на текущем временном шаге
        """

        x_t_shape = tf.shape(x_t)
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
            + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
        posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t_shape)
        
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
        x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
        if clip_denoised:
            x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        
        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(self, pred_noise, x, t, clip_denoised=True):
        """Выборка из модели диффузии.

        Args:
            pred_noise: Шум, предсказанный моделью диффузии
            x: Образцы на определенном временном шаге, для которого был предсказан шум
            t: Текущий временной шаг
            clip_denoised (bool): Нужно ли обрезать предсказанный шум в указанном диапазоне или нет.
        """
        model_mean, _, model_log_variance = self.p_mean_variance(pred_noise, x=x, t=t, clip_denoised=clip_denoised)
        noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
        # Нет шума, когда t == 0
        nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1])
        
        return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise


In [None]:
class SR3DiffusionModel(keras.Model):
    def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
        super().__init__()
        self.network = network
        self.ema_network = ema_network
        self.timesteps = timesteps
        self.gdf_util = gdf_util

    def plot_images(self, num_rows, num_cols, figsize, images):
        images = (tf.clip_by_value(images * 127 + 127, 0.0, 255.0).numpy().astype(np.uint8))
        _, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
        for i, image in enumerate(images):
            if num_rows == 1:
                if num_cols == 1:
                    ax.imshow(image)
                    ax.axis("off")
                else:
                    ax[i].imshow(image)
                    ax[i].axis("off")
            else:
                ax[i // num_cols, i % num_cols].imshow(image)
                ax[i // num_cols, i % num_cols].axis("off")

        plt.tight_layout()
        plt.show()
        
    def run_generation(self, small_images, num_rows=2, num_cols=8, figsize=(5, 2), annotation = " ", ema_mode = True, ex_rate = 0, from_annotation = False):
        num_images = num_rows * num_cols
        small_images = small_images / 127 - 1
        small_images = small_images / 2.0
        self.plot_images(num_rows, num_cols, figsize, small_images*2.0)
        
        if from_annotation:
            annotation = tf.expand_dims(annotation, axis = 0)
            embedding = process_text(annotation)
            embedding = tf.expand_dims(embedding, axis = 0)
            embeddings = tf.repeat(embedding, num_images , axis = 0)
            
        if ex_rate > 0 or not from_annotation:
             # 1.2 Преобразуем negative prompt в эмбеддинг
            negative_prompt = " "
            negative_prompt = tf.expand_dims(negative_prompt, axis = 0)
            negative_embedding = process_text(negative_prompt)
            negative_embedding = tf.expand_dims(negative_embedding, axis = 0)
            negative_embeddings = tf.repeat(negative_embedding, num_images , axis = 0)
            
        if not from_annotation:
            embeddings = negative_embeddings
            
        samples = tf.random.normal(shape=(num_images, 128, 128, 3), dtype=tf.float32) 
        bar = IntProgress(min=0, max=self.timesteps)
        display(bar)
        for t in reversed(range(0, self.timesteps)):
            bar.value+=1
            tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64)
            if ema_mode:
                pred_noise = self.ema_network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
            else:
                pred_noise = self.network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
            if ex_rate >0:
                pred_negative_noise = self.ema_network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(negative_embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
                #Экстраполяция шума от negative в сторону positive
                resulted_noise = pred_noise + (pred_noise - pred_negative_noise)*ex_rate
            else:
                resulted_noise = pred_noise     
            samples = self.gdf_util.p_sample(resulted_noise, samples, tt, clip_denoised=True)
        
        generated_samples = samples
        self.plot_images(num_rows, num_cols, figsize, generated_samples*2.0)
        return (tf.clip_by_value((generated_samples*2.0) * 127 + 127, 0.0, 255.0).numpy().astype(np.uint8))

In [None]:
total_timesteps = 300

gdf_util = GaussianDiffusion(timesteps=total_timesteps)

embedding_dims = 32
embedding_max_frequency = 1000.0
image_size = 128
img_channels = 3

network = tf.keras.models.load_model('sr3 32 to 128.h5')

SR3model = SR3DiffusionModel(network=network, ema_network=network, gdf_util=gdf_util, timesteps=total_timesteps)

In [None]:
SR3_result = SR3model.run_generation(tf.expand_dims(small_image*256, axis = 0), num_rows=1, num_cols=1, ex_rate = 3, from_annotation = True, annotation = "Some trees near the field")[0]/256

In [None]:
class SR64DiffusionModel(keras.Model):
    def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
        super().__init__()
        self.network = network
        self.ema_network = ema_network
        self.timesteps = timesteps
        self.gdf_util = gdf_util

    def plot_images(self, num_rows, num_cols, figsize, images):
        images = (tf.clip_by_value(images * 127 + 127, 0.0, 255.0).numpy().astype(np.uint8))
        _, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
        for i, image in enumerate(images):
            if num_rows == 1:
                if num_cols == 1:
                    ax.imshow(image)
                    ax.axis("off")
                else:
                    ax[i].imshow(image)
                    ax[i].axis("off")
            else:
                ax[i // num_cols, i % num_cols].imshow(image)
                ax[i // num_cols, i % num_cols].axis("off")

        plt.tight_layout()
        plt.show()
        
    def run_generation(self, small_images, num_rows=2, num_cols=8, figsize=(5, 2), annotation = " ", ema_mode = True, ex_rate = 0, from_annotation = False):
        num_images = num_rows * num_cols
        small_images = small_images / 127 - 1
        small_images = small_images / 2.0
        small_images = tf.image.resize(small_images, (64,64), method = "gaussian")
        self.plot_images(num_rows, num_cols, figsize, small_images*2.0)
        
        if from_annotation:
            annotation = tf.expand_dims(annotation, axis = 0)
            embedding = process_text(annotation)
            embedding = tf.expand_dims(embedding, axis = 0)
            embeddings = tf.repeat(embedding, num_images , axis = 0)
            
        if ex_rate > 0 or not from_annotation:
             # 1.2 Преобразуем negative prompt в эмбеддинг
            negative_prompt = " "
            negative_prompt = tf.expand_dims(negative_prompt, axis = 0)
            negative_embedding = process_text(negative_prompt)
            negative_embedding = tf.expand_dims(negative_embedding, axis = 0)
            negative_embeddings = tf.repeat(negative_embedding, num_images , axis = 0)
            
        if not from_annotation:
            embeddings = negative_embeddings
            
        samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32) 
        bar = IntProgress(min=0, max=self.timesteps)
        display(bar)
        for t in reversed(range(0, self.timesteps)):
            bar.value+=1
            tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64)
            if ema_mode:
                pred_noise = self.ema_network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
            else:
                pred_noise = self.network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
            if ex_rate >0:
                pred_negative_noise = self.ema_network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(negative_embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
                #Экстраполяция шума от negative в сторону positive
                resulted_noise = pred_noise + (pred_noise - pred_negative_noise)*ex_rate
            else:
                resulted_noise = pred_noise     
            samples = self.gdf_util.p_sample(resulted_noise, samples, tt, clip_denoised=True)
        
        generated_samples = samples
        self.plot_images(num_rows, num_cols, figsize, generated_samples*2.0)
        return (tf.clip_by_value((generated_samples*2.0) * 127 + 127, 0.0, 255.0).numpy().astype(np.uint8))

In [None]:
total_timesteps = 300

# Get an instance of the Gaussian Diffusion utilities
gdf_util = GaussianDiffusion(timesteps=total_timesteps)

#устанавливаем то, что не сохраняется в h5 файл
embedding_dims = 32
embedding_max_frequency = 1000.0
image_size = 128
img_channels = 3

network = tf.keras.models.load_model('sr3 32 to 64.h5')

# Get the model
SR64model = SR64DiffusionModel(network=network, ema_network=network, gdf_util=gdf_util, timesteps=total_timesteps)

In [None]:
class SR128DiffusionModel(keras.Model):
    def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
        super().__init__()
        self.network = network
        self.ema_network = ema_network
        self.timesteps = timesteps
        self.gdf_util = gdf_util

    def plot_images(self, num_rows, num_cols, figsize, images):
        images = (tf.clip_by_value(images * 127 + 127, 0.0, 255.0).numpy().astype(np.uint8))
        _, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
        for i, image in enumerate(images):
            if num_rows == 1:
                if num_cols == 1:
                    ax.imshow(image)
                    ax.axis("off")
                else:
                    ax[i].imshow(image)
                    ax[i].axis("off")
            else:
                ax[i // num_cols, i % num_cols].imshow(image)
                ax[i // num_cols, i % num_cols].axis("off")

        plt.tight_layout()
        plt.show()
        
    def run_generation(self, small_images, num_rows=2, num_cols=8, figsize=(5, 2), annotation = " ", ema_mode = True, ex_rate = 0, from_annotation = False):
        num_images = num_rows * num_cols
        small_images = small_images / 127 - 1
        small_images = small_images / 2.0
        small_images = tf.image.resize(small_images, (128,128), method = "gaussian")
        self.plot_images(num_rows, num_cols, figsize, small_images*2.0)
        
        if from_annotation:
            annotation = tf.expand_dims(annotation, axis = 0)
            embedding = process_text(annotation)
            embedding = tf.expand_dims(embedding, axis = 0)
            embeddings = tf.repeat(embedding, num_images , axis = 0)
            
        if ex_rate > 0 or not from_annotation:
             # 1.2 Преобразуем negative prompt в эмбеддинг
            negative_prompt = " "
            negative_prompt = tf.expand_dims(negative_prompt, axis = 0)
            negative_embedding = process_text(negative_prompt)
            negative_embedding = tf.expand_dims(negative_embedding, axis = 0)
            negative_embeddings = tf.repeat(negative_embedding, num_images , axis = 0)
            
        if not from_annotation:
            embeddings = negative_embeddings
            
        samples = tf.random.normal(shape=(num_images, 128, 128, 3), dtype=tf.float32) 
        bar = IntProgress(min=0, max=self.timesteps)
        display(bar)
        for t in reversed(range(0, self.timesteps)):
            bar.value+=1
            tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64)
            if ema_mode:
                pred_noise = self.ema_network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
            else:
                pred_noise = self.network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
            if ex_rate >0:
                pred_negative_noise = self.ema_network.predict([samples, tf.reshape(tt, shape = (-1, 1,1,1)), tf.reshape(negative_embeddings, shape = (-1, 1,1,512)), small_images], verbose=0, batch_size=num_images)
                #Экстраполяция шума от negative в сторону positive
                resulted_noise = pred_noise + (pred_noise - pred_negative_noise)*ex_rate
            else:
                resulted_noise = pred_noise     
            samples = self.gdf_util.p_sample(resulted_noise, samples, tt, clip_denoised=True)
        
        generated_samples = samples
        self.plot_images(num_rows, num_cols, figsize, generated_samples*2.0)
        result = tf.clip_by_value((generated_samples*2.0) * 127 + 127, 0.0, 255.0)
        return tf.clip_by_value(result, 0.0, 255.0).numpy().astype(np.uint8)

In [None]:
total_timesteps = 300

gdf_util = GaussianDiffusion(timesteps=total_timesteps)

embedding_dims = 32
embedding_max_frequency = 1000.0
image_size = 128
img_channels = 3

network = tf.keras.models.load_model('sr3 64 to 128.h5')

SR128model = SR128DiffusionModel(network=network, ema_network=network, gdf_util=gdf_util, timesteps=total_timesteps)

In [None]:
SR64_result = SR64model.run_generation(tf.expand_dims(small_image*256, axis = 0), num_rows=1, num_cols=1, ex_rate = 3, from_annotation = True, annotation = "A grass field")[0]/256
SR128_result = SR128model.run_generation(tf.expand_dims(SR64_result*256, axis = 0), num_rows=1, num_cols=1, ex_rate = 3, from_annotation = True, annotation = "A grass field")[0]/256

In [None]:
_, ax = plt.subplots(2, 3, figsize=(20, 10))

ax[0][0].imshow(tf.clip_by_value(original_image, 0, 1))
ax[0][0].axis("off")
ax[0][0].set_title('Reference', fontsize=16)

ax[0][1].imshow(tf.clip_by_value(bicubic_result, 0, 1))
ax[0][1].axis("off")
ax[0][1].set_title('Bicubic', fontsize=16)

ax[0][2].imshow(tf.clip_by_value(esrgan_result, 0, 1))
ax[0][2].axis("off")
ax[0][2].set_title('Esrgan', fontsize=16)

ax[1][0].imshow(tf.clip_by_value(regression_result, 0, 1))
ax[1][0].axis("off")
ax[1][0].set_title('Regression', fontsize=16)

ax[1][1].imshow(tf.clip_by_value(SR3_result, 0, 1))
ax[1][1].axis("off")
ax[1][1].set_title('SR3', fontsize=16)

ax[1][2].imshow(tf.clip_by_value(SR128_result, 0, 1))
ax[1][2].axis("off")
ax[1][2].set_title('Cascade SR3', fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
SR3_result_text = SR3model.run_generation(tf.expand_dims(small_image*256, axis = 0), num_rows=1, num_cols=1, ex_rate = 5, from_annotation = True, annotation = "A dog on the green grass")[0]/256
SR3_result_no_text = SR3model.run_generation(tf.expand_dims(small_image*256, axis = 0), num_rows=1, num_cols=1, ex_rate = 0, from_annotation = False)[0]/256
_, ax = plt.subplots(1, 2, figsize=(20, 10))

ax[0].imshow(tf.clip_by_value(SR3_result_no_text, 0, 1))
ax[0].axis("off")
ax[0].set_title('No text', fontsize=16)

ax[1].imshow(tf.clip_by_value(SR3_result_text, 0, 1))
ax[1].axis("off")
ax[1].set_title('With text', fontsize=16)


plt.tight_layout()
plt.show()

In [None]:
def parse_record(record):
    feature_description = {
        'a': tf.io.FixedLenFeature([], tf.string),
        'b': tf.io.FixedLenFeature([], tf.string),
        'c': tf.io.FixedLenFeature([], tf.string)
    }
    parsed_record = tf.io.parse_single_example(record, feature_description)
    
    a = tf.io.parse_tensor(parsed_record['a'], out_type=tf.string)
    b = tf.io.parse_tensor(parsed_record['b'], out_type=tf.float32)
    a = preprocess_image(a[0])
    return a, b

def preprocess_image(img_filename):
    img_filename = tf.strings.regex_replace(img_filename, " ", "")
    img_filename = tf.strings.regex_replace(img_filename, "\\\\", "/")
    img = tf.io.read_file('C:/users/user/desktop/' + img_filename)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.cast(img, tf.float32)
    return img

dataset = tf.data.TFRecordDataset('C:/users/user/ai tests/superresolution/halfmillion_custom_endless_no_images_shuffled.tfrecord')
dataset = dataset.map(parse_record)
dataset = dataset.shuffle(16000).prefetch(buffer_size=tf.data.AUTOTUNE).batch(32).shuffle(50)

In [None]:
#Расчитаем FID метрику
#подробнее: https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/

from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input

def calculate_fid(model, images1, images2):
    act1 = model.predict(images1)
    act2 = model.predict(images2)
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

In [None]:
cc = 0
for images, annotations in dataset.take(1):
    pass

small_images = tf.image.resize(images, (32, 32), method = "bicubic")
bicubic_result = tf.image.resize(small_images, (128, 128), method = "bicubic")

esrgan_result = model_upscaler(small_images)

regression_result = regression_sr(small_images/256)*256

SR3_result = SR3model.run_generation(small_images, num_rows=4, num_cols=8, ex_rate = 0, figsize = (20, 10), from_annotation = False)

SR64_result = SR64model.run_generation(small_images, num_rows=4, num_cols=8, ex_rate = 0, figsize = (20, 10), from_annotation = False)
SR128_result = SR128model.run_generation(SR64_result, num_rows=4, num_cols=8, ex_rate = 0, figsize = (20, 10), from_annotation = False)


images = preprocess_input(images)
bicubic_result = preprocess_input(bicubic_result)
esrgan_result = preprocess_input(esrgan_result)
regression_result = preprocess_input(regression_result)
SR3_result = preprocess_input(SR3_result)
SR128_result = preprocess_input(SR128_result)

images = tf.image.resize(images, (299,299))
bicubic_result = tf.image.resize(bicubic_result, (299,299))
esrgan_result = tf.image.resize(esrgan_result, (299,299))
regression_result = tf.image.resize(regression_result, (299,299))
SR3_result = tf.image.resize(SR3_result, (299,299))
SR128_result = tf.image.resize(SR128_result, (299,299))

print("Bicubic FID: ", calculate_fid(inception_model, images, bicubic_result))
print("Esrgan FID: ", calculate_fid(inception_model, images, esrgan_result))
print("Regression FID: ", calculate_fid(inception_model, images, regression_result))
print("SR3 FID: ", calculate_fid(inception_model, images, SR3_result))
print("Cascade SR3 FID: ", calculate_fid(inception_model, images, SR128_result))
