In [3]:
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoTokenizer
import os

onnx_path = "../data/onnx_model"
os.makedirs(onnx_path, exist_ok=True)


# Wrap the full model (including pooling + dense layer)
class STEncoderWrapper(torch.nn.Module):
    def __init__(self, st_model):
        super().__init__()
        modules = st_model._modules  # OrderedDict
        self.transformer = modules['0'].auto_model
        self.pooling = modules['1']
        # Grab the inner layers directly
        self.linear = modules['2'].linear  # nn.Linear
        self.activation = modules['2'].activation_function

    def mean_pooling(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = input_mask_expanded.sum(dim=1)
        return sum_embeddings / sum_mask

    def forward(self, input_ids, attention_mask):
        output = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = output[0]
        pooled = self.mean_pooling(token_embeddings, attention_mask)
        projected = self.activation(self.linear(pooled))
        return projected


# Load model and tokenizer
model_id = "sentence-transformers/distiluse-base-multilingual-cased-v2"
model = SentenceTransformer(model_id)
wrapper = STEncoderWrapper(model)

tokenizer = AutoTokenizer.from_pretrained(model_id)
dummy_inputs = tokenizer("Exporting to ONNX is fun!", return_tensors="pt")

# Export to ONNX — this model will output [batch_size, 512]
torch.onnx.export(
    wrapper,
    (dummy_inputs["input_ids"], dummy_inputs["attention_mask"]),
    f"{onnx_path}/model.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["sentence_embedding"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "seq_len"},
        "attention_mask": {0: "batch_size", 1: "seq_len"},
        "sentence_embedding": {0: "batch_size"},
    },
    opset_version=14
)

# Save tokenizer (optional, but useful)
tokenizer.save_pretrained("../data/onnx_model")


('../data/onnx_model/tokenizer_config.json',
 '../data/onnx_model/special_tokens_map.json',
 '../data/onnx_model/vocab.txt',
 '../data/onnx_model/added_tokens.json',
 '../data/onnx_model/tokenizer.json')