diff --git a/GANDLF/metrics/regression.py b/GANDLF/metrics/regression.py index 823ab010c..ca388ce66 100644 --- a/GANDLF/metrics/regression.py +++ b/GANDLF/metrics/regression.py @@ -62,9 +62,8 @@ def per_label_accuracy(output, label, params): torch.Tensor: The per class accuracy. """ if params["problem_type"] == "classification": - # acc = [0 for _ in params["model"]["class_list"]] - predicted_classes = np.array([0 for _ in params["model"]["class_list"]]) - label_cpu = np.array([0 for _ in params["model"]["class_list"]]) + predicted_classes = np.array([0] * len(params["model"]["class_list"])) + label_cpu = np.array([0] * len(params["model"]["class_list"])) predicted_classes[torch.argmax(output, 1).cpu().item()] = 1 label_cpu[label.cpu().item()] = 1 return torch.from_numpy((predicted_classes == label_cpu).astype(float))