In [1]:


import torch.onnx

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import os

from transformer_net import TransformerNet, SmallTransformerNet, \
    EfficientTransformerNet, SmallEfficientTransformerNet, \
    SmallTransformerNet48, SmallEfficientTransformerNet48, \
    MobileTransformerNet, EfficientMobileTransformerNet

import torch.nn.utils.prune as prune


# Define your paths and options here
dataset_path = "train"
style_image_path = "test/capture.png"
save_model_dir = "test"
model_name = "model.onnx"
pth_name = 'test/model.pth'

def export_to_onnx(pth_path, onnx_path, image_size=(1280, 720)):
    """
    Load a .pth file and export the model to ONNX format.

    Parameters:
    - pth_path: Path to the .pth file containing the model's state_dict.
    - onnx_path: Path where the ONNX model should be saved.
    - image_size: Tuple indicating the height and width of the input images. Defaults to (224, 224).
    """
    # Ensure the TransformerNet class is defined or imported correctly
    transformer = EfficientMobileTransformerNet()

    # Load the model weights
    transformer.load_state_dict(torch.load(pth_path))
    transformer.eval()  # Set to evaluation mode and move to CPU
    

    # static quantization dummy code

    # Fuse conv and relu layers if present
    # Example: transformer.fuse_modules(['conv1', 'relu'], inplace=True)

    # Define the path to your dataset folder

    # Define a transform to convert the images to a format suitable for your model
    # transform = transforms.Compose([
    #     transforms.Resize((224, 224)),  # Resize images to the size expected by your model
    #     transforms.ToTensor(),  # Convert images to PyTorch tensors
    # ])

    # # Create an ImageFolder dataset to load the images
    # full_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

    # # Select a subset of the full dataset to minimize data loading (e.g., the first 100 images)
    # num_images_to_use = 100
    # subset_indices = list(range(0, min(num_images_to_use, len(full_dataset))))
    # subset_dataset = Subset(full_dataset, subset_indices)

    # # Create a DataLoader to iterate over the subset of the dataset
    # batch_size = 10  # Adjust based on your memory constraints and model requirements
    # data_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True)

    # # Assuming the transformer_prepared model is already defined and prepared for quantization
    # num_calibration_batches = 5  # Define the number of batches to use for calibration
    
    # # Specify quantization configuration
    # transformer.qconfig = torch.quantization.get_default_qconfig('fbgemm')

    # # Prepare the model for static quantization. This inserts observers that will collect range information on the tensors passing through.
    # transformer_prepared = torch.quantization.prepare(transformer, inplace=False)
    # for batch, (images, _) in enumerate(data_loader):
    #     transformer_prepared(images)  # Pass the images through the prepared model
    #     if batch >= num_calibration_batches - 1:
    #         break  # Exit the loop after processing the specified number of calibration batches

    # quantized_model = torch.quantization.convert(transformer_prepared, inplace=True)
    
    # Example of layer-wise pruning for Conv2d layers
    # for name, module in transformer.named_modules():
    #     if isinstance(module, torch.nn.Conv2d):
    #         prune.l1_unstructured(module, name='weight', amount=0.2)  # 20% pruning
    #         prune.remove(module, 'weight')  # Make pruning permanent

    # something like this might be needed in the class for quantization I dunno yet
    # class MobileTransformerNetQuantized(torch.nn.Module):
    # def __init__(self):
    #     super(MobileTransformerNetQuantized, self).__init__()
    #     self.quant = torch.quantization.QuantStub()
    #     # Your layers here
    #     self.dequant = torch.quantization.DeQuantStub()
    
    # def forward(self, x):
    #     x = self.quant(x)
    #     # Your forward pass here
    #     x = self.dequant(x)
    #     return x

    # Apply dynamic quantization (only weights are quantized), static would be better but IDK how
    quantized_model = torch.quantization.quantize_dynamic(
        transformer, {torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU}, dtype=torch.qint8
    )

    # Prepare a dummy input for ONNX export. The size can be dynamic
    dummy_input = torch.randn(1, 3, *image_size)  # Assuming the model expects CxHxW images

    # Export the model to ONNX with dynamic axes for flexible input dimensions
    torch.onnx.export(quantized_model, dummy_input, onnx_path, 
                      export_params=True, opset_version=18,
                      do_constant_folding=True, 
                      input_names=['input'], output_names=['output'],
                      dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'}, # currently will match input and output sizes, this can be separated for compression reasons
                                    'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
    
    print(f"Model loaded from {pth_path} and saved to {onnx_path} in ONNX format with dynamic input resolution.")

export_to_onnx(pth_name, save_model_dir+'/'+model_name)



Model loaded from test/model.pth and saved to test/model.onnx in ONNX format with dynamic input resolution.


  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
