Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx infer #24

Open
yxl502 opened this issue Jan 23, 2024 · 0 comments
Open

onnx infer #24

yxl502 opened this issue Jan 23, 2024 · 0 comments

Comments

@yxl502
Copy link

yxl502 commented Jan 23, 2024

import torch
import torchvision.models as models
import sys
import os
current_dir = os.path.dirname(os.path.abspath(file))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from tubevit.model import TubeViTLightningModule

加载一个预训练的PyTorch模型

model = TubeViTLightningModule(
num_classes=3,
video_shape=[3, 1, 448, 224],
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weight_path="../weights/tubevit_vitbase_nc3_fpc3_448_224.pt",
test_each_epoch = False
)
model.eval()

定义模型的输入示例

dummy_input = torch.randn(1, 3, 1, 448, 224)

指定要保存的ONNX文件的路径

onnx_file_path = "./weights/test.onnx"

导出模型到ONNX格式

torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True,opset_version=12) # 版本只有7-16,但是都不支持

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant