In [1]:
print("Bismillah Hirrahamaa Nirraheem")

Bismillah Hirrahamaa Nirraheem


In [2]:
import os,sys,warnings,time,re,math
from IPython.display import clear_output
warnings.filterwarnings("ignore")
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from concurrent.futures import ThreadPoolExecutor,ProcessPoolExecutor,wait
import numpy as np
import pandas as pd
from typing import Literal
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default = "plotly_dark"
from einops import rearrange
import keras
import tensorflow as tf
import torch
from torch import nn
tf.get_logger().setLevel("ERROR")
%matplotlib inline
clear_output()

In [3]:
try:
    tpu_cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None

if tpu:
    tf.tpu.experimental.initialize_tpu_system(tpu_cluster)
    tf.config.experimental_connect_to_cluster(tpu_cluster)
    strategy = tf.distribute.TPUStrategy(tpu_cluster)
    details = None
    device = "TPU"
if any(device.device_type == "GPU" for device in tf.config.list_physical_devices()):
    strategy = tf.distribute.OneDeviceStrategy("GPU")
    device = "GPU"
    details = tf.config.experimental.get_device_details(tf.config.list_physical_devices("GPU")[0])
else:
    strategy = tf.distribute.OneDeviceStrategy("CPU")
    device = "CPU"
    details = tf.config.experimental.get_device_details(tf.config.list_physical_devices("CPU")[0])

# DataLoading

In [4]:
PATH = "gs://stanfordrna/ribo/train_*.tfrecord"
total_files = tf.io.gfile.glob(PATH)
train_files = total_files[:150]
val_files = total_files[150:160]
test_files = total_files[160:]
train_raw_ds = tf.data.TFRecordDataset(train_files,compression_type="GZIP")
val_raw_ds = tf.data.TFRecordDataset(val_files,compression_type="GZIP")
test_raw_ds = tf.data.TFRecordDataset(test_files,compression_type="GZIP")

In [5]:
rna_feature = dict(
    seq_id = tf.io.FixedLenFeature([],tf.string),
    seq = tf.io.VarLenFeature(tf.string),
    dataset_name_2a3 = tf.io.FixedLenFeature([],tf.string),
    dataset_name_dms = tf.io.FixedLenFeature([],tf.string),
    reads_2a3 = tf.io.FixedLenFeature([],tf.string),
    reads_dms = tf.io.FixedLenFeature([],tf.string),
    signal_to_noise_2a3 = tf.io.FixedLenFeature([],tf.string),
    signal_to_noise_dms = tf.io.FixedLenFeature([],tf.string),
    reactivity_2a3 = tf.io.VarLenFeature(tf.string),
    reactivity_dms = tf.io.VarLenFeature(tf.string),
    reactivity_error_2a3 = tf.io.VarLenFeature(tf.string),
    reactivity_error_dms = tf.io.VarLenFeature(tf.string),
    sn_filter_2a3 = tf.io.FixedLenFeature([],tf.string),
    sn_filter_dms = tf.io.FixedLenFeature([],tf.string),
    length = tf.io.FixedLenFeature([],tf.string),
    bpp_matrix = tf.io.VarLenFeature(tf.string),
    bracket_seq = tf.io.VarLenFeature(tf.string)
    )

def rna_example(example):
    example = tf.io.parse_single_example(example, rna_feature)

    ### Dense Features
    example["seq_id"] = tf.io.parse_tensor(example["seq_id"], out_type=tf.string)
    example["reads_2a3"] = tf.io.parse_tensor(example["reads_2a3"], out_type=tf.float32)
    example["reads_dms"] = tf.io.parse_tensor(example["reads_dms"], out_type=tf.float32)
    example["sn_filter_2a3"] = tf.io.parse_tensor(example["sn_filter_2a3"], out_type=tf.float32)
    example["sn_filter_dms"] = tf.io.parse_tensor(example["sn_filter_dms"], out_type=tf.float32)
    example["dataset_name_2a3"] = tf.io.parse_tensor(example["dataset_name_2a3"], out_type=tf.string)
    example["dataset_name_dms"] = tf.io.parse_tensor(example["dataset_name_dms"], out_type=tf.string)
    example["signal_to_noise_2a3"] = tf.io.parse_tensor(example["signal_to_noise_2a3"], out_type=tf.float32)
    example["signal_to_noise_dms"] = tf.io.parse_tensor(example["signal_to_noise_dms"], out_type=tf.float32)
    example["length"] = tf.io.parse_tensor(example["length"], out_type=tf.float32)

    ### Sparse Features
    example["seq"] = tf.io.parse_tensor(tf.sparse.to_dense(example["seq"])[0], out_type=tf.float32)
    example["reactivity_2a3"] = tf.io.parse_tensor(tf.sparse.to_dense(example["reactivity_2a3"])[0], out_type=tf.float32)
    example["reactivity_dms"] = tf.io.parse_tensor(tf.sparse.to_dense(example["reactivity_dms"])[0], out_type=tf.float32)
    example["reactivity_error_2a3"] = tf.io.parse_tensor(tf.sparse.to_dense(example["reactivity_error_2a3"])[0], out_type=tf.float32)
    example["reactivity_error_dms"] = tf.io.parse_tensor(tf.sparse.to_dense(example["reactivity_error_dms"])[0], out_type=tf.float32)
    example["bpp_matrix"] = tf.io.parse_tensor(tf.sparse.to_dense(example["bpp_matrix"])[0], out_type=tf.float32)
    example["bracket_seq"] = tf.io.parse_tensor(tf.sparse.to_dense(example["bracket_seq"])[0], out_type=tf.float32)

    return example

train_modified_ds = train_raw_ds.map(rna_example, tf.data.AUTOTUNE)
val_modified_ds = val_raw_ds.map(rna_example, tf.data.AUTOTUNE)
test_modified_ds = test_raw_ds.map(rna_example, tf.data.AUTOTUNE)

In [6]:
single_example = train_modified_ds.take(1).get_single_element()
print("id"," : ",single_example["seq_id"].numpy())
print("\n\n")
print("seq"," : ",single_example["seq"].numpy()[:20])
print("seq shape :",single_example["seq"].numpy().shape)
print("\n\n")
print("dataset_name_2a3"," : ",single_example["dataset_name_2a3"].numpy())
print("\n\n")
print("dataset_name_dms"," : ",single_example["dataset_name_dms"].numpy())
print("\n\n")
print("reads_2a3",": ",single_example["reads_2a3"].numpy())
print("\n\n")
print("reads_dms",": ",single_example["reads_dms"].numpy())
print("\n\n")
print("signal_to_noise_2a3",": ",single_example["signal_to_noise_2a3"].numpy())
print("\n\n")
print("signal_to_noise_dms",": ",single_example["signal_to_noise_dms"].numpy())
print("\n\n")
print("reactivity_2a3"," : ",single_example["reactivity_2a3"].numpy()[:20])
print("shape of reactivity_2a3 :",single_example["reactivity_2a3"].shape)
print("\n\n")
print("reactivity_dms"," : ",single_example["reactivity_dms"].numpy()[:20])
print("shape of reactivity_dms :",single_example["reactivity_dms"].shape)
print("\n\n")
print("reactivity_error_2a3",": ",single_example["reactivity_error_2a3"].numpy()[:20])
print("shape of reactivity_error_2a3",single_example["reactivity_error_2a3"].shape)
print("\n\n")
print("reactivity_error_dms",": ",single_example["reactivity_error_dms"].numpy()[:20])
print("shape of reactivity_error_dms",single_example["reactivity_error_dms"].shape)
print("\n\n")
print("sn_filter_2a3",": ",single_example["sn_filter_2a3"].numpy())
print("\n\n")
print("sn_filter_dms",": ",single_example["sn_filter_dms"].numpy())
print("\n\n")
print("Lenght :",single_example["length"].numpy())
print("\n\n")
print("bpp_matrix :",single_example["bpp_matrix"].numpy()[:5,:5])
print("bpp_matrix shape :",single_example["bpp_matrix"].numpy().shape)
print("\n\n")
print("bracket_sequence :",single_example["bracket_seq"].numpy()[:20])
print('bracket_sequence shape :',single_example["bracket_seq"].numpy().shape)

id  :  b'8cdfeef009ea'



seq  :  [3. 3. 3. 0. 0. 2. 3. 0. 2. 4. 2. 3. 0. 3. 4. 0. 3. 0. 3. 4.]
seq shape : (170,)



dataset_name_2a3  :  b'15k_2A3'



dataset_name_dms  :  b'15k_DMS'



reads_2a3 :  2343.0



reads_dms :  1668.0



signal_to_noise_2a3 :  0.944



signal_to_noise_dms :  0.972



reactivity_2a3  :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_2a3 : (206,)



reactivity_dms  :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_dms : (206,)



reactivity_error_2a3 :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_error_2a3 (206,)



reactivity_error_dms :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_error_dms (206,)



sn_filter_2a3 :  0.0



sn_filter_dms :  0.0



Lenght : 170.0



bpp_matrix : [[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]

# DataProcessing

In [7]:
seq_map = {"A":1,"C":2,"G":3,"U":4,"START":4,"END":5,"EMPTY":0}
bracket_map = {"(":1,")":2,"[":3,"]":4,"{":5,"}":6,"<":7,">":8,".":9,"START":10,"END":11,"EMPTY":0}

def convert_and_pad(ex,Lmax:int=206,shift=True,sn_filter:bool=True):

    l = tf.shape(ex["seq"],tf.int32)[0]

    if not shift:
        shift = 0
    else:
        shift = tf.random.uniform(shape=[1],minval=0,maxval=Lmax-l+1,dtype=tf.int32)[0]

    # Sequence Processing and Mask Processing
    seq = ex["seq"] + 1
    mask = tf.math.greater(tf.pad(seq,[[shift+1,Lmax-l-shift+1]]),0)
    seq = tf.pad(seq,[[1,0]],constant_values=seq_map["START"])                               # seq_map["START"]
    seq = tf.pad(seq,[[0,1]],constant_values=seq_map["END"])                                 # seq_map["END"]
    seq = tf.pad(seq,[[shift,Lmax-l-shift]])                                                 # seq_map["EMPTY"]


    # Bracket Processing
    brac = ex["bracket_seq"] + 1
    brac = tf.pad(brac,[[1,0]],constant_values=bracket_map["START"])                            # bracket_map["START"]
    brac = tf.pad(brac,[[0,1]],constant_values=bracket_map["END"])                              # bracket_map["END"]
    brac = tf.pad(brac,[[shift,Lmax-l-shift]],constant_values=bracket_map["EMPTY"])             # bracket_map["EMPTY"]

    # Reactivity Processing
    reac = tf.stack([ex["reactivity_2a3"][:l],ex["reactivity_dms"][:l]],axis=-1)
    reac = tf.pad(reac,[[shift+1,Lmax+1-l-shift],[0,0]],constant_values=np.nan)

    # SN_filter
    sn = (ex["sn_filter_2a3"] > 0) & (ex["sn_filter_dms"] > 0)

    # BPPMatrix
    bppm = ex["bpp_matrix"][:l,:l]
    bppm = tf.pad(bppm,[[shift+1,Lmax+1-l-shift],[shift+1,Lmax+1-l-shift]])

    if sn_filter:
        return (seq,brac,mask,sn,bppm),(reac,mask)
    return (seq,brac,mask,bppm),(reac,mask)


BATCH_SIZE = strategy.num_replicas_in_sync * 16 if (device == "TPU") else 32


def create_train_ds(dataset,batch_size:int=BATCH_SIZE):
    dataset = dataset.map(convert_and_pad,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(20000)
    dataset = dataset.batch(batch_size)
    return dataset.prefetch(tf.data.AUTOTUNE)


def create_val_ds(dataset,batch_size:int=BATCH_SIZE):
    dataset = dataset.map(convert_and_pad,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    return dataset.prefetch(tf.data.AUTOTUNE)

def create_test_ds(dataset,batch_size:int=BATCH_SIZE):
    dataset = dataset.map(convert_and_pad,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    return dataset.prefetch(tf.data.AUTOTUNE)


train_ds = create_train_ds(train_modified_ds,batch_size=BATCH_SIZE)
val_ds = create_val_ds(val_modified_ds,batch_size=BATCH_SIZE)
test_ds = create_test_ds(test_modified_ds,batch_size=BATCH_SIZE)

In [8]:
X,y = train_ds.take(1).get_single_element()
seq_input = X[0]
bracket_input = X[1]
mask = X[2]
sn_filter = X[3]
bpp_matrix = X[4]
reactivity = y[0]

# Custom Metrics and Losses

In [9]:
class CustomMetric2A3(keras.metrics.Metric):

    def __init__(self,**kwargs):
        super(CustomMetric2A3,self).__init__(**kwargs)
        self.mae = self.add_weight(name="mae_2a3_metric",initializer="zeros")

    def update_state(self,y_true,y_pred, **kwargs):
        y_mask_metric = y_true[1]
        y_true_metric = y_true[0][...,0]
        y_pred_metric = y_pred[...,0]
        y_pred_metric = tf.clip_by_value(y_pred_metric[y_mask_metric],clip_value_min=0,clip_value_max=1)
        y_true_metric = tf.clip_by_value(y_true_metric[y_mask_metric],clip_value_min=0,clip_value_max=1)
        mae = tf.abs(tf.subtract(y_true_metric,y_pred_metric))
        mae = tf.reduce_mean(mae[~tf.math.is_nan(mae)])
        self.mae.assign_add(mae)

    def result(self):
        return self.mae
    
    def reset_state(self):
        return self.mae.assign(0.0)

In [10]:
class CustomMetricDMS(keras.metrics.Metric):

    def __init__(self,**kwargs):
        super(CustomMetricDMS,self).__init__(**kwargs)
        self.mae = self.add_weight(name="mae_dms_metric",initializer="zeros")

    def update_state(self,y_true,y_pred, **kwargs):
        y_mask_metric = y_true[1]
        y_true_metric = y_true[0][...,1]
        y_pred_metric = y_pred[...,1]
        y_pred_metric = tf.clip_by_value(y_pred_metric[y_mask_metric],clip_value_min=0,clip_value_max=1)
        y_true_metric = tf.clip_by_value(y_true_metric[y_mask_metric],clip_value_min=0,clip_value_max=1)
        mae = tf.abs(tf.subtract(y_true_metric,y_pred_metric))
        mae = tf.reduce_mean(mae[~tf.math.is_nan(mae)])
        self.mae.assign_add(mae)

    def result(self):
        return self.mae
    
    def reset_state(self):
        return self.mae.assign(0.0)

In [11]:
cust_metric_2a3 = CustomMetric2A3()
cust_metric_dms = CustomMetricDMS()
y_pred = tf.random.uniform(shape=y[0].shape,minval=-1,maxval=2)
cust_metric_2a3.update_state(y,y_pred)
cust_metric_dms.update_state(y,y_pred)
print(cust_metric_2a3.result().numpy())
print(cust_metric_dms.result().numpy())

0.4863567
0.4832784


In [12]:
class CustomLoss(keras.losses.Loss):

    def __init__(self,**kwargs):
        super(CustomLoss,self).__init__(**kwargs)

    def call(self,y_true,y_pred):
        y_mask = tf.cast(y_true[1],tf.bool)
        y_true = y_true[0]
        y_pred = tf.clip_by_value(y_pred[y_mask],clip_value_min=0,clip_value_max=1)
        y_true = tf.clip_by_value(y_true[y_mask],clip_value_max=1,clip_value_min=0)
        mae = tf.reduce_mean(tf.math.abs(tf.math.subtract(y_true,y_pred)),axis=-1)
        return tf.reduce_mean(mae[~tf.math.is_nan(mae)])

In [13]:
cust_loss = CustomLoss()
cust_loss(y,y_pred).numpy()

0.48453128

# Transformers Model

## Encoding Block

In [14]:
class StaticPosEncoding(keras.Layer):

    def __init__(self,
            vocab_size:int,
            d_model,
            length,
            casting,
            **kwargs):

        super(StaticPosEncoding,self).__init__(**kwargs)

        assert d_model%2 == 0,"Depth of model(d_model) should be even"

        d_model = d_model//2
        positions = np.arange(length)[:,np.newaxis]
        angles = np.arange(d_model)[np.newaxis,:]/d_model
        angles = 1/(10000**angles)
        angle_rads = positions * angles
        if casting == "concat":
            encode = tf.concat([tf.sin(angle_rads),tf.cos(angle_rads)],axis=-1)
        else:
            encode = np.zeros(shape=[length,d_model])
            encode[:,::2] = tf.sin(angle_rads)
            encode[:,1::2] = tf.cos(angle_rads)

        self.encode = tf.cast(encode,tf.float32)
        self.factor = tf.sqrt(tf.cast(d_model,tf.float32))
        self.embedding_layer = keras.layers.Embedding(vocab_size,d_model*2,mask_zero=True)

    def compute_mask(self, *args,**kwargs):
        return self.embedding_layer.compute_mask(*args,**kwargs)

    def call(self,x):
        seq_l = tf.shape(x)[1]
        x = self.embedding_layer(x)
        x *= self.factor
        return x + self.encode[tf.newaxis,:seq_l,:]

In [15]:
pos_encode = StaticPosEncoding(vocab_size=len(seq_map),d_model=512,length=2048,casting="concat")
pos_encode_output = pos_encode(X[0])
pos_encode_output.shape

TensorShape([32, 208, 512])

## Attention

In [18]:
class ConvolutionLayer(keras.Layer):

    def __init__(self,
                 n_layers,
                 n_filters,
                 ksize,
                 padding,
                 bias,
                 drop_rate,
                 norm,
                 **kwargs):
        super(ConvolutionLayer,self).__init__(**kwargs)
        self.conv_layers = [keras.layers.Conv2D(filters=n_filters,kernel_size=ksize,padding=padding,use_bias=bias) for _ in range(n_layers)]
        self.perm = keras.layers.Permute(dims=[3,1,2])
        self.norm = keras.layers.LayerNormalization() if norm == "layer_norm" else keras.layers.BatchNormalization()
        self.drop = keras.layers.Dropout(drop_rate)

    def call(self,x):
        x = tf.expand_dims(x,axis=-1)
        for layer in self.conv_layers:
            x = layer(x)
        x = self.perm(x)
        return self.norm(self.drop(x))

In [19]:
conv = ConvolutionLayer(n_layers=3,n_filters=6,ksize=3,padding="same",bias=False,drop_rate=0.1,norm="layer_norm")
bpp_conv_ouput = conv(bpp_matrix)
bpp_conv_ouput.shape

TensorShape([32, 6, 208, 208])

In [20]:
# a = tf.random.uniform(shape=[32,208,208])
# first_a = a[:,tf.newaxis,:,:]
# last_a = a[...,tf.newaxis]
# conv1 = keras.layers.Conv2D(filters=6,kernel_size=3,padding="same",data_format="channels_first",use_bias=False)
# conv2 = keras.layers.Conv2D(filters=6,kernel_size=3,padding="same",data_format="channels_first",use_bias=False)
# conv3 = keras.layers.Conv2D(filters=6,kernel_size=3,padding="same",data_format="channels_first",use_bias=False)
# first_a = conv1(first_a)
# first_a = conv2(first_a)
# first_a = conv3(first_a)
# weights1 = conv1.get_weights()
# weights2 = conv2.get_weights()
# weights3 = conv3.get_weights()
# con1 = keras.layers.Conv2D(filters=6,kernel_size=3,padding="same",use_bias=False)
# con2 = keras.layers.Conv2D(filters=6,kernel_size=3,padding="same",use_bias=False)
# con3 = keras.layers.Conv2D(filters=6,kernel_size=3,padding="same",use_bias=False)
# transopose_layer = keras.layers.Permute(dims=[3,1,2])
# _ = con1(tf.random.uniform(shape=[32,208,208,1]))
# _ = con2(tf.random.uniform(shape=[32,208,208,6]))
# _ = con3(tf.random.uniform(shape=[32,208,208,6]))
# con1.set_weights(weights1)
# con2.set_weights(weights2)
# con3.set_weights(weights3)
# last_a = con1(last_a)
# last_a = con2(last_a)
# last_a = con3(last_a)
# last_a = transopose_layer(last_a)
# tf.reduce_mean(tf.cast(tf.equal(first_a,last_a),tf.float32))

In [23]:
class ConvolutedAttention(keras.Layer):

    def __init__(self,
            num_heads,
            key_dim,
            **kwargs):

        super(ConvolutedAttention,self).__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.wq = keras.layers.Dense(num_heads*key_dim)
        self.wv = keras.layers.Dense(num_heads*key_dim)
        self.wk = keras.layers.Dense(num_heads*key_dim)
        self.dense = keras.layers.Dense(key_dim)
        self.first_drop = keras.layers.Dropout(0.1)
        self.last_drop = keras.layers.Dropout(0.1)
        self.layer_norm = keras.layers.LayerNormalization()
        self.factor = tf.math.rsqrt(tf.constant(key_dim,tf.float32))
        self.softmax = keras.layers.Softmax()
        self.add = keras.layers.Add()


    def call(self,query,key,value,convoluted=None,return_attention_scores=False,mask=None):


        batch_size = tf.shape(query)[0]
        seq_len = tf.shape(query)[1]
        q = self.wq(query)
        k = self.wk(key)
        v = self.wv(value)

        q = tf.transpose(tf.reshape(q,shape=[batch_size,-1,self.num_heads,self.key_dim]),perm=[0,2,1,3])
        k = tf.transpose(tf.reshape(k,shape=[batch_size,-1,self.num_heads,self.key_dim]),perm=[0,2,1,3])
        v = tf.transpose(tf.reshape(v,shape=[batch_size,-1,self.num_heads,self.key_dim]),perm=[0,2,1,3])

        attention_score = self.factor * tf.matmul(q,k,transpose_b=True)
        attention_score = self.first_drop(self.softmax(attention_score + convoluted))

        attention_out = tf.reshape(tf.transpose(tf.matmul(attention_score,v),perm=[0,2,1,3]),shape=[batch_size,seq_len,-1])
        attention_out = self.last_drop(self.dense(attention_out))
        attention_out = self.layer_norm(self.add([attention_out,query]))

        if return_attention_scores:
            return attention_out,attention_score
        return attention_out

In [24]:
conv_attn = ConvolutedAttention(num_heads=6,key_dim=512)
conv_attn_out = conv_attn(query=pos_encode_output,key=pos_encode_output,value=pos_encode_output,convoluted=bpp_conv_ouput)
conv_attn_out.shape

TensorShape([32, 208, 512])

## Feed Forward

In [25]:
class FeedForward(keras.Layer):

    def __init__(self,
            feed_out_unit,
            feed_forward,
            feed_forward_drop,
            **kwargs):

        super(FeedForward,self).__init__(**kwargs)
        self.dense_out = keras.layers.Dense(feed_out_unit)
        self.feed_forward = keras.layers.Dense(feed_forward)
        self.feed_forward_drop = keras.layers.Dropout(feed_forward_drop)
        self.norm = keras.layers.LayerNormalization()
        self.add = keras.layers.Add()

    def call(self,x):

        x_copy = x
        x = self.feed_forward(x)
        x = self.dense_out(x)
        x = self.feed_forward_drop(x)
        x = self.feed_forward_drop(x)
        return self.norm(self.add([x_copy,x]))

In [26]:
feed_forward = FeedForward(feed_out_unit=512,feed_forward=2048,feed_forward_drop=0.2,norm="layer_norm")
feed_out = feed_forward(conv_attn_out)
feed_out.shape

TensorShape([32, 208, 512])

# Encoder Unit

In [44]:
class SequenceEncoderUint(keras.Model):

    def __init__(self,
            encoder_convolution_layers:int,
            encoder_convolution_ksize:int,
            encoder_convolution_padding:str,
            encoder_convolution_bias:bool,
            encoder_convolution_drop_rate:float,
            encoder_convolution_norm:str,
            encoder_convolution_num_heads:int,
            encoder_cross_attention_num_heads:int,
            encoder_key_dim:int,
            encoder_feed_out_units:int,
            encoder_feed_forward_units:int,
            encoder_feed_forward_drop_rate:float,
            **kwargs):

        super(SequenceEncoderUint,self).__init__(**kwargs)

        self.convolution_layer = ConvolutionLayer(
            encoder_convolution_layers,
            encoder_convolution_num_heads,
            encoder_convolution_ksize,
            encoder_convolution_padding,
            encoder_convolution_bias,
            encoder_convolution_drop_rate,
            encoder_convolution_norm
            )
        
        self.convoluted_attention = ConvolutedAttention(
            encoder_convolution_num_heads,
            encoder_key_dim
            )
        
        self.encoder_cross_attention = keras.layers.MultiHeadAttention(
            num_heads=encoder_cross_attention_num_heads,
            key_dim=encoder_key_dim
            )
        
        self.feed_forward = FeedForward(
            encoder_feed_out_units,
            encoder_feed_forward_units,
            encoder_feed_forward_drop_rate
            )
        
    def call(self,sequence_encode_out,bpp_matrix,bracket_encode_out):

        convolution_out = self.convolution_layer(bpp_matrix)
        convolutedattention_out = self.convoluted_attention(
            query=sequence_encode_out,
            key=sequence_encode_out,
            value=sequence_encode_out,
            convoluted=convolution_out
            )
        crossattention_out = self.encoder_cross_attention(
            query=convolutedattention_out,
            key=bracket_encode_out,
            value=bracket_encode_out
            )
        feed_out = self.feed_forward(crossattention_out)

        return feed_out

In [45]:
seq_encoder = SequenceEncoderUint(
    encoder_convolution_layers=3,
    encoder_convolution_ksize=3,
    encoder_convolution_padding='same',
    encoder_convolution_bias=False,
    encoder_convolution_drop_rate=0.1,
    encoder_convolution_num_heads=6,
    encoder_cross_attention_num_heads=6,
    encoder_convolution_norm="layer_norm",
    encoder_key_dim=512,
    encoder_feed_forward_drop_rate=0.1,
    encoder_feed_forward_units=2048,
    encoder_feed_out_units=512
    )

seq_encode_out = tf.random.uniform(shape=[32,208,512])
bpp_mat =  tf.random.uniform(shape=[32,208,208])
bracket_sequence_out = tf.random.uniform(shape=[32,208,512])
seq_encoder_out = seq_encoder(seq_encode_out,bpp_mat,bracket_sequence_out)
seq_encoder_out.shape

TensorShape([32, 208, 512])

In [48]:
class BracketEncoderUnit(keras.Model):

    def __init__(self,
            encoder_convolution_layers:int,
            encoder_convolution_ksize:int,
            encoder_convolution_padding:int,
            encoder_convolution_bias:int,
            encoder_convolution_drop_rate:float,
            encoder_convolution_norm:int,
            encoder_convolution_num_heads:int,
            encoder_key_dim:int,
            encoder_feed_out_units:int,
            encoder_feed_forward_units:int,
            encoder_feed_forward_drop_rate:float,
            **kwargs):
        

        super(BracketEncoderUnit,self).__init__(**kwargs)

        self.convolution_layer = ConvolutionLayer(
            encoder_convolution_layers,
            encoder_convolution_num_heads,
            encoder_convolution_ksize,
            encoder_convolution_padding,
            encoder_convolution_bias,
            encoder_convolution_drop_rate,
            encoder_convolution_norm
        )

        self.convoluted_attention = ConvolutedAttention(
            encoder_convolution_num_heads,
            encoder_key_dim
        )

        self.feed_forward = FeedForward(
            encoder_feed_out_units,
            encoder_feed_forward_units,
            encoder_feed_forward_drop_rate
        )

    def call(self,bracket_encode_out,bpp_matrix):

        convolution_out = self.convolution_layer(bpp_matrix)
        convoluted_attention_out = self.convoluted_attention(
            query = bracket_encode_out,
            key = bracket_encode_out,
            value = bracket_encode_out,
            convoluted = convolution_out
        )
        feed_out = self.feed_forward(convoluted_attention_out)
        
        return feed_out

In [50]:
bracket_seq_encoder = BracketEncoderUnit(
    encoder_convolution_layers=3,
    encoder_convolution_ksize=3,
    encoder_convolution_padding="same",
    encoder_convolution_bias=False,
    encoder_convolution_drop_rate=0.1,
    encoder_convolution_num_heads=6,
    encoder_convolution_norm="layer_norm",
    encoder_key_dim=512,
    encoder_feed_forward_drop_rate=0.1,
    encoder_feed_forward_units=2048,
    encoder_feed_out_units=512
)

bracket_encode_out = tf.random.uniform(shape=[32,208,512])
bpp_mat = tf.random.uniform(shape=[32,208,208])
bracket_seq_encoder_out = bracket_seq_encoder(bracket_encode_out,bpp_mat)
bracket_seq_encoder_out.shape

TensorShape([32, 208, 512])

In [None]:
class Transformer(keras.Model):

    def __init__(self,
            
            **kwargs):
