# Make sure the dataset in in the following structre
```
* dataset ("/input/dataset/)
|-dataset_iden
 |-dataset_iden
  |-x.tfrecord
  |-x.tfrecord
  |-x.tfrecord
  .................
|-config.json
|-enc.h5
|-seq.h5
|-pos.h5
|-fuse.h5
```

# Change the variables in first cell accordingly

In [1]:
#------------------------------
# change able params
#------------------------------
#--> dataset pipeline
PER_REPLICA_BATCH_SIZE  = 128                           # batch size per replica
EPOCHS                  = 100                           # number of epochs to train
DATASET_IDENTIFIER      = 'apsis-cdr-gen-bangla-final'  # kaggle dataset name
use_pretrained          = True                          # train from a pretrained version 
eval_split              = 20                            # % of total data to use for evaluation

# Every thing is fixed from this point on

In [None]:
import os
inp_path=f"../input/{DATASET_IDENTIFIER}/"
#--> data property
config_json  =  f'{inp_path}config.json'           # @path to config.json
#--> weights
# only applicable when use_pretrained is true
enc_weights_path        = f'{inp_path}enc.h5'  # path to "enc.h5"
seq_weights_path        = f'{inp_path}seq.h5'  # path to "seq.h5"
pos_weights_path        = f'{inp_path}pos.h5'  # path to "pos.h5"
fuse_weights_path       = f'{inp_path}fuse.h5'  # path to "fuse.h5"

assert os.path.exists(config_json),"config.json not found"
if use_pretrained:
    assert os.path.exists(enc_weights_path ),"enc.h5 not found"
    assert os.path.exists(seq_weights_path ),"seq.h5 not found"
    assert os.path.exists(pos_weights_path ),"pos.h5 not found"
    assert os.path.exists(fuse_weights_path ),"fuse.h5 not found"



In [None]:
#----------------
# imports
#---------------
import tensorflow as tf
import random
import json
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from glob import glob
from tqdm.auto import tqdm
from kaggle_datasets import KaggleDatasets
import random
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  
import tensorflow as tf
#-------------------
# fixed params
#------------------
nb_channels =  3        
enc_filters =  256
factor      =  32
#-------------
# config-globals
#-------------
with open(config_json) as f:
    conf = json.load(f)

print(conf)

img_height  =  conf["img_height"]
img_width   =  conf["img_width"]
vocab       =  conf["vocab"]
pos_max     =  conf["pos_max"]
RECORD_SIZE =  conf["tf_size"] 
zip_iden    =  conf["zip_iden"]


# calculated
enc_shape   =  (img_height//factor,img_width//factor, enc_filters )
attn_shape  =  (None, enc_filters )
mask_len    =  int((img_width//factor)*(img_height//factor))
start_value    =vocab.index("start")
end_value      =vocab.index("end") 
pad_value      =vocab.index("pad")

print("Label len:",pos_max)
print("Vocab len:",len(vocab))
print("Start value:",start_value)
print("End value:",end_value)
print("pad_value:",pad_value)



#--------------------------
# GCS Paths and tfrecords
#-------------------------
def get_tfrecs(tfrec_folder_path):
    gcs_pattern=os.path.join(tfrec_folder_path,'*.tfrecord')
    file_paths = tf.io.gfile.glob(gcs_pattern)
    random.shuffle(file_paths)
    print(f"{tfrec_folder_path}:",len(file_paths))
    return file_paths




GCS_PATH = KaggleDatasets().get_gcs_path(DATASET_IDENTIFIER)
rec_path=os.path.join(GCS_PATH,zip_iden,zip_iden) 
recs=get_tfrecs(rec_path)
# dist/split
len_recs=len(recs)
nb_eval_recs=int(len_recs*(eval_split/100))

eval_recs =recs[:nb_eval_recs]
train_recs=recs[nb_eval_recs:]
random.shuffle(eval_recs)
random.shuffle(train_recs)

print("Eval-recs:",len(eval_recs))
print("Train-recs:",len(train_recs))

In [None]:
#----------------------------------------------------------
# Detect hardware, return appropriate distribution strategy
#----------------------------------------------------------
# TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    tf.config.optimizer.set_jit(True)
else:
    strategy = tf.distribute.get_strategy() 
    # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
#-------------------------------------
# batching , strategy and steps
#-------------------------------------
if strategy.num_replicas_in_sync==1:
    BATCH_SIZE = PER_REPLICA_BATCH_SIZE
else:
    BATCH_SIZE = PER_REPLICA_BATCH_SIZE*strategy.num_replicas_in_sync

# set    
STEPS_PER_EPOCH = (len(train_recs)*RECORD_SIZE)//BATCH_SIZE
EVAL_STEPS      = (len(eval_recs)*RECORD_SIZE)//BATCH_SIZE
print("Steps:",STEPS_PER_EPOCH)
print("Batch Size:",BATCH_SIZE)
print("Eval Steps:",EVAL_STEPS)

In [None]:
#------------------------------
# parsing tfrecords basic
#------------------------------
def data_input_fn(recs,mode): 
    '''
      This Function generates data from gcs
      * The parser function should look similiar now because of datasetEDA
    '''
    def _parser(example):   
        feature ={  'image'  : tf.io.FixedLenFeature([],tf.string) ,
                    'label'  : tf.io.FixedLenFeature([pos_max],tf.int64),
                    'mask'   : tf.io.FixedLenFeature([mask_len],tf.int64)
        }    
        parsed_example=tf.io.parse_single_example(example,feature)
        # image
        image_raw=parsed_example['image']
        image=tf.image.decode_png(image_raw,channels=nb_channels)
        image=tf.cast(image,tf.float32)/255.0
        image=tf.reshape(image,(img_height,img_width,nb_channels))
        
        # label
        label=parsed_example['label']
            
        # position
        pos=tf.range(0,pos_max)
        pos=tf.cast(pos,tf.int32)
        # mask
        mask=parsed_example['mask']
        mask=1-tf.cast(mask,tf.float32)
        mask=tf.stack([mask for _ in range(pos_max)])

        return {"image":image,"label":tf.cast(label, tf.int32),"pos":pos,"mask":mask},tf.cast(label, tf.float32)
    
      

    # fixed code (for almost all tfrec training)
    dataset = tf.data.TFRecordDataset(recs)
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(2048,reshuffle_each_iteration=True)
    dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = dataset.apply(tf.data.experimental.ignore_errors())
    return dataset

train_ds  =   data_input_fn(train_recs,"train")
eval_ds  =   data_input_fn(eval_recs,"eval")

In [None]:
#------------------------
# visualizing data
#------------------------


print("---------------------------------------------------------------")
print("visualizing data")
print("---------------------------------------------------------------")
for x,y in train_ds.take(1):
    data=np.squeeze(x["image"][0])
    plt.imshow(data)
    plt.show()
    print("---------------------------------------------------------------")
    print("label:",x["label"][0])
    print("---------------------------------------------------------------")
    print("pos:",x["pos"][0])
    print("---------------------------------------------------------------")
    print("mask:",x["mask"][0][0])
    print("---------------------------------------------------------------")
    print('Image Batch Shape:',x["image"].shape)
    print('Label Batch Shape:',x["label"].shape)
    print('Position Batch Shape:',x["pos"].shape)
    print('Mask Batch Shape:',x["mask"].shape)
    print("---------------------------------------------------------------")
    print('Target Batch Shape:',y.shape)

In [None]:
#-----------------------------------
#creating Embedding Weights
#-----------------------------------
if not use_pretrained:
    import torch
    import torch.nn as nn
    seq_emb              = nn.Embedding(len(vocab)+1,enc_filters, padding_idx=pad_value)
    seq_emb_weight       = seq_emb.weight.data.numpy()
    print(seq_emb_weight.shape)
    pos_emb              = nn.Embedding(pos_max+1,enc_filters)
    pos_emb_weight       = pos_emb.weight.data.numpy()
    print(pos_emb_weight.shape)

In [None]:
class DotAttention(tf.keras.layers.Layer):
    """
        Calculate the attention weights.
        q, k, v must have matching leading dimensions.
        k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
        The mask has different shapes depending on its type(padding or look ahead)
        but it must be broadcastable for addition.

        Args:
        q: query shape == (..., seq_len_q, depth)
        k: key shape == (..., seq_len_k, depth)
        v: value shape == (..., seq_len_v, depth_v)
        mask: Float tensor with shape broadcastable
              to (..., seq_len_q, seq_len_k). Defaults to None.

        Returns:
        output
    """
    def __init__(self):
        super().__init__()
        self.inf_val=-1e9
        
    def call(self,q, k, v, mask):
        matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
       
        # scale matmul_qk
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        # add the mask to the scaled tensor.
        if mask is not None:
            scaled_attention_logits += (mask * self.inf_val)

        # softmax is normalized on the last axis (seq_len_k) so that the scores
        # add up to 1.
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

        output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

        return output

In [None]:
import tensorflow as tf 


def encoder():
    '''
    creates the encoder part:
    * defatult backbone : DenseNet121 **changeable
    args:
      img           : input image layer
        
    returns:
      enc           : channel reduced feature layer

    '''
    # img input
    img=tf.keras.Input(shape=(img_height,img_width,nb_channels),name='image')
    # backbone
    backbone=tf.keras.applications.DenseNet121(input_tensor=img ,weights=None,include_top=False)
    # feat_out
    enc=backbone.output
    # enc 
    enc=tf.keras.layers.Conv2D(enc_filters,kernel_size=3,padding="same")(enc)
    return tf.keras.Model(inputs=img,outputs=enc,name="rs_encoder")

def seq_decoder():
    '''
    sequence attention decoder (for training)
    Tensorflow implementation of : 
    https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sequence_attention_decoder.py
    '''
    # label input
    gt=tf.keras.Input(shape=(pos_max,),dtype='int32',name="label")
    # mask
    mask=tf.keras.Input(shape=(pos_max,mask_len),dtype='float32',name="mask")
    # encoder
    enc=tf.keras.Input(shape=enc_shape,name='enc_seq')
    
    # embedding,weights=[seq_emb_weight]
    if use_pretrained:
        embedding=tf.keras.layers.Embedding(len(vocab)+1,enc_filters)(gt)
    else:
        embedding=tf.keras.layers.Embedding(len(vocab)+1,enc_filters,weights=[seq_emb_weight])(gt)
    # sequence layer (2xlstm)
    lstm=tf.keras.layers.LSTM(enc_filters,return_sequences=True)(embedding)
    query=tf.keras.layers.LSTM(enc_filters,return_sequences=True)(lstm)
    # attention modeling
    # value
    bs,h,w,nc=enc.shape
    value=tf.keras.layers.Reshape((h*w,nc))(enc)
    attn=DotAttention()(query,value,value,mask)
    return tf.keras.Model(inputs=[gt,enc,mask],outputs=attn,name="rs_seq_decoder")
 


def pos_decoder():
    '''
    position attention decoder (for training)
    Tensorflow implementation of : 
    https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/position_attention_decoder.py
    '''
    # pos input
    pt=tf.keras.Input(shape=(pos_max,),dtype='int32',name="pos")
    # mask
    mask=tf.keras.Input(shape=(pos_max,mask_len),dtype='float32',name="mask")
    # encoder
    enc=tf.keras.Input(shape=enc_shape,name='enc_pos')
    
    # embedding,weights=[pos_emb_weight]
    if use_pretrained:
        query=tf.keras.layers.Embedding(pos_max+1,enc_filters)(pt)
    else:
        query=tf.keras.layers.Embedding(pos_max+1,enc_filters,weights=[pos_emb_weight])(pt)
    # part-1:position_aware_module
    bs,h,w,nc=enc.shape
    value=tf.keras.layers.Reshape((h*w,nc))(enc)
    # sequence layer (2xlstm)
    lstm=tf.keras.layers.LSTM(enc_filters,return_sequences=True)(value)
    x=tf.keras.layers.LSTM(enc_filters,return_sequences=True)(lstm)
    x=tf.keras.layers.Reshape((h,w,nc))(x)
    # mixer
    x=tf.keras.layers.Conv2D(enc_filters,kernel_size=3,padding="same")(x)
    x=tf.keras.layers.Activation("relu")(x)
    key=tf.keras.layers.Conv2D(enc_filters,kernel_size=3,padding="same")(x)
    bs,h,w,c=key.shape
    key=tf.keras.layers.Reshape((h*w,nc))(key)
    attn=DotAttention()(query,key,value,mask)
    return tf.keras.Model(inputs=[pt,enc,mask],outputs=attn,name="rs_pos_decoder")

def fusion():
    '''
    fuse the output of gt_attn and pt_attn 
    '''
    # label input
    gt_attn=tf.keras.Input(shape=attn_shape,name="gt_attn")
    # pos input
    pt_attn=tf.keras.Input(shape=attn_shape,name="pt_attn")
    
    x=tf.keras.layers.Concatenate()([gt_attn,pt_attn])
    # Linear
    x=tf.keras.layers.Dense(enc_filters*2,activation=None)(x)
    # GLU
    xl=tf.keras.layers.Activation("linear")(x)
    xs=tf.keras.layers.Activation("sigmoid")(x)
    x =tf.keras.layers.Multiply()([xl,xs])
    # prediction
    x=tf.keras.layers.Dense(len(vocab),activation=None)(x)
    return tf.keras.Model(inputs=[gt_attn,pt_attn],outputs=x,name="rs_fusion")

with strategy.scope():
    rs_encoder    =  encoder()
    rs_seq_decoder=  seq_decoder()
    rs_pos_decoder=  pos_decoder()
    rs_fusion     =  fusion()
    if use_pretrained:
        rs_encoder.load_weights(enc_weights_path)
        rs_seq_decoder.load_weights(seq_weights_path)
        rs_pos_decoder.load_weights(pos_weights_path)
        rs_fusion.load_weights(fuse_weights_path)


In [None]:
with strategy.scope():
    # optimizer
    optimizer = tf.keras.optimizers.Adam(lr=0.0001)
    # loss
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    def CE_loss(real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, pad_value))
        loss_ = loss_object(real, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        return tf.reduce_sum(loss_)/tf.reduce_sum(mask)
    def C_acc(real, pred):
        accuracies = tf.equal(tf.cast(real,tf.int64), tf.argmax(pred, axis=2))
        mask = tf.math.logical_not(tf.math.equal(real,pad_value))
        accuracies = tf.math.logical_and(mask, accuracies)
        accuracies = tf.cast(accuracies, dtype=tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)

In [None]:
class robust_scanner(tf.keras.Model):
    def __init__(self,encoder,seq_decoder,pos_decoder,fusion):
        super(robust_scanner, self).__init__()
        self.encoder     = encoder
        self.seq_decoder = seq_decoder
        self.pos_decoder = pos_decoder
        self.fusion      = fusion
        
    def compile(self,optimizer,loss_fn,acc):
        super(robust_scanner, self).compile()
        self.optimizer = optimizer
        self.loss_fn   = loss_fn
        self.acc       = acc
       
        
    def train_step(self, batch_data):
        data,gt= batch_data
        image=data["image"]
        pos  =data["pos"]
        mask =data["mask"]
        # label
        label=tf.ones_like(gt,dtype=tf.float32)*start_value
        preds=[]
        
        with tf.GradientTape() as enc_tape, tf.GradientTape() as pos_dec_tape,tf.GradientTape() as seq_dec_tape,tf.GradientTape() as fusion_tape:
            enc    = self.encoder(image, training=True)
            pt_attn= self.pos_decoder({"pos":pos,"enc_pos":enc,"mask":mask},training=True)
            
            gt_attn= self.seq_decoder({"label":gt,"enc_seq":enc,"mask":mask},training=True)
            pred   = self.fusion({"gt_attn":gt_attn,"pt_attn":pt_attn},training=True)
            
            # loss
            loss = self.loss_fn(gt[:,1:],pred[:,:-1,:])
            # c acc
            char_acc=self.acc(gt[:,1:],pred[:,:-1,:])
            
        # calc gradients    
        enc_grads     = enc_tape.gradient(loss,self.encoder.trainable_variables)
        pos_dec_grads = pos_dec_tape.gradient(loss,self.pos_decoder.trainable_variables)
        seq_dec_grads = seq_dec_tape.gradient(loss,self.seq_decoder.trainable_variables)
        fusion_grads  = fusion_tape.gradient(loss,self.fusion.trainable_variables)
        
        # apply
        self.optimizer.apply_gradients(zip(enc_grads,self.encoder.trainable_variables))
        self.optimizer.apply_gradients(zip(pos_dec_grads,self.pos_decoder.trainable_variables))

        self.optimizer.apply_gradients(zip(seq_dec_grads,self.seq_decoder.trainable_variables))
        self.optimizer.apply_gradients(zip(fusion_grads,self.fusion.trainable_variables))

        
        return {"loss"    : loss,
                "char_acc": char_acc}
    
    def test_step(self, batch_data):
        data,gt= batch_data
        image=data["image"]
        pos  =data["pos"]
        mask =data["mask"]
        # label
        label=tf.ones_like(gt,dtype=tf.float32)*start_value
        preds=[]
        
        enc    = self.encoder(image, training=False)
        pt_attn= self.pos_decoder({"pos":pos,"enc_pos":enc,"mask":mask},training=False)
        
        for i in range(pos_max):
            gt_attn=self.seq_decoder({"label":label,"enc_seq":enc,"mask":mask},training=False)
            step_gt_attn=gt_attn[:,i,:]
            step_pt_attn=pt_attn[:,i,:]
            pred=self.fusion({"gt_attn":step_gt_attn,"pt_attn":step_pt_attn},training=False)
            preds.append(pred)
            # can change on error
            char_out=tf.nn.softmax(pred,axis=-1)
            max_idx =tf.math.argmax(char_out,axis=-1)
            if i < pos_max - 1:
                label=tf.unstack(label,axis=-1)
                label[i+1]=tf.cast(max_idx,tf.float32)
                label=tf.stack(label,axis=-1)
                
        pred=tf.stack(preds,axis=1)
        # loss
        loss = self.loss_fn(gt[:,1:],pred[:,:-1,:])
        # c acc
        char_acc=self.acc(gt[:,1:],pred[:,:-1,:])
        
        
        return {"loss"    : loss,
                "char_acc": char_acc}
    

In [None]:
with strategy.scope():
    model = robust_scanner(rs_encoder,
                           rs_seq_decoder,
                           rs_pos_decoder,
                           rs_fusion)

    model.compile(optimizer = optimizer,
                  loss_fn   = CE_loss,
                  acc       = C_acc)

In [None]:
# reduces learning rate on plateau
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(factor=0.1,
                                                  cooldown= 10,
                                                  patience=3,
                                                  verbose =1,
                                                  min_lr=0.1e-7)
# early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(patience=15, 
                                                  verbose=1, 
                                                  mode = 'auto') 


class SaveBestModel(tf.keras.callbacks.Callback):
    def __init__(self):
        self.best = float('inf')

    def on_epoch_end(self, epoch, logs=None):
        metric_value = logs['val_loss']
        if metric_value < self.best:
            print(f"Loss Improved epoch:{epoch} from {self.best} to {metric_value}")
            self.best = metric_value
            self.model.encoder.save_weights("enc.h5")
            self.model.seq_decoder.save_weights("seq.h5")
            self.model.pos_decoder.save_weights("pos.h5")
            self.model.fusion.save_weights(f"fuse.h5")
            print("Saved Best Weights")
    def set_model(self, model):
        self.model = model
            
model_save=SaveBestModel()
model_save.set_model(model)
callbacks= [lr_reducer,early_stopping,model_save]

In [None]:
history=model.fit(train_ds,
                  epochs=EPOCHS,
                  steps_per_epoch=STEPS_PER_EPOCH,
                  verbose=1,
                  validation_data=eval_ds,
                  validation_steps=EVAL_STEPS, 
                  callbacks=callbacks)