In [None]:
import os

import warnings
warnings.filterwarnings('ignore')

In [None]:
# gives control over the number of threads 
os.environ["OMP_NUM_THREADS"] = "16"
#os.environ.pop('OMP_NUM_THREADS', None)
#torch.set_num_threads(

# manages the amount of time before cores go back to sleep between processes
os.environ["KMP_BLOCKTIME"] = "50"
#os.environ.pop('KMP_BLOCKTIME', None)

# We can specify how many threads to use in total and how the threads are distributed across different layers of machine topology
#os.environ["KMP_HW_SUBSET"] = "1s,1n,56c,2t"
os.environ.pop('KMP_HW_SUBSET', None)

# Thread affinity restricts execution of certain threads (virtual execution units) to a subset of the physical processing units in a multiprocessor computer. 
#os.environ["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
os.environ.pop('KMP_AFFINITY', None)

os.environ["MKLDNN_VERBOSE"] = "0"

!echo $OMP_NUM_THREADS
!echo $KMP_BLOCKTIME
!echo $KMP_HW_SUBSET
!echo $KMP_AFFINITY


# use top + 1 + t , to check the utilization of your cores. 

# AI-Powered Customer Care Chatbots (based on Intel AI Reference Kit)  

Briefly, given a customer query, the AI system must understand the intent and the entities involved within the query, lookup or launch the relevant information, and return the appropriate response to the customer in a reasonable amount of time. In this example, we focus on leveraging the Intel® oneAPI AI Analytics Toolkit on the task of training and deploying an accurate and quick AI system to predict the Intent and Entities of a user query.

In [None]:
import argparse
import logging
import pathlib
import time

import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast

# if using intel, optimize the model and the optimizer
import intel_extension_for_pytorch as ipex

In [None]:
# !/usr/bin/env python3
# -*- coding: utf-8 -*-1s,

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

# pylint: disable=C0415,E0401,R0914

"""
Code adopted from
https://github.com/sz128/slot_filling_and_intent_detection_of_SLU
"""

import operator
from typing import Any, Dict, List, Union

import numpy as np
import torch
from torch.utils.data import Dataset


def construct_vocab(
    input_seqs: List[str],
    vocab_config: Dict[str, Any] = None
) -> Union[Dict[str, int], Dict[int, str]]:
    """Construct a vocabulary given a list of sentences.

    Args:
        input_seqs (List[str]): list of sentences
        vocab_config (Dict[str, Any], optional): options for constructing
            the vocab. Defaults to None.

    Returns:
        Union[Dict[str,int], Dict[int, str]]: dictionarys for lookup and
            reverse lookup
    """

    if vocab_config is None:
        vocab_config = {'mini_word_freq': 1, 'bos_eos': False}

    vocab = {}
    for seq in input_seqs:
        if isinstance(seq, type([])):
            for word in seq:
                if word not in vocab:
                    vocab[word] = 1
                else:
                    vocab[word] += 1
        else:
            if seq not in vocab:
                vocab[seq] = 1
            else:
                vocab[seq] += 1

    # Discard start, end, pad and unk tokens if already present
    if '<s>' in vocab:
        del vocab['<s>']
    if '<pad>' in vocab:
        del vocab['<pad>']
    if '</s>' in vocab:
        del vocab['</s>']
    if '<unk>' in vocab:
        del vocab['<unk>']

    if vocab_config['bos_eos'] is True:
        word2id = {'<pad>': 0, '<unk>': 1, '<s>': 2, '</s>': 3}
        id2word = {0: '<pad>', 1: '<unk>', 2: '<s>', 3: '</s>'}
    else:
        word2id = {'<pad>': 0, '<unk>': 1, }
        id2word = {0: '<pad>', 1: '<unk>', }

    sorted_word2id = sorted(
        vocab.items(),
        key=operator.itemgetter(1),
        reverse=True
    )

    sorted_words = [x[0] for x in sorted_word2id if x[1]
                    >= vocab_config['mini_word_freq']]

    for word in sorted_words:
        idx = len(word2id)
        word2id[word] = idx
        id2word[idx] = word

    return word2id, id2word


def read_vocab_file(
        vocab_path: str,
        bos_eos: bool = False,
        no_pad: bool = False,
        no_unk: bool = False,
        separator: str = ':'
) -> Union[Dict[str, int], Dict[int, str]]:
    """Reads a pre-existing vocabulary.

    Args:
        vocab_path (str): path to vocab file
        bos_eos (bool, optional): add begining and ending. Defaults to False.
        no_pad (bool, optional): use pad tokens. Defaults to False.
        no_unk (bool, optional): use unknown tokens. Defaults to False.
        separator (str, optional): separator token  to use. Defaults to ':'.

    Returns:
        Union[Dict[str,int], Dict[int,str]]: dictionarys for lookup and
            reverse lookup
    """

    word2id, id2word = {}, {}
    if not no_pad:
        word2id['<pad>'] = len(word2id)
        id2word[len(id2word)] = '<pad>'
    if not no_unk:
        word2id['<unk>'] = len(word2id)
        id2word[len(id2word)] = '<unk>'
    if bos_eos is True:
        word2id['<s>'] = len(word2id)
        id2word[len(id2word)] = '<s>'
        word2id['</s>'] = len(word2id)
        id2word[len(id2word)] = '</s>'
    with open(vocab_path, 'r', encoding="utf8") as file:
        for line in file:
            if separator in line:
                word, idx = line.strip('\r\n').split(' '+separator+' ')
                idx = int(idx)
            else:
                word = line.strip()
                idx = len(word2id)
            if word not in word2id:
                word2id[word] = idx
                id2word[idx] = word
    return word2id, id2word


def read_vocab_from_data_file(
    data_path: str,
    vocab_config: Dict[str, Any] = None,
    with_tag: bool = True,
    separator: str = ':'
) -> Union[Dict[str, int], Dict[int, str]]:
    """Build a vocab from a data file

    Args:
        data_path (str): file path of data
        vocab_config (Dict[str, Any], optional): vocab config. Defaults to None.
        with_tag (bool, optional): use tags. Defaults to True.
        separator (_type_, optional): separator token to use. Defaults to ':'.

    Returns:
        Union[Dict[str, int], Dict[int, str]]: dictionarys for lookup and
            reverse lookup
    """

    if vocab_config is None:
        vocab_config = {'mini_word_freq': 1,
                        'bos_eos': False, 'lowercase': False}
    print('Reading source data ...')
    input_seqs = []
    with open(data_path, 'r', encoding="utf8") as file:
        for _, line in enumerate(file):
            slot_tag_line = line.strip('\n\r').split(' <=> ')[0]
            if slot_tag_line == "":
                continue
            in_seq = []
            for item in slot_tag_line.split(' '):
                if with_tag:
                    tmp = item.split(separator)
                    word, _ = separator.join(tmp[:-1]), tmp[-1]
                else:
                    word = item
                if vocab_config['lowercase']:
                    word = word.lower()
                in_seq.append(word)
            input_seqs.append(in_seq)

    print('Constructing input vocabulary from ', data_path, ' ...')
    word2idx, idx2word = construct_vocab(input_seqs, vocab_config)
    return (word2idx, idx2word)


def read_seqtag_data_with_class(
    data_path: str,
    word2idx: Dict[str, int],
    tag2idx: Dict[str, int],
    class2idx: Dict[str, int],
    separator: str = ':',
    multi_class: bool = False,
    keep_order: bool = False,
    lowercase: bool = False
) -> Union[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
    """Read data from files.

    Args:
        data_path (str): file path of data
        word2idx (Dict[str, int]): input vocab
        tag2idx (Dict[str, int]): tag vocab
        class2idx (Dict[str, int]): classification vocab
        separator (_type_, optional): separator to use. Defaults to ':'.
        multi_class (bool, optional): multiple classifiers. Defaults to False.
        keep_order (bool, optional): keep a track of line number.
            Defaults to False.
        lowercase (bool, optional): use lowercase. Defaults to False.

    Returns:
        Union[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: input features,
            tag labels, class labels
    """

    print('Reading source data ...')
    input_seqs = []
    tag_seqs = []
    class_labels = []
    line_num = -1
    with open(data_path, 'r', encoding="utf8") as file:
        for _, line in enumerate(file):
            line_num += 1
            slot_tag_line, class_name = line.strip('\n\r').split(' <=> ')
            if slot_tag_line == "":
                continue
            in_seq, tag_seq = [], []
            for item in slot_tag_line.split(' '):
                tmp = item.split(separator)
                word, tag = separator.join(tmp[:-1]), tmp[-1]
                if lowercase:
                    word = word.lower()
                in_seq.append(
                    word2idx[word] if word in word2idx else word2idx['<unk>'])
                tag_seq.append(tag2idx[tag] if tag in tag2idx else (
                    tag2idx['<unk>'], tag))
            if keep_order:
                in_seq.append(line_num)
            input_seqs.append(in_seq)
            tag_seqs.append(tag_seq)
            if multi_class:
                if class_name == '':
                    class_labels.append([])
                else:
                    class_labels.append([class2idx[val]
                                        for val in class_name.split(';')])
            else:
                if ';' not in class_name:
                    class_labels.append(class2idx[class_name])
                else:
                    # get the first class for training
                    class_labels.append(
                        (
                            class2idx[class_name.split(';')[0]],
                            class_name.split(';')
                        )
                    )

    input_feats = {'data': input_seqs}
    tag_labels = {'data': tag_seqs}
    class_labels = {'data': class_labels}

    return input_feats, tag_labels, class_labels


class ATISDataset(Dataset):
    """Dataset for use within PyTorch
    """

    def __init__(
            self, sentences, tags, class_labels, tokenizer, max_length,
            word2id, id2word,
            class2id, id2class,
            tag2id, id2tag):

        self.len = len(sentences)
        self.sentences = sentences
        self.tags = tags
        self.class_labels = class_labels
        self.tokenizer = tokenizer
        self.max_len = max_length

        self.word2id, self.id2word = word2id, id2word
        self.class2id, self.id2class = class2id, id2class
        self.tag2id, self.id2word = tag2id, id2tag

    def __getitem__(self, index):

        sentence = self.sentences[index].strip().split()
        word_labels = self.tags[index]
        class_label = self.class2id[self.class_labels[index]]

        labels = [self.tag2id[label] for label in word_labels]

        encoding = self.tokenizer(sentence,
                                  return_offsets_mapping=True,
                                  is_split_into_words=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len
                                  )

        encoded_labels = np.ones(
            len(encoding['offset_mapping']), dtype=int) * -100

        i = 0
        for idx, mapping in enumerate(encoding['offset_mapping']):
            if mapping[0] == 0 and mapping[1] != 0:
                encoded_labels[idx] = labels[i]
                i += 1

        item = {key: torch.as_tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.as_tensor(encoded_labels, dtype=torch.long)
        item['class_label'] = torch.as_tensor(class_label, dtype=torch.long)

        return item

    def __len__(self):
        return self.len


def load_dataset(data_path, tokenizer, max_length):
    """load the dataset

    Args:
        data_path (str): _description_
        tokenizer : transformers tokenizer_
        max_length (int): max padding length
    Returns:
        Dict[str, Any] : collection of datasets
    """
    word2id, id2word = read_vocab_from_data_file(data_path + "/train")
    class2id, id2class = read_vocab_file(data_path + "/vocab.intent")
    tag2id, id2tag = read_vocab_file(data_path + "/vocab.slot")

    def get_ds(file_name):
        input_feats, tag_labels, class_labels = read_seqtag_data_with_class(
            data_path + "/" + file_name, word2id, tag2id, class2id)
        sentences = []
        labels = []
        cls_labels = []
        for i in range(len(input_feats['data'])):
            sent = input_feats['data'][i]
            tag = tag_labels['data'][i]
            class_label = class_labels['data'][i]
            if not isinstance(class_label, int):
                class_label = class_label[0]

            sentences.append(" ".join([id2word[idx] for idx in sent]))
            labels.append([id2tag[idx] for idx in tag])
            cls_labels.append(id2class[class_label])
        return ATISDataset(sentences, labels, cls_labels,
                           tokenizer, max_length,
                           word2id, id2word,
                           class2id, id2class,
                           tag2id, id2tag)

    return {"train": get_ds("train_all"),
            "test": get_ds("test"),
            "word2id": word2id, "id2word": id2word,
            "tag2id": word2id, "id2tag": id2tag,
            "class2id": word2id, "id2class": id2class}


In [None]:
# !/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

# pylint: disable=C0415,E0401,R0914

"""
Intent and Token Classification Model built using BERT
"""

import torch
from transformers import BertModel

class IntentAndTokenClassifier(torch.nn.Module):
    """Model that performs intent and token classification
    """

    def __init__(
        self,
        num_token_labels: int,
        num_sequence_labels: int
    ) -> None:
        super().__init__()
        self.num_token_labels = num_token_labels
        self.num_sequence_labels = num_sequence_labels
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        self.token_classifier = torch.nn.Linear(
            768, self.num_token_labels)
        self.sequence_classifier = torch.nn.Linear(
            768,
            self.num_sequence_labels
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        token_labels=None,
        sequence_labels=None
    ) -> None:
        """Predicts the intent and token tags for a given input sequence.

        Args:
            input_ids (optional): tokenized sentence. Defaults to None.
            attention_mask (optional): attention mask to use. Defaults to None.
            token_type_ids (optional): token ids. Defaults to None.
            position_ids (optional): position ids. Defaults to None.
            head_mask (optional): head mask. Defaults to None.
            output_attentions (optional): whether to output attentions.
                Defaults to None.
            output_hidden_states (optional): whether to output hidden states.
                Defaults to None.
            token_labels (optional): true tag labels for each token to compute
                loss. Defaults to None.
            sequence_labels (optional): true class label to compute loss.
                Defaults to None.

        Returns:
            logits_token, token_loss, logits_sequence, sequence_loss
        """

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states
        )

        token_output = outputs[0]

        sequence_output = outputs[1]

        logits_token = self.token_classifier(token_output)
        logits_sequence = self.sequence_classifier(sequence_output)

        token_loss = 0
        if token_labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            token_loss = loss_fct(
                logits_token.view(-1, self.num_token_labels),
                token_labels.view(-1)
            )

        sequence_loss = 0
        if sequence_labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            sequence_loss = loss_fct(
                logits_sequence.view(-1, self.num_sequence_labels),
                sequence_labels.view(-1)
            )

        if token_labels is not None and sequence_labels is not None:
            return logits_token, token_loss, logits_sequence, sequence_loss
        return logits_token, logits_sequence


In [None]:
# !/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

# pylint: disable=C0415,E0401,R0914

"""
Train the intent and classfier model
"""

import logging
from typing import Union

from sklearn.metrics import accuracy_score
import torch

logger = logging.getLogger()


def evaluate_accuracy(
    dataloader: torch.utils.data.DataLoader,
    model: IntentAndTokenClassifier,
) -> Union[float, float]:
    """Evaluate the accuracy on the provided dataset

    Args:
        dataloader (torch.utils.data.DataLoader): dataloader to evaluate on
        model (IntentAndTokenClassifier): model to evaluate

    Returns:
        Union[float, float]: token prediction accuracy, class prediction
            accuracy
    """

    tr_tk_preds, tr_tk_labels = [], []
    tr_sq_preds, tr_sq_labels = [], []
    model.eval()

    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            ids = batch['input_ids']
            mask = batch['attention_mask']
            labels = batch['labels']
            class_label = batch['class_label']

            # pass inputs through model
            out = model(
                input_ids=ids,
                attention_mask=mask,
                token_labels=labels,
                sequence_labels=class_label)

            tr_tk_logits = out[0]
            tr_sq_logits = out[2]

            # compute batch accuracy for token classification
            flattened_targets = labels.view(-1)
            active_logits = tr_tk_logits.view(-1, model.num_token_labels)
            flattened_predictions = torch.argmax(active_logits, axis=1)

            # only get predictions of relevant tags
            active_accuracy = labels.view(-1) != -100
            labels = torch.masked_select(flattened_targets, active_accuracy)
            predictions = torch.masked_select(
                flattened_predictions,
                active_accuracy
            )

            tr_tk_labels.extend(labels.numpy())
            tr_tk_preds.extend(predictions.numpy())

            # compute accuracy for seqeunce classification
            predictions_sq = torch.argmax(tr_sq_logits, axis=1)
            tr_sq_labels.extend(class_label.numpy())
            tr_sq_preds.extend(predictions_sq.numpy())

    return (
        accuracy_score(tr_tk_labels, tr_tk_preds),
        accuracy_score(tr_sq_labels, tr_sq_preds)
    )


# Model Training Function

This function contains options for triggering AMP and bf16 OR training with AVX512 and FP32

In [None]:
def train(
        dataloader: torch.utils.data.DataLoader,
        model: torch.nn.Module,
        epochs: int = 5,
        amx: bool = True,
        dataType: str = 'bf16',
        max_grad_norm: float = 10) -> None:
    """train a model on the given dataset

    Args:
        dataloader (torch.utils.data.DataLoader): training dataset
        model (torch.nn.Module): model to train
        optimizer (torch.optim.Optimizer): optimizer to use
        epochs (int, optional): number of training epochs. Defaults to 5.
        max_grad_norm (float, optional): gradient clipping. Defaults to 10.
    """
    
    #optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)
    
    model.train()
    
    # Configure environment variable
    if not amx and 'bf16' == dataType:
        print('going to AVX BF16 rather than AMX')
        os.environ["ONEDNN_MAX_CPU_ISA"] = "AVX512_CORE_BF16"
    else:
        os.environ["ONEDNN_MAX_CPU_ISA"] = "DEFAULT"
    
    
    # Optimize with BF16 or FP32 (default)
    if "bf16" == dataType:
        print('setting dtype to bf16 in IPEX')
        model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
    else:
        model, optimizer = ipex.optimize(model, optimizer=optimizer)

    for epoch in range(1, epochs + 1):

        running_loss = 0
        tr_tk_preds, tr_tk_labels = [], []
        tr_sq_preds, tr_sq_labels = [], []

        for idx, batch in enumerate(dataloader):

            optimizer.zero_grad()

            ids = batch['input_ids']
            mask = batch['attention_mask']
            labels = batch['labels']
            class_label = batch['class_label']

            # pass inputs through model
            if "bf16" == dataType:
                with torch.cpu.amp.autocast(): # required or Auto Mixed Precision (AMP)
                    out = model(
                        input_ids=ids,
                        attention_mask=mask,
                        token_labels=labels,
                        sequence_labels=class_label)
        
                    # evaluate loss
                    token_loss = out[1]
                    sequence_loss = out[3]
                    combined_loss = token_loss + sequence_loss
        
                    running_loss += combined_loss.item()
        
                    tr_tk_logits = out[0]
                    tr_sq_logits = out[2]
        
                    if idx % 100 == 0:
                        print("loss/100 batches: %.4f", running_loss/(idx + 1))
        
                    # compute batch accuracy for token classification
                    flattened_targets = labels.view(-1)
                    active_logits = tr_tk_logits.view(-1, model.num_token_labels)
                    flattened_predictions = torch.argmax(active_logits, axis=1)
        
                    # only get predictions of relevant tags
                    active_accuracy = labels.view(-1) != -100
                    labels = torch.masked_select(flattened_targets, active_accuracy)
                    predictions = torch.masked_select(
                        flattened_predictions,
                        active_accuracy)
        
                    tr_tk_labels.extend(labels.numpy())
                    tr_tk_preds.extend(predictions.numpy())
        
                    # compute accuracy for seqeunce classification
                    predictions_sq = torch.argmax(tr_sq_logits, axis=1)
                    tr_sq_labels.extend(class_label.numpy())
                    tr_sq_preds.extend(predictions_sq.numpy())
        
                    # clip gradients for stability
                    torch.nn.utils.clip_grad_norm_(
                        parameters=model.parameters(), max_norm=max_grad_norm)
        
                    combined_loss.backward()
            else:
                out = model(input_ids=ids,
                            attention_mask=mask,
                            token_labels=labels,
                            sequence_labels=class_label)
    
                # evaluate loss
                token_loss = out[1]
                sequence_loss = out[3]
                combined_loss = token_loss + sequence_loss
    
                running_loss += combined_loss.item()
    
                tr_tk_logits = out[0]
                tr_sq_logits = out[2]
    
                if idx % 100 == 0:
                    print("loss/100 batches: %.4f", running_loss/(idx + 1))
    
                # compute batch accuracy for token classification
                flattened_targets = labels.view(-1)
                active_logits = tr_tk_logits.view(-1, model.num_token_labels)
                flattened_predictions = torch.argmax(active_logits, axis=1)
    
                # only get predictions of relevant tags
                active_accuracy = labels.view(-1) != -100
                labels = torch.masked_select(flattened_targets, active_accuracy)
                predictions = torch.masked_select(flattened_predictions, active_accuracy)
    
                tr_tk_labels.extend(labels.numpy())
                tr_tk_preds.extend(predictions.numpy())
    
                # compute accuracy for seqeunce classification
                predictions_sq = torch.argmax(tr_sq_logits, axis=1)
                tr_sq_labels.extend(class_label.numpy())
                tr_sq_preds.extend(predictions_sq.numpy())
    
                # clip gradients for stability
                torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
    
                combined_loss.backward()
                    
            optimizer.step()

        epoch_loss = running_loss / len(dataloader)
        print("Training loss epoch #%d : %.4f", epoch, epoch_loss)
        print("Training NER accuracy epoch #%d : %.4f", epoch, accuracy_score(tr_tk_labels, tr_tk_preds))
        print("Training CLS accuracy epoch #%d : %.4f", epoch, accuracy_score(tr_sq_labels, tr_sq_preds))

    return model

In [None]:
def train(
        dataloader: torch.utils.data.DataLoader,
        model: torch.nn.Module,
        epochs: int = 5,
        amx: bool = True,
        dataType: str = 'bf16',
        max_grad_norm: float = 10) -> None:
    """train a model on the given dataset

    Args:
        dataloader (torch.utils.data.DataLoader): training dataset
        model (torch.nn.Module): model to train
        optimizer (torch.optim.Optimizer): optimizer to use
        epochs (int, optional): number of training epochs. Defaults to 5.
        max_grad_norm (float, optional): gradient clipping. Defaults to 10.
    """
    
    #optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)
    
    model.train()
    
    # Configure environment variable
    if not amx and 'bf16' == dataType:
        print('going to AVX BF16 rather than AMX')
        os.environ["ONEDNN_MAX_CPU_ISA"] = "AVX512_CORE_BF16"
    else:
        os.environ["ONEDNN_MAX_CPU_ISA"] = "DEFAULT"
    
    
    # Optimize with BF16 or FP32 (default)
    if "bf16" == dataType:
        print('setting dtype to bf16 in IPEX')
        model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
    else:
        model, optimizer = ipex.optimize(model, optimizer=optimizer)

    for epoch in range(1, epochs + 1):

        running_loss = 0
        tr_tk_preds, tr_tk_labels = [], []
        tr_sq_preds, tr_sq_labels = [], []

        for idx, batch in enumerate(dataloader):

            optimizer.zero_grad()

            ids = batch['input_ids']
            mask = batch['attention_mask']
            labels = batch['labels']
            class_label = batch['class_label']

            # pass inputs through model
            
                with torch.cpu.amp.autocast(): # required or Auto Mixed Precision (AMP)
            out = model(
                input_ids=ids,
                attention_mask=mask,
                token_labels=labels,
                sequence_labels=class_label)
        
            # evaluate loss
            token_loss = out[1]
            sequence_loss = out[3]
            combined_loss = token_loss + sequence_loss
        
            running_loss += combined_loss.item()
        
            tr_tk_logits = out[0]
            tr_sq_logits = out[2]
        
            if idx % 100 == 0:
                print("loss/100 batches: %.4f", running_loss/(idx + 1))
        
            # compute batch accuracy for token classification
            flattened_targets = labels.view(-1)
            active_logits = tr_tk_logits.view(-1, model.num_token_labels)
            flattened_predictions = torch.argmax(active_logits, axis=1)
        
            # only get predictions of relevant tags
            active_accuracy = labels.view(-1) != -100
            labels = torch.masked_select(flattened_targets, active_accuracy)
            predictions = torch.masked_select(
                flattened_predictions,
                active_accuracy)
        
            tr_tk_labels.extend(labels.numpy())
            tr_tk_preds.extend(predictions.numpy())
        
            # compute accuracy for seqeunce classification
            predictions_sq = torch.argmax(tr_sq_logits, axis=1)
            tr_sq_labels.extend(class_label.numpy())
            tr_sq_preds.extend(predictions_sq.numpy())
        
            # clip gradients for stability
            torch.nn.utils.clip_grad_norm_(
                parameters=model.parameters(), max_norm=max_grad_norm)
        
            combined_loss.backward()
            
            
            optimizer.step()

        epoch_loss = running_loss / len(dataloader)
        print("Training loss epoch #%d : %.4f", epoch, epoch_loss)
        print("Training NER accuracy epoch #%d : %.4f", epoch, accuracy_score(tr_tk_labels, tr_tk_preds))
        print("Training CLS accuracy epoch #%d : %.4f", epoch, accuracy_score(tr_sq_labels, tr_sq_preds))

    return model

## Setting some basic File Paths and Hyperparameters

In [None]:
intel = True
save_model_dir = './customer-chatbot/model/'
data_path = './customer-chatbot/data/atis-2/'
model_name = 'intel_ipex.pt'

# training parameters
MAX_LENGTH = 64
BATCH_SIZE = 100
EPOCHS = 3
MAX_GRAD_NORM = 10

torch.manual_seed(0)

## Tokenizing Dataset (Airline Travel Info Systems - ATIS)  

In [None]:
# Create tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Read in the datasets and crate dataloaders
print("Reading in the data...")

dataset = load_dataset(data_path, tokenizer, MAX_LENGTH)

train_loader = DataLoader(
    dataset['train'], batch_size=BATCH_SIZE, shuffle=True
)
test_loader = DataLoader(dataset['test'], batch_size=BATCH_SIZE)

# Example of Model Training accross various Configurations

I've also included comparisons with: 
- bf16 with AMX
- FP32 

### IPEX bf16 w/ AMX 

In [None]:
# Create model and prepare for training
start_bf16_wAMX = time.time()
model_bf16_wAMX = IntentAndTokenClassifier(
    num_token_labels=len(dataset['train'].tag2id),
    num_sequence_labels=len(dataset['train'].class2id)
)

# Train the model
print("Training the model...")
model_bf16_wAMX_trained = train(train_loader, model_bf16_wAMX, epochs=EPOCHS, max_grad_norm=MAX_GRAD_NORM, amx=True, dataType='bf16')
training_time_bf16_wAMX = time.time()

# Evaluate accuracy on the testing set in batches
accuracy_ner_bf16_wAMX, accuracy_class_bf16_wAMX = evaluate_accuracy(test_loader, model_bf16_wAMX_trained)
testing_time_bf16_wAMX = time.time()

### IPEX FP32

In [None]:
# Create model and prepare for training
start_fp32 = time.time()
model_fp32 = IntentAndTokenClassifier(
    num_token_labels=len(dataset['train'].tag2id),
    num_sequence_labels=len(dataset['train'].class2id)
)

# Train the model
print("Training the model...")
model_fp32_trained = train(train_loader, model_fp32, epochs=EPOCHS, max_grad_norm=MAX_GRAD_NORM, amx=False, dataType='fp32')
training_time_fp32 = time.time()

# Evaluate accuracy on the testing set in batches
accuracy_ner_fp32, accuracy_class_fp32 = evaluate_accuracy(test_loader, model_fp32_trained)
testing_time_fp32 = time.time()

## Model Performance Summary

Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex. Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available updates. See backup for configuration details. No product or component can be absolutely secure. © Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.

In [None]:
print('TIME METRICS')

print("=======> FP32 ONLY - Test Accuracy on NER: ", accuracy_ner_fp32)
print("=======> FP32 ONLY - Test Accuracy on CLS: ", accuracy_class_fp32)
print("=======> FP32 ONLY - Training Time   mins", (training_time_fp32 - start_fp32)/60)
print("=======> FP32 ONLY - Inference Time:  secs", (testing_time_fp32 - training_time_fp32))
print("=======> FP32 ONLY - Total Time:  mins", (testing_time_fp32 - start_fp32)/60)

print('-'*100)

print("=======> BF16 with AMX - Test Accuracy on NER : ", accuracy_ner_bf16_wAMX)
print("=======> BF16 with AMX - Test Accuracy on CLS : ", accuracy_class_bf16_wAMX)
print("=======> BF16 with AMX - Training Time:  mins", (training_time_bf16_wAMX - start_bf16_wAMX)/60)
print("=======> BF16 with AMX - Inference Time:  secs", (testing_time_bf16_wAMX - training_time_bf16_wAMX))
print("=======> BF16 with AMX - Total Time:  mins", (testing_time_bf16_wAMX - start_bf16_wAMX)/60)

fp32_training_time = (training_time_fp32 - start_fp32)
bf16_wAMX_training_time = (training_time_bf16_wAMX - start_bf16_wAMX)

fp32_inference_time = (testing_time_fp32 - training_time_fp32)
bf16_wAMX_inference_time = (testing_time_bf16_wAMX - training_time_bf16_wAMX)

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.title(" Training Time")
plt.xlabel("Test Case")
plt.ylabel("Training Time (seconds)")
plt.bar(["FP32", "BF16 with AMX"], [fp32_training_time, bf16_wAMX_training_time])

In [None]:
bf16AMX_Uplift_over_FP32 = (fp32_training_time/bf16_wAMX_training_time)

print("BF16 with AMX is %.2fX faster than FP32 in Training" %bf16AMX_Uplift_over_FP32)

In [None]:
plt.figure()
plt.title(" Inference Time")
plt.xlabel("Test Case")
plt.ylabel("Inference Time (seconds)")
plt.bar(["FP32", "BF16 with AMX"], [fp32_inference_time, bf16_wAMX_inference_time])

In [None]:
bf16AMX_Uplift_over_FP32 = (fp32_inference_time/bf16_wAMX_inference_time)

print("BF16 with AMX is %.2fX faster than FP32 in Inference" %bf16AMX_Uplift_over_FP32)