# Generative Model

## 5:34 PM 6.9.22

In [1]:
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [2]:
for i in plt.get_fignums():
    plt.close(i)

## Load SILVA Dataset

In [3]:
from Bio import SeqIO
import numpy as np
from tqdm.notebook import tqdm
s = np.array([record for record in tqdm(SeqIO.parse('silva.fasta', "fasta"))], dtype=object)

0it [00:00, ?it/s]

In [4]:
import multiprocessing as mp
from tqdm.notebook import tqdm
def fn(i):
    return np.array(list(str(i.seq)[:300]))
with mp.Pool() as p:
    string_seqs = np.array(list(tqdm(p.imap(fn, s, chunksize=100), total=s.shape[0])))

  0%|          | 0/227331 [00:00<?, ?it/s]

In [5]:
BASES = ['A', 'U', 'G', 'C']
def fn(i):
    enc_seq = np.empty((300, 5), dtype=np.intc)
    for bp in range(string_seqs.shape[1]):
        idx = BASES.index(i[bp]) if i[bp] in BASES else 4
        enc_seq[bp] = [1 if j == idx else 0 for j in range(5)]
    return enc_seq
with mp.Pool() as p:
    seqs = np.asarray(list(tqdm(p.imap(fn, string_seqs, chunksize=100), total=string_seqs.shape[0])))

  0%|          | 0/227331 [00:00<?, ?it/s]

In [6]:
desc = np.array([i.description.split(' ')[1] for i in s])
num_items = np.vectorize(lambda i: len(i.split(';')))(desc)
parsable = num_items == 7
raw_tax = desc[parsable]
tax = []
for i in raw_tax:
    tax.append(i.split(';'))
tax = np.array(tax)
seqs = seqs[parsable]

In [7]:
from sklearn.model_selection import train_test_split
train, val = train_test_split(seqs, test_size=.1)

## VAE Definition

In [8]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [9]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [10]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [261]:
latent_dim = 50

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
    inputs = layers.Input((300, 5))
    den = layers.Dense(5)(inputs)
    res = layers.Reshape((100, 5 * 3))(den)
    
#     conv = layers.Conv1D(20, 3)(res)
#     maxpool = layers.MaxPooling1D()(conv)
#     res2 = layers.Flatten()(maxpool)
#     norm = layers.BatchNormalization()(res2)
#     den = layers.Dense(100 * 50 * 3)(norm)
#     res = layers.Reshape((100, 50 * 3))(den)

    trans = TransformerBlock(5 * 3, 4, 100)(res)
    norm = layers.BatchNormalization()(trans)

#     trans = TransformerBlock(50 * 3, 4, 100)(norm)
#     norm = layers.BatchNormalization()(trans)

    conv = layers.Conv1D(20, 3)(norm)
    maxpool = layers.MaxPooling1D()(conv)
    res2 = layers.Flatten()(maxpool)
    norm = layers.BatchNormalization()(res2)

    z_mean = layers.Dense(latent_dim, name="z_mean")(norm)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(norm)
    z = Sampling()([z_mean, z_log_var])
    encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
    encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_26 (InputLayer)          [(None, 300, 5)]     0           []                               
                                                                                                  
 dense_107 (Dense)              (None, 300, 5)       30          ['input_26[0][0]']               
                                                                                                  
 reshape_42 (Reshape)           (None, 100, 15)      0           ['dense_107[0][0]']              
                                                                                                  
 transformer_block_32 (Transfor  (None, 100, 15)     6970        ['reshape_42[0][0]']             
 merBlock)                                                                                  

In [262]:
with mirrored_strategy.scope():
    latent_inputs = keras.Input(shape=(latent_dim,))
    flat = layers.Flatten()(latent_inputs)
    den = layers.Dense(100 * 5 * 3)(flat)
    res = layers.Reshape((100, 5 * 3))(den)
    
#     conv = layers.Conv1D(20, 3)(res)
#     maxpool = layers.MaxPooling1D()(conv)
#     res2 = layers.Flatten()(maxpool)
#     norm = layers.BatchNormalization()(res2)
#     den = layers.Dense(100 * 50 * 3)(norm)
#     res = layers.Reshape((100, 50 * 3))(den)

    trans = TransformerBlock(5 * 3, 4, 100)(res)
    norm = layers.BatchNormalization()(trans)

#     trans = TransformerBlock(50 * 3, 4, 100)(norm)
#     norm = layers.BatchNormalization()(trans)

    conv = layers.Conv1D(20, 3)(norm)
    maxpool = layers.MaxPooling1D()(conv)
    res2 = layers.Flatten()(maxpool)
    norm = layers.BatchNormalization()(res2)
    
    den = layers.Dense(300 * 5)(norm)
    res = layers.Reshape((5, 300))(den)
    mid = layers.Dense(300, activation='sigmoid')(res)
    out = layers.Reshape((300, 5))(mid)
    
    decoder = keras.Model(latent_inputs, out, name="decoder")
    decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_27 (InputLayer)       [(None, 50)]              0         
                                                                 
 flatten_38 (Flatten)        (None, 50)                0         
                                                                 
 dense_110 (Dense)           (None, 1500)              76500     
                                                                 
 reshape_43 (Reshape)        (None, 100, 15)           0         
                                                                 
 transformer_block_33 (Trans  (None, 100, 15)          6970      
 formerBlock)                                                    
                                                                 
 batch_normalization_59 (Bat  (None, 100, 15)          60        
 chNormalization)                                          

In [263]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=1
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss ** 2
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    
    def test_step(self, data):
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=1
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss ** 2
        
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    
    def __call__(self, data, training=False):
        z_mean, z_log_var, z = self.encoder(data)
        return self.decoder(z)

In [264]:
with mirrored_strategy.scope():
    vae = VAE(encoder, decoder)
    vae.compile(optimizer=keras.optimizers.Adam())

## Training

In [266]:
vae.fit(train, validation_data=(val,), epochs=60, batch_size=1000)

Epoch 1/60
  1/163 [..............................] - ETA: 9s - loss: 61.5986 - reconstruction_loss: 42.5179 - kl_loss: 4.3679

2022-06-12 14:03:37.711717: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_9"
op: "FlatMapDataset"
input: "PrefetchDataset/_8"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: -2
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_slice_batch_indices_1125419"
    }
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\023FlatMapDataset:8031"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: -1
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT64
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_P



2022-06-12 14:03:40.847045: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_9"
op: "FlatMapDataset"
input: "PrefetchDataset/_8"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: -2
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_slice_batch_indices_1126615"
    }
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\023FlatMapDataset:8061"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: -1
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT64
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_P

Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60


Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
  1/163 [..............................] - ETA: 2s - loss: 60.6041 - reconstruction_loss: 41.1254 - kl_loss: 4.4134

KeyboardInterrupt: 

In [268]:
# vae.save('Models/vae/full_model')
encoder.save('Models/vae/encoder')
decoder.save('Models/vae/decoder')



### Evaluation

In [267]:
tf.reduce_mean((val[:10000] - vae(val[:10000])) ** 2)

<tf.Tensor: shape=(), dtype=float32, numpy=0.042191874>

In [269]:
plt.close()
plt.hist(encoder(val)[2].numpy().mean(axis=-1))
plt.savefig('out.png')

## Threshold Finding

In [270]:
from tqdm.notebook import tqdm
a, b = np.unique(tax[:, 6], return_counts=True)
species = a[b > 1]
rng = np.random.default_rng()
pairs = []
for i in tqdm(range(500)):
    spec = rng.choice(species, 1)
    choose_from = seqs[tax[:, 6] == spec]
    pair = rng.choice(choose_from, 2, replace=False)
    pairs.append(pair)
pairs = np.array(pairs)
((pairs[:, 0] - pairs[:, 1]) ** 2).mean()

  0%|          | 0/500 [00:00<?, ?it/s]

0.16497333333333333

In [276]:
a = encoder(val)[2]
b = a + tf.keras.backend.random_normal(shape=a.shape) * .95
pair_diff = tf.reduce_mean((decoder(a) - decoder(b)) ** 2).numpy()
init_diff = tf.reduce_mean((decoder(b) - val) ** 2).numpy()
pair_diff, init_diff

(0.121285945, 0.16438128)

Optimal threshold: .95

## Species-Level Pairs

In [11]:
encoder = tf.keras.models.load_model('Models/vae/encoder')
decoder = tf.keras.models.load_model('Models/vae/decoder')

2022-06-12 14:21:02.478606: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-06-12 14:21:03.261867: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38420 MB memory:  -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:21:00.0, compute capability: 8.0
2022-06-12 14:21:03.263300: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 38420 MB memory:  -> device: 1, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:81:00.0, compute capability: 8.0


In [232]:
def gen_similar(arr):
    a = encoder(arr)[2]
    b = a + tf.keras.backend.random_normal(shape=a.shape) * .95
    b_dec = decoder(b)
    return tf.math.round(b_dec)