diff --git a/network/loss.py b/network/loss.py index 069a151..c4310a4 100644 --- a/network/loss.py +++ b/network/loss.py @@ -66,4 +66,4 @@ def __init__(self, device): self.zero = torch.zeros([1]).type(self.dtype) def forward(self, r_x, r_x_hat): - return torch.nn.ReLU(1 + r_x_hat[0]) + torch.nn.ReLU(1 - r_x[0]) + return F.relu(1 + r_x_hat[0]) + F.relu(1 - r_x[0])