In [1]:
import torch
import torch.nn as nn
from transformers import (
    BertForSequenceClassification,
    AutoTokenizer,
    AutoModel,
    BertPreTrainedModel,
)
from transformers import pipeline
from transformers import BertTokenizer, BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
)

In [2]:
from typing import Union


class BertWithCustomHead(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)

        # Define the custom classification head
        self.custom_dropout = nn.Dropout(p=0.3)
        self.classifier = nn.Sequential(
            nn.Linear(self.config.hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, self.config.num_labels),
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        return_dict=True,
    ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
        # Get outputs from the base BERT model
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=return_dict,
        )

        # Extract the [CLS] pooled output
        pooled_output = outputs.pooler_output

        # Pass pooled output through the custom classification head
        pooled_output = self.custom_dropout(pooled_output)
        logits = self.classifier(pooled_output)

        # Return logits and loss if labels are provided
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        if not return_dict:
            return (loss, logits) if loss is not None else logits

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES["bert_with_custom_head"] = (
    "BertWithCustomHead"
)
# del MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES["bert_with_custom_head"]

In [3]:
model = BertWithCustomHead.from_pretrained("./roman-classifier")
model.classifier # Verify that the correct classification heads are loaded

Sequential(
  (0): Linear(in_features=128, out_features=128, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=128, out_features=64, bias=True)
  (4): ReLU()
  (5): Linear(in_features=64, out_features=3, bias=True)
)

In [7]:
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.exporters.tasks import TasksManager
TasksManager.infer_library_from_model = lambda *args, **kwargs: "transformers"
TasksManager.get_model_class_for_task = lambda *args, **kwargs: BertWithCustomHead

In [8]:
model_checkpoint = "./roman-classifier"
save_directory = "onnx/"

ort_model = ORTModelForSequenceClassification.from_pretrained("./roman-classifier", export=True, task="text-classification")
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Save the onnx model and tokenizer
ort_model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)

I have loaded the model Sequential(
  (0): Linear(in_features=128, out_features=128, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=128, out_features=64, bias=True)
  (4): ReLU()
  (5): Linear(in_features=64, out_features=3, bias=True)
)


  if not return_dict:


('onnx/tokenizer_config.json',
 'onnx/special_tokens_map.json',
 'onnx/vocab.txt',
 'onnx/added_tokens.json',
 'onnx/tokenizer.json')

In [9]:
onnx_inference_pipeline = pipeline(
    "text-classification",
    model=ort_model,
    tokenizer=tokenizer,
)

Device set to use mps:0


In [22]:
model_path = "./roman-classifier"
model = BertWithCustomHead.from_pretrained(model_path)

In [23]:
inference_pipeline = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
)

Device set to use mps:0


In [26]:
inference_pipeline(["Namaste, tapai sanchai hunuhuncha?", "dattebayo chan", "k cha khabar", "hi there, how you doing?", "arigato gozaimasu"])

[{'label': 'LABEL_2', 'score': 0.4014548063278198},
 {'label': 'LABEL_2', 'score': 0.40924739837646484},
 {'label': 'LABEL_2', 'score': 0.42295145988464355},
 {'label': 'LABEL_0', 'score': 0.41292551159858704},
 {'label': 'LABEL_2', 'score': 0.3565642833709717}]

In [25]:
onnx_inference_pipeline(["Namaste, tapai sanchai hunuhuncha?", "dattebayo chan", "k cha khabar", "hi there, how you doing?", "arigato gozaimasu"])

[{'label': 'LABEL_2', 'score': 0.4014548659324646},
 {'label': 'LABEL_2', 'score': 0.40924733877182007},
 {'label': 'LABEL_2', 'score': 0.42295145988464355},
 {'label': 'LABEL_0', 'score': 0.41292551159858704},
 {'label': 'LABEL_2', 'score': 0.3565642833709717}]

In [17]:
import time
import pandas as pd

df = pd.read_csv("datasets/dataset.csv")

In [18]:
model.eval()
for i in range(10):
    start = time.time()
    df["preds"] = df.sentences.apply(lambda x: inference_pipeline(x)[0])
    end = time.time()
    print("Total time taken: ", end - start)

Total time taken:  6.128638982772827
Total time taken:  4.91447901725769
Total time taken:  4.506620168685913
Total time taken:  4.699854850769043
Total time taken:  4.491797924041748
Total time taken:  4.353509902954102
Total time taken:  4.253670692443848
Total time taken:  4.248823165893555
Total time taken:  4.246496200561523
Total time taken:  4.304681777954102


In [19]:
model.eval()
for i in range(10):
    start = time.time()
    df["preds"] = df.sentences.apply(lambda x: onnx_inference_pipeline(x)[0])
    end = time.time()
    print("Total time taken: ", end - start)

Total time taken:  1.264888048171997
Total time taken:  1.3372550010681152
Total time taken:  1.3093030452728271
Total time taken:  1.496689796447754
Total time taken:  1.380295991897583
Total time taken:  1.290701150894165
Total time taken:  1.3331658840179443
Total time taken:  1.6963269710540771
Total time taken:  2.06199312210083
Total time taken:  1.2847800254821777
