In [1]:
import os,sys,warnings,time
from IPython.display import clear_output
warnings.filterwarnings('ignore')
os.cpu_count()

96

In [2]:
!pip install -q -U plotly
!pip install google-cloud-storage
clear_output()

In [3]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import numpy as np
from typing import Literal
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default = "plotly_dark"
import tensorflow as tf
from tensorflow import keras
from tensorflow.sparse import to_dense
from tensorflow.io import FixedLenFeature,VarLenFeature,parse_tensor,parse_single_example
from tensorflow.data import TFRecordDataset
from tensorflow import keras
from keras.utils import Progbar
tf.get_logger().setLevel("ERROR")
from google.cloud import storage
client = storage.Client("kaggle-406814")
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)
import re,time,random,math
%matplotlib inline

In [4]:
try:
    tpu_cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("Running on TPU ADDR: ",tpu_cluster.cluster_spec().as_dict())
    tf.config.experimental_connect_to_cluster(tpu_cluster)
    tf.tpu.experimental.initialize_tpu_system(tpu_cluster)
    strategy = tf.distribute.TPUStrategy(tpu_cluster)
    print(strategy)
    print(strategy.num_replicas_in_sync)
    tpu_present = True
except ValueError:
    print("Error: Not connected to TPU runtime using GPU")
    tpu_present = False

if not tpu_present:
    strategy = tf.distribute.OneDeviceStrategy("GPU")
    print(strategy)

Running on TPU ADDR:  {}
<tensorflow.python.distribute.tpu_strategy.TPUStrategyV2 object at 0x7c19e121d660>
8


In [5]:
PATH = "gs://stanfordrna/ribo/train_*.tfrecord"
total_files = tf.io.gfile.glob(PATH)
train_files = total_files[:25]
valid_files = total_files[25:28]
test_files = total_files[28:]
train_raw_ds = TFRecordDataset(train_files,compression_type="GZIP",num_parallel_reads=tf.data.AUTOTUNE)
valid_raw_ds = TFRecordDataset(valid_files,compression_type="GZIP",num_parallel_reads=tf.data.AUTOTUNE)
test_raw_ds = TFRecordDataset(test_files,compression_type="GZIP",num_parallel_reads=tf.data.AUTOTUNE)

In [6]:
# @title
rna_feature = dict(
    seq_id = FixedLenFeature([],tf.string),
    seq = VarLenFeature(tf.string),
    reads_2a3 = FixedLenFeature([],tf.string),
    reads_dms = FixedLenFeature([],tf.string),
    signal_to_noise_2a3 = FixedLenFeature([],tf.string),
    signal_to_noise_dms = FixedLenFeature([],tf.string),
    reactivity_2a3 = VarLenFeature(tf.string),
    reactivity_dms = VarLenFeature(tf.string),
    reactivity_error_2a3 = VarLenFeature(tf.string),
    reactivity_error_dms = VarLenFeature(tf.string),
    bpp_matrix = VarLenFeature(tf.string),
    bracket_seq = VarLenFeature(tf.string)
    )

def rna_example(example):
    example = parse_single_example(example, rna_feature)

    ### Dense Features
    example["seq_id"] = parse_tensor(example["seq_id"], out_type=tf.string)
    example["reads_2a3"] = parse_tensor(example["reads_2a3"], out_type=tf.float32)
    example["reads_dms"] = parse_tensor(example["reads_dms"], out_type=tf.float32)
    example["signal_to_noise_2a3"] = parse_tensor(example["signal_to_noise_2a3"], out_type=tf.float32)
    example["signal_to_noise_dms"] = parse_tensor(example["signal_to_noise_dms"], out_type=tf.float32)

    ### Sparse Features
    example["seq"] = parse_tensor(to_dense(example["seq"])[0], out_type=tf.float32)
    example["reactivity_2a3"] = parse_tensor(to_dense(example["reactivity_2a3"])[0], out_type=tf.float32)
    example["reactivity_dms"] = parse_tensor(to_dense(example["reactivity_dms"])[0], out_type=tf.float32)
    example["reactivity_error_2a3"] = parse_tensor(to_dense(example["reactivity_error_2a3"])[0], out_type=tf.float32)
    example["reactivity_error_dms"] = parse_tensor(to_dense(example["reactivity_error_dms"])[0], out_type=tf.float32)
    example["bpp_matrix"] = parse_tensor(to_dense(example["bpp_matrix"])[0], out_type=tf.float32)
    example["bracket_seq"] = parse_tensor(to_dense(example["bracket_seq"])[0], out_type=tf.float32)

    return example

train_modified_ds = train_raw_ds.map(rna_example,num_parallel_calls=tf.data.AUTOTUNE)
valid_modified_ds = valid_raw_ds.map(rna_example,num_parallel_calls=tf.data.AUTOTUNE)
test_modified_ds = test_raw_ds.map(rna_example,num_parallel_calls=tf.data.AUTOTUNE)

In [7]:
# @title
single_example = train_modified_ds.take(1).get_single_element()
print("id"," : ",single_example["seq_id"].numpy())
print("\n\n")
print("seq"," : ",single_example["seq"].numpy()[:20])
print("seq shape :",single_example["seq"].numpy().shape)
print("\n\n")
print("reads_2a3",": ",single_example["reads_2a3"].numpy())
print("\n\n")
print("reads_dms",": ",single_example["reads_dms"].numpy())
print("\n\n")
print("signal_to_noise_2a3",": ",single_example["signal_to_noise_2a3"].numpy())
print("\n\n")
print("signal_to_noise_dms",": ",single_example["signal_to_noise_dms"].numpy())
print("\n\n")
print("reactivity_2a3"," : ",single_example["reactivity_2a3"].numpy()[:20])
print("shape of reactivity_2a3 :",single_example["reactivity_2a3"].shape)
print("\n\n")
print("reactivity_dms"," : ",single_example["reactivity_dms"].numpy()[:20])
print("shape of reactivity_dms :",single_example["reactivity_dms"].shape)
print("\n\n")
print("reactivity_error_2a3",": ",single_example["reactivity_error_2a3"].numpy()[:20])
print("shape of reactivity_error_2a3",single_example["reactivity_error_2a3"].shape)
print("\n\n")
print("reactivity_error_dms",": ",single_example["reactivity_error_dms"].numpy()[:20])
print("shape of reactivity_error_dms",single_example["reactivity_error_dms"].shape)
print("\n\n")
print("bpp_matrix :",single_example["bpp_matrix"].numpy()[:5,:5])
print("bpp_matrix shape :",single_example["bpp_matrix"].numpy().shape)
print("\n\n")
print("bracket_sequence :",single_example["bracket_seq"].numpy()[:20])
print('bracket_sequence shape :',single_example["bracket_seq"].numpy().shape)

id  :  b'25ce8d5109cd'



seq  :  [3. 3. 3. 1. 1. 2. 3. 1. 2. 4. 2. 3. 1. 3. 4. 1. 3. 1. 3. 4.]
seq shape : (170,)



reads_2a3 :  4647.0



reads_dms :  1964.0



signal_to_noise_2a3 :  2.347



signal_to_noise_dms :  1.848



reactivity_2a3  :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_2a3 : (206,)



reactivity_dms  :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_dms : (206,)



reactivity_error_2a3 :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_error_2a3 (206,)



reactivity_error_dms :  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
shape of reactivity_error_dms (206,)



bpp_matrix : [[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
bpp_matrix shape : (170, 170)



bracket_sequence : [1. 1. 9. 9. 9. 1. 1. 1. 1. 1. 1. 9. 9. 9. 9. 

# The

## Total number of rows in each dataset is: 821840

## Total number of rows in sn_filter > 1 is: 181267

### Total number of tfrecord files: 164

### Total nunber of instances per file: 5005

### Total number of steps per epoch: 6420

### Total number of steps per epoch for 3 test files: 117



## Example sizes of the input sequences
- [170, 177, 115, 155, 206]
- minimum of reactivity columns : -129.281
- maximum of reactivity columns :  129.281
- positions upto 26 and from 126 are all nans can be padded


In [9]:
K = keras.backend
seq_map = {"A":1,"C":2,"G":3,"U":4,"START":4,"END":5,"EMPTY":0}
bracket_map = {"(":1,")":2,"[":3,"]":4,"{":5,"}":6,"<":7,">":8,".":9,"START":10,"END":11,"EMPTY":0}

def convert_and_pad(ex):

    l = tf.shape(ex["seq"],tf.int32)[0]

    shift = tf.random.uniform(shape=[1],minval=0,maxval=206-l+1,dtype=tf.int32)[0]

    # Sequence Processing and Mask Processing
    seq = ex["seq"] + 1
    seq = tf.pad(seq,[[1,0]],constant_values=seq_map["START"])                               # seq_map["START"]
    seq = tf.pad(seq,[[0,1]],constant_values=seq_map["END"])                                 # seq_map["END"]
    seq = tf.pad(seq,[[shift,206-l-shift]])                                                 # seq_map["EMPTY"]

    # Bracket Processing
    brac = ex["bracket_seq"] + 1
    brac = tf.pad(brac,[[1,0]],constant_values=bracket_map["START"])                            # bracket_map["START"]
    brac = tf.pad(brac,[[0,1]],constant_values=bracket_map["END"])                              # bracket_map["END"]
    brac = tf.pad(brac,[[shift,206-l-shift]],constant_values=bracket_map["EMPTY"])             # bracket_map["EMPTY"]

    # Reactivity Processing
    reac = tf.stack([ex["reactivity_2a3"][:l],ex["reactivity_dms"][:l]],axis=-1)
    reac = tf.pad(reac,[[shift+1,206+1-l-shift],[0,0]],constant_values=np.nan)

    # BPPMatrix
    bppm = ex["bpp_matrix"][:l,:l]
    bppm = tf.pad(bppm,[[shift+1,206+1-l-shift],[shift+1,206+1-l-shift]])

    return (seq,brac,bppm),reac


def shape_set(X,y,batch_size):
    (seq,brac,bpp),reac = X,y
    seq.set_shape([batch_size,208])
    brac.set_shape([batch_size,208])
    bpp.set_shape([batch_size,208,208])
    reac.set_shape([batch_size,208,2])
    return (seq,brac,bpp),reac


def create_train_ds(dataset,batch_size):
    dataset = dataset.map(convert_and_pad,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.ignore_errors()
    dataset = dataset.shuffle(10000)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size,drop_remainder=True,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(lambda X,y: shape_set(X,y,batch_size),num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.prefetch(tf.data.AUTOTUNE)


def create_val_ds(dataset,batch_size):
    dataset = dataset.map(convert_and_pad,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size,drop_remainder=True,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(lambda X,y: shape_set(X,y,batch_size),num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.cache()
    return dataset.prefetch(tf.data.AUTOTUNE)

def create_test_ds(dataset,batch_size):
    dataset = dataset.map(convert_and_pad,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size,drop_remainder=True,num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(lambda X,y: shape_set(X,y,batch_size),num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.prefetch(tf.data.AUTOTUNE)


testing_ds = create_train_ds(train_modified_ds,batch_size=32)

In [10]:
X,y = testing_ds.take(1).get_single_element()

seq_input = X[0]
bracket_input = X[1]
bpp_matrix = X[2]
reactivity = y

In [11]:
print("sequence input shape: ",seq_input.shape)
print("bracket input :",bracket_input.shape)
print("bpp matrix shape :",bpp_matrix.shape)
print("reactivity shape :",reactivity.shape)

sequence input shape:  (32, 208)
bracket input : (32, 208)
bpp matrix shape : (32, 208, 208)
reactivity shape : (32, 208, 2)


# Model Trying

In [12]:
class StaticPosEncoding(keras.layers.Layer):

    def __init__(self,
            vocab_size:int,
            d_model,
            length,
            **kwargs):

        super(StaticPosEncoding,self).__init__(**kwargs)

        assert d_model%2 == 0,"Depth of model(d_model) should be even"

        d_model = d_model//2
        positions = np.arange(length)[:,np.newaxis]
        angles = np.arange(d_model)[np.newaxis,:]/d_model
        angles = 1/(10000**angles)
        angle_rads = positions * angles
        encode = tf.concat([tf.sin(angle_rads),tf.cos(angle_rads)],axis=-1)
        self.encode = tf.cast(encode,tf.float32)
        self.factor = tf.sqrt(tf.cast(d_model,tf.float32))
        self.embedding_layer = keras.layers.Embedding(vocab_size,d_model*2,mask_zero=True)

    def compute_mask(self, *args,**kwargs):
        return self.embedding_layer.compute_mask(*args,**kwargs)

    def call(self,x):
        seq_l = tf.shape(x)[1]
        x = self.embedding_layer(x)
        x *= self.factor
        return x + self.encode[tf.newaxis,:seq_l,:]

In [13]:
seq_pos_encode = StaticPosEncoding(vocab_size=len(seq_map),d_model=512,length=2048)
brac_pos_encode = StaticPosEncoding(vocab_size=len(bracket_map),d_model=512,length=2048)
seq_pos_encode_output = seq_pos_encode(X[0])
brac_pos_encode_output = brac_pos_encode(X[1])
print(seq_pos_encode_output.shape)
print(brac_pos_encode_output.shape)

(32, 208, 512)
(32, 208, 512)


In [14]:
class ConvolutionLayer(keras.layers.Layer):

    def __init__(self,
                 n_layers,
                 n_filters,
                 ksize,
                 drop_rate,
                 **kwargs):
        super(ConvolutionLayer,self).__init__(**kwargs)
        self.conv_layers = [keras.layers.Conv2D(filters=n_filters,kernel_size=ksize,padding="same",use_bias=False,activation="relu") for _ in range(n_layers)]
        self.perm = keras.layers.Permute(dims=[3,1,2])
        self.norm = keras.layers.LayerNormalization()
        self.drop = keras.layers.Dropout(drop_rate)

    def call(self,x):
        x = tf.expand_dims(x,axis=-1)
        for layer in self.conv_layers:
            x = layer(x)
        x = self.perm(x)
        return self.norm(self.drop(x))

In [15]:
conv = ConvolutionLayer(n_layers=3,n_filters=6,ksize=3,drop_rate=0.1)
bpp_conv_ouput = conv(bpp_matrix)
bpp_conv_ouput.shape

TensorShape([32, 6, 208, 208])

In [16]:
class ConvolutedAttention(keras.layers.Layer):

    def __init__(self,
            num_heads,
            key_dim,
            **kwargs):

        super(ConvolutedAttention,self).__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.wq = keras.layers.Dense(num_heads*key_dim)
        self.wv = keras.layers.Dense(num_heads*key_dim)
        self.wk = keras.layers.Dense(num_heads*key_dim)
        self.dense = keras.layers.Dense(key_dim)
        self.first_drop = keras.layers.Dropout(0.1)
        self.last_drop = keras.layers.Dropout(0.1)
        self.layer_norm = keras.layers.LayerNormalization()
        self.factor = tf.math.rsqrt(tf.constant(key_dim,tf.float32))
        self.softmax = keras.layers.Softmax()
        self.add = keras.layers.Add()


    def call(self,query,key,value,convoluted=None,return_attention_scores=False,mask=None):


        batch_size = tf.shape(query)[0]
        seq_len = tf.shape(query)[1]
        q = self.wq(query)
        k = self.wk(key)
        v = self.wv(value)

        q = tf.transpose(tf.reshape(q,shape=[batch_size,-1,self.num_heads,self.key_dim]),perm=[0,2,1,3])
        k = tf.transpose(tf.reshape(k,shape=[batch_size,-1,self.num_heads,self.key_dim]),perm=[0,2,1,3])
        v = tf.transpose(tf.reshape(v,shape=[batch_size,-1,self.num_heads,self.key_dim]),perm=[0,2,1,3])

        attention_score = self.factor * tf.matmul(q,k,transpose_b=True)
        attention_score = self.first_drop(self.softmax(attention_score + convoluted))

        attention_out = tf.reshape(tf.transpose(tf.matmul(attention_score,v),perm=[0,2,1,3]),shape=[batch_size,seq_len,-1])
        attention_out = self.last_drop(self.dense(attention_out))
        attention_out = self.layer_norm(self.add([attention_out,query]))

        if return_attention_scores:
            return attention_out,attention_score
        return attention_out

In [17]:
conv_attn = ConvolutedAttention(num_heads=6,key_dim=512)
conv_attn_out = conv_attn(query=seq_pos_encode_output,key=seq_pos_encode_output,value=seq_pos_encode_output,convoluted=bpp_conv_ouput)
conv_attn_out.shape

TensorShape([32, 208, 512])

In [18]:
class FeedForward(keras.layers.Layer):

    def __init__(self,
            feed_out_unit,
            feed_forward,
            feed_forward_drop,
            **kwargs):

        super(FeedForward,self).__init__(**kwargs)
        self.dense_out = keras.layers.Dense(feed_out_unit)
        self.feed_forward = keras.layers.Dense(feed_forward)
        self.feed_forward_drop = keras.layers.Dropout(feed_forward_drop)
        self.norm = keras.layers.LayerNormalization()
        self.add = keras.layers.Add()

    def call(self,x):

        x_copy = x
        x = self.feed_forward(x)
        x = self.dense_out(x)
        x = self.feed_forward_drop(x)
        x = self.feed_forward_drop(x)
        return self.norm(self.add([x_copy,x]))

In [19]:
feed_forward = FeedForward(feed_out_unit=512,feed_forward=2048,feed_forward_drop=0.2)
feed_out = feed_forward(conv_attn_out)
feed_out.shape

TensorShape([32, 208, 512])

In [20]:
class BracketEncoderUnit(keras.models.Model):

    def __init__(self,
            encoder_convolution_layers:int,
            encoder_convolution_ksize:int,
            encoder_convolution_drop_rate:float,
            encoder_convolution_num_heads:int,
            encoder_key_dim:int,
            encoder_feed_out_units:int,
            encoder_feed_forward_units:int,
            encoder_feed_forward_drop_rate:float,
            **kwargs):


        super(BracketEncoderUnit,self).__init__(**kwargs)

        self.convolution_layer = ConvolutionLayer(
            encoder_convolution_layers,
            encoder_convolution_num_heads,
            encoder_convolution_ksize,
            encoder_convolution_drop_rate,
        )

        self.convoluted_attention = ConvolutedAttention(
            encoder_convolution_num_heads,
            encoder_key_dim
        )

        self.feed_forward = FeedForward(
            encoder_feed_out_units,
            encoder_feed_forward_units,
            encoder_feed_forward_drop_rate
        )

    def call(self,bracket_encode_out,bpp_matrix):

        convolution_out = self.convolution_layer(bpp_matrix)
        convoluted_attention_out = self.convoluted_attention(
            query = bracket_encode_out,
            key = bracket_encode_out,
            value = bracket_encode_out,
            convoluted = convolution_out
        )
        feed_out = self.feed_forward(convoluted_attention_out)

        return feed_out

In [21]:
bracket_seq_encoder = BracketEncoderUnit(
    encoder_convolution_layers=3,
    encoder_convolution_ksize=3,
    encoder_convolution_drop_rate=0.1,
    encoder_convolution_num_heads=6,
    encoder_key_dim=512,
    encoder_feed_forward_drop_rate=0.1,
    encoder_feed_forward_units=2048,
    encoder_feed_out_units=512
)

bracket_seq_encoder_out = bracket_seq_encoder(brac_pos_encode_output,bpp_matrix)
bracket_seq_encoder_out.shape

TensorShape([32, 208, 512])

In [22]:
class SequenceEncoderUint(keras.models.Model):

    def __init__(self,
            encoder_convolution_layers:int,
            encoder_convolution_ksize:int,
            encoder_convolution_drop_rate:float,
            encoder_convolution_num_heads:int,
            encoder_cross_attention_num_heads:int,
            encoder_key_dim:int,
            encoder_feed_out_units:int,
            encoder_feed_forward_units:int,
            encoder_feed_forward_drop_rate:float,
            **kwargs):

        super(SequenceEncoderUint,self).__init__(**kwargs)

        self.convolution_layer = ConvolutionLayer(
            encoder_convolution_layers,
            encoder_convolution_num_heads,
            encoder_convolution_ksize,
            encoder_convolution_drop_rate
            )

        self.convoluted_attention = ConvolutedAttention(
            encoder_convolution_num_heads,
            encoder_key_dim
            )

        self.encoder_cross_attention = keras.layers.MultiHeadAttention(
            num_heads=encoder_cross_attention_num_heads,
            key_dim=encoder_key_dim
            )

        self.feed_forward = FeedForward(
            encoder_feed_out_units,
            encoder_feed_forward_units,
            encoder_feed_forward_drop_rate
            )

    def call(self,sequence_encode_out,bpp_matrix,bracket_encode_out):

        convolution_out = self.convolution_layer(bpp_matrix)
        convolutedattention_out = self.convoluted_attention(
            query=sequence_encode_out,
            key=sequence_encode_out,
            value=sequence_encode_out,
            convoluted=convolution_out
            )
        crossattention_out = self.encoder_cross_attention(
            query=convolutedattention_out,
            key=bracket_encode_out,
            value=bracket_encode_out
            )
        feed_out = self.feed_forward(crossattention_out)

        return feed_out

In [23]:
seq_encoder = SequenceEncoderUint(
    encoder_convolution_layers=3,
    encoder_convolution_ksize=3,
    encoder_convolution_drop_rate=0.1,
    encoder_convolution_num_heads=6,
    encoder_cross_attention_num_heads=6,
    encoder_key_dim=512,
    encoder_feed_forward_drop_rate=0.1,
    encoder_feed_forward_units=2048,
    encoder_feed_out_units=512
    )

seq_encoder_out = seq_encoder(seq_pos_encode_output,bpp_matrix,bracket_seq_encoder_out)
seq_encoder_out.shape

TensorShape([32, 208, 512])

In [24]:
class Transformer(keras.models.Model):

    def __init__(self,

            # Seuqence encoder arguments

            sequence_encoder_vocab_size:int=len(seq_map),
            sequence_encoder_d_model:int=256,
            sequence_encoder_length:int=2048,
            sequence_encoder_num_layers:int=3,
            sequence_encoder_convolution_layers:int=3,
            sequence_encoder_convolution_ksize:int=3,
            sequence_encoder_convolution_drop_rate:float=0.1,
            sequence_encoder_convolution_num_heads:int=4,
            sequence_encoder_cross_attention_num_heads:int=4,
            sequence_encoder_key_dim:int=256,
            sequence_encoder_feed_out_units:int=256,
            sequence_encoder_feed_forward_units:int=1024,
            sequence_encoder_feed_forward_drop_rate:float=0.1,

            # Bracket Sequence Encoder arguments

            bracket_sequence_encoder_vocab_size:int=len(bracket_map),
            bracket_sequence_encoder_d_model:int=256,
            bracket_sequence_encoder_length:int=2048,
            bracket_sequence_encoder_num_layers:int=3,
            bracket_sequence_encoder_convolution_layers:int=3,
            bracket_sequence_encoder_convolution_ksize:int=3,
            bracket_sequence_encoder_convolution_drop_rate:float=0.1,
            bracket_sequence_encoder_num_heads:int=4,
            bracket_sequence_encoder_key_dim:int=256,
            bracket_sequence_encoder_feed_out_units:int=256,
            bracket_sequence_encoder_feed_forward_units:int=1024,
            bracket_sequence_encoder_feed_forward_drop_rate:float=0.1,

            # Out Dense
            total_out_units = 2,
            **kwargs):


        super(Transformer,self).__init__(**kwargs)

        self.sequence_pos_encoder = StaticPosEncoding(
            sequence_encoder_vocab_size,
            sequence_encoder_d_model,
            sequence_encoder_length,
        )

        self.bracket_sequence_pos_encoder = StaticPosEncoding(
            bracket_sequence_encoder_vocab_size,
            bracket_sequence_encoder_d_model,
            bracket_sequence_encoder_length,
        )

        self.sequence_encoder_units_list = [SequenceEncoderUint(
            sequence_encoder_convolution_layers,
            sequence_encoder_convolution_ksize,
            sequence_encoder_convolution_drop_rate,
            sequence_encoder_convolution_num_heads,
            sequence_encoder_cross_attention_num_heads,
            sequence_encoder_key_dim,
            sequence_encoder_feed_out_units,
            sequence_encoder_feed_forward_units,
            sequence_encoder_feed_forward_drop_rate
        ) for _ in range(sequence_encoder_num_layers)]

        self.bracket_sequence_encoder_units_list = [BracketEncoderUnit(
            bracket_sequence_encoder_convolution_layers,
            bracket_sequence_encoder_convolution_ksize,
            bracket_sequence_encoder_convolution_drop_rate,
            bracket_sequence_encoder_num_heads,
            bracket_sequence_encoder_key_dim,
            bracket_sequence_encoder_feed_out_units,
            bracket_sequence_encoder_feed_forward_units,
            bracket_sequence_encoder_feed_forward_drop_rate
        ) for _ in range(bracket_sequence_encoder_num_layers)]

        self.total_out = keras.layers.Dense(total_out_units,activation=None)


    def call(self,X):

        sequence,bracket_sequence,bpp_matrix = X
        out = self.bracket_sequence_pos_encoder(bracket_sequence)

        for layer in self.bracket_sequence_encoder_units_list:
            out = layer(out,bpp_matrix)

        seq_out = self.sequence_pos_encoder(sequence)

        for layer in self.sequence_encoder_units_list:
            seq_out = layer(seq_out,bpp_matrix,out)


        seq_out = self.total_out(seq_out)

        try:
            del seq_out._keras_mask
        except AttributeError:
            pass

        return seq_out

In [25]:
transformer = Transformer()
transformer_out = transformer(X)
transformer_out.shape

TensorShape([32, 208, 2])

In [26]:
transformer.summary()

In [27]:
class CustomMetric2A3(keras.metrics.Metric):

    def __init__(self,**kwargs):

        super(CustomMetric2A3,self).__init__(**kwargs)
        self.mae = self.add_weight(name="mae",initializer="zeros")

    def update_state(self,y_true,y_pred,sample_weight=None):

        y_true = tf.clip_by_value(y_true[...,0],clip_value_max=1,clip_value_min=0)
        y_pred = tf.clip_by_value(y_pred[...,0],clip_value_max=1,clip_value_min=0)
        mae = tf.abs(tf.subtract(y_true,y_pred))
        mae = tf.reduce_mean(mae[~tf.math.is_nan(mae)])
        self.mae.assign_add(mae)


    def result(self):
        return self.mae


    def reset_state(self):
        self.mae.assign(0.)

In [28]:
class CustomMetricDMS(keras.metrics.Metric):

    def __init__(self,**kwargs):

        super(CustomMetricDMS,self).__init__(**kwargs)
        self.mae = self.add_weight(name="mae",initializer="zeros")

    def update_state(self,y_true,y_pred,sample_weight=None):

        y_true = tf.clip_by_value(y_true[...,1],clip_value_min=0,clip_value_max=1)
        y_pred = tf.clip_by_value(y_pred[...,1],clip_value_max=1,clip_value_min=0)
        mae = tf.abs(tf.subtract(y_true,y_pred))
        mae = tf.reduce_mean(mae[~tf.math.is_nan(mae)])
        self.mae.assign_add(mae)

    def result(self):
        return self.mae


    def reset_state(self):
        self.mae.assign(0.)

In [29]:
class CustomLoss(keras.losses.Loss):

    def __init__(self,**kwargs):

        super(CustomLoss,self).__init__(**kwargs)

    def __call__(self,y_true,y_pred,sample_weight=None):

        y_true = tf.clip_by_value(y_true,clip_value_max=1,clip_value_min=0)
        y_pred = tf.clip_by_value(y_pred,clip_value_max=1,clip_value_min=0)
        mae_loss = tf.reduce_mean(tf.abs(tf.subtract(y_true,y_pred)),axis=-1)
        return tf.reduce_mean(mae_loss[~tf.math.is_nan(mae_loss)])

In [30]:
cust_loss= CustomLoss()
cust_metric_2a3 = CustomMetric2A3()
cust_metric_dms = CustomMetricDMS()
cust_metric_2a3.update_state(y,transformer_out)
cust_metric_dms.update_state(y,transformer_out)
print("Loss:       ",cust_loss(y,transformer_out).numpy())
print("Metric 2A3: ",cust_metric_2a3.result().numpy())
print("Metric DMS: ",cust_metric_dms.result().numpy())

Loss:        0.35013685
Metric 2A3:  0.33700576
Metric DMS:  0.36326793


In [31]:
class ExpLR(keras.callbacks.Callback):

    def __init__(self,factor,**kwargs):

        super(ExpLR,self).__init__(**kwargs)
        self.factor = factor
        self.rates = []
        self.losses = []

    def on_epoch_begin(self,epoch,logs=None):
        self.sum_of_epoch_losses = 0

    def on_batch_end(self,batch,logs=None):
        mean_epoch_loss = logs["loss"]
        new_sum_of_epoch_losses = mean_epoch_loss * (batch+1)
        batch_loss = new_sum_of_epoch_losses - self.sum_of_epoch_losses
        self.sum_of_epoch_losses = new_sum_of_epoch_losses
        self.rates.append(K.get_value(self.model.optimizer.learning_rate))
        self.losses.append(batch_loss)
        K.set_value(self.model.optimizer.learning_rate,self.model.optimizer.learning_rate* self.factor)

In [33]:
def find_learning_rate(model,ds,iterations,epochs,batch_size,min_rate:float=1e-7,max_rate:float=1):

    init_weights = model.get_weights()
    total_iterations = iterations * epochs
    steps_per_epoch = iterations//BATCH_SIZE
    factor = (max_rate/min_rate) ** (1/total_iterations)
    init_lr = K.get_value(model.optimizer.learning_rate)
    K.set_value(model.optimizer.learning_rate,min_rate)
    exp_lr = ExpLR(factor)
    history = model.fit(ds,epochs=epochs,steps_per_epoch=steps_per_epoch,callbacks=[exp_lr])
    return exp_lr.rates,exp_lr.losses,init_weights

In [None]:
with strategy.scope():
    model = Transformer()
    optimizer = keras.optimizers.SGD(1e-7)
    training_loss = keras.metrics.Mean(name="Custom_Loss",dtype=tf.float32)
    training_metrics_2a3 = CustomMetric2A3(name="Custom_Metric_2A3",dtype=tf.float32)
    training_metrics_dms = CustomMetricDMS(name="CUstom_Metric_DMS",dtype=tf.float32)



BATCH_SIZE = 16 * strategy.num_replicas_in_sync
steps_per_epoch = 1175
num_train_steps = steps_per_epoch * BATCH_SIZE
validation_steps = 141
test_steps  = 94
per_replica_batch_size = BATCH_SIZE // strategy.num_replicas_in_sync
cust_loss = CustomLoss(name="custom_loss")
train_ds = strategy.distribute_datasets_from_function(lambda _: create_train_ds(train_modified_ds,per_replica_batch_size))

@tf.function
def train_step(iterator):
    
    def step_fn(inputs):
        
        X,y_true = inputs
        
        with tf.GradientTape() as tape:
            
            y_pred = model(X,training=True)
            
            loss = cust_loss(y_true,y_pred)
            
        grads = tape.gradient(loss,model.trainable_variables)
        optimizer.apply_gradients(list(zip(grads,model.trainable_variables)))
        training_loss.update_state(loss*strategy.num_replicas_in_sync)
        training_metrics_2a3.update_state(y_true,y_pred)
        training_metrics_dms.update_state(y_true,y_pred)
        
    strategy.run(step_fn,args=(next(iterator),))
    
    


train_iterator = iter(train_ds)
num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch: {epoch+1}/{num_epochs}")
    
    pb_i = Progbar(num_train_steps,stateful_metrics=["custom metric 2a3","custom metric dms"])
    
    for step in range(steps_per_epoch):
        
        train_step(train_iterator)
        
        values = [("custom loss",training_loss.result().numpy()),("custom metric 2a3",training_metrics_2a3.result().numpy()),("custom metric dms",training_metrics_dms.result().numpy())]
        
        pb_i.add(BATCH_SIZE,values=values)
        
    print(f"Total loss for epoch: {training_loss.result().numpy()}")
    print(f"Total Metric 2A3 for epoch : {training_metrics_2a3.result().numpy()}")
    print(f"Total Metric DMS for epoch : {training_metrics_dms.result().numpy()}")
    training_loss.reset_state()
    training_metrics_2a3.reset_state()
    training_metrics_dms.reset_state()

Epoch: 1/5
[1m 33152/150400[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m4:16[0m 2ms/step - custom loss: 3.9834 - custom metric 2a3: 169.8111 - custom metric dms: 81.1902