In [79]:
from gliner import GLiNER

model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")


Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 39016.78it/s]


In [80]:
model.save_pretrained("../models/gliner_medium-v2.1")

In [81]:
ONNX_SAVE_PATH = "../models/gliner_medium-v2.1/model.onnx"

In [82]:
text = "ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools."
labels = ['format', 'model', 'tool', 'cat']

inputs, _ = model.prepare_model_inputs([text], labels)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [84]:
import torch

if model.config.span_mode == 'token_level':
    all_inputs =  (inputs['input_ids'], inputs['attention_mask'], 
                    inputs['words_mask'], inputs['text_lengths'])
    input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths']
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "words_mask": {0: "batch_size", 1: "sequence_length"},
        "text_lengths": {0: "batch_size", 1: "value"},
        "logits": {0: "position", 1: "batch_size", 2: "sequence_length", 3: "num_classes"},
    }
else:
    all_inputs =  (inputs['input_ids'], inputs['attention_mask'], 
                    inputs['words_mask'], inputs['text_lengths'],
                    inputs['span_idx'], inputs['span_mask'])
    input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths', 'span_idx', 'span_mask']
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "words_mask": {0: "batch_size", 1: "sequence_length"},
        "text_lengths": {0: "batch_size", 1: "value"},
        "span_idx": {0: "batch_size", 1: "num_spans", 2: "idx"},
        "span_mask": {0: "batch_size", 1: "num_spans"},
        "logits": {0: "batch_size", 1: "sequence_length", 2: "num_spans", 3: "num_classes"},
    }
print('Converting the model...')
torch.onnx.export(
    model.model,
    all_inputs,
    f=ONNX_SAVE_PATH,
    input_names=input_names,
    output_names=["logits"],
    dynamic_axes=dynamic_axes,
    opset_version=14,
)

Converting the model...




In [85]:
#quantize model
import os
from onnxruntime.quantization import quantize_dynamic, QuantType

quantized_save_path = ("../models/gliner_medium-v2.1/model_quantized.onnx")
# Quantize the ONNX model
print("Quantizing the model...")
quantize_dynamic(
    ONNX_SAVE_PATH,  # Input model
    quantized_save_path,  # Output model
    weight_type=QuantType.QUInt8  # Quantize weights to 8-bit integers
)

Quantizing the model...


  elem_type: 7
  shape {
    dim {
      dim_value: 2
    }
    dim {
      dim_param: "unk__788"
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 2
    }
    dim {
      dim_param: "unk__789"
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 2
    }
    dim {
      dim_param: "unk__803"
    }
  }
}
.


In [86]:
from gliner import GLiNER
model = GLiNER.from_pretrained("../models/gliner_medium-v2.1", load_onnx=True, load_tokenizer=True)

config.json not found in /workspaces/NER-project/models/gliner_medium-v2.1
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [87]:
text2 = "My name is Tom, I live in New York and my girlfriend's name is Elaine. Our parents live in Viet Nam, Nha Trang city, and their names are Que and Mai"
labels = ['Person', 'Place']

inputs, raw_batch = model.prepare_model_inputs([text2], labels)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [88]:
import onnxruntime as ort
import numpy as np
ort_sess = ort.InferenceSession('../models/gliner_medium-v2.1/model.onnx')

In [89]:
import torch
outputs = ort_sess.run(None, {'input_ids': inputs['input_ids'].numpy(),
                            'attention_mask': inputs['attention_mask'].numpy(),
                            'words_mask': inputs['words_mask'].numpy(),
                            'text_lengths': inputs['text_lengths'].numpy(),
                            'span_idx': inputs['span_idx'].numpy(),
                            'span_mask': inputs['span_mask'].numpy(),
                            })[0]
outputs = torch.from_numpy(outputs)

In [90]:
outputs = model.decoder.decode(
            raw_batch["tokens"],
            raw_batch["id_to_classes"],
            outputs,
            flat_ner=True,
            threshold=0.5,
            multi_label=False,
        )[0]

outputs

[(3, 3, 'Person', 0.9672063589096069),
 (8, 9, 'Place', 0.8985015153884888),
 (17, 17, 'Person', 0.9670381546020508),
 (23, 24, 'Place', 0.9371719360351562),
 (26, 28, 'Place', 0.9003996849060059),
 (34, 34, 'Person', 0.8820420503616333),
 (36, 36, 'Person', 0.7397370934486389)]

In [91]:
texts = raw_batch['tokens'][0]

for output in outputs:
    start, end = output[:2]
    entity = output[2]
    print(f"{texts[start:end+1]} => {entity}")

['Tom'] => Person
['New', 'York'] => Place
['Elaine'] => Person
['Viet', 'Nam'] => Place
['Nha', 'Trang', 'city'] => Place
['Que'] => Person
['Mai'] => Person


In [20]:
# import bentoml
# import onnx 

# model_onnx = onnx.load(ONNX_SAVE_PATH)
# signatures = {
#     "run": {"batchable": True},
# }
# bento_model = bentoml.onnx.save_model("onnx_ner", model_onnx, signatures=signatures)
# print(bento_model.tag)