# PyTorch Training Optimizations with Advanced Matrix Extensions Bfloat16

This code sample will train a ResNet50 model using the CIFAR10 dataset while using Intel's Extension for PyTorch (IPEX). The model will be trained using FP32 and BF16 precision, including the use of Intel's Advanced Matrix Extensions (AMX) on BF16. AMX is supported on BF16 and INT8 data types starting with the 4th Generation of Xeon Scalable Processors, Sapphire Rapids. The training time will be compared, showcasing the speedup of AMX.

## Environment Setup

Ensure the PyTorch kernel is activated before running this notebook.

## Imports, Dataset, Hyperparameters

In [None]:
import os
from time import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import intel_extension_for_pytorch as ipex

In [None]:
# Hyperparameters and constants
LR = 0.001
MOMENTUM = 0.9
DOWNLOAD = True
DATA = 'datasets/cifar10/'

## Identify Supported ISA  
We identify the underlying supported ISA to determine whether AMX is supported. The 4th Gen Intel® Xeon® Scalable Processor (codenamed Sapphire Rapids) or newer must be used to run this sample.  

In [None]:
# Check if hardware supports AMX
import sys
sys.path.append('../../')
import version_check
from cpuinfo import get_cpu_info
info = get_cpu_info()
flags = info['flags']
amx_supported = False
for flag in flags:
    if "amx" in flag:
        amx_supported = True
        break
if not amx_supported:
    print("AMX is not supported on current hardware. Code sample cannot be run.\n")

If the message "AMX is not supported on current hardware. Code sample cannot be run." is printed above, the hardware being used does not support AMX. Therefore, this code sample cannot proceed.

## Training the Model
The function below will train the Resnet50 model based on the whether AMX should be enabled, and whether to use FP32 or BF16 data type. The environment variable `ONEDNN_MAX_CPU_ISA` is used to enable or disable AMX. For more information, refer to the [oneDNN documentation on CPU Dispatcher Control](https://oneapi-src.github.io/oneDNN/dev_guide_cpu_dispatcher_control.html). To use BF16 in operations, use the `torch.cpu.amp.autocast()` function to perform forward and backward propagation.

In [None]:
"""
Function to run a test case
"""
def trainModel(train_loader, modelName="myModel", amx=True, dataType="fp32"):
    """
    Input parameters
        train_loader: a torch DataLoader object containing the training data
        modelName: a string representing the name of the model
        amx: set to False to disable AMX on BF16, default True otherwise
        dataType: the data type for model parameters, supported values - fp32, bf16
    Return value
        training_time: the time in seconds it takes to train the model
    """
    
    # Configure environment variable
    if not amx and "bf16" == dataType:
        os.environ["ONEDNN_MAX_CPU_ISA"] = "AVX512_CORE_BF16"
    else:
        os.environ["ONEDNN_MAX_CPU_ISA"] = "DEFAULT"

    # Initialize the model 
    model = torchvision.models.resnet50()
    model = model.to(memory_format=torch.channels_last)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
    model.train()
    
    # Optimize with BF16 or FP32 (default)
    if "bf16" == dataType:
        model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
    else:
        model, optimizer = ipex.optimize(model, optimizer=optimizer)

    # Train the model
    num_batches = len(train_loader)
    start_time = time()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        if "bf16" == dataType:
            with torch.cpu.amp.autocast():   # Auto Mixed Precision
                # Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
                data = data.to(memory_format=torch.channels_last)
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
        else:
            # Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
            data = data.to(memory_format=torch.channels_last)
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
        optimizer.step()
        if 0 == (batch_idx+1) % 50:
            print("Batch %d/%d complete" %(batch_idx+1, num_batches))
    end_time = time()
    training_time = end_time-start_time
    print("Training took %.3f seconds" %(training_time))
    
    # Save a checkpoint of the trained model
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        }, 'checkpoint_%s.pth' %modelName)
        
    return training_time

## Loading the dataset
The CIFAR10 dataset is used for this sample. Batch size will be set to 128.

In [None]:
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
        root=DATA,
        train=True,
        transform=transform,
        download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=128
)

## Training with FP32 and BF16, including AMX
Train the Resnet50 model in three different cases:
1. FP32 (baseline)  
2. BF16 without AMX  
3. BF16 with AMX  

The training time is recorded.

In [None]:
print("Training model with FP32")
fp32_training_time = trainModel(train_loader, modelName="fp32", dataType="fp32")

In [None]:
print("Training model with BF16 without AMX")
bf16_noAmx_training_time = trainModel(train_loader, modelName="bf16_noAmx", amx=False, dataType="bf16")

In [None]:
print("Training model with BF16 with AMX")
bf16_withAmx_training_time = trainModel(train_loader, modelName="bf16_withAmx", dataType="bf16")

## Summary of Results
The following cells below will summarize the training times for all three cases and display graphs to show the performance speedup.

In [None]:
print("Summary")
print("FP32 training time: %.3f" %fp32_training_time)
print("BF16 without AMX training time: %.3f" %bf16_noAmx_training_time)
print("BF16 with AMX training time: %.3f" %bf16_withAmx_training_time)

In [None]:
plt.figure()
plt.title("ResNet Training Time")
plt.xlabel("Test Case")
plt.ylabel("Training Time (seconds)")
plt.bar(["FP32", "BF16 no AMX", "BF16 with AMX"], [fp32_training_time, bf16_noAmx_training_time, bf16_withAmx_training_time])

The training times for the 3 cases are printed out and shown in the figure above. Using BF16 should show significant reduction in training time. However, there is little to no change using AVX512 with BF16 and AMX with BF16 because the amount of computations required for one batch is too small with this dataset.   

In [None]:
speedup_from_fp32 = fp32_training_time / bf16_withAmx_training_time
print("BF16 with AMX is %.2fX faster than FP32" %speedup_from_fp32)
speedup_from_bf16 = bf16_noAmx_training_time / bf16_withAmx_training_time
print("BF16 with AMX is %.2fX faster than BF16 without AMX" %speedup_from_bf16)

In [None]:
plt.figure()
plt.title("AMX Speedup")
plt.xlabel("Test Case")
plt.ylabel("Speedup")
plt.bar(["FP32", "BF16 no AMX"], [speedup_from_fp32, speedup_from_bf16])

This figure shows the relative performance speedup of AMX compared to FP32 and BF16 with AVX512.

## Conclusion
This code sample shows how to enable and disable AMX during runtime, as well as the performance improvements using AMX BF16 for training on the ResNet50 model. Performance will vary based on your hardware and software versions. To see more performance improvement between AVX-512 BF16 and AMX BF16, increase the amount of required computations in one batch. This can be done by increasing the batch size with CIFAR10 or using another dataset. For even more speedup, consider using the Intel® Extension for PyTorch* [Launch Script](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/launch_script.html). 

In [None]:
print('[CODE_SAMPLE_COMPLETED_SUCCESFULLY]')