Skip to content

Commit

Permalink
logic changed to not use for loop
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Sep 4, 2022
1 parent b9c7735 commit 2f5175d
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions GANDLF/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def per_label_accuracy(output, label, params):
torch.Tensor: The per class accuracy.
"""
if params["problem_type"] == "classification":
# acc = [0 for _ in params["model"]["class_list"]]
predicted_classes = np.array([0 for _ in params["model"]["class_list"]])
label_cpu = np.array([0 for _ in params["model"]["class_list"]])
predicted_classes = np.array([0] * len(params["model"]["class_list"]))
label_cpu = np.array([0] * len(params["model"]["class_list"]))
predicted_classes[torch.argmax(output, 1).cpu().item()] = 1
label_cpu[label.cpu().item()] = 1
return torch.from_numpy((predicted_classes == label_cpu).astype(float))
Expand Down

0 comments on commit 2f5175d

Please sign in to comment.