### Librerías

In [20]:
import numpy as np
import os
import sys
import warnings
import matplotlib.pyplot as plt
from google.colab import drive

import tensorflow as tf
import tensorflow.keras.backend as K
import keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Dense, Conv2D, Dropout, BatchNormalization, Conv2DTranspose, UpSampling2D, Flatten, Activation, Reshape 
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.optimizers import Adam

%matplotlib inline

In [21]:
tf.compat.v1.disable_eager_execution()
warnings.filterwarnings("ignore", category=FutureWarning)

Acceder a los datos de Google Drive

In [None]:
drive.mount('/content/drive')

Cargar datos

In [22]:
input_images = "/content/drive/MyDrive/mamma_disco_duro.npy"
data = np.load(input_images, allow_pickle=True)

In [23]:
GENERATOR_PATH = "/content/drive/MyDrive/mammi_disco_duro_WGAN_BATCH_SIZE_8_GP_10_models/"

### Clase WGAN: contiene inicialización de parámetros y entrenamiento de la red

In [29]:
class WGAN_GP:
    def __init__(self, img_shape, sample_shape=(4,4), latent=128):

        """

        INICIALIZACIÓN PARÁMETROS Y ARQUITECTURAS GENERADOR Y DISCRIMINADOR

        """

        self.img_shape = img_shape
        self.sample_shape = sample_shape
        self.latent_dim = latent
        self.n_critic = 5

        self.generator_model = Generator(self.img_shape, self.latent_dim).architecture()
        self.discriminator_model = Discriminator(self.img_shape, self.latent_dim).architecture()


        """

        PARA EL PROCESO DE TRANSFER LEARNING. CARGAR PESOS DEL MEJOR MODELO CON IMÁGENES USF

        self.generator_model.load_weights(GENERATOR_PATH+'generator_weights.h5')
        self.discriminator_model.load_weights(DISCRIMINATOR_PATH+'discriminator_weights.h5')
        """

        discriminator = Discriminator(self.img_shape, self.latent_dim, self.generator_model, self.discriminator_model)
        self.discriminator_train_func = discriminator.build()

        combined = Generator(self.img_shape, self.latent_dim, self.generator_model, self.discriminator_model)
        self.generator = combined.generator
        self.combined_train_func = combined.build()

    

    def train_one_epoch(self, X_train, epoch, batch_size, valid, fake, dummy):

        """

        PROCESO ENTRENAMIENTO PARA UNA EPOCH

        """

        for _ in range(self.n_critic):
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            alpha = np.random.uniform(size=(batch_size, 1, 1 ,1))
            d_loss_real, d_loss_fake = self.discriminator_train_func([imgs, noise, alpha])
            d_loss = d_loss_real - d_loss_fake

        noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
        g_loss, = self.combined_train_func([noise])

        print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss, g_loss))

    def general_training(self, data, epochs, batch_size=128, sample_interval=200):

        """

        PROCESO ENTRENAMIENTO GENERAL

        """

        X_train = data

        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        valid =-np.ones((batch_size, 1))
        fake = np.ones((batch_size, 1))
        dummy = np.zeros((batch_size, 1))

        for epoch in range(1, epochs+1):
            self.train_one_epoch(X_train, epoch, batch_size, valid, fake, dummy)

            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                #print(GENERATOR_PATH+'wgangp_generator_%d.h5' % (epoch))
                self.generator.save(GENERATOR_PATH+'wgangp_generator_%d.h5' % (epoch))

    def sample_images(self, epoch):

        """

        OBTENCIÓN IMÁGENES SINTÉTICAS

        """

        r, c = self.sample_shape
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        samples = "/content/drive/MyDrive/mammi_disco_duro_WGAN_BATCH_SIZE_8_GP_10_images/"
        if not os.path.exists(samples):
            os.mkdir(samples)

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("%s%d.png" % (samples, epoch))
        plt.close()


### Clase Generador

In [30]:
class Generator:
    def __init__(self, img_shape=None, latent=None, generator_model=None, discriminator_model=None):

        """

        INICIALIZACIÓN PARÁMETROS
        
        """

        self.img_shape = img_shape
        self.latent_dim = latent

        self.generator = generator_model
        self.discriminator = discriminator_model

    
    def build(self):

        """

        FASE ENTRENAMIENTO GENERADOR

        """
        self.discriminator.trainable = False
        self.generator.trainable = True

        noise = Input(shape=(self.latent_dim,))
        z = self.generator(noise)
        loss = -K.mean(self.discriminator(z))
        
        training_updates = Adam(lr=0.0001, beta_1=0.0, beta_2=0.9).get_updates(params=self.generator.trainable_weights, loss=loss)
        train_func = K.function([noise], [loss], training_updates)

        return train_func

    def architecture(self):

        """

        ARQUITECTURA GENERADOR

        """


        model = Sequential()

        model.add(Dense(64 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 64)))

        model.add(UpSampling2D())
        model.add(Conv2DTranspose(64/2, kernel_size=4, padding="same"))
        model.add(Activation("relu"))

        model.add(UpSampling2D())
        model.add(Conv2DTranspose(64/4, kernel_size=4, padding="same"))
        model.add(Activation("relu"))

        model.add(UpSampling2D())
        model.add(Conv2DTranspose(64/8, kernel_size=4, padding="same"))
        model.add(Activation("relu"))

        model.add(UpSampling2D())
        model.add(Conv2DTranspose(64/16, kernel_size=4, padding="same"))
        model.add(Activation("relu"))

        model.add(UpSampling2D())
        model.add(Conv2DTranspose(64/16, kernel_size=4, padding="same"))
        model.add(Activation("relu"))

        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(Activation("relu"))
        model.add(Conv2D(self.img_shape[2], kernel_size=4, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(inputs=noise, outputs=img)

### Clase Discriminador

In [31]:
class Discriminator:
    def __init__(self, img_shape=None, latent=None, generator_model=None, discriminator_model=None):

        """
      
        INICIALIZACIÓN PARÁMETROS


        """

        self.img_shape = img_shape
        self.latent_dim = latent

        self.generator = generator_model
        self.discriminator = discriminator_model

    def get_interpolated(self, real_img, fake_img):

        """

        OBTENCIÓN DE LA MUESTRA INTERPOLADA MEDIANTE MUESTRAS REALES Y FALSAS

        """
        alpha = K.placeholder(shape=(None,1,1,1))
        interpolated_img = Input(shape=self.img_shape, 
                                tensor=alpha*real_img + (1-alpha)*fake_img)

        return interpolated_img, alpha

    def gradient_penalty_loss(self, real_img, fake_img, interpolated_img):

        """

        CÁLCULO FUNCIÓN DE PÉRDIDA A PARTIR DE DICHA MUESTRA INTERPOLADA 
        PREVIAMENTE CALCULADA

        """

        loss_real = K.mean(self.discriminator(real_img))
        loss_fake = K.mean(self.discriminator(fake_img))

        grad_mixed = K.gradients(self.discriminator(interpolated_img), [interpolated_img])[0]
        gradients_sqr = K.square(grad_mixed)
        norm_grad_mixed = K.sqrt(K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape))))
        grad_penalty = K.mean(K.square(norm_grad_mixed-1))

        loss = loss_fake - loss_real + 10 * grad_penalty

        return loss_real, loss_fake, loss

    def build(self):

        """

        FASE ENTRENAMIENTO DISCRIMINADOR

        """

        self.generator.trainable = False
        self.discriminator.trainable = True

        real_img = Input(shape=self.img_shape)
        noise = Input(shape=(self.latent_dim,))
        fake_img = self.generator(noise)

        interpolated_img,alpha = self.get_interpolated(real_img, fake_img)

        loss_real, loss_fake, loss = self.gradient_penalty_loss(real_img, fake_img, interpolated_img)

        training_updates = Adam(lr=0.0001, beta_1=0.0, beta_2=0.9).get_updates(params=self.discriminator.trainable_weights, loss=loss)
        discriminator_train = K.function([real_img, noise, alpha],
                                [loss_real, loss_fake],    
                                training_updates)

        return discriminator_train

    def architecture(self):

        """

        ARQUITECTURA DEL DISCRIMINADOR

        """

        model = Sequential()

        model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64*2, kernel_size=3, strides=2, padding="same"))

        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64*4, kernel_size=3, strides=2, padding="same"))

        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64*8, kernel_size=3, strides=1, padding="same"))

        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(inputs=img, outputs=validity)


In [None]:
if __name__ == "__main__":
    wgan_gp = WGAN_GP(img_shape=(224,224,1))
    wgan_gp.general_training(data=data, epochs=4000, batch_size=8, sample_interval=100)