In [None]:
import math
import os.path

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import activations

from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import time
from skimage import transform, metrics
from umap import UMAP
import datetime
from scipy.signal import argrelextrema
import os
import pandas as pd
import scipy
from skimage import io
import csv
from sklearn.manifold import TSNE
from skimage.metrics import structural_similarity as ssim
import math
import plotly.express as px

from src.nn import ConvCVAE, Decoder, Encoder
from src.nn import RSU7, RSU6, RSU5, RSU4, RSU4F, ConvBlock
from src.nn_utils import SaveImageCallback

gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
np.prod((64, 64))

In [None]:
def train_step(data, model, optimizer):
    with tf.GradientTape() as tape:
        model_output = model(data, is_train=True)

    trainable_variables = model.trainable_variables
    grads = tape.gradient(model_output['loss'], trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))

    total_loss = model_output['loss'].numpy().mean()
    recon_loss = model_output['reconstr_loss'].numpy().mean()
    latent_loss = model_output['latent_loss'].numpy().mean()

    return total_loss, recon_loss, latent_loss

In [None]:
label_dim = 40
image_dim = [64, 64, 3]
latent_dim = 128
beta = 0.65


In [None]:
learning_rate = 0.001
train_size = 0.01
batch_size = 32

# Model
encoder = Encoder(latent_dim)
decoder = Decoder()
model = ConvCVAE(
    encoder,
    decoder,
    label_dim=label_dim,
    latent_dim=latent_dim,
    beta=beta,
    image_dim=image_dim)

# Optiizer
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
image_shape = (1024, 1024, 1)

inputs = tf.keras.Input(shape=image_shape)
encoder = Encoder(latent_dim)
out = encoder(inputs)

model = tf.keras.Model(inputs, out)

In [None]:
model.summary()

In [None]:
class Encoder(tf.keras.Model):

    def __init__(self, latent_dim, label_dim):
        super(Encoder, self).__init__()

        self.latent_dim = latent_dim
        self.label_dim = label_dim

        self.stage1 = RSU7(16, 64)
        self.pool12 = layers.MaxPool2D((2, 2), 2)

        self.stage2 = RSU6(32, 64)
        self.pool23 = layers.MaxPool2D((2, 2), 2)

        self.stage3 = RSU5(64, 128)
        self.pool34 = layers.MaxPool2D((2, 2), 2)

        self.stage4 = RSU4(128, 256)
        self.pool45 = layers.MaxPool2D((2, 2), 2)
        #out_w_h=64

        self.stage5 = RSU4F(256, 512)
        self.pool56 = layers.MaxPool2D((2, 2), 2)

    def __call__(self, inputs):
        # Encoder block 1

        input_image,input_label=inputs[0],inputs[1]

        x_label = layers.Dense(128)(input_label)
        x_label = layers.LeakyReLU()(x_label)
        x_label = layers.Dense(self.label_dim)(x_label)
        x_label = layers.LeakyReLU()(x_label)

        hx1 = self.stage1(input_image)

        hx = self.pool12(hx1)
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        hx5 = self.stage5(hx)
        x = self.pool56(hx5)
        x=layers.GlobalAveragePooling2D()(x)

        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(512)(x)
        x = layers.LeakyReLU()(x)

        x = tf.keras.layers.Dense(self.latent_dim * 2)(x)
        x = layers.LeakyReLU()(x)

        return x,x_label


In [None]:
latent_dim = (200,)
label_dim = (5,)

In [None]:
encoder=Encoder(latent_dim=200,label_dim=5)
image_input=layers.Input(shape=(1024,1024,1))
label_input=layers.Input(shape=(5,))

outputs=encoder([image_input,label_input])
model=tf.keras.Model(inputs=[image_input,label_input],outputs=outputs)

In [None]:
model.summary()

In [None]:
class Decoder(tf.keras.Model):

    def __init__(self, label_dim, batch_size=32, out_ch=1):
        super(Decoder, self).__init__()

        self.batch_size = batch_size
        self.label_dim = label_dim
        self.dense = tf.keras.layers.Dense(4 * 4 * self.batch_size * 8)
        self.reshape = tf.keras.layers.Reshape(target_shape=(4, 4, self.batch_size * 8))

        self.stage6 = RSU4F(256, 512)
        self.stage5d = RSU4F(256, 512)
        self.stage4d = RSU4(128, 256)
        self.stage3d = RSU5(64, 128)
        self.stage2d = RSU6(32, 64)
        self.stage1d = RSU7(16, out_ch)

        self.upsample_1 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.upsample_2 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.upsample_3 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.upsample_4 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.upsample_5 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.upsample_6 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')

    def __call__(self, z_inputs):
        # Reshape input
        # z_image_v, labels = tf.split(z_inputs, axis=1, num_or_size_splits=2)
        z_data = tf.concat(z_inputs, axis=1)
        print(z_data)
        x = layers.Dense(256)(z_data)
        x = tf.nn.leaky_relu(x)

        x_label = layers.Dense(128)(x)
        x_label = layers.LeakyReLU()(x_label)
        x_label = layers.Dense(self.label_dim)(x_label)
        x_label = layers.LeakyReLU()(x_label)

        x = layers.Reshape(target_shape=(16,16,1))(x)


        hx6 = self.stage6(x)
        hx6up = self.upsample_6(hx6)

        hx5d = self.stage5d(hx6up)
        hx5dup = self.upsample_5(hx5d)

        hx4d = self.stage4d(hx5dup)
        hx4dup = self.upsample_4(hx4d)


        hx3d = self.stage3d(hx4dup)
        hx3dup = self.upsample_3(hx3d)

        hx2d = self.stage2d(hx3dup)
        hx2dup = self.upsample_2(hx2d)

        hx1d = self.stage1d(hx2dup)
        x = self.upsample_1(hx1d)

        x = activations.sigmoid(x)

        return x, x_label

In [None]:
latent_dim=(200,)
label_dim=(5,)

In [None]:
decoder=Decoder(label_dim=5)
z_input=layers.Input(shape=latent_dim)
label_input=layers.Input(shape=label_dim)
outputs=decoder([z_input,label_input])

In [None]:
model=tf.keras.Model(inputs=[z_input,label_input],outputs=outputs)
model.summary()

In [None]:

class ConvCVAE(tf.keras.Model):

    def __init__(self,
                 encoder,
                 decoder,
                 label_dim,
                 latent_dim,
                 beta=1,
                 image_dim=(64, 64, 3)):
        super(ConvCVAE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.label_dim = label_dim
        self.latent_dim = latent_dim
        self.beta = beta
        self.image_dim = image_dim

    def __call__(self, inputs):
        input_img, input_label, conditional_input = self.conditional_input(inputs)

        z_mean, z_log_var = tf.split(self.encoder(conditional_input), num_or_size_splits=2, axis=1)
        z_cond = self.reparametrization(z_mean, z_log_var, input_label)
        logits = self.decoder(z_cond)

        recon_img = tf.nn.sigmoid(logits)

        # Loss computation #
        latent_loss = - 0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
                                            axis=-1)  # KL divergence

        # очень странная метрика для изображений
        reconstr_loss = np.prod((64, 64)) * tf.keras.losses.binary_crossentropy(tf.keras.backend.flatten(input_img),
                                                                                tf.keras.backend.flatten(
                                                                                    recon_img))  # over weighted MSE
        loss = reconstr_loss + self.beta * latent_loss  # weighted ELBO loss
        loss = tf.reduce_mean(loss)

        return {
            'recon_img': recon_img,
            'latent_loss': latent_loss,
            'reconstr_loss': reconstr_loss,
            'loss': loss,
            'z_mean': z_mean,
            'z_log_var': z_log_var
        }

    def conditional_input(self, inputs):
        """ Builds the conditional input and returns the original input images, their labels and the conditional input."""

        input_img = tf.keras.layers.InputLayer(input_shape=self.image_dim, dtype='float32')(inputs[0])
        input_label = tf.keras.layers.InputLayer(input_shape=(self.label_dim,), dtype='float32')(inputs[1])
        labels = tf.reshape(inputs[1], [-1, 1, 1, self.label_dim])  # batch_size, 1, 1, label_size
        ones = tf.ones([inputs[0].shape[0]] + self.image_dim[0:-1] + [self.label_dim])  # batch_size, 64, 64, label_size
        labels = ones * labels  # batch_size, 64, 64, label_size
        conditional_input = tf.keras.layers.InputLayer(
            input_shape=(self.image_dim[0], self.image_dim[1], self.image_dim[2] + self.label_dim), dtype='float32')(
            tf.concat([inputs[0], labels], axis=3))

        return input_img, input_label, conditional_input

    def reparametrization(self, z_mean, z_log_var, input_label):
        """ Performs the riparametrization trick"""

        eps = tf.random.normal(shape=(input_label.shape[0], self.latent_dim), mean=0.0, stddev=1.0)
        z = z_mean + tf.math.exp(z_log_var * .5) * eps
        z_cond = tf.concat([z, input_label], axis=1)  # (batch_size, label_dim + latent_dim)

        return z_cond