In [None]:
import torch
import torch.onnx
from transformers import ViTConfig, ViTForImageClassification
import torch.nn as nn
import onnx
from onnxsim import simplify
from onnxruntime.quantization import quantize_dynamic, QuantType

In [None]:
device = torch.device("cpu")
class_names = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]
num_classes = 7

config = ViTConfig(num_labels=num_classes)
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    config=config,
    ignore_mismatched_sizes=True
)

in_features = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Dropout(p=0.4),
    nn.Linear(in_features, num_classes)
)

checkpoint = torch.load('/content/best_vit.pth', map_location=device, weights_only=False)

if isinstance(checkpoint, dict):
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)
else:
    model.load_state_dict(checkpoint)

model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "emotion_vit_model.onnx",
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=['pixel_values'],
    output_names=['logits'],
    dynamic_axes={
        'pixel_values': {0: 'batch_size'},
        'logits': {0: 'batch_size'}
    }
)

onnx_model = onnx.load("emotion_vit_model.onnx")
onnx.checker.check_model(onnx_model)


In [None]:
model = onnx.load("emotion_vit_model.onnx")
model_simp, check = simplify(model)

if check:
    print("Model simplified successfully")
    onnx.save(model_simp, "emotion_vit_model_simplified.onnx")
    quantize_dynamic(
        "emotion_vit_model_simplified.onnx",
        "emotion_vit_model_quantized.onnx",
        weight_type=QuantType.QUInt8
    )
