# Demonstrate loading and using ERNIE4us

## Demo loading the ERNIE2 model and preparing inputs

In [None]:
import numpy as np
import dataclasses
import os
import sys
import logging
import tensorflow as tf
from ernie4us import *

In [None]:
tf.__version__

In [None]:
ernie_input_builder, ernie_tf_inputs, ernie_tf_outputs = load_ernie_model(ERNIE_BASE_EN,
    "./model_artifacts")

In [None]:
session = tf.compat.v1.Session(graph=ernie_tf_inputs.token_ids.graph)

In [None]:
initializers = [i for i in [tf.compat.v1.global_variables_initializer(),
                            tf.compat.v1.local_variables_initializer()] if i]
session.run(initializers)

In [None]:
for i, op in enumerate(session.graph.get_operations()):
    if op.type == 'Const':
        tensor = session.graph.get_tensor_by_name(op.name + ':0')
        print(i, '\t', tensor, tensor.op.type)

## Demo ERNIE tensorflow usages

In [None]:
text_a = "ERNIE for the rest of us"
text_b = "thank you ERNIE"

record = ernie_input_builder.build(text_a, text_b,task_id=0)
for key, item in dataclasses.asdict(record).items():
    print(f'{key}', item.shape, '=>', np.reshape(item, (512,))[:20])

In [None]:
def run_model(session, record, fetches):
    return session.run(
        fetches, 
        feed_dict={
            ernie_tf_inputs.token_ids: record.token_ids,
            ernie_tf_inputs.sentence_ids: record.sentence_ids,
            ernie_tf_inputs.position_ids: record.position_ids,
            ernie_tf_inputs.task_ids: record.task_ids,
            ernie_tf_inputs.input_mask: record.input_mask})


In [None]:
# note: dataclasses.astuple() caused errors on field values of type tf.Tensor
fetches = [ernie_tf_outputs.sequence_features, ernie_tf_outputs.classification_features]

In [None]:
import time

with session.as_default():
    print(session.graph.get_tensor_by_name('src_ids:0'))
    initializers = [i for i in [tf.compat.v1.global_variables_initializer(),
                                tf.compat.v1.local_variables_initializer()] if i]
    session.run(initializers)
    print('n. local vars:', len(tf.local_variables()), 'n. global vars:', len(tf.global_variables()))
    ernie_sequence_features, ernie_classification_features = run_model(session, record, fetches)
    print(ernie_classification_features.shape, ernie_sequence_features.shape)
    st = time.time()
    n_times = 20
    for _ in range(n_times):
        run_model(session, record, fetches)
    dt = time.time() - st
    print(f"finished in {dt}s. avg {dt / n_times}s/request")

In [None]:
with session.as_default():
    num_class = 3
    with tf.variable_scope("classifier", reuse=tf.AUTO_REUSE):
        output_weights = tf.get_variable(
            "logits_W", [num_class, ernie_tf_outputs.classification_features.shape[-1]],
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        output_bias = tf.get_variable(
            "logits_b", [num_class], initializer=tf.zeros_initializer())
        ernie_classification_features = tf.nn.dropout(ernie_tf_outputs.classification_features, 
                                                    rate=get_dropout_rate_tensor(0.2))
        logits = tf.matmul(ernie_classification_features, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        logits = tf.nn.leaky_relu(logits)
        print('classifier logits.shape', logits.shape)
    initializers = [i for i in [tf.compat.v1.global_variables_initializer(),
                                tf.compat.v1.local_variables_initializer()] if i]
    session.run(initializers)
    print('n. local vars:', len(tf.local_variables()), 'n. global vars:', len(tf.global_variables()))

In [None]:
with session.as_default():
    batch_size = 3
    token_ids2 = np.squeeze(np.stack([record.token_ids] * batch_size, axis=0), axis=1)
    sentence_ids2 = np.squeeze(np.stack([record.sentence_ids] * batch_size, axis=0), axis=1)
    position_ids2 = np.squeeze(np.stack([record.position_ids] * batch_size, axis=0), axis=1)
    task_ids2 = np.squeeze(np.stack([record.task_ids] * batch_size, axis=0), axis=1)
    input_mask2 = np.squeeze(np.stack([record.input_mask] * batch_size, axis=0), axis=1)
    record2 = Ernie2Input(token_ids2, sentence_ids2, position_ids2, task_ids2, input_mask2)
    logits_out = run_model(session, record2, [logits])
    print(logits_out[0].shape)
    st = time.time()
    n_times = 20
    for _ in range(n_times):
        run_model(session, record2, [logits])
    dt = time.time() - st
    print(f"finished in {dt}s. avg {dt / n_times}s/req avg {dt / n_times / batch_size}s/item")

In [None]:
batch_size = 3
before_squeezed = np.stack([record.token_ids] * batch_size, axis=0)
token_ids2 = np.squeeze(before_squeezed, axis=1)
before_squeezed.shape, token_ids2.shape