We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import torch import torchvision.models as models import sys import os current_dir = os.path.dirname(os.path.abspath(file)) parent_dir = os.path.dirname(current_dir) sys.path.insert(0, parent_dir) from tubevit.model import TubeViTLightningModule
model = TubeViTLightningModule( num_classes=3, video_shape=[3, 1, 448, 224], num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, weight_path="../weights/tubevit_vitbase_nc3_fpc3_448_224.pt", test_each_epoch = False ) model.eval()
dummy_input = torch.randn(1, 3, 1, 448, 224)
onnx_file_path = "./weights/test.onnx"
torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True,opset_version=12) # 版本只有7-16,但是都不支持
The text was updated successfully, but these errors were encountered:
No branches or pull requests
import torch
import torchvision.models as models
import sys
import os
current_dir = os.path.dirname(os.path.abspath(file))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from tubevit.model import TubeViTLightningModule
加载一个预训练的PyTorch模型
model = TubeViTLightningModule(
num_classes=3,
video_shape=[3, 1, 448, 224],
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weight_path="../weights/tubevit_vitbase_nc3_fpc3_448_224.pt",
test_each_epoch = False
)
model.eval()
定义模型的输入示例
dummy_input = torch.randn(1, 3, 1, 448, 224)
指定要保存的ONNX文件的路径
onnx_file_path = "./weights/test.onnx"
导出模型到ONNX格式
torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True,opset_version=12) # 版本只有7-16,但是都不支持
The text was updated successfully, but these errors were encountered: