In [None]:
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch


# Load pretrained model and tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model.eval()  # Set the model to evaluation mode

# Define a maximum sequence length for inputs (choose based on your dataset)
max_length = 512

# Tokenize text with padding to max_length, ensuring all models are exported with this consideration
text = "Hello, my dog is cute"
inputs = tokenizer(text, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)

# Perform a dummy inference to ensure the model is ready
with torch.no_grad():
    logits = model(**inputs).logits

# Export the model
# Adjust dynamic_axes to correctly reflect variable input and output lengths
torch.onnx.export(model,
                  args=(inputs['input_ids'], inputs['attention_mask']),
                  f="distilbert.onnx",
                  export_params=True,
                  opset_version=11,
                  do_constant_folding=True,
                  input_names=['input_ids', 'attention_mask'],
                  output_names=['logits'],
                  dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'},
                                'attention_mask': {0: 'batch_size', 1: 'sequence'},
                                'logits': {0: 'batch_size'}})
