diff --git a/forest_example.ipynb b/forest_example.ipynb index 3c76b07e..19a91440 100644 --- a/forest_example.ipynb +++ b/forest_example.ipynb @@ -381,11 +381,11 @@ "outputs": [], "source": [ "preds_valid = np.array(clf_xgb.predict_proba(X_valid, ))\n", - "valid_acc = accuracy_score(y_pred=np.argmax(preds_valid, axis=1), y_true=y_valid)\n", + "valid_acc = accuracy_score(y_pred=np.argmax(preds_valid, axis=1) + 1, y_true=y_valid)\n", "print(valid_acc)\n", "\n", "preds_test = np.array(clf_xgb.predict_proba(X_test))\n", - "test_acc = accuracy_score(y_pred=np.argmax(preds_test, axis=1), y_true=y_test)\n", + "test_acc = accuracy_score(y_pred=np.argmax(preds_test, axis=1) + 1, y_true=y_test)\n", "print(test_acc)" ] },