In [1]:
import time
import requests
import numpy as np
import pandas as pd
import torch
import evaluate
import torch.nn as nn
import torch.nn.functional as F
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType
from optimum.onnxruntime import ORTModelForTokenClassification, ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DistilBertConfig,
    DistilBertForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    pipeline
)

from datasets import load_dataset
from peft import PeftModel, LoraConfig, get_peft_model, TaskType

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_DIR = '../data/location_detection/'

## Load data

In [3]:
uk_train_processed_dataset = load_dataset(
    'parquet',
    data_files=DATA_DIR + 'uk_geo_dataset_processed_train_av.parquet',
    split='train[:10%]'
)
uk_holdout_processed_dataset = load_dataset(
    'parquet',
    data_files=DATA_DIR + 'uk_geo_dataset_processed_holdout_av.parquet',
    split='train'
)

In [4]:
test_dataset = pd.read_csv(DATA_DIR + 'competition/test.csv', converters={'locations': eval})

## Roberta distilation

### Load models

In [5]:
base_model_name = 'xlm-roberta-base'
uk_checkpoint = DATA_DIR + 'models/uk-loc/checkpoint-2000'

labels = ['S', 'O', 'B-LOC', 'I-LOC']
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for i, l in enumerate(labels)}

tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForTokenClassification.from_pretrained(base_model_name, label2id=label2id, id2label=id2label)

uk_model = PeftModel.from_pretrained(base_model, uk_checkpoint)

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
def get_distill_config():
    return DistilBertConfig(
        vocab_size=tokenizer.vocab_size,
        label2id=label2id,
        id2label=id2label,
        n_layers=3,
        n_heads=6,
        hidden_dim=1024,
        dim=384
    )

def get_distill_base_model():
    return DistilBertForTokenClassification(get_distill_config())

distill_uk_model = get_distill_base_model()

### Align labels

In [7]:
def align_labels_with_word_ids(labels, word_ids):
    new_labels = []
    current_word = None

    for word_id in word_ids:
        if word_id is None:
            # special tokens
            current_word = word_id
            new_labels.append(label2id["S"])
        elif word_id != current_word:
            # start of new word
            current_word = word_id
            new_labels.append(label2id[labels[word_id]])
        else:
            # part of a word
            label = labels[word_id]

            if label == "B-LOC":
                label = "I-LOC"

            new_labels.append(label2id[label])

    return new_labels

def align_labels(examples):
    bert_tokens = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)

    new_labels = []
    for i, labels in enumerate(examples['labels']):
        word_ids = bert_tokens.word_ids(i)
        new_labels.append(align_labels_with_word_ids(labels, word_ids))

    bert_tokens['labels'] = new_labels
    return bert_tokens

In [8]:
uk_train_dataset = uk_train_processed_dataset.map(
    align_labels,
    batched=True
)
uk_eval_dataset = uk_holdout_processed_dataset.map(
    align_labels,
    batched=True
)

### Metric

In [9]:
metric = evaluate.load("seqeval")

def compute_metrics(eval_preds):
  logits, labels = eval_preds
  predictions = np.argmax(logits, axis=-1)

  true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
  prediction_label = [[id2label[p] for p, l in zip(prediction, label) if l != -100]
                      for prediction, label in zip(predictions, labels)]

  all_metrics = metric.compute(predictions=prediction_label, references=true_labels)
  return {
      'precision': all_metrics['overall_precision'],
      'recall': all_metrics['overall_recall'],
      'f1': all_metrics['overall_f1']
  }

### Training

In [None]:
class DistillTrainer(Trainer):
    def __init__(self, student_model, teacher_model, temperature, lambda_param, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_fn = nn.KLDivLoss()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
            teacher_output = self.teacher(**inputs)

        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # calculate loss for diference between student and teacher
        distill_loss = self.loss_fn(soft_student, soft_teacher) * (self.temperature**2)

        # loss for student target predictions
        student_target_loss = student_output.loss

        # combine student teacher loss and student target loss
        loss = (1-self.lambda_param)*student_target_loss + self.lambda_param*distill_loss
        
        return (loss, student_output) if return_outputs else loss

In [None]:
def train(output_dir, student_model, teacher_model, train_dataset, eval_dataset):
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        evaluation_strategy='steps',
        save_strategy='steps',
        logging_strategy='steps',
        eval_steps=100,
        logging_steps=100,
        save_steps=100,
        save_total_limit=3,
        num_train_epochs=1,
    )
    trainer = DistillTrainer(
        student_model=student_model,
        teacher_model=teacher_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        temperature=2,
        lambda_param=0.5
    )
    trainer.train()

In [None]:
train(
    DATA_DIR + 'models/uk-distilled',
    distill_uk_model,
    uk_model,
    uk_train_dataset,
    uk_eval_dataset
)

### Model comparison

In [10]:
base_uk_checkpoint = DATA_DIR + 'models/uk-loc/checkpoint-2000'
distill_uk_checkpoint = DATA_DIR + 'models/uk-distilled/checkpoint-800'

base_model = AutoModelForTokenClassification.from_pretrained(base_model_name, label2id=label2id, id2label=id2label)
base_uk_model = PeftModel.from_pretrained(base_model, base_uk_checkpoint).merge_and_unload()
base_uk_model.eval()

distill_uk_model = DistilBertForTokenClassification.from_pretrained(distill_uk_checkpoint)
distill_uk_model.eval();

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def time_model(model, dataset):
    start_time = time.time()
    
    input = tokenizer(dataset['tokens'], return_tensors='pt', max_length=50, padding='max_length', truncation=True, is_split_into_words=True)
    labels = [p + ([0] * (50 - len(p))) for p in [t[:50] for t in dataset['labels']]]
    
    with torch.no_grad():
        logits = model(**input).logits.cpu().detach().numpy()
        
    end_time = time.time()
    return end_time - start_time, compute_metrics((logits, labels))

In [12]:
base_time, base_metrics = time_model(base_uk_model, uk_eval_dataset)

distill_time, distill_metrics = time_model(distill_uk_model, uk_eval_dataset)

print(f"Original model inference time: {base_time}, f1: {base_metrics['f1']}")
print(f"Distilled model inference time: {distill_time}, f1: {distill_metrics['f1']}")
print(f"Inference acceleration: {round(base_time/distill_time, 2)}, f1 degradation: {base_metrics['f1'] - distill_metrics['f1']}")

Original model inference time: 25.507636547088623, f1: 0.4667684195920632
Distilled model inference time: 1.730381965637207, f1: 0.42507580565545444
Inference acceleration: 14.74, f1 degradation: 0.04169261393660878


As we can see, distilled model is 15 times faster, but metric is only ~10% worse.

## Onnx conversion

In [13]:
onnx_directory = DATA_DIR + 'models/onnx/'

In [None]:
ort_model = ORTModelForTokenClassification.from_pretrained(distill_uk_checkpoint, export=True)

ort_model.save_pretrained(onnx_directory + 'regular/')

In [14]:
onnx_uk_model = ORTModelForTokenClassification.from_pretrained(onnx_directory + 'regular/', file='model.onnx')

In [15]:
onnx_time, onnx_metrics = time_model(onnx_uk_model, uk_eval_dataset)

print(f"Onnx model inference time: {onnx_time}, f1: {onnx_metrics['f1']}")
print(f"Inference acceleration: {round(distill_time/onnx_time, 2)}, f1 degradation: {distill_metrics['f1'] - onnx_metrics['f1']}")

Onnx model inference time: 1.6918556690216064, f1: 0.42507580565545444
Inference acceleration: 1.02, f1 degradation: 0.0


Onnx model is slower than distilled model. But since metric not degraded, using onnx model for quantization:

## Quantization

In [None]:
qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
quantizer = ORTQuantizer.from_pretrained(onnx_uk_model)

quantizer.quantize(save_dir=onnx_directory + 'quantized/', quantization_config=qconfig)

In [16]:
onnx_quant_uk_model = ORTModelForTokenClassification.from_pretrained(onnx_directory + 'quantized/', file='model_quantized.onnx')

In [17]:
quant_time, quant_metrics = time_model(onnx_quant_uk_model, uk_eval_dataset)

print(f"Qunatized model inference time: {quant_time}, f1: {quant_metrics['f1']}")
print(f"Inference acceleration: {round(distill_time/quant_time, 2)}, f1 degradateion: {distill_metrics['f1'] - quant_metrics['f1']}")

Qunatized model inference time: 1.4115355014801025, f1: 0.41883185840707965
Inference acceleration: 1.23, f1 degradateion: 0.006243947248374793


As we can see, quantized model is a little bit faster than regular distilled model

## FastAPI service

Code for api and docker is in ./service folder

Predictions from distilled quantized onnx model:

In [18]:
request_data = {'texts': test_dataset['text'].sample(5).to_list()}
request_data

{'texts': ['У пам’ять про загиблого випускника НаУКМА збирають кошти на меморіальну стипендію\n\nСтипендія, яку заснували на честь Євгена Олефіренка, покриє плату за навчання для студента чи студентки на магістерській програмі «Історія» в Могилянці.\n\nНа початку повномасштабного вторгнення Євген мобілізувався до 206-го київського батальйону ТрО, був командиром взводу Першої окремої бригади спеціального призначення імені Івана Богуна.\nХлопець воював під Миколаєвом, був гранатометником і мінометником, навчав іноземних добровольців. Його підрозділ захищав Лисичанськ на Луганщині. І могилянець зумів вивести свій взвод з напівоточеного міста неушкодженим. Євген загинув під Бахмутом 7 липня 2022 року.\n\nНа кафедрі історії зазначають, що наразі вдалось зібрати повну суму на 2-й рік навчання, а для повної стипендії на 2023-2024 роки бракує лише 13 тис. грн. Необхідна сума: 40\u202f000 грн. Здійснити переказ можна за посиланням. \n\n🔷\xa0Підписатися на Telegram\xa0| Instagram\xa0| Facebook\x

In [19]:
requests.post('http://localhost:8088/extract_locations', json=request_data).text

'[["Микола"],[],["Києві"],["Україні","Київ"],["Київ"]]'

As we can see, distilled model misses a lot of locations and mostly predicts obvious ones: "Київ...", "Україна", etc.

This is probably happens because distilled model was trained on very small dataset and didn't see a lot of rare locations.

To solve this, model should be trained on bigger dataset.