<a href="https://colab.research.google.com/github/davidandw190/faas-dl-inference/blob/main/notebooks/toxicity_assessment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install datasets transformers onnx onnxruntime tqdm

In [2]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import numpy as np
from datasets import load_metric
import transformers
import transformers.convert_graph_to_onnx as onnx_convert
from pathlib import Path
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnxruntime as ort
from tqdm.auto import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
dataset = load_dataset("civil_comments")

model_name = 'microsoft/xtremedistil-l6-h256-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4)

In [None]:
def prepare_dataset(examples):
    examples["label"] = [1 if toxicity > 0.5 else 0 for toxicity in examples["toxicity"]]
    return examples

prepared_datasets = tokenized_datasets.map(prepare_dataset, num_proc=4)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model = model.to(device)

In [None]:
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


In [None]:
train_dataset = prepared_datasets["train"].shuffle(seed=42).select(range(100000))
eval_dataset = prepared_datasets["test"].shuffle(seed=42).select(range(10000))

In [None]:
training_args = TrainingArguments(
    "toxicity_classifier",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    learning_rate=5e-5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    use_cpu=False
)

In [None]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

In [None]:
model = model.to("cpu")
pipeline = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer)

onnx_convert.convert_pytorch(pipeline,
                             opset=14,
                             output=Path("toxicity_classifier.onnx"),
                             use_external_format=False)

quantize_dynamic("toxicity_classifier.onnx",
                 "toxicity_classifier_int8.onnx",
                 weight_type=QuantType.QUInt8)

In [None]:
session = ort.InferenceSession("toxicity_classifier.onnx", providers=['CPUExecutionProvider'])
session_int8 = ort.InferenceSession("toxicity_classifier_int8.onnx", providers=['CPUExecutionProvider'])

In [None]:
input_sample = eval_dataset.select(range(1000))
input_feed = {
    "input_ids": np.array(input_sample['input_ids']),
    "attention_mask": np.array(input_sample['attention_mask']),
    "token_type_ids": np.array(input_sample['token_type_ids'])
}

In [None]:
out = session.run(input_feed=input_feed, output_names=['output_0'])[0]
out_int8 = session_int8.run(input_feed=input_feed, output_names=['output_0'])[0]


In [None]:
predictions = np.argmax(out, axis=-1)
predictions_int8 = np.argmax(out_int8, axis=-1)

In [None]:
onnx_accuracy = metric.compute(predictions=predictions, references=input_sample['label'])
onnx_int8_accuracy = metric.compute(predictions=predictions_int8, references=input_sample['label'])

print(f"ONNX model accuracy: {onnx_accuracy}")
print(f"ONNX INT8 model accuracy: {onnx_int8_accuracy}")

In [None]:
from google.colab import files

files.download('toxicity_classifier_int8.onnx')
files.download('toxicity_classifier.onnx')