Skip to content

Commit

Permalink
[torch-frontend] update torch-mlir to support nll_loss_forward (#481)
Browse files Browse the repository at this point in the history
as title
  • Loading branch information
qingyunqu authored Nov 7, 2024
1 parent 2f544d9 commit 42c5e7b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,18 @@ def test_attention():

# ==============================================================================

# class NllLossStaticModule(torch.nn.Module):
# # Here the 2nd index is ignored.
# def forward(self, x, y):
# return torch.ops.aten.nll_loss_forward(
# x, target=y, weight=None, reduction=0, ignore_index=2
# )

# def test_nll_loss_forward():
# inputs = [tu.rand(2, 3), tu.randint(low=0, high=3, size=(2,))]
# module = compile(NllLossStaticModule(), inputs, "stablehlo", verbose=True, debug=torch_frontend.DebugType(1))
# numerical_test_helper(module, inputs, model(*inputs))
class NllLossStaticModule(torch.nn.Module):
# Here the 2nd index is ignored.
def forward(self, x, y):
return torch.ops.aten.nll_loss_forward(
x, target=y, weight=None, reduction=0, ignore_index=2
)

def test_nll_loss_forward():
inputs = [tu.rand(2, 3), tu.randint(low=0, high=3, size=(2,))]
model = NllLossStaticModule()
module = compile(model, inputs, "stablehlo")
numerical_test_helper(module, inputs, model(*inputs))

# ==============================================================================

Expand Down

0 comments on commit 42c5e7b

Please sign in to comment.