In [46]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.transformers import optimizer

signal = 'TIS'
# Path to your existing exported ONNX file
model_fp32 = f"./models/{signal}_model/onnx/model.onnx"
model_opt = f"./models/{signal}_model/onnx/model_optimized.onnx"
model_int8 = f"./models/{signal}_model/onnx/model_quantized.onnx"

print(f"1. Optimizing graph for {model_fp32}...")
# This step handles the "DistilBERT" masking by simplifying the graph
# into raw mathematical operations that are easier to quantize.
optimized_model = optimizer.optimize_model(
    model_fp32,
    model_type='bert', # Use 'bert' even for distilbert to catch the transformer layers
    num_heads=4,      # Match your model's heads
    hidden_size=256    # Match your model's hidden size
)
optimized_model.save_model_to_file(model_opt)

extra_options = {
    'DefaultTensorType': onnx.TensorProto.FLOAT,
}

print(f"2. Quantizing {model_opt} to INT8...")
quantize_dynamic(
    model_input=model_opt,
    model_output=model_int8,
    weight_type=QuantType.QInt8,
    extra_options=extra_options
)

print(f"Done! Final model: {model_int8}")

1. Optimizing graph for ./models/TIS_model/onnx/model.onnx...




2. Quantizing ./models/TIS_model/onnx/model_optimized.onnx to INT8...
Done! Final model: ./models/TIS_model/onnx/model_quantized.onnx


In [47]:
import time
import numpy as np
import torch
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score

# Force CPU for quantized model testing
model_fp32 = ORTModelForSequenceClassification.from_pretrained(
    f"./models/{signal}_model",
    file_name="model.onnx",
    provider="CPUExecutionProvider"
)

model_int8 = ORTModelForSequenceClassification.from_pretrained(
    f"./models/{signal}_model",
    file_name="model_quantized.onnx",
    provider="CPUExecutionProvider"
)

In [48]:
def evaluate_model(model, inputs):
    preds = []
    latencies = []

    start = time.perf_counter()
    with torch.inference_mode():
        outputs = model(**{'input_ids': inputs})
    latencies.append(time.perf_counter() - start)

    # Get predicted class
    logits = outputs.logits
    preds = torch.argmax(logits, dim=1).tolist()

    return preds, np.mean(latencies) * 1000 # returns ms

In [49]:
def gen_collate_fn(augment=False):
    def collate_fn(batch):
        start = np.random.randint(3) + 1 if augment else 3 # to keep every seq with 99 tokens
        token_ids = torch.cat([torch.as_tensor([row['upstream'][start::3] for row in batch]),
                                torch.as_tensor([row['downstream'][start::3] for row in batch])], dim=1)
        return (token_ids, torch.as_tensor([row['label'] for row in batch]))
    return collate_fn

In [51]:
from datasets import load_dataset, DatasetDict

dataset = load_dataset('dvgodoy/DeepGSR_trinucleotides', split='train')
dataset = dataset.shuffle(seed=19)
train_test = dataset.train_test_split(test_size=0.25, shuffle=False)
train_val = train_test['train'].train_test_split(test_size=0.2, shuffle=False)
dataset = DatasetDict({'train': train_val['train'], 'val': train_val['test'], 'test': train_test['test']})
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'organism', 'motif', 'label', 'signal', 'upstream', 'downstream'],
        num_rows: 215897
    })
    val: Dataset({
        features: ['id', 'organism', 'motif', 'label', 'signal', 'upstream', 'downstream'],
        num_rows: 53975
    })
    test: Dataset({
        features: ['id', 'organism', 'motif', 'label', 'signal', 'upstream', 'downstream'],
        num_rows: 89958
    })
})

In [52]:
#motif = 'AATAAA'
motif = 'ATG'
dataset = dataset.filter(lambda row: row['signal'] == signal and row['motif'] == motif)

Filter:   0%|          | 0/215897 [00:00<?, ? examples/s]

Filter:   0%|          | 0/53975 [00:00<?, ? examples/s]

Filter:   0%|          | 0/89958 [00:00<?, ? examples/s]

In [53]:
from torch.utils.data import DataLoader
bsize = 256
dataloaders = {}
dataloaders['test'] = DataLoader(dataset['test'], batch_size=bsize, shuffle=False, collate_fn=gen_collate_fn())
idl = iter(dataloaders['test'])

In [54]:
inputs, labels = next(idl)

In [56]:
# 4. Run Comparison
print("\n--- Running Comparison ---")

preds_fp32, time_fp32 = evaluate_model(model_fp32, inputs)
preds_int8, time_int8 = evaluate_model(model_int8, inputs)

# 5. Results
acc_fp32 = accuracy_score(labels, preds_fp32)
acc_int8 = accuracy_score(labels, preds_int8)
agreement = accuracy_score(preds_fp32, preds_int8) # How often they agree

print(f"FP32 Accuracy: {acc_fp32:.4f} | Latency: {time_fp32:.2f}ms")
print(f"INT8 Accuracy: {acc_int8:.4f} | Latency: {time_int8:.2f}ms")
print(f"Model Agreement: {agreement * 100:.2f}%")
print(f"Speedup: {time_fp32 / time_int8:.2f}x faster")


--- Running Comparison ---
FP32 Accuracy: 0.9180 | Latency: 10044.36ms
INT8 Accuracy: 0.9219 | Latency: 6568.66ms
Model Agreement: 99.61%
Speedup: 1.53x faster
