Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dnnl Execution Provider GetMemoryAndReshape function issues with Status Message: not a valid reshape, inconsistent dim product #20426

Open
varunkatiyar819 opened this issue Apr 23, 2024 · 0 comments
Labels
ep:oneDNN questions/issues related to DNNL EP

Comments

@varunkatiyar819
Copy link

varunkatiyar819 commented Apr 23, 2024

Describe the issue

I Created a simple 2 linear layer architecture model with Relu activation, i converted the model to onnx format with all the essesntial parameters required. When i am trying to infer with onnxruntime using DnnlExecutionProvider for say batch of 8 (8X764) where 764 is the fixed embed dimenison and total sentences to be processed are 25, so uptil for the batch of 8(8X764) each, it is working fine i.e for 24 sentences, when it comes to process the 25th sentence which has a shape of (1, 764) it returns the error stating
[E:onnxruntime:Default, dnnl_subgraph_primitive.cc:561 GetMemoryAndReshape] fc.0.weight, Dims From: 784 1024 , To: 784 784]
[E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running DNNL_15662691735485739911_0 node. Name:'DnnlExecutionProvider_DNNL_15662691735485739911_0_0' Status Message: not a valid reshape, inconsistent dim product]

This is the actual stacktrace :
/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running DNNL_15662691735485739911_0 node. Name:'DnnlExecutionProvider_DNNL_15662691735485739911_0_0' Status Message: not a valid reshape, inconsistent dim product

Also i tried running the same inference using CPUExecutionProvider and CUDAExecutionProvider as well, but they are working absolutely fine. with same batch size and same number of sentences.

Possible Approach

It seems like something related to memory allocation in GetMemoryAndReshape function is causing the issue.

To reproduce

Here's the Code for model creation, converting to onnx, and inferencing the model

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

#Creating your Own custom model
class CustomModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(CustomModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size,1024),  
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, output_size)  
        )

    def forward(self, x):
        return self.fc(x)

# Set random seed for reproducibility
torch.manual_seed(42)

# Define input and output dimensions
input_size =  28 * 28
output_size = 10

# Instantiate your custom model
model = CustomModel(input_size, output_size)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()  # Example loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Example optimizer

# Load the dataset and create dataloaders
train_dataset = MNIST(root='data/', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Train your model
num_epochs = 10  # Example: number of training epochs
for epoch in range(num_epochs):
    # Loop over the dataset and perform forward pass, backward pass, and optimization
     for batch in train_loader:
        images, labels = batch
        images = images.view(images.size(0), -1)  # Flatten the input images
        optimizer.zero_grad()  # Clear gradients
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights


# Save the trained model weights to a file
torch.save(model.state_dict(), 'custom_model.pth')

# Loading the weights
model.load_state_dict(torch.load('custom_model.pth'))
model.eval()

input_names = ["input"]
output_names = ["output"]
MODEL_MAX_LENGTH = 28 * 28
# Export the loaded model to ONNX format with dynamic axes
dummy_input = {key: torch.ones(1, MODEL_MAX_LENGTH, dtype=torch.float32) for key in input_names}
dynamic_axes = {'input': {0: 'batch_size'}}  # Specify dynamic axes for input tensor


torch.onnx.export(model,
                 dummy_input["input"],
                 "custom_model.onnx",
                 verbose=False,
                 input_names=input_names,
                 output_names=output_names,
                 export_params=True,
                 dynamic_axes=dynamic_axes
                 )
                 
#Inference of the Converted Onnx Model

import onnx
import torch
import os
import torch
import onnxruntime as rt
import onnx
import numpy as np
import time

def inference_onnx():
    CONVERTED_COMET_MODEL_PATH = os.path.join(os.getcwd() ,'custom_model.onnx')

    onnx.checker.check_model(CONVERTED_COMET_MODEL_PATH, full_check=True)

    # ort_sess = rt.InferenceSession(CONVERTED_COMET_MODEL_PATH, providers=['CUDAExecutionProvider'])
    ort_sess = rt.InferenceSession(CONVERTED_COMET_MODEL_PATH, providers=['DnnlExecutionProvider'])
    # ort_sess = rt.InferenceSession(CONVERTED_COMET_MODEL_PATH, providers=['CPUExecutionProvider'])


    batch_size=4
    shape = (5, 28 * 28)
    concatenated_data = {"input": torch.rand(*shape)}

    for idx in range(0, shape[0], batch_size):
       
        inp = {"input": np.array(concatenated_data["input"][idx: idx+batch_size])}
        start_time = time.time()
        print("inp = ", inp, inp["input"].shape)
        outputs = ort_sess.run(None, inp)
        print("output_sentence = ", outputs, outputs[0].shape)
        endtime = time.time()
        print("Time Taken : ", round(endtime-start_time, 2))
    # outputs = ort_sess.run(["score"], inp)

inference_onnx()

Urgency

No response

Platform

Linux

OS Version

Ubuntu 22.04.4 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.18.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

oneDNN

Execution Provider Library Version

No response

@github-actions github-actions bot added the ep:oneDNN questions/issues related to DNNL EP label Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:oneDNN questions/issues related to DNNL EP
Projects
None yet
Development

No branches or pull requests

1 participant