In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow as tf
from tensorflow import keras
x = np.load('features.npy')

In [3]:
batch_size=16
latent_dim=512
epochs=10

In [4]:
x.shape[0]

1746

In [5]:
contact_maps = x[:, :120*120]

In [6]:
contact_maps /= contact_maps.max()

In [7]:
contact_maps.shape

(1746, 14400)

In [8]:
contact_maps = contact_maps.reshape([-1, 120, 120, 1])

In [9]:
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        layers.Dense(5 * 5 * 256),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((5, 5, 256)),
        layers.Conv2DTranspose(128, (5, 5), strides=(3, 3), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(32, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (4, 4), strides=(1, 1), padding="same"),
    ],
    name="generator",
)

In [10]:
generator.summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 6400)              3283200   
_________________________________________________________________
batch_normalization (BatchNo (None, 6400)              25600     
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 6400)              0         
_________________________________________________________________
reshape (Reshape)            (None, 5, 5, 256)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 15, 15, 128)       819328    
_________________________________________________________________
batch_normalization_1 (Batch (None, 15, 15, 128)       512       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 15, 15, 128)       0 

In [11]:
discriminator = keras.Sequential(
    [
        keras.Input(shape=(120, 120, 1)),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(32, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(latent_dim),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1)
    ],
    name="discriminator",
)

In [12]:
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
leaky_re_lu_5 (LeakyReLU)    (None, 120, 120, 1)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 60, 60, 32)        544       
_________________________________________________________________
batch_normalization_5 (Batch (None, 60, 60, 32)        128       
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 60, 60, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 30, 30, 64)        32832     
_________________________________________________________________
batch_normalization_6 (Batch (None, 30, 30, 64)        256       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 30, 30, 64)      

In [13]:
generator_lr = 0.001
discriminator_lr = 0.0002

In [14]:
from tensorflow.keras import backend
 
# implementation of wasserstein loss
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

In [15]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim, batch_size):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.batch_size = batch_size

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        print('got here 0')
        batch_size = tf.shape(real_images)[0]
        print('got here 1')
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        print('got here 2')
        
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        print('got here 3')
        
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        print('got here 4')
        
        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        print('got here 5')
        
        # Add random noise to the labels - important trick!
#         labels += 0.01 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions_d = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions_d)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        return {"d_loss": d_loss, "g_loss": g_loss}
    def call(inputs):
        random_latent_vectors = tf.random.normal(shape=(self.batch_size, self.latent_dim))
        return self.generator(random_latent_vectors)


In [16]:
contact_maps.shape

(1746, 120, 120, 1)

In [17]:
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim, batch_size=batch_size)

In [18]:
gan.compile(
        d_optimizer=keras.optimizers.Adam(learning_rate=discriminator_lr),
        g_optimizer=keras.optimizers.Adam(learning_rate=generator_lr),
        loss_fn=keras.losses.BinaryCrossentropy(from_logits=True)
    )

In [None]:
gan.train_step(contact_maps[:10,...])

got here 0
got here 1
got here 2


In [None]:
history = gan.fit(contact_maps[:1,...], batch_size=batch_size, epochs=epochs)

Epoch 1/10
