Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculating the AUC and ROC #370

Open
Mohammed-Ryiad-Eiadeh opened this issue May 21, 2024 · 7 comments
Open

Calculating the AUC and ROC #370

Mohammed-Ryiad-Eiadeh opened this issue May 21, 2024 · 7 comments
Labels
question General question

Comments

@Mohammed-Ryiad-Eiadeh
Copy link

Mohammed-Ryiad-Eiadeh commented May 21, 2024

Dear Tribuo developers,

I am trying to get the TRP and the FTP in order to calculate the AUC and plot the ROC curve. But the results sometimes unreasonable since i have high accuracy, yet low AUC. Furthermore, here is my code:

    // use KNN classifier
    // noinspection DuplicatedCode
    var KnnTrainer =  new KNNTrainer<>(3,
            new L1Distance(),
            Runtime.getRuntime().availableProcessors(),
            new VotingCombiner(),
            KNNModel.Backend.THREADPOOL,
            NeighboursQueryFactoryType.BRUTE_FORCE);

    // disply the model provenance
    var modelProvenance = KnnTrainer.getProvenance();
    System.out.println("The model provenance is \n" + ProvenanceUtil.formattedProvenanceString(modelProvenance));

    // use crossvalidation
    // noinspection DuplicatedCode
    var crossValidation = new CrossValidation<>(KnnTrainer, dataSet, new LabelEvaluator(), 10, Trainer.DEFAULT_SEED);

    // get outputs
    // noinspection DuplicatedCode
    var avgAcc = 0d;
    var sensitivity = 0d;
    var specificity = 0d;
    var macroAveragedF1 = 0d;
    var precision = 0d;
    var recall = 0d;
    var avgTP = new double[crossValidation.getK()];
    var avgFP = new double[crossValidation.getK()];
    var counter = 0;
    var sTrain = System.currentTimeMillis();
    for (var result: crossValidation.evaluate()) {
        avgAcc += result.getA().accuracy();
        sensitivity += result.getA().tp() / (result.getA().tp() + result.getA().fn());
        specificity += result.getA().tn() / (result.getA().tn() + result.getA().fp());
        macroAveragedF1 += result.getA().macroAveragedF1();
        precision += result.getA().tp() / (result.getA().tp() + result.getA().fp());
        recall += result.getA().tp() / (result.getA().tp() + result.getA().fn());
       avgTP[counter] = result.getA().tp() / (result.getA().tp() + result.getA().fn());
        avgFP[counter] = 1 - (result.getA().tn() / (result.getA().tn() + result.getA().fp()));
        counter++;
    }

    // noinspection DuplicatedCode
    var eTrain = System.currentTimeMillis();

    /*System.out.printf("The FS duration time is : %s\nThe number of selected features is : %d\nThe feature names are : %s\n",
            Util.formatDuration(sDate, eDate), SFS.featureNames().size(), SFS.featureNames());*/

    for (var stuff : List.of("The Training_Testing duration time is : " + Util.formatDuration(sTrain, eTrain),
            "The average accuracy is : " + (avgAcc / crossValidation.getK()),
            "The average sensitivity is : " + (sensitivity / crossValidation.getK()),
            "The average macroAveragedF1 is : " + (macroAveragedF1 / crossValidation.getK()),
            "The average precision is : " + (precision / crossValidation.getK()),
            "The average recall is : " + (recall / crossValidation.getK()))) {
        System.out.println(stuff);
    }

    AucCalculator aucCalculator = new AucCalculator(avgTP, avgFP);
    System.out.println("The AUC is : " + aucCalculator.getAUC());

    // Display the ROC curve chart and save it
    System.out.println(Arrays.toString(avgTP));
    System.out.println(Arrays.toString(avgFP));
@Mohammed-Ryiad-Eiadeh Mohammed-Ryiad-Eiadeh added the question General question label May 21, 2024
@Craigacp
Copy link
Member

I'm not sure how you're plotting the ROC curve when you need a threshold to sweep through to change the point at which a label is predicted. Tribuo already supports AUCROC for classifiers which produce probabilities, but KNNTrainer doesn't.

@Mohammed-Ryiad-Eiadeh
Copy link
Author

Well I just calculate the FPR and TPR after each fold and use them to plot my ROC curve and I pass them to AUCCalculator to get the AUC value which is done by the trapezoidal rule. please if this is not correct tell me to change it.

@Craigacp
Copy link
Member

That won't give you an appropriate ROC curve as it's not on the same data and doesn't represent how changing the classification threshold would change the false positive & true positive rate.

@Mohammed-Ryiad-Eiadeh
Copy link
Author

Mohammed-Ryiad-Eiadeh commented May 22, 2024 via email

@Craigacp
Copy link
Member

You'll need to use a model which supports generating probabilities, and then you can use the methods on LabelEvaluation to compute the AUC, or LabelEvaluationUtil to compute the ROC curve itself - https://tribuo.org/learn/4.3/javadoc/org/tribuo/classification/evaluation/LabelEvaluationUtil.html.

@Mohammed-Ryiad-Eiadeh
Copy link
Author

Dear Adam,

I need your help here. After getting the FPR nad TPR and Threshold like this:

FPR: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007352941176470588, 0.014705882352941176, 0.022058823529411766, 0.029411764705882353, 0.03676470588235294, 0.04411764705882353, 0.051470588235294115, 0.058823529411764705, 0.0661764705882353, 0.07352941176470588, 0.08088235294117647, 0.08823529411764706, 0.09558823529411764, 0.10294117647058823, 0.11764705882352941, 1.0]

TPR: [0.0, 0.9456521739130435, 0.9565217391304348, 0.967391304347826, 0.9782608695652174, 0.9891304347826086, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

Threshold: [Infinity, 0.9999999999999585, 0.9999999999900508, 0.9999999999532103, 0.9999999993470261, 0.999999986279678, 0.999999082458228, 0.4672622616731544, 0.010474966475835753, 7.848048691768931E-4, 4.464634619108306E-4, 1.8357524563945583E-4, 7.697946270832445E-5, 1.5905677137563365E-5, 1.258714255136621E-7, 4.6428762209717544E-8, 1.3855807195706487E-8, 8.900923141832403E-9, 7.567814072544735E-9, 7.443858692758792E-9, 2.8687081675940852E-9, 1.3147063911388807E-12, 1.7464080806775956E-82]

when plotting FPR and TPR, how to get the number of correctly classified points corresponding the positive label to get the ROC ?!

@Craigacp
Copy link
Member

That information isn't stored in the ROC class, the number of correctly classified points is stored in your LabelEvaluation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question General question
Projects
None yet
Development

No branches or pull requests

2 participants