In [None]:
import os
import sys

In [None]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [None]:
sys.path.append("../../deep-learning-dna")
sys.path.append("../../settransformer")

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display
import math
import string

import settransformer as stf
from common.models import dnabert
from common import dna
from lmdbm import Lmdb
from common.data import DnaSequenceGenerator, DnaLabelType, DnaSampleGenerator, find_dbs
import wandb

import tf_utils as tfu

In [None]:
strategy = tfu.strategy.gpu(0)

---
# Load Data

In [None]:
#Import pretrained model
api = wandb.Api()
model_path = api.artifact("sirdavidludwig/deep-learning-dna/dnabert-pretrain-ablation-dim:8dim").download()
pretrained_model = dnabert.DnaBertModel.load(model_path)
pretrained_model

In [None]:
#Load datafiles
dataset_path = api.artifact("sirdavidludwig/nachusa-dna/dnasamples:v1").download('/data/dna_samples:v1')
samples = find_dbs(dataset_path + '/train')
samples[13]

---
# Create Dataset

In [None]:
#Generate batches
subsample_length = 400
sequence_length = 150
kmer = 3
batch_size = 60
batches_per_epoch = 40
augument = True
labels = DnaLabelType.SampleIds
dataset = DnaSampleGenerator(samples=samples[0:5], subsample_length = subsample_length, sequence_length=sequence_length,kmer=kmer,batch_size=batch_size,batches_per_epoch=batches_per_epoch,augment=augument,labels=labels)

In [None]:
dataset[0]

In [None]:
max_files = len(dataset.samples)
max_files

---
# Create Embeddings

In [None]:
#Create 8 dimensional embeddings
pretrained_encoder= dnabert.DnaBertEncoderModel(pretrained_model.base)
pretrained_encoder.trainable = False

In [None]:
class Create_Embeddings(keras.layers.Layer):
    def __init__(self, encoder):
        super(Create_Embeddings, self).__init__()
        self.encoder = encoder
        
    
    def subbatch_predict(self, model, batch, subbatch_size, concat=lambda old, new: tf.concat((old, new), axis=0)):
        def predict(i, result=None):
            n = i + subbatch_size
            pred = tf.stop_gradient(model(batch[i:n]))
            if result is None:
                return [n, pred]
            return [n, concat(result, pred)]
        i, result = predict(0)
        batch_size = tf.shape(batch)[0]
        i, result = tf.while_loop(
            cond=lambda i, _: i < batch_size,
            body=predict,
            loop_vars=[i, result],
            parallel_iterations=1)

        return result
    
    def modify_data_for_input(self, data):
        batch_size = tf.shape(data)[0]
        subsample_size = tf.shape(data)[1]
        flat_data = tf.reshape(data, (batch_size*subsample_size, -1))
        encoded = self.subbatch_predict(self.encoder, flat_data, subsample_size)
        return tf.reshape(encoded, (batch_size, subsample_size, -1))
    
    def call(self, data):
        return  self.modify_data_for_input(data)

---
# Set Transformer Class

In [None]:
class Set_Transformer(keras.Model):
    def __init__(self, embed_dim, num_heads, stack, use_layernorm, pre_layernorm, use_keras_mha, seq_len, encoder, output_shape):
        super(Set_Transformer, self).__init__()
        
        self.embedding_layer = Create_Embeddings(encoder)
        self.linear_layer = keras.layers.Dense(embed_dim)
        
        self.isabs = []
        
        for i in range(stack):
            self.isabs.append(stf.InducedSetAttentionBlock(embed_dim=embed_dim,num_heads=num_heads,num_induce=24,use_layernorm=use_layernorm,pre_layernorm=pre_layernorm,use_keras_mha=use_keras_mha))
      
        self.pooling_layer = stf.PoolingByMultiHeadAttention(num_seeds=1,embed_dim=embed_dim,num_heads=1,use_layernorm=use_layernorm,pre_layernorm=pre_layernorm,use_keras_mha=use_keras_mha,is_final_block=True)
    
        self.reshape_layer = keras.layers.Reshape((embed_dim,))
        
        self.output_layer = keras.layers.Dense(output_shape)
    
    def call(self, data):
        
            embeddings = self.embedding_layer(data)
            
            linear_transform = self.linear_layer(embeddings)
            
            attention = linear_transform
            
            for isab in self.isabs:
                attention = isab(attention)
                
            pooling = self.pooling_layer(attention)
        
            reshape = self.reshape_layer(pooling)
            
            output = self.output_layer(reshape)    
            
            return output

---
# Create Model

In [None]:
#Hyperparameters
embed_dim = 32
num_heads = 4
stack = 4
use_layernorm = True
pre_layernorm = True
use_keras_mha = True
seq_len = 148
encoder = pretrained_encoder
output_shape = max_files
epochs = 2000

In [None]:
model = Set_Transformer(embed_dim, num_heads, stack, use_layernorm, pre_layernorm, use_keras_mha, seq_len, encoder, output_shape)
model.compile(optimizer=keras.optimizers.Adam(1e-4),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics = [keras.metrics.sparse_categorical_accuracy])

In [None]:
dataset[0][0].shape

In [None]:
#keras.losses.SparseCategoricalCrossentropy(from_logits=True)(dataset[0][1], model(dataset[0][0]))

In [None]:
#keras.metrics.SparseCategoricalAccuracy()(dataset[0][1], model(dataset[0][0]))

In [None]:
model.fit(dataset, epochs = 20)

In [None]:
# print('Max:', np.amax(prediction))
# print(np.where(prediction == np.amax(prediction)))

In [None]:
epochs = 200

In [None]:
history = model.fit(dataset, epochs=epochs, verbose=1)

In [None]:
#Plot history and accuracy
plt.subplot(211)
plt.plot(history.history['sparse_categorical_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')

plt.subplot(212)
plt.plot(history.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.tight_layout()
plt.show() 