In [None]:
import os
import torch
from torchvision import models
import torch.onnx

# Paths for saving and exporting the model
model_path = "/home/mostafabakr/Desktop/Project X/Final_models/asl_image_model.pth"
onxx_model_path = "/home/mostafabakr/Desktop/Project X/hardware/asl_image_model.onnx"

def convert_to_onnx(model_path, onnx_model_path, input_size=(1, 3, 224, 224)):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file '{model_path}' not found.")

    print("Loading the model...")
    # Load the trained model
    model = models.mobilenet_v2()
    num_classes = 24  # Replace with the actual number of classes in your dataset
    model.classifier[1] = torch.nn.Linear(model.last_channel, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()

    print("Preparing dummy input...")
    # Create a dummy input tensor for export
    dummy_input = torch.randn(input_size)

    print("Exporting model to ONNX format...")
    # Export the model to ONNX format
    torch.onnx.export(
        model,
        dummy_input,
        onnx_model_path,
        export_params=True,              # Store trained parameter weights inside the model file
        opset_version=11,                # ONNX version to export to
        do_constant_folding=True,        # Optimize the model
        input_names=['input'],           # Input layer name
        output_names=['output'],         # Output layer name
        dynamic_axes={                   # Variable-length axes
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )

    print(f"Model successfully exported to '{onnx_model_path}'")

if __name__ == "__main__":
    convert_to_onnx(model_path, onxx_model_path)


Loading the model...
Preparing dummy input...
Exporting model to ONNX format...


  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


Model successfully exported to '/home/mostafabakr/Desktop/Project X/models/asl_image_model.onnx'
