In [None]:
!pip install optimum[onnxruntime] transformers

In [12]:
from optimum.onnxruntime import ORTModelForFeatureExtraction
import onnxruntime as onnxrt
from transformers import AutoTokenizer, pipeline

In [13]:
model_checkpoint = "intfloat/multilingual-e5-large"
save_directory = "/tmp/onnx"

original_model = ORTModelForFeatureExtraction.from_pretrained(
    model_checkpoint,
    export=True
)

original_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

original_model.save_pretrained(save_directory)
original_tokenizer.save_pretrained(save_directory)

# 3m

Framework not specified. Using pt to export the model.
Using the export variant default. Available variants are:
    - default: The default ONNX variant.
Using framework PyTorch: 2.1.0+cu121
Overriding 1 configuration item(s)
	- use_cache -> False
Saving external data to one file...


('/tmp/onnx/tokenizer_config.json',
 '/tmp/onnx/special_tokens_map.json',
 '/tmp/onnx/sentencepiece.bpe.model',
 '/tmp/onnx/added_tokens.json',
 '/tmp/onnx/tokenizer.json')

In [14]:
!ls -lh /tmp/onnx

total 2.2G
-rw-r--r-- 1 root root  716 Feb 28 10:19 config.json
-rw-r--r-- 1 root root 534K Feb 28 10:19 model.onnx
-rw-r--r-- 1 root root 2.1G Feb 28 10:19 model.onnx_data
-rw-r--r-- 1 root root 4.9M Feb 28 10:19 sentencepiece.bpe.model
-rw-r--r-- 1 root root  964 Feb 28 10:19 special_tokens_map.json
-rw-r--r-- 1 root root 1.2K Feb 28 10:19 tokenizer_config.json
-rw-r--r-- 1 root root  17M Feb 28 10:19 tokenizer.json


In [15]:
import multiprocessing

from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer, pipeline

onnxrt_options = onnxrt.SessionOptions()

onnxrt_options.execution_mode = onnxrt.ExecutionMode.ORT_SEQUENTIAL
onnxrt_options.intra_op_num_threads = multiprocessing.cpu_count()

# onnxrt_options.execution_mode = onnxrt.ExecutionMode.ORT_PARALLEL
# onnxrt_options.inter_op_num_threads = multiprocessing.cpu_count()

onnxrt_options.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL
onnxrt_options.add_session_config_entry('session.intra_op.allow_spinning', '1')

model = ORTModelForFeatureExtraction.from_pretrained(
    "/tmp/onnx",
    session_options=onnxrt_options,
    providers=['CPUExecutionProvider']
)

tokenizer = AutoTokenizer.from_pretrained("/tmp/onnx")

onnx_extractor = pipeline("feature-extraction", model=model, tokenizer=tokenizer)

# 6s

In [16]:
text = "Giraffes live in Africa."
embedding = onnx_extractor(text)

In [17]:
print(len(embedding[0][0]))
print(embedding[0][0][:5])

1024
[-0.3003837764263153, 0.7525163292884827, -0.5127310752868652, -1.2323561906814575, 1.0537605285644531]


In [18]:
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from datetime import datetime

class EmbeddingDataset(Dataset):
    def __init__(self, data_list):
      self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
      return self.data_list[index]

text_list = ["Test One", "Test Two"]
text_list = text_list * 500

metadata_list = ["Metadata One", "Metadata Two"]
metadata_list = metadata_list * 500

dataset = EmbeddingDataset(text_list)

# for output in tqdm(onnx_extractor(dataset, batch_size=100), total=len(dataset)):
#     pass

embeddings_list = []

start_time = datetime.now()

for output in onnx_extractor(dataset, batch_size=100):
    embeddings_list.extend(output)

end_time = datetime.now()

print(end_time - start_time)

0:00:38.301701


In [19]:
print(len(embeddings_list))
print(len(embeddings_list[0][0]))

1000
1024


In [20]:
for metadata_item, text_item, embeddings_item in zip(metadata_list[:10], text_list[:10], embeddings_list[:10]):
    print(f'{metadata_item} - {text_item} - {embeddings_item[0][:5]}')

Metadata One - Test One - [0.580862820148468, 0.5749151110649109, -0.12789195775985718, -2.678431987762451, 1.2358945608139038]
Metadata Two - Test Two - [0.9441616535186768, 0.4838103652000427, -0.3671732544898987, -2.198167562484741, 0.7757053375244141]
Metadata One - Test One - [0.580862820148468, 0.5749151110649109, -0.12789195775985718, -2.678431987762451, 1.2358945608139038]
Metadata Two - Test Two - [0.9441616535186768, 0.4838103652000427, -0.3671732544898987, -2.198167562484741, 0.7757053375244141]
Metadata One - Test One - [0.580862820148468, 0.5749151110649109, -0.12789195775985718, -2.678431987762451, 1.2358945608139038]
Metadata Two - Test Two - [0.9441616535186768, 0.4838103652000427, -0.3671732544898987, -2.198167562484741, 0.7757053375244141]
Metadata One - Test One - [0.580862820148468, 0.5749151110649109, -0.12789195775985718, -2.678431987762451, 1.2358945608139038]
Metadata Two - Test Two - [0.9441616535186768, 0.4838103652000427, -0.3671732544898987, -2.1981675624847