Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Slice indexing in ONNX #94

Closed
soodoshll opened this issue Feb 12, 2023 · 4 comments · Fixed by #106
Closed

[Bug] Slice indexing in ONNX #94

soodoshll opened this issue Feb 12, 2023 · 4 comments · Fixed by #106

Comments

@soodoshll
Copy link
Collaborator

soodoshll commented Feb 12, 2023

Please refer to pytorch/pytorch#24251

Basically, ONNX uses extremely large numbers to represent slicing until the end of certain dimensions, which will be prohibited the defensive conditions in

if not (-n <= i <= n and -n <= j <= n):

@yaoyaoding
Copy link
Member

The indexing of hidet tensor would follow the specification of ArrayAPI standard. Thus, we need to do deal with the difference between onnx sematics and ArrayAPI standard when importing onnx model.

Could you please provide some examples to trigger this error? If not, we could leave it to the future when some actual model triggers this error.

@soodoshll
Copy link
Collaborator Author

Yes, a very simple snippet (torch->onnx->hidet) can trigger this error:

import torch
import hidet
import onnx

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x[2:]

device = 'cuda'

model = Foo()
model.to(device)

x = torch.ones([100], dtype=torch.int32, device=device)
z = model(x)

torch.onnx.export(model, (x,), 'tmp.onnx', input_names = ['x'],
                  output_names = ['z'])
model = onnx.load('tmp.onnx')

hidet.torch.dynamo_config.search_space(1)

x = hidet.from_torch(x)
symbol_data = [hidet.symbol_like(x)]
hidet_onnx_module = hidet.graph.frontend.from_onnx(model)
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
    graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run([x])
print(outputs[0])

@yaoyaoding
Copy link
Member

Thanks @soodoshll, working on it.

@yaoyaoding
Copy link
Member

Thanks @soodoshll, this error should be fixed in #106.

yaoyaoding pushed a commit that referenced this issue Apr 3, 2024
Add graph module for using flash attention and clarify some differences
in flash attention vs torch sdpa.

**Attention: (pun intended)**

Softmax has temperature scaling option. Divides inputs by scalar, good
explanation of numerical effects
[here](https://medium.com/@harshit158/softmax-temperature-5492e4007f71).

Used when softmax inputs QK are too big for float 16 (abs value >
65504). This usually means the numbers are so large that dividing by
small (< 4) scalar has little effect.

Stable diffusion does not use this, as torch spda supports float 32 (or
somehow avoids NaNs from large values). No visual or significant numeric
differences in this output layer noticed.

Towards #57.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants