In [1]:
!pip install -q Levenshtein

In [2]:
import os
if os.path.isdir('/content/drive/MyDrive'):
    os.makedirs('/content/drive/MyDrive/aslfr', exist_ok=True)
    os.chdir('/content/drive/MyDrive/aslfr')
else:
    os.makedirs('/content/aslfr', exist_ok=True)
    os.chdir('/content/aslfr')
print(os.getcwd())

/content/drive/MyDrive/aslfr


In [3]:
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold, GroupKFold, KFold
import gc
from tqdm.auto import tqdm
import Levenshtein
import time
import os

In [4]:
DEBUG = True #False for submission
N = -1 #-1 for all samples
MODEL_PATH = [
    'aslfr-fp16-192d-17l-ctcattjoint-seed42-fold0-last.h5',
    'aslfr-fp16-192d-17l-ctcattjoint-seed42-foldall-last.h5',
    'aslfr-fp16-192d-17l-ctcattjoint-seed43-foldall-last.h5',
    'aslfr-fp16-192d-17l-ctcattjoint-seed44-foldall-last.h5',
]

In [5]:
#NOTE: you should run KaggleDatasets.get_gcs_path(dataset_name) in the kaggle notebook to update gcs_path as they expires after several weeks..
#notebook: https://www.kaggle.com/hoyso48/aslfr-get-gcs-path/edit

GCS_PATH = {
            'aslfr':'gs://kds-1dadda248a69bd8cbc18044d03c2444a9593eed795d5f632a2307052',
            'aslfr-5fold':'gs://kds-bf210dd73d66268f4c9d4897567ab8b79267f25eab4aa5c501305eef',
            }

TRAIN_FILENAMES = tf.io.gfile.glob(GCS_PATH['aslfr-5fold']+'/*.tfrecords')
COMPETITION_PATH = GCS_PATH['aslfr']

print(len(TRAIN_FILENAMES))
!gsutil cp {COMPETITION_PATH}/train.csv .
!gsutil cp {COMPETITION_PATH}/character_to_prediction_index.json .

135
Copying gs://kds-1dadda248a69bd8cbc18044d03c2444a9593eed795d5f632a2307052/train.csv...
/ [1 files][  5.0 MiB/  5.0 MiB]                                                
Operation completed over 1 objects/5.0 MiB.                                      
Copying gs://kds-1dadda248a69bd8cbc18044d03c2444a9593eed795d5f632a2307052/character_to_prediction_index.json...
/ [1 files][  405.0 B/  405.0 B]                                                
Operation completed over 1 objects/405.0 B.                                      


In [6]:
SEL_COLS = [f'x_face_{i}' for i in range(468)] + [f'y_face_{i}' for i in range(468)] + [f'z_face_{i}' for i in range(468)] \
            + [f'x_left_hand_{i}' for i in range(21)] + [f'y_left_hand_{i}' for i in range(21)] + [f'z_left_hand_{i}' for i in range(21)] \
            + [f'x_right_hand_{i}' for i in range(21)] + [f'y_right_hand_{i}' for i in range(21)] + [f'z_right_hand_{i}' for i in range(21)] \
            + [f'x_pose_{i}' for i in range(33)] + [f'y_pose_{i}' for i in range(33)] + [f'z_pose_{i}' for i in range(33)]
print(len(SEL_COLS))

1629


In [7]:
import json
with open('./character_to_prediction_index.json') as json_file:
    CHAR_TO_NUM = json.load(json_file)
NUM_TO_CHAR = dict([(y+1,x) for x,y in CHAR_TO_NUM.items()] )
NUM_TO_CHAR[60] = 'S'
NUM_TO_CHAR[61] = 'E'
NUM_TO_CHAR[0] = 'P'

# LABEL_DICT

In [8]:
#for the lip_lr function. LEFT[i] is matching with RIGHT[i](i.e LEFT[i](x) == -RIGHT[i](x)).
#computed from https://github.com/google/mediapipe/blob/master/mediapipe/modules/face_geometry/data/canonical_face_model.obj

LEFT = [
         248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264,
         265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281,
         282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298,
         299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
         316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332,
         333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
         350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366,
         367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383,
         384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
         401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417,
         418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434,
         435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451,
         452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467,  #LFACE
         468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, #LHAND
         493, 494, 495, 497, 499, 501, 503, 505, 507, 509, 511, 513, #LPOSE
         515, 517, 519, 521, #LLEG
         ]

RIGHT = [
         3, 7, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
         39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
         60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
         81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 96, 97, 98, 99, 100, 101, 102,
         103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
         121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
         139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 153, 154, 155, 156, 157, 158,
         159, 160, 161, 162, 163, 165, 166, 167, 169, 170, 171, 172, 173, 174, 176, 177, 178, 179,
         180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 196, 198, 201,
         202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
         220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
         238, 239, 240, 241, 242, 243, 244, 245, 246, 247, #RFACE
        522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, #RHAND
        490, 491, 492, 496, 498, 500, 502, 504, 506, 508, 510, 512, #RPOSE
        514, 516, 518, 520, #RLEG
        ]

CENTRE = [
          0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 94, 151, 152, 164, 168, 175, 195, 197, 199, 200, #FACE
          489, #POSE
          ]

print(len(LEFT+RIGHT+CENTRE))

543


In [9]:
ROWS_PER_FRAME = 543
MAX_LEN = 768
CROP_LEN = MAX_LEN
NUM_CLASSES  = len(NUM_TO_CHAR.values()) #62
PAD = -100.

LHAND = np.arange(468, 489).tolist()
RHAND = np.arange(522, 543).tolist()
POINT_LANDMARKS = list(range(543))

NUM_NODES = len(POINT_LANDMARKS)
CHANNELS = 3*NUM_NODES

print(NUM_NODES)
print(CHANNELS)

def interp1d_(x, target_len, method='random'):
    length = tf.shape(x)[1]
    target_len = tf.maximum(1,target_len)
    if method == 'random':
        if tf.random.uniform(()) < 0.33:
            x = tf.image.resize(x, (target_len,tf.shape(x)[1]),'bilinear')
        else:
            if tf.random.uniform(()) < 0.5:
                x = tf.image.resize(x, (target_len,tf.shape(x)[1]),'bicubic')
            else:
                x = tf.image.resize(x, (target_len,tf.shape(x)[1]),'nearest')
    else:
        x = tf.image.resize(x, (target_len,tf.shape(x)[1]),method)
    return x

def tf_nan_mean(x, axis=0, keepdims=False):
    return tf.reduce_sum(tf.where(tf.math.is_nan(x), tf.zeros_like(x), x), axis=axis, keepdims=keepdims) / tf.reduce_sum(tf.where(tf.math.is_nan(x), tf.zeros_like(x), tf.ones_like(x)), axis=axis, keepdims=keepdims)

def tf_nan_std(x, center=None, axis=0, keepdims=False):
    if center is None:
        center = tf_nan_mean(x, axis=axis,  keepdims=True)
    d = x - center
    return tf.math.sqrt(tf_nan_mean(d * d, axis=axis, keepdims=keepdims))

def filter_nans_tf(x, ref_point=POINT_LANDMARKS):
    mask = tf.math.logical_not(tf.reduce_all(tf.math.is_nan(tf.gather(x,ref_point,axis=1)), axis=[-2,-1]))
    x = tf.boolean_mask(x, mask, axis=0)
    return x

def is_left_handed(x, left=LHAND, right=RHAND):
    lhand = tf.gather(x, left, axis=1)
    rhand = tf.gather(x, right, axis=1)
    lhand_nans = tf.reduce_sum(tf.cast(tf.math.is_nan(lhand), tf.int32))
    rhand_nans = tf.reduce_sum(tf.cast(tf.math.is_nan(rhand), tf.int32))
    return lhand_nans < rhand_nans

def flip_lr(x):
    x,y,z = tf.unstack(x, axis=-1)
    x = 1-x
    new_x = tf.stack([x,y,z], -1)
    new_x = tf.transpose(new_x, [1,0,2])
    l_x = tf.gather(new_x, LEFT, axis=0)
    r_x = tf.gather(new_x, RIGHT, axis=0)
    c_x = tf.gather(new_x, CENTRE, axis=0)
#     new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(left)[...,None], r_x) <-weird behavior in tflite!!!:(
#     new_x = tf.tensor_scatter_nd_update(new_x, tf.constant(right)[...,None], l_x)
    new_xr = tf.scatter_nd(tf.constant(LEFT)[...,None], r_x, tf.shape(new_x))
    new_xl = tf.scatter_nd(tf.constant(RIGHT)[...,None], l_x, tf.shape(new_x))
    new_xc = tf.scatter_nd(tf.constant(CENTRE)[...,None], c_x, tf.shape(new_x))
    new_x = new_xr + new_xl + new_xc
    new_x = tf.transpose(new_x, [1,0,2])
    return new_x

class Preprocess(tf.keras.layers.Layer):
    def __init__(self, max_len=MAX_LEN, point_landmarks=POINT_LANDMARKS, **kwargs):
        super().__init__(**kwargs)
        self.max_len = max_len
        self.point_landmarks = point_landmarks

    def call(self, inputs):
        # if tf.rank(inputs) == 3:
        #     x = inputs[None,...]
        # else:
        #     x = inputs
        x = inputs
        x = filter_nans_tf(x)
        x = tf.cond(is_left_handed(x), lambda:flip_lr(x), lambda:x)
        x = x[None,...]

        if self.max_len is not None:
            x = x[:,:self.max_len]
        length = tf.shape(x)[1]

        mean = tf_nan_mean(tf.gather(x, self.point_landmarks, axis=2), axis=[1,2], keepdims=True)
        mean = tf.where(tf.math.is_nan(mean), tf.constant([0.5,0.5,0.],x.dtype), mean)
        x = tf.gather(x, self.point_landmarks, axis=2) #N,T,P,C
        std = tf_nan_std(x, center=mean, axis=[1,2], keepdims=True)

        x = (x - mean)/std

        x = tf.concat([
            tf.reshape(x, (-1,length,3*len(self.point_landmarks))),
            # tf.reshape(dx, (-1,length,3*len(self.point_landmarks))),
        ], axis = -1)

        x = tf.where(tf.math.is_nan(x),tf.constant(0.,x.dtype),x)

        return x

543
1629


In [10]:
def decode_tfrec(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'coordinates': tf.io.FixedLenFeature([], tf.string),
        'phrase_encoded': tf.io.VarLenFeature(dtype=tf.int64),
        'phrase': tf.io.FixedLenFeature([], tf.string),
    })
    out = {}
    out['coordinates']  = tf.reshape(tf.io.decode_raw(features['coordinates'], tf.float32), (-1,3*ROWS_PER_FRAME))
    out['phrase'] = features['phrase']
    return out

ds = tf.data.TFRecordDataset([x for x in TRAIN_FILENAMES if 'fold0' in x], num_parallel_reads=tf.data.AUTOTUNE, compression_type='GZIP')
ds = ds.map(decode_tfrec, tf.data.AUTOTUNE)
for x in ds:
    frames = x['coordinates']
    phrase = x['phrase']
    break

In [11]:
class ECA(tf.keras.layers.Layer):
    def __init__(self, kernel_size=5, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.kernel_size = kernel_size
        self.conv = tf.keras.layers.Conv1D(1, kernel_size=kernel_size, strides=1, padding="same", use_bias=False)

    def call(self, inputs, mask=None):
        nn = tf.keras.layers.GlobalAveragePooling1D()(inputs, mask=mask)
        nn = tf.expand_dims(nn, -1)
        nn = self.conv(nn)
        nn = tf.squeeze(nn, -1)
        nn = tf.nn.sigmoid(nn)
        nn = nn[:,None,:]
        return inputs * nn

class LateDropout(tf.keras.layers.Layer):
    def __init__(self, rate, noise_shape=None, start_step=0, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.rate = rate
        self.start_step = start_step
        self.dropout = tf.keras.layers.Dropout(rate, noise_shape=noise_shape)

    def build(self, input_shape):
        super().build(input_shape)
        agg = tf.VariableAggregation.ONLY_FIRST_REPLICA
        self._train_counter = tf.Variable(0, dtype="int64", aggregation=agg, trainable=False)

    def call(self, inputs, training=False):
        x = tf.cond(self._train_counter < self.start_step, lambda:inputs, lambda:self.dropout(inputs, training=training))
        if training:
            self._train_counter.assign_add(1)
        return x

class MaskingConv1D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, groups=1, strides=1,
        dilation_rate=1,
        padding='same',
        use_bias=False,
        kernel_initializer='glorot_uniform',**kwargs):
        super().__init__(**kwargs)
        assert padding == 'same'
        self.filters = filter_dataset_eager_fallback
        self.strides = strides
        self.groups = groups
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.use_bias = use_bias
        self.padding = padding
        self.conv = tf.keras.layers.Conv1D(
                            filters,
                            kernel_size,
                            strides=strides,
                            groups=groups,
                            dilation_rate=dilation_rate,
                            padding=padding,
                            use_bias=use_bias,
                            kernel_initializer=kernel_initializer)
        self.supports_masking = True

    def compute_mask(self, inputs, mask=None):
      if mask is not None:
        if self.strides > 1:
          mask = mask[:,::self.strides]
      return mask

    def call(self, inputs, mask=None):
        x = inputs
        if mask is not None:
            x = tf.where(mask[...,None], x, tf.constant(0., dtype=x.dtype))
        x = self.conv(x)
        return x

class MaskingDWConv1D(tf.keras.layers.Layer):
    def __init__(self, kernel_size, strides=1,
        dilation_rate=1,
        padding='same',
        use_bias=False,
        kernel_initializer='glorot_uniform',**kwargs):
        super().__init__(**kwargs)
        assert padding == 'same'
        self.strides = strides
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.use_bias = use_bias
        self.padding = padding
        self.conv = tf.keras.layers.DepthwiseConv1D(
                            kernel_size,
                            strides=strides,
                            dilation_rate=dilation_rate,
                            padding=padding,
                            use_bias=use_bias,
                            kernel_initializer=kernel_initializer)
        self.supports_masking = True

    def compute_mask(self, inputs, mask=None):
      if mask is not None:
        if self.strides > 1:
          mask = mask[:,::self.strides]
      return mask

    def call(self, inputs, mask=None):
        x = inputs
        if mask is not None:
            x = tf.where(mask[...,None], x, tf.constant(0., dtype=x.dtype))
        x = self.conv(x)
        return x

def Conv1DBlock(channel_size,
          kernel_size,
          dilation_rate=1,
          strides=1,
          drop_rate=0.0,
          expand_ratio=2,
          se_ratio=0.25,
          activation='swish',
          name=None):
    '''
    efficient conv1d block, @hoyso48
    '''
    if name is None:
        name = str(tf.keras.backend.get_uid("mbblock"))
    # Expansion phase
    def apply(inputs):
        channels_in = tf.keras.backend.int_shape(inputs)[-1]
        channels_expand = channels_in * expand_ratio

        skip = inputs

        x = tf.keras.layers.BatchNormalization(momentum=0.95, name=name + 'pre_bn')(inputs)

        x = tf.keras.layers.Dense(
            channels_expand,
            use_bias=True,
            activation=activation,
            name=name + '_expand_conv')(x)

        # Depthwise Convolution
        x = MaskingDWConv1D(kernel_size,
            dilation_rate=dilation_rate,
            strides=strides,
            use_bias=False,
            name=name + '_dwconv')(x)

        x = tf.keras.layers.BatchNormalization(momentum=0.95, name=name + 'conv_bn')(x)

        x  = ECA()(x)

        x = tf.keras.layers.Dense(
            channel_size,
            use_bias=True,
            name=name + '_project_conv')(x)

        if drop_rate > 0:
            x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None,1,1), name=name + '_drop')(x)

        if (channels_in == channel_size) and (strides == 1):
            x = tf.keras.layers.add([x, skip], name=name + '_add')
        return x

    return apply

In [12]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, dim=256, num_heads=4, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.scale = self.dim ** -0.5
        self.num_heads = num_heads
        # self.qkv = tf.keras.layers.Dense(3 * dim, use_bias=False)
        self.q = tf.keras.layers.Dense(dim, use_bias=False)
        self.k = tf.keras.layers.Dense(dim, use_bias=False)
        self.v = tf.keras.layers.Dense(dim, use_bias=False)
        self.drop1 = tf.keras.layers.Dropout(dropout)
        self.proj = tf.keras.layers.Dense(dim, use_bias=False)
        self.supports_masking = True

    def get_causal_mask(self, q, k):
        q_len = tf.shape(q)[1]
        k_len = tf.shape(k)[1]
        i = tf.range(q_len)[:, None]
        j = tf.range(k_len)
        mask = i >= j
        mask = tf.reshape(mask, (q_len, k_len))
        return mask

    def merge_input_state(self, input, state, layer):
        if input is not None and state is not None:
            return tf.keras.layers.Concatenate(axis=1)([state, layer(input)])
        elif input is not None and state is None:
            return layer(input)
        elif input is None and state is not None:
            return state
        else:
            raise ValueError
        # return out

    def call(self, q, k=None, v=None, key_state=None, value_state=None, return_states=False, use_causal_mask=False):
        q = self.q(q)
        k = self.merge_input_state(k, key_state, self.k)
        v = self.merge_input_state(v, value_state, self.v)
        mask = getattr(k, '_keras_mask', None)
        if mask is not None:
            mask = mask[:,None,None,:]
        if use_causal_mask:
            if mask is not None:
                mask = tf.logical_and(mask, self.get_causal_mask(q,k)[None,None,:,:])
            else:
                mask = self.get_causal_mask(q,k)[None,None,:,:]
        q_ = tf.keras.layers.Permute((2, 1, 3))(tf.keras.layers.Reshape((-1, self.num_heads, self.dim // self.num_heads))(q))
        k_ = tf.keras.layers.Permute((2, 1, 3))(tf.keras.layers.Reshape((-1, self.num_heads, self.dim // self.num_heads))(k))
        v_ = tf.keras.layers.Permute((2, 1, 3))(tf.keras.layers.Reshape((-1, self.num_heads, self.dim // self.num_heads))(v))
        attn = tf.matmul(q_, k_, transpose_b=True) * self.scale

        attn = tf.keras.layers.Softmax(axis=-1)(attn, mask=mask)
        attn = self.drop1(attn)

        x = attn @ v_
        x = tf.keras.layers.Reshape((-1, self.dim))(tf.keras.layers.Permute((2, 1, 3))(x))
        x = self.proj(x)
        if return_states:
            return x, k, v
        else:
            return x

class PosEmbedding(tf.keras.layers.Layer):
    def __init__(self, dim=64, max_len=64, **kwargs):
        super().__init__(**kwargs)
        self.pos_emb = tf.keras.layers.Embedding(input_dim=max_len, output_dim=dim)
        self.supports_masking = True

    def call(self, x, positions=None):
        if positions is None:
            maxlen = tf.shape(x)[1]
            positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        return x + positions

def TransformerDecoderBlock(dim=256, num_heads=4, expand=4, attn_dropout=0.2, drop_rate=0.2, activation='swish', name=''):
    def apply(q,k,v):
        x = q
        # key_mask=None
        x = tf.keras.layers.BatchNormalization(momentum=0.95, name=name + '_bn1')(x)
        x = MultiHeadAttention(dim=dim,num_heads=num_heads,dropout=attn_dropout, name=name + '_self_attn')(x,x,x,use_causal_mask=True)
        # print(x.shape, q.shape)
        # x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None,1,1))(x)
        x = tf.keras.layers.Add(name=name + '_add1')([q, x])
        attn_out1 = x

        x = tf.keras.layers.BatchNormalization(momentum=0.95, name=name + '_bn2')(x)
        x = MultiHeadAttention(dim=dim,num_heads=num_heads,dropout=attn_dropout, name=name + '_cross_attn')(x,k,v)
        # x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None,1,1))(x)
        x = tf.keras.layers.Add(name=name + '_add2')([attn_out1, x])
        attn_out2 = x

        x = tf.keras.layers.BatchNormalization(momentum=0.95, name=name + '_bn3')(x)
        x = tf.keras.layers.Dense(dim*expand, use_bias=False, activation=activation, name=name + '_fc1')(x)
        x = tf.keras.layers.Dense(dim, use_bias=False, name=name + '_fc2')(x)
        # x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None,1,1))(x)
        x = tf.keras.layers.Add(name=name + '_add3')([attn_out2, x])
        return x
    return apply

In [13]:
class MultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, dim=256, num_heads=4, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.scale = self.dim ** -0.5
        self.num_heads = num_heads
        self.qkv = tf.keras.layers.Dense(3 * dim, use_bias=False)
        self.drop1 = tf.keras.layers.Dropout(dropout)
        self.proj = tf.keras.layers.Dense(dim, use_bias=False)
        self.supports_masking = True

    def call(self, inputs, mask=None):
        qkv = self.qkv(inputs)
        qkv = tf.keras.layers.Permute((2, 1, 3))(tf.keras.layers.Reshape((-1, self.num_heads, self.dim * 3 // self.num_heads))(qkv))
        q, k, v = tf.split(qkv, [self.dim // self.num_heads] * 3, axis=-1)

        attn = tf.matmul(q, k, transpose_b=True) * self.scale

        if mask is not None:
            mask = mask[:, None, None, :]
            # print('selfattn mask', mask.shape)

        attn = tf.keras.layers.Softmax(axis=-1)(attn, mask=mask)
        attn = self.drop1(attn)

        x = attn @ v
        x = tf.keras.layers.Reshape((-1, self.dim))(tf.keras.layers.Permute((2, 1, 3))(x))
        x = self.proj(x)
        return x

def TransformerBlock(dim=256, num_heads=4, expand=4, attn_dropout=0.2, drop_rate=0.2, activation='swish'):
    def apply(inputs):
        x = inputs
        x = tf.keras.layers.BatchNormalization(momentum=0.95)(x)
        x = MultiHeadSelfAttention(dim=dim,num_heads=num_heads,dropout=attn_dropout)(x)
        x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None,1,1))(x)
        x = tf.keras.layers.Add()([inputs, x])
        attn_out = x

        x = tf.keras.layers.BatchNormalization(momentum=0.95)(x)
        x = tf.keras.layers.Dense(dim*expand, use_bias=False, activation=activation)(x)
        x = tf.keras.layers.Dense(dim, use_bias=False)(x)
        x = tf.keras.layers.Dropout(drop_rate, noise_shape=(None,1,1))(x)
        x = tf.keras.layers.Add()([attn_out, x])
        return x
    return apply

In [14]:
def get_model(max_len=MAX_LEN, target_len=64, dim=192, dtype='float32'):
    ################# ENCODER #################
    inp1 = tf.keras.Input((max_len,CHANNELS),dtype=dtype)
#     x = tf.keras.layers.Masking(mask_value=PAD,input_shape=(max_len,CHANNELS))(inp1)
    x = inp1
    ksize = 17
    drop_rate = 0.2
    x = tf.keras.layers.Dense(dim,use_bias=False,name='stem_conv')(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = TransformerBlock(dim,expand=2,num_heads=4,drop_rate=drop_rate,attn_dropout=0.2)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = TransformerBlock(dim,expand=2,num_heads=4,drop_rate=drop_rate,attn_dropout=0.2)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=0,strides=2)(x) #drop_rate=0 since we don't want to drop the whole output
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = TransformerBlock(dim,expand=2,num_heads=4,drop_rate=drop_rate,attn_dropout=0.2)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = Conv1DBlock(dim,ksize,expand_ratio=4,drop_rate=drop_rate)(x)
    x = TransformerBlock(dim,expand=2,num_heads=4,drop_rate=drop_rate,attn_dropout=0.2)(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.95)(x)

    encoder = tf.keras.Model(inp1,x,name='encoder')

    ################# CTC DECDODER #################
    inp3 = tf.keras.Input((x.shape[1],dim),name='ctc_decoder_inp2',dtype=dtype)
    x = inp3
    x = tf.keras.layers.RNN(tf.keras.layers.GRUCell(dim), return_sequences=True)(x)
    x = tf.keras.layers.Dense(dim*2)(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(NUM_CLASSES,name='ctc_classifier')(x) #include sos, eos token
    ctc_decoder = tf.keras.Model(inp3,x,name='ctc_decoder')

    ################# ATT DECODER #################
    inp2 = tf.keras.Input((None,),name='att_decoder_inp1',dtype='int32')
    inp3 = tf.keras.Input((x.shape[1],dim),name='att_decoder_inp2',dtype=dtype)

    x = inp3
#     y = tf.keras.layers.Masking(mask_value=0,input_shape=(None,),name='att_decoder_input_masking')(inp2)
    y = inp2
    y = tf.keras.layers.Embedding(NUM_CLASSES,dim,name='att_decoder_token_emb')(y) #include sos token
    y = PosEmbedding(dim,max_len=target_len,name='att_decoder_pos_emb')(y)
    y = TransformerDecoderBlock(dim,expand=2,num_heads=4,attn_dropout=0.2,name='att_decoder_block1')(y,x,x)
    y = tf.keras.layers.Dropout(0.5)(y)
    y = tf.keras.layers.Dense(NUM_CLASSES,name='att_decoder_classifier')(y)

    decoder = tf.keras.Model([inp2,inp3],y,name='att_decoder')

    ################### MODEL #####################
    inp1 = tf.keras.Input((max_len,CHANNELS),dtype=dtype)
    inp2 = tf.keras.Input((None,),dtype='int32')

    x = inp1
    enc_out = encoder(x)
    y = inp2
    dec_out = decoder([y, enc_out])
    ctc_out = ctc_decoder(enc_out)
    model = tf.keras.Model([inp1,inp2], [dec_out,ctc_out])

    return model

model = get_model()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 768, 1629)]  0           []                               
                                                                                                  
 input_3 (InputLayer)           [(None, None)]       0           []                               
                                                                                                  
 encoder (Functional)           (None, 384, 192)     5565377     ['input_2[0][0]']                
                                                                                                  
 att_decoder (Functional)       (None, None, 62)     480830      ['input_3[0][0]',                
                                                                  'encoder[0][0]']            

In [15]:
model_list = [get_model() for _ in MODEL_PATH]

for model, path in zip(model_list, MODEL_PATH):
    model.load_weights(path)

In [16]:
class CTCGreedyDecoder(tf.keras.layers.Layer):
    def __init__(self, model, pad_token_idx=0, **kwargs):
        super().__init__(**kwargs)
        self.encoder = model.get_layer('encoder')
        self.ctc_decoder = model.get_layer('ctc_decoder')
        self.pad_token_idx = pad_token_idx

    def decode_phrase(self, pred):
        x = tf.argmax(pred, axis=1, output_type=tf.int32)
        diff = tf.not_equal(x[:-1], x[1:])
        adjacent_indices = tf.where(diff)[:, 0]
        x = tf.gather(x, adjacent_indices)
        mask = x != self.pad_token_idx
        x = tf.boolean_mask(x, mask, axis=0)
        return x

    def call(self, batch_x):
        encoder_out = self.encoder(batch_x)
        ctc_probs = self.ctc_decoder(encoder_out)
        return tf.identity([self.decode_phrase(ctc_probs[0])])

In [17]:
class ATTGreedyDecoder(tf.keras.layers.Layer):
    def __init__(self, model, max_output_length=64, input_strides=2, sos_token_idx=60, eos_token_idx=61, pad_token_idx=0, **kwargs):
        super().__init__(**kwargs)
        self.model = model
        self.encoder = self.model.get_layer('encoder')
        self.decoder = self.model.get_layer('att_decoder')
        self.max_output_length = max_output_length
        self.sos_token_idx = sos_token_idx
        self.eos_token_idx = eos_token_idx
        self.pad_token_idx = pad_token_idx
        self.input_strides = input_strides

    def att_inference_module(self, query, query_position, key_state, value_state, encoder_key_state, encoder_value_state):
        x = self.decoder.get_layer('att_decoder_inp1')(query)
        x = self.decoder.get_layer('att_decoder_token_emb')(x)
        x = self.decoder.get_layer('att_decoder_pos_emb')(x, positions=query_position)

        q = x
        x = self.decoder.get_layer('att_decoder_block1_bn1')(x)
        x, k, v = self.decoder.get_layer('att_decoder_block1_self_attn')(x, x, x, key_state=key_state, value_state=value_state, return_states=True)
        x = self.decoder.get_layer('att_decoder_block1_add1')([q,x])
        attn_out1 = x

        x = self.decoder.get_layer('att_decoder_block1_bn2')(x)
        x = self.decoder.get_layer('att_decoder_block1_cross_attn')(x, None, None, key_state=encoder_key_state, value_state=encoder_value_state)
        x = self.decoder.get_layer('att_decoder_block1_add2')([attn_out1,x])
        attn_out2 = x

        x = self.decoder.get_layer('att_decoder_block1_bn3')(x)
        x = self.decoder.get_layer('att_decoder_block1_fc1')(x)
        x = self.decoder.get_layer('att_decoder_block1_fc2')(x)
        x = self.decoder.get_layer('att_decoder_block1_add3')([attn_out2,x])
        out = self.decoder.get_layer('att_decoder_classifier')(x)
        return out, k, v

    def compute_input_length(self, batch_x):
        input_length = tf.cast(tf.shape(batch_x)[1], tf.float32)
        input_length = tf.math.ceil(input_length/self.input_strides)
        return tf.cast(input_length, tf.int32)

    def call(self, batch_x):
        encoder_out = self.encoder(batch_x)
        input_length = self.compute_input_length(batch_x)
        encoder_key_state = self.decoder.get_layer('att_decoder_block1_cross_attn').k(encoder_out)
        encoder_value_state = self.decoder.get_layer('att_decoder_block1_cross_attn').v(encoder_out)

        time = tf.constant(0, dtype=tf.int32)
        predictions = tf.ones((tf.shape(batch_x)[0],1), dtype=tf.int32) * self.sos_token_idx
        pad = tf.ones((tf.shape(batch_x)[0],), dtype=tf.int32) * self.pad_token_idx
        init = True
        key_state = tf.zeros((0,0,192))
        value_state = tf.zeros((0,0,192))

        def condition(_time, predictions, key_state, value_state, init):
            return tf.logical_and(tf.logical_and(_time < self.max_output_length, tf.logical_not(tf.reduce_all(tf.reduce_any(predictions==self.eos_token_idx, axis=1)))), tf.reduce_any(_time < input_length))

        def body(_time, predictions, key_state, value_state, init):
            if init:
                out, key_state, value_state = self.att_inference_module(predictions[:,-1:], _time, None, None, encoder_key_state, encoder_value_state)
                init = False
            else:
                out, key_state, value_state = self.att_inference_module(predictions[:,-1:],  _time, key_state, value_state, encoder_key_state, encoder_value_state)
            pred_curr = tf.where(tf.logical_or(tf.reduce_any(predictions==self.eos_token_idx, axis=1), _time >= input_length), pad, tf.argmax(out[:,-1], axis=-1, output_type=tf.int32))
            predictions = tf.concat([predictions, pred_curr[...,None]], axis=1)
            return _time+1, predictions, key_state, value_state, init

        _, predictions, _, _, _ = tf.while_loop(condition, body,
                                                shape_invariants=[tf.TensorShape([]),
                                                                  tf.TensorShape([None,None]),
                                                                  tf.TensorShape([None,None,192]),
                                                                  tf.TensorShape([None,None,192]),
                                                                  tf.TensorShape([])],
                                                loop_vars=[time, predictions, key_state, value_state, init])
        return predictions

In [18]:
def get_ctc_initial_states(log_probs, blank_idx=0):

    blank_probs = log_probs[...,blank_idx]
    states_n = tf.ones_like(blank_probs, dtype=tf.float32) * tf.float32.min
    states_b = tf.math.cumsum(blank_probs)

    return states_n, states_b

def compute_ctc_prefix_scores(beams, log_probs, states_n, states_b, eos_idx=61, blank_idx=0):
    # beams: (N=hypothesis_length)
    # probs: (L=(padded/strided)input_length,M=num_models,V=vocab_size)
    # states_n: (L,M)
    # states_b: (L,M)

    N = tf.shape(beams)[0]
    L = tf.shape(states_n)[0]
    V = tf.shape(log_probs)[-1]
    M = tf.shape(states_n)[1]
    new_states_n = tf.ones((L,M,V), dtype=tf.float32) * tf.float32.min
    new_states_b = tf.ones((L,M,V), dtype=tf.float32) * tf.float32.min
    new_states_n = tf.cond(N==1, lambda:log_probs, lambda:new_states_n)

    r_sum = tf.math.reduce_logsumexp([states_n, states_b], axis=0) #(B,N)
    last = beams[-1] #(1,)

    repeated_idx = last

    log_phi_ = tf.repeat(r_sum[None,...], repeats=V, axis=0) #(V,L,M)
    log_phi = tf.tensor_scatter_nd_update(log_phi_, [[repeated_idx]], [states_b])
    log_phi = tf.transpose(log_phi, (1,2,0)) #(L,M,V)

    log_phi = tf.cond(N==1, lambda:tf.transpose(log_phi_, (1,2,0)), lambda:log_phi)

    def step_function(prev, inputs):
        prev_r_n, prev_r_b = prev
        current_log_phi, current_prob = inputs
        updated_r_n = tf.math.reduce_logsumexp([prev_r_n, current_log_phi], axis=0) + current_prob
        updated_r_b = tf.math.reduce_logsumexp([prev_r_b, prev_r_n], axis=0) + current_prob[...,blank_idx][...,None]
        return updated_r_n, updated_r_b

    start = 1
    log_psi = new_states_n[start-1]

    sequence_log_phi = log_phi[start-1:L-1]
    sequence_probs = log_probs[start-1+N:L-1+N]
    sequences = (sequence_log_phi, sequence_probs) #((L-start,M,V), (L-start,M,V))

    initial_state = (new_states_n[start-1], new_states_b[start-1]) #((M,V),(M,V),(M,V))

    log_psi = tf.math.reduce_logsumexp([tf.math.reduce_logsumexp(sequence_log_phi + sequence_probs, axis=0), log_psi], axis=0)

    new_states_n, new_states_b = tf.scan(step_function, sequences, initial_state)

    log_psi_eos = r_sum[-1]
    model_idx = tf.range(M)
    eos_idxs = tf.stack([model_idx, tf.fill((M,), eos_idx)], axis=-1)
    blank_idxs = tf.stack([model_idx, tf.fill((M,), blank_idx)], axis=-1)
    log_psi = tf.tensor_scatter_nd_update(log_psi, eos_idxs, log_psi_eos)
    log_psi = tf.tensor_scatter_nd_update(log_psi, blank_idxs, tf.fill((M,), tf.float32.min))

    return log_psi, new_states_n, new_states_b #(M,V), (L,M,V), (L,M,V)


class EnsembleCTCAttentionJointGreedyDecoder(tf.keras.layers.Layer):
    def __init__(self, model_list, ctc_weight=0.2, input_strides=2, max_output_length=64, blank_idx=0, pad_frame_idx=-100, sos_token_idx=60, eos_token_idx=61, pad_token_idx=0, from_logits=True, **kwargs):
        super().__init__(**kwargs)
        self.encoder_list = [m.get_layer('encoder') for m in model_list]
        self.decoder_list = [m.get_layer('att_decoder') for m in model_list]
        self.ctc_decoder_list = [m.get_layer('ctc_decoder') for m in model_list]
        self.ctc_weight = ctc_weight
        self.input_strides = input_strides
        self.max_output_length = max_output_length
        self.blank_idx = blank_idx
        self.pad_frame_idx = pad_frame_idx
        self.sos_token_idx = sos_token_idx
        self.eos_token_idx = eos_token_idx
        self.pad_token_idx = pad_token_idx
        self.from_logits = from_logits

    def compute_input_length(self, batch_x):
        input_length = tf.cast(tf.shape(batch_x)[1], tf.float32)#tf.reduce_sum(tf.cast(mask, tf.float32), axis=-1)
        input_length = tf.math.ceil(input_length/self.input_strides)
        return tf.cast(input_length, tf.int32)

    def att_inference_module(self, query, query_position, key_state_list, value_state_list, encoder_key_state_list, encoder_value_state_list):
        outputs = []
        key_states = []
        value_states = []
        for i in range(len(self.decoder_list)):
            decoder = self.decoder_list[i]
            key_state = key_state_list[i] if key_state_list is not None else None
            value_state = value_state_list[i] if value_state_list is not None else None
            encoder_key_state = encoder_key_state_list[i] if encoder_key_state_list is not None else None
            encoder_value_state = encoder_value_state_list[i] if encoder_value_state_list is not None else None
            x = decoder.get_layer('att_decoder_inp1')(query)
            x = decoder.get_layer('att_decoder_token_emb')(x)
            x = decoder.get_layer('att_decoder_pos_emb')(x, positions=query_position)

            q = x
            x = decoder.get_layer('att_decoder_block1_bn1')(x)
            x, k, v = decoder.get_layer('att_decoder_block1_self_attn')(x, x, x, key_state=key_state, value_state=value_state, return_states=True)
            x = decoder.get_layer('att_decoder_block1_add1')([q,x])
            attn_out1 = x

            x = decoder.get_layer('att_decoder_block1_bn2')(x)
            x = decoder.get_layer('att_decoder_block1_cross_attn')(x, None, None, key_state=encoder_key_state, value_state=encoder_value_state)
            x = decoder.get_layer('att_decoder_block1_add2')([attn_out1,x])
            attn_out2 = x

            x = decoder.get_layer('att_decoder_block1_bn3')(x)
            x = decoder.get_layer('att_decoder_block1_fc1')(x)
            x = decoder.get_layer('att_decoder_block1_fc2')(x)
            x = decoder.get_layer('att_decoder_block1_add3')([attn_out2,x])
            out = decoder.get_layer('att_decoder_classifier')(x)
            outputs.append(out)
            key_states.append(k)
            value_states.append(v)
        return tf.identity(outputs), tf.identity(key_states), tf.identity(value_states)

    def get_initial_states(self, batch_x):
        encoder_outputs = [enc(batch_x) for enc in self.encoder_list]
        encoder_key_states = [dec.get_layer('att_decoder_block1_cross_attn').k(x) for dec, x in zip(self.decoder_list, encoder_outputs)]
        encoder_value_states = [dec.get_layer('att_decoder_block1_cross_attn').v(x) for dec, x in zip(self.decoder_list, encoder_outputs)]
        key_states = [tf.zeros((0,0,192)) for _ in self.encoder_list]
        value_states = [tf.zeros((0,0,192)) for _ in self.encoder_list]
        ctc_probs = [dec(x)[0] for dec,x in zip(self.ctc_decoder_list, encoder_outputs)]

        encoder_key_states = tf.stack(encoder_key_states)
        encoder_value_states = tf.stack(encoder_value_states)
        key_states = tf.stack(key_states)
        value_states = tf.stack(value_states)
        encoder_outputs = tf.stack(encoder_outputs)

        if self.from_logits:
            ctc_probs = [tf.nn.softmax(x, axis=-1) for x in ctc_probs]
        ctc_probs = tf.stack([tf.math.log(x) for x in ctc_probs], axis=1)
        ctc_states_n, ctc_states_b = get_ctc_initial_states(ctc_probs, self.blank_idx)
        return encoder_key_states, encoder_value_states, key_states, value_states, ctc_probs, ctc_states_n, ctc_states_b

    def call(self, batch_x):

        encoder_key_state, encoder_value_state, key_state, value_state, ctc_log_probs, ctc_states_n, ctc_states_b = self.get_initial_states(batch_x)
        input_length = self.compute_input_length(batch_x)

        time = tf.constant(0, dtype=tf.int32)
        predictions = tf.ones((tf.shape(batch_x)[0],1), dtype=tf.int32) * self.sos_token_idx#tf.TensorArray(dtype=tf.int32,size=self.max_output_length)
        pad = tf.ones((tf.shape(batch_x)[0],), dtype=tf.int32) * self.pad_token_idx
        init = True

        def condition(_time, predictions, ctc_states_n, ctc_states_b, key_state, value_state, init):
            return tf.logical_and(_time < tf.minimum(self.max_output_length, input_length), tf.logical_not(tf.reduce_all(tf.reduce_any(predictions==self.eos_token_idx, axis=1))))

        def body(_time, predictions, ctc_states_n, ctc_states_b, key_state, value_state, init):
            if init:
                out, key_state, value_state = self.att_inference_module(predictions[:,-1:], _time, None, None, encoder_key_state, encoder_value_state)
                init = False
            else:
                out, key_state, value_state = self.att_inference_module(predictions[:,-1:],  _time, key_state, value_state, encoder_key_state, encoder_value_state)

            log_ctc, new_ctc_states_n, new_ctc_states_b = compute_ctc_prefix_scores(predictions[0],
                                                                                   ctc_log_probs,
                                                                                   ctc_states_n,
                                                                                   ctc_states_b,
                                                                                   self.eos_token_idx,
                                                                                   self.blank_idx)
            log_ctc = tf.reduce_mean(log_ctc, axis=0) #log-prob ensemble
            out = out[:,0,0] #(M,V)
            if self.from_logits:
                out = tf.nn.softmax(out, axis=-1)
            out = tf.math.log(out)
            log_att = tf.reduce_mean(out, axis=0) #log-prob ensemble

            probs_final = (1-self.ctc_weight) * log_att + self.ctc_weight * log_ctc #tf.expand_dims(log_psi, axis=0)
            next_token = tf.argmax(probs_final, axis=-1, output_type=tf.int32)#[0]

            ctc_states_n = new_ctc_states_n[...,next_token]
            ctc_states_b = new_ctc_states_b[...,next_token]

            predictions = tf.concat([predictions, [next_token[...,None]]], axis=1)
            return _time+1, predictions, ctc_states_n, ctc_states_b, key_state, value_state, init

        _, predictions, _, _, _, _, _ = tf.while_loop(condition, body,
                                                shape_invariants=[tf.TensorShape([]),
                                                                  tf.TensorShape([1,None]),
                                                                  tf.TensorShape([None,len(self.encoder_list)]),
                                                                  tf.TensorShape([None,len(self.encoder_list)]),
                                                                  tf.TensorShape([len(self.encoder_list),None,None,None]),
                                                                  tf.TensorShape([len(self.encoder_list),None,None,None]),
                                                                  tf.TensorShape([])],
                                                loop_vars=[time, predictions, ctc_states_n, ctc_states_b, key_state, value_state, init])
        return predictions

In [19]:
class TFLiteModel(tf.Module):
    def __init__(self, model):
        super(TFLiteModel, self).__init__()
        self.model = model
        self.preprocess = Preprocess()

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, len(SEL_COLS)], dtype=tf.float32, name='inputs')])
    def __call__(self, inputs, training=False):
        # Preprocess Data
        x = tf.transpose(tf.reshape(inputs, (-1,3,543)), (0,2,1))
        x = tf.cond(tf.shape(x)[0] == 0, lambda: tf.zeros((1, 543, 3),dtype=tf.float32), lambda: tf.identity(x))
        x = self.preprocess(x)
        x = self.model(x)[0]
        x = x - 1
        idxs = tf.where((0<=x) & (x<=58))[...,0]
        x = tf.gather(x, idxs)
        x = tf.cond(tf.shape(x)[0] == 0, lambda: tf.zeros(1, tf.int32), lambda: tf.identity(x))
        x = tf.one_hot(x, 59)
        return {'outputs': x}

In [20]:
tf.keras.backend.clear_session()
gc.collect()

21671

In [21]:
tflitemodel_base = TFLiteModel(CTCGreedyDecoder(model_list[0]))
# tflitemodel_base = TFLiteModel(ATTGreedyDecoder(model_list[0]))
# tflitemodel_base = TFLiteModel(EnsembleCTCAttentionJointGreedyDecoder(model_list[:3], ctc_weight=0.3))
with open ("character_to_prediction_index.json", "r") as f:
    character_map = json.load(f)
rev_character_map = {j:i for i,j in character_map.items()}
pred = tflitemodel_base(frames)["outputs"].numpy().argmax(-1)
''.join([rev_character_map[x] for x in pred]), phrase.numpy().decode('utf-8')

('100-407-5928', '100-407-5928')

In [22]:
keras_model_converter = tf.lite.TFLiteConverter.from_keras_model(tflitemodel_base)
keras_model_converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]#, tf.lite.OpsSet.SELECT_TF_OPS]
keras_model_converter._experimental_default_to_single_batch_in_tensor_list_ops = True
keras_model_converter.optimizations = [tf.lite.Optimize.DEFAULT]
keras_model_converter.target_spec.supported_types = [tf.float16]
tflite_model = keras_model_converter.convert()
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

with open('inference_args.json', "w") as f:
    json.dump({"selected_columns" : SEL_COLS}, f)

!zip submission.zip  './model.tflite' './inference_args.json'



updating: model.tflite (deflated 9%)
updating: inference_args.json (deflated 84%)


In [23]:
interpreter = tf.lite.Interpreter("model.tflite")

REQUIRED_SIGNATURE = "serving_default"
REQUIRED_OUTPUT = "outputs"

prediction_fn = interpreter.get_signature_runner(REQUIRED_SIGNATURE)

In [24]:
if DEBUG:

    fold = 0
    # pqfiles = df[df['fold']==fold].file_id.unique()
    N = N #100

    test_dataset = ds.map(lambda x:(x['coordinates'],x['phrase'])).prefetch(tf.data.AUTOTUNE).take(N)

In [25]:
from Levenshtein import distance

def competition_metric(true, pred):
    D = sum([distance(x, y) for x, y in zip(true, pred)])
    N = len(''.join(true))
    return max((N-D)/N, 0.), D/len(true)

if DEBUG:
    true = []
    pred = []

    start_time = time.perf_counter()

    loop_durations = []
    for frame, target in tqdm(test_dataset):
        loop_start = time.perf_counter()

        try:
            output = prediction_fn(inputs=frame)
            prediction_str = "".join([rev_character_map.get(s, "") for s in np.argmax(output[REQUIRED_OUTPUT], axis=1)])
            target = target.numpy().decode("utf-8")
            true.append(target)
            pred.append(prediction_str)
        except Exception as e:
            print(e)
            break

        loop_durations.append(time.perf_counter() - loop_start)

    total_duration = time.perf_counter() - start_time

    avg_loop_duration = sum(loop_durations) / len(loop_durations) if loop_durations else 0
    throughput = 1 / avg_loop_duration if avg_loop_duration != 0 else 0  # throughput

    model_size = os.path.getsize("model.tflite") / (1024 * 1024)  # MB

    print(true[:5])
    print(pred[:5])
    print(competition_metric(true, pred))

    print(f"\n---- Execution Metrics ----")
    print(f"Latency (ms per sample): {avg_loop_duration * 1000:.2f}")  # ms
    print(f"Throughput (samples per second): {throughput:.2f}")
    print(f"Model File Size (MB): {model_size:.2f}")

0it [00:00, ?it/s]

['100-407-5928', '+86-6197-6479-5413-52', 'bruce peterson', '844-233-4773', '1288 mccoys creek road']
['100-407-5928', '+82-6107-6479-541-52', 'care parson', '844-233-4773', '12 ncoys creek road']
(0.799991698489125, 3.5914138779160765)

---- Execution Metrics ----
Latency (ms per sample): 53.53
Throughput (samples per second): 18.68
Model File Size (MB): 11.41
