# Intel TensorFlow AMX BF16 Training
This code sample will train a DistilBERT model while using Intel Optimized TensorFlow. The model will be trained using FP32 and BF16 precision, including the use of Intel(R) Advanced Matrix Extensions (AMX) on BF16. AMX is supported on BF16 data type starting with the 4th Generation of Xeon Scalable Processors. The training time will be compared, showcasing the speedup of AMX.

## Environment Setup
Ensure the TensorFlow kernel is activated before running this notebook.

# Imports, Dataset, Hyperparameters

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
import transformers
import time
import matplotlib.pyplot as plt
import argparse

In [None]:
is_tune_model = True # or False
log_dir = "logs"
profiling_needed = False
execution_mode = "graph"
load_weights_dir = "weights"
save_weights_dir = "weights"

In [None]:
if execution_mode == "graph":
  tf.compat.v1.disable_eager_execution()

# Identify Supported ISA
We identify the underlying supported ISA to determine whether AMX is supported. You must use a 4th Gen Intel® Xeon® Scalable Processor or newer must to run this sample.

In [None]:
# Check if hardware supports AMX

from cpuinfo import get_cpu_info
info = get_cpu_info()
flags = info['flags']
amx_supported = False
for flag in flags:
    if "amx" in flag:
        amx_supported = True
    print("AMX is supported on current hardware. Code sample can be run.\n")
if not amx_supported:
    print("AMX is not supported on current hardware. Code sample cannot be run.\n")
    sys.exit("AMX is not supported on current hardware. Code sample cannot be run.\n")


If the message "AMX is not supported on current hardware. Code sample cannot be run." is printed above, the hardware being used does not support AMX. Therefore, this code sample cannot proceed

# Build the Model
The functions below will build up the DistilBERT model based on the whether AMX should be enabled, and whether to use FP32 or BF16 data type. The environment variable ONEDNN_MAX_CPU_ISA is used to enable or disable AMX. For more information, refer to the oneDNN documentation on CPU Dispatcher Control. To use BF16 in operations, use the tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') function.

In [None]:
def bert_encode(texts, tokenizer, max_len=512):
    all_tokens = []
    
    for text in texts:
        text = tokenizer.tokenize(text)
            
        text = text[:max_len-2]
        input_sequence = ["[CLS]"] + text + ["[SEP]"]
        pad_len = max_len - len(input_sequence)
        
        tokens = tokenizer.convert_tokens_to_ids(input_sequence)
        tokens += [0] * pad_len
        pad_masks = [1] * len(input_sequence) + [0] * pad_len
        segment_ids = [0] * max_len
        
        all_tokens.append(tokens)
    
    return np.array(all_tokens)
    
def build_model(transformer, max_len=512):
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    sequence_output = transformer(input_word_ids)[0]
    cls_token = sequence_output[:, 0, :]
    out = Dense(1, activation='sigmoid')(cls_token)
    
    model = Model(inputs=input_word_ids, outputs=out)
    model.compile(Adam(lr=1e-5), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

In [None]:
train = pd.read_csv("data/train.csv")
test = pd.read_csv("data/test.csv")
classified_results = pd.read_csv("data/sample_submission.csv")

# load distilbert uncased pre-trained model and corresponding tokenizer from hugging face
transformer_layer = transformers.TFDistilBertModel.from_pretrained('distilbert-base-uncased')
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Training with FP32 and BF16, including AMX
Train the DistilBERT model in three different cases:

1. FP32 (baseline)
2. BF16 without AMX
3. BF16 with AMX

The training time is recorded

In [None]:
# FP32 (baseline)
# build model
model = build_model(transformer_layer, max_len=160)

# fine tune model according to disaster tweets dataset
if is_tune_model:
    train_input = bert_encode(train.text.values, tokenizer, max_len=160)
    train_labels = train.target.values
    start_time = time.time()
    train_history = model.fit(train_input, train_labels, validation_split=0.2, epochs=1, batch_size=16)
    end_time = time.time()
  # save model weights so we don't have to fine tune it every time
    os.makedirs(save_weights_dir, exist_ok=True)
    model.save_weights(save_weights_dir + "/model_weights.h5")

else:
    try:
        model.load_weights(load_weights_dir + "/model_weights.h5")
    except FileNotFoundError:
        sys.exit("\n\nTuned model weights not available. Tune model first by setting parameter -t=True")

fp32_training_time = end_time-start_time
print("Training model with FP32")

In [None]:
# BF16 without AMX
os.environ["ONEDNN_MAX_CPU_ISA"] = "AVX512_BF16"
tf.config.optimizer.set_experimental_options({'auto_mixed_precision_onednn_bfloat16':True})

transformer_layer = transformers.TFDistilBertModel.from_pretrained('distilbert-base-uncased')
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = build_model(transformer_layer, max_len=160)

# fine tune model according to disaster tweets dataset
if is_tune_model:
    train_input = bert_encode(train.text.values, tokenizer, max_len=160)
    train_labels = train.target.values
    start_time = time.time()
    train_history = model.fit(train_input, train_labels, validation_split=0.2, epochs=1, batch_size=16)
    end_time = time.time()
  # save model weights so we don't have to fine tune it every time
    os.makedirs(save_weights_dir, exist_ok=True)
    model.save_weights(save_weights_dir + "/bf16_model_weights.h5")

else:
    try:
        model.load_weights(load_weights_dir + "/bf16_model_weights.h5")
    except FileNotFoundError:
        sys.exit("\n\nTuned model weights not available. Tune model first by setting parameter -t=True")

bf16_noAmx_training_time = end_time-start_time
print("Training model with BF16 without AMX")

In [None]:
# BF16 with AMX
os.environ["ONEDNN_MAX_CPU_ISA"] = "AMX_BF16"

transformer_layer = transformers.TFDistilBertModel.from_pretrained('distilbert-base-uncased')
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = build_model(transformer_layer, max_len=160)

# fine tune model according to disaster tweets dataset
if is_tune_model:
    train_input = bert_encode(train.text.values, tokenizer, max_len=160)
    train_labels = train.target.values
    start_time = time.time()
    train_history = model.fit(train_input, train_labels, validation_split=0.2, epochs=1, batch_size=16)
    end_time = time.time()
  # save model weights so we don't have to fine tune it every time
    os.makedirs(save_weights_dir, exist_ok=True)
    model.save_weights(save_weights_dir + "/AMX_bf16_model_weights.h5")

else:
    try:
        model.load_weights(load_weights_dir + "/AMX_bf16_model_weights.h5")
    except FileNotFoundError:
        sys.exit("\n\nTuned model weights not available. Tune model first by setting parameter -t=True")

bf16_withAmx_training_time = end_time-start_time
print("Training model with BF16 with AMX")

# Summary of Results
The following cells below will summarize the training time for all three cases and display graphs to show the performance speedup.

In [None]:
print("Summary")
print("FP32 training time: %.3f" %fp32_training_time)
print("BF16 without AMX training time: %.3f" %bf16_noAmx_training_time)
print("BF16 with AMX training time: %.3f" %bf16_withAmx_training_time)

In [None]:
plt.figure()
plt.title("DistilBERT Training Time")
plt.xlabel("Test Case")
plt.ylabel("Training Time (seconds)")
plt.bar(["FP32", "BF16 no AMX", "BF16 with AMX"], [fp32_training_time, bf16_noAmx_training_time, bf16_withAmx_training_time])

In [None]:
speedup_from_fp32 = fp32_training_time / bf16_withAmx_training_time
print("BF16 with AMX is %.2fX faster than FP32" %speedup_from_fp32)
speedup_from_bf16 = bf16_noAmx_training_time / bf16_withAmx_training_time
print("BF16 with AMX is %.2fX faster than BF16 without AMX" %speedup_from_bf16)

In [None]:
plt.figure()
plt.title("AMX Speedup")
plt.xlabel("Test Case")
plt.ylabel("Speedup")
plt.bar(["FP32", "BF16 no AMX"], [speedup_from_fp32, speedup_from_bf16])

In [None]:
print('[CODE_SAMPLE_COMPLETED_SUCCESFULLY]')