diff --git a/torch_struct/semirings/sample.py b/torch_struct/semirings/sample.py index 611c04b9..d815cec7 100644 --- a/torch_struct/semirings/sample.py +++ b/torch_struct/semirings/sample.py @@ -102,7 +102,7 @@ def forward(ctx, logits, dim): @staticmethod def backward(ctx, grad_output): logits, = ctx.saved_tensors - return logits.softmax(-1), None + return logits.softmax(-1) * grad_output, None class _GumbelCRFLogSumExp(torch.autograd.Function): @staticmethod