# SetBERT Architecture

In [None]:
# All necessary imports
import sys
sys.path.append("../deep-learning-dna")
from __future__ import absolute_import, division, print_function, unicode_literals
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras as keras
import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=FutureWarning)
import wandb
from lmdbm import Lmdb
import settransformer as st
from common.models import CustomModel
from common.models import dnabert
import os, glob
import tf_utils as tfu

In [None]:
# Connect to WandB
run = wandb.init(entity="cguptil", project="SetBERT", name="SetBERT_Run_1")

In [None]:
#strategy = tfu.strategy.gpu(0) #Optional strategy distribution

In [None]:
tf.config.get_visible_devices()

## Custom classes for Model Architecture

In [None]:
# Inverted mask from David Ludwig's DNABERT architecture, only Inverted Mask is used here
class InvertMask(keras.layers.Layer):
    """
    Invert the current mask. Useful for DNABERT models where we *want* to pay attention to the
    masked elements.
    """
    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        return tf.logical_not(mask)

    def call(self, inputs):
        # tf.print(inputs)
        # tf.print(inputs + 0)
        return inputs + 0 # hacky, but without modification

In [None]:
#New Set Mask implementation. Blocks out 15% of sequences in a batch. Does not block individual embedding values
#in sequences, but rather the entire sequence. Used for BERT style training.
class SetMask(keras.layers.Layer):
    def __init__(self, mask_ratio, **kwargs):
        super().__init__(**kwargs)
        self.mask_ratio = tf.Variable(mask_ratio, trainable=False, dtype=tf.float32, name='Mask_Ratio')
        
    def call(self, inputs, mask=None):
        mask = self.compute_mask(inputs, mask)
        return tf.cast(mask, dtype=tf.float32) * inputs
    
    def compute_mask(self, inputs, mask):
        batch_size = tf.shape(inputs)[0]
        seq_len = tf.shape(inputs)[1]
        embed_dim = tf.shape(inputs)[2]
        mask_len = tf.cast(tf.cast(seq_len, dtype=tf.float32) * self.mask_ratio, dtype=tf.int32)
        
        random = tf.random.uniform((batch_size, seq_len), 0, 1)
        values, indices = tf.math.top_k(random, mask_len)
        batch_indices = tf.reshape(tf.repeat(tf.range(batch_size), mask_len), (-1,1))
        embedding_indices = tf.reshape(indices, (-1, 1))
        indices = tf.concat((batch_indices, embedding_indices), axis=1)
        mask = tf.ones((batch_size, seq_len))
        mask = tf.tensor_scatter_nd_update(mask, indices, tf.zeros((batch_size * mask_len)))
        mask = tf.tile(tf.expand_dims(mask, axis=2), (1, 1, embed_dim))
        return tf.cast(mask, dtype=tf.bool)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "mask_ratio": self.mask_ratio.numpy()
        })
        return config

In [None]:
#Appends a class token to the beginning of each set with the same dimensionality as sequences.
#Used for BERT style training.
class SetClassToken(keras.layers.Layer):
    def __init__(self, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.class_token = self.add_weight(shape=(1, 1, embedding_dim), initializer='random_normal',
                                           trainable=True, name='Class_token')
    
    def compute_mask(self, inputs, mask):
        batch_size = tf.shape(inputs)[0]
        token_mask = tf.ones((batch_size, 1, self.embedding_dim), dtype=tf.bool)
        return tf.concat((token_mask, mask), axis=1)
    
    def call(self, inputs, mask=None):
        batch_size = tf.shape(inputs)[0]
        tokens = tf.tile(self.class_token, (batch_size, 1, 1))
        return tf.concat((tokens, inputs), axis=1)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_dim": self.embedding_dim
        })
        return config

In [None]:
#Allows for subbatch usage to have larger batches without running out of memory
class SubBatchModel(CustomModel):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def call(self, inputs, training=None):
        return self.model(inputs, training=training)
    
    def compute_mask(self, inputs, mask):
        return mask
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "model": self.model
        })
        return config

In [None]:
#Data generator used for shuffling training and validation data.
class DataGenerator(keras.utils.Sequence):
    def __init__(self, samples, batch_size, subsample_size, num_batches, rng=None):
        super().__init__()
        self.batch_size = batch_size
        self.subsample_size = subsample_size
        self.samples = self.open_samples(samples)
        self.num_batches = num_batches
        self.rng = rng if rng is not None else np.random.default_rng()
        
    def __getitem__(self, i):
        batch = []
        sample_indices = self.rng.integers(len(self.samples), size=self.batch_size)
        for si in sample_indices:
            indices = self.rng.choice(len(self.samples[si]), self.subsample_size, replace=False)
            pre_embeddings = []
            for i in indices:
                pre_embeddings.append(np.asarray(np.frombuffer(self.samples[si][str(i)], dtype=np.float32)))
            batch.append(pre_embeddings)
            
        return np.array(batch), np.array(batch)
    
    def __len__(self):
        return self.num_batches
    
    def open_samples(self, samples):
        samples_final = []
        for sample in samples:
            store = Lmdb.open(sample, lock=False)
            if len(store) < self.subsample_size:
                print(f"Warning: Sample '{sample}' only contains {len(store)} sequences. This sample will not be included.")
                store.close()
                continue
            samples_final.append(store)
        return samples_final

## Model Creation

In [None]:
import random

In [None]:
#Randomly shuffles pre-embedded sequence data in order to split the training and validation data randomly.
#Essential for Monte-Carlo random cross validation.
samples_train = []
path = './pre_embedded_samples_complete/train'
for filename in glob.glob(os.path.join(path, '*.db')):
    samples_train.append(filename)
random.shuffle(samples_train)

In [None]:
split_index = int(len(samples_train) * 0.8)

In [None]:
samples_train[:split_index]

In [None]:
#Creating seperate training and validation instances of data generators to keep data seperate.
seq_gen_train = DataGenerator(samples_train[:split_index], 32, 1000, 20)
seq_gen_val = DataGenerator(samples_train[split_index:], 32, 1000, 10)

In [None]:
input_layer = keras.layers.Input((1000, 8)) #Takes in batches of 1000 sequences with 8 dimensional embeddings
masking_layer = SetMask(mask_ratio=0.15) #Masking out 15% of sequences for BERT style training
masked = masking_layer(input_layer)
class_tokens = SetClassToken(8) #Adding 8-dim class token to each batch of 1000 sequences
tokens_added = class_tokens(masked)
for i in range(8): # Creating the 8 ISABs with 30 inducing points per block
    ISAB = st.ISAB(8, 2, 30, pre_layernorm=True)
    tokens_added = ISAB(tokens_added)
inverted_mask = InvertMask()(tokens_added) #Inverted mask used to calculate loss on the 15% reconstruction
output = keras.layers.Lambda(lambda x: x[:, 1:, :])(inverted_mask) #Removing the class token, returning 1000 sequences
class_token_embeddings = keras.layers.Lambda(lambda x: x[:, 0, :])(inverted_mask) #Returning contextualized class token

In [None]:
setbert = SubBatchModel(keras.Model(input_layer, output)) #create model instance which outputs 1000 sequences
keras.utils.plot_model(setbert.model,show_shapes=True,expand_nested=True)

In [None]:
setbert_tokens = SubBatchModel(keras.Model(input_layer, class_token_embeddings)) #create model instance which outputs set class token
keras.utils.plot_model(setbert_tokens.model,show_shapes=True,expand_nested=True)

## Model Training

In [None]:
setbert.compile(optimizer=keras.optimizers.Nadam(1e-4), loss=tf.keras.losses.LogCosh()) #compile model

In [None]:
#Train model, save weights to WandB
history = setbert.fit(seq_gen_train, epochs=300, subbatch_size=8, validation_data=seq_gen_val, callbacks=[wandb.keras.WandbCallback(save_weights_only=True)])

In [None]:
#Save backup weights to native system
setbert.save_weights("./SetBERTSave", save_format="h5")

In [None]:
#Save the validation samples to local machine
with open("new_validation_samples.txt", 'w') as f:
    for sample in samples_train[split_index:]:
        f.write(sample + '\n')

In [None]:
#Save the validation samples to WandB
wandb.save("new_validation_samples.txt")

In [None]:
#Finishes WandB run
run.finish()

In [None]:
#Plot model loss and val loss
plt.figure(1) 
plt.plot(history.history['loss']) 
plt.plot(history.history['val_loss'])
plt.ylabel('Loss')  
plt.xlabel('Epoch')
plt.tight_layout()
plt.show()

## Model Testing

In [None]:
# Uses Validation data for testing
samples_test = []
for sample in samples_train[split_index:]:
    samples_test.append(sample)
samples_test

In [None]:
# Create seperate data generator instances for each dataset as not to mix samples between sets.
seq_gen_collection = []
labels = []
for sample in samples_test:
    seq_gen = DataGenerator([sample], 32, 1000, 20)
    if len(seq_gen.samples) > 0:
        seq_gen_collection.append(seq_gen[0][0])
        labels.append(sample)
seq_gen_collection = seq_gen_collection[:10]

In [None]:
# Pull off class tokens with newly trained model
class_tokens = []
for seq_gen in seq_gen_collection:
    class_tokens.append(setbert_tokens.predict(seq_gen))

In [None]:
# Concat class tokens together
class_tokens = np.concatenate(class_tokens)

## MDS Plotting

In [None]:
from sklearn.manifold import MDS
from scipy.spatial.distance import cdist

In [None]:
# Create a distance matrix for distances between set embeddings captured by class tokens
dist_mat = cdist(class_tokens, class_tokens)
dist_mat

In [None]:
dist_mat.shape

In [None]:
# MDS plot with 8 components for the 8 dimensional embedding space
mds = MDS(n_components=8, metric=True, dissimilarity='precomputed', n_jobs=10)

In [None]:
points = mds.fit_transform(dist_mat)
points.shape

In [None]:
# Only plotting the first two principal components for visiualization
for cluster in points[:,0:2].reshape(len(seq_gen_collection), 32, 2):
    plt.scatter(*cluster.T)
plt.legend([os.path.basename(s) for s in labels], loc='center left', bbox_to_anchor=(1., 0.5))
plt.title("Bad Pre-Embedding")

In [None]:
from common import metrics

In [None]:
# Computing chamfer distance between sets with David Ludwig's chamfer distance metric implementation.
# Runs noticably longer than SetBERT distance matrix calculations.
chamfer_dist = metrics.chamfer_distance_matrix(np.concatenate(seq_gen_collection), p=1, workers=10,
                                fn=metrics.chamfer_distance)

In [None]:
chamfer_dist.shape

In [None]:
chamfer_dist

In [None]:
# MDS plot with 8 components for the 8 dimensional embedding space
mds_chamfer = MDS(n_components=8, metric=True, dissimilarity='precomputed', n_jobs=10)

In [None]:
chamfer_points = mds_chamfer.fit_transform(chamfer_dist)

In [None]:
# Only plotting the first two principal components for visiualization
for cluster in chamfer_points[:, 0:2].reshape(len(seq_gen_collection), 32, 2):
    plt.scatter(*cluster.T)
plt.legend([os.path.basename(s) for s in labels], loc='center left', bbox_to_anchor=(1., 0.5))
plt.title("Bad Pre-Embedding")

## KMeans Plotting

In [None]:
from sklearn.cluster import KMeans

In [None]:
#KMeans implemented on class tokens for comparison to Chamfer distance
kmeans = KMeans(n_clusters=10)
labels = kmeans.fit_predict(class_tokens)
labels

In [None]:
kmeans = KMeans(n_clusters=10)
chamfer_labels = kmeans.fit_predict(chamfer_points)
chamfer_labels

In [None]:
unique_labels = np.unique(labels)

In [None]:
for i in unique_labels:
    plt.scatter(points[labels == i, 0], points[labels == i, 1], label = i)
plt.legend(loc='center left', bbox_to_anchor=(1., 0.5))
plt.title("KMeans SetBERT")

In [None]:
unique_labels = np.unique(chamfer_labels)

In [None]:
for i in unique_labels:
    plt.scatter(chamfer_points[chamfer_labels == i, 0], chamfer_points[chamfer_labels == i, 1], label = i)
plt.legend(loc='center left', bbox_to_anchor=(1., 0.5))
plt.title("KMeans Chamfer")

In [None]:
from sklearn.metrics import normalized_mutual_info_score

## SetBERT Mutual Information Score:

In [None]:
normalized_mutual_info_score(np.repeat(np.arange(len(seq_gen_collection)), 32), labels)

## Chamfer Mutual Information Score:

In [None]:
normalized_mutual_info_score(np.repeat(np.arange(len(seq_gen_collection)), 32), chamfer_labels)