# check tf devices

In [None]:
import tensorflow as tf
print(tf.config.get_visible_devices())

# pipeline test

In [None]:
from transformers import pipeline
classifier = pipeline('text-classification')
classifier("My name is Clara and I live in Berkeley, California. I work at this cool company called Hugging Face.")

# TFBertForSequenceClassification Batch prediction

## select model & tokenizer

In [None]:
from transformers import BertTokenizer, TFBertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = TFBertForSequenceClassification.from_pretrained(
    'bhadresh-savani/bert-base-uncased-emotion')


## define batch_predict

In [None]:
import tensorflow as tf
import numpy as np


def batch_predict(model, tokenizer, texts, batch_size):
    texts_remain = texts[len(texts) // batch_size * batch_size:]
    texts = np.resize(texts, (len(texts) // batch_size, batch_size))

    batches = [tokenizer(list(batch), return_tensors="tf", padding=True)
            for batch in texts]

    for batch in batches:
        batch["labels"] = tf.fill(
            [1, batch['input_ids'].shape[0]], tf.constant(model.num_labels-1))

    outputs = []
    for batch in batches:
        output = dict()
        output['TFSequenceClassifierOutput'] = model(batch)
        output['original_text'] = np.array(
            [tokenizer.decode(sentence) for sentence in batch['input_ids']])
        outputs.append(output)

    if (len(texts_remain) > 0):
        batch_remain = tokenizer(
            list(texts_remain), return_tensors="tf", padding=True)
        batch_remain["labels"] = tf.fill(
            [1, batch_remain['input_ids'].shape[0]], tf.constant(model.num_labels-1))
        output = dict()
        output['TFSequenceClassifierOutput'] = model(batch_remain)
        output['original_text'] = np.array(
            [tokenizer.decode(sentence) for sentence in batch_remain['input_ids']])
        outputs.append(output)
        
    return outputs


## define dataset

In [None]:
import pandas as pd
df = pd.read_csv('data/SPAM text message 20170820 - Data.csv', header=0)
texts = df['Message'].to_numpy()[:100]
print(len(texts))

## experiment

In [None]:
import time
result = []
batch_sizes = []
for batch_size in range(1, len(texts)+1):
    batch_sizes.append(batch_size)
    result_at_batch_size = []
    for i in range(5):
        start = time.time()
        outputs = batch_predict(model, tokenizer, texts, batch_size=batch_size)
        end = time.time()
        result_at_batch_size.append(end - start)
    result.append(result_at_batch_size)


### show results

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(20,5))
ax.boxplot(result)
ax.set_yscale('log')
ax.set_xlabel('batch_size')
ax.set_xticklabels(batch_sizes)
ax.set_ylabel('seconds')
plt.show()

In [None]:
avg = np.average(np.array(result), axis=-1)
print(f'batch_size: {avg.argmin()-1}, avg: {avg.min()} sec')