# ROC and PR Curves in Julia

**Packages Used:** PlotlyJS, MLJBase, MLJLinearModels, *Random*, CategoricalArrays, DataFrames

*Italic packages* are Julia Standard Library Packages as of **v1.6**.

**Note:** The examples below are encapsulated in functions since it is best practice in Julia.

# ROC and PR Curves in Julia
We will make Receiver Operating Characteristics (ROC) and Precision-Recall (PR) curves to Plotly with the Julia language.

# Preliminary plots
Before visualizing the receiver operating characteristic (ROC) curve, we will look at two plots that will give some context to the mechanism behind the ROC and PR curves.

We first plot the number of the distribution of prediction probabilities for two categories from the artificial `make_blobs` dataset from **MLJBase**. We have 5000 entries in our dataset.
The model is from **MLJLinearModels**, and we use `histogram` plot the distributions.

In the second plot we look plot the false positive rate, `fprs`, and true positive rate, `tprs`, with **MLBase**'s `roc` function.
**PlotlyJS** provides the `scatter` trace, so we can plot both. 

In [1]:
using PlotlyJS, MLJBase, MLJLinearModels, Random

function display_binary_score_histogram()
    Random.seed!(42) # for reproducibility
    X, y = make_blobs(5000, 2; centers=2, cluster_std=1.5)
    mach = machine(LogisticClassifier(), X, y)
    fit!(mach)
    ŷ = predict(mach, X)
    
    fprs_trace = histogram(; x=pdf.(ŷ,1), name="Category 1", opacity=0.50)
    tprs_trace = histogram(; x=pdf.(ŷ,2), name="Category 2", opacity=0.50)
    layout = Layout(xaxis_title="Threshold", yaxis_title="Percentage", 
                    font=attr(family="Veranda", size=14,color="Black"))
    Plot([tprs_trace,fprs_trace], layout)
end

display_binary_score_histogram()

┌ Info: Training Machine{LogisticClassifier,…}.
└ @ MLJBase /home/markus/.julia/packages/MLJBase/HZmTU/src/machines.jl:403


In [2]:
using PlotlyJS, MLJBase, MLJLinearModels, Random

function display_binary_score_rate_threshold()
    Random.seed!(42) # for reproducibility
    X, y = make_blobs(5000, 2; centers=2, cluster_std=1.5)
    mach = machine(LogisticClassifier(), X, y)
    fit!(mach)
    
    fprs, tprs, ts = roc(predict(mach, X), y)
    fprs_trace = scatter(; x=ts, y=fprs, name="(fprs) False Positive Rate")
    tprs_trace = scatter(; x=ts, y=tprs, name="(tprs) True Positive Rate")
    layout = Layout(xaxis_title="Threshold", font=attr(family="Veranda", size=14,color="Black"))
    Plot([tprs_trace,fprs_trace], layout)
end

display_binary_score_rate_threshold()

┌ Info: Training Machine{LogisticClassifier,…}.
└ @ MLJBase /home/markus/.julia/packages/MLJBase/HZmTU/src/machines.jl:403


# Basic Binary ROC Curve
Here we use the previous code with slight modifications.
**MLJBase** offers a nice function `auc` to calculate the Area Under Curve Score, `aucs`.
Nevertheless, **PlotlyJS**, allows one to make use of various styling options for line plots constructed from `scatter`.

In [3]:
using PlotlyJS, MLJBase, MLJLinearModels, Random

function display_binary_roc()
    Random.seed!(42) # for reproducibility
    X, y = make_blobs(5000, 2; centers=2, cluster_std=1.5)
    mach = machine(LogisticClassifier(), X, y)
    fit!(mach)
    
    ŷ = predict(mach, X)
    fprs, tprs, ts = roc(ŷ, y)
    aucs = auc(ŷ, y)
    roc_trace = scatter(; x=fprs, y=tprs)
    diagonal_trace = scatter(; x=0:0.1:1, y=0:0.1:1, mode="lines", line=attr(dash="dash", color="black"))
    layout = Layout(xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", 
                    title = "Area Under Curve Score = $(aucs)", 
                    font=attr(family="Veranda", size=14,color="Black"), showlegend=false)
    Plot([roc_trace, diagonal_trace], layout)
end

display_binary_roc()

┌ Info: Training Machine{LogisticClassifier,…}.
└ @ MLJBase /home/markus/.julia/packages/MLJBase/HZmTU/src/machines.jl:403


# Multiclass ROC Curve
In other languages, there are functions for calculating the multiclass ROC curves.
For the time being Julia lacks a library with such a function.
Nevertheless, the following code implements the simple logic for calculating multiclass ROC curves.
Here we use the iris data set from **PlotlyJS**.
**CategoricalArrays** is used to prepare the labels to be fed into `machine`.
Similar to before, the plotting is simply more curves with `scatter`.

In [6]:
using PlotlyJS, MLJBase, MLJLinearModels, Random, CategoricalArrays, DataFrames

function display_multiclass_roc()
    Random.seed!(42) # for reproducibility
    df = DataFrame(dataset("iris"))
    df = df[shuffle(1:nrows(df)),:]
    X, y = unpack(df, x -> x!=:species && x!=:species_id, ==(:species))
    y = categorical(y)
    X = coerce(X, :petal_length => Continuous, :petal_width => Continuous, 
              :sepal_length => Continuous, :sepal_width => Continuous)
    
    train, test = partition(1:nrows(X), 0.5)
    mach = machine(LogisticClassifier(), X, y)
    fit!(mach; rows = train)
    
    ŷ = predict(mach, X[test, :])
    yt = y[test]
    
    ts = 0:0.01:1
    
    tprs = []
    fprs = []
    
    count_true_pos(preds, gts, thhold, class) = count(yy -> class==yy[1] && pdf(yy[2],class)>thhold, zip(gts,preds))
    count_false_pos(preds, gts, thhold, class) = count(yy -> class!=yy[1] && pdf(yy[2],class)>thhold, zip(gts,preds))
    
    for lvl in levels(yt)
        push!(tprs, [count_true_pos(ŷ, yt, t, lvl)/count(==(lvl),yt) for t in ts])
        push!(fprs, [count_false_pos(ŷ, yt, t, lvl)/count(!=(lvl),yt) for t in ts])
    end
    
    roc_traces = [scatter(; x=fprs[i], y=tprs[i], name=levels(yt)[i], opacity=0.65) for i in 1:length(levels(yt))]
    diagonal_trace = scatter(; x=0:0.1:1, y=0:0.1:1, mode="lines", line=attr(dash="dash", color="black"), 
                               name="baseline")
    layout = Layout(xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", 
                    title = "Multiclass ROCs", 
                    font=attr(family="Veranda", size=14,color="Black"))
    Plot([roc_traces..., diagonal_trace], layout)
end

display_multiclass_roc()

┌ Info: Training Machine{LogisticClassifier,…}.
└ @ MLJBase /home/markus/.julia/packages/MLJBase/HZmTU/src/machines.jl:403


# Precision-Recall Curves
In addition to ROC curves, one can also plot Precision-Recall (PR) curves.
We use the previous code with a slight modification to plot PR curves instead of ROC curves.

In [7]:
using PlotlyJS, MLJBase, MLJLinearModels, Random, CategoricalArrays, DataFrames

function display_multiclass_roc()
    Random.seed!(42) # for reproducibility
    df = DataFrame(dataset("iris"))
    df = df[shuffle(1:nrows(df)),:]
    X, y = unpack(df, x -> x!=:species && x!=:species_id, ==(:species))
    y = categorical(y)
    X = coerce(X, :petal_length => Continuous, :petal_width => Continuous, 
              :sepal_length => Continuous, :sepal_width => Continuous)
    
    train, test = partition(1:nrows(X), 0.5)
    mach = machine(LogisticClassifier(), X, y)
    fit!(mach; rows = train)
    
    ŷ = predict(mach, X[test, :])
    yt = y[test]
    
    ts = 0:0.01:1
    
    precs = []
    reclls = []
    
    count_true_pos(preds, gts, thhold, class) = count(yy -> class==yy[1] && pdf(yy[2],class)>thhold, zip(gts,preds))
    count_false_pos(preds, gts, thhold, class) = count(yy -> class!=yy[1] && pdf(yy[2],class)>thhold, zip(gts,preds))
    count_correct(preds, gts, thhold, class) = count(yy -> class==yy[1] && pdf(yy[2],class)>thhold, zip(gts,preds)) + count(yy -> class!=yy[1] && pdf(yy[2],class)<thhold, zip(gts,preds))
    
    for lvl in levels(yt)
        push!(precs, [count_true_pos(ŷ, yt, t, lvl)/count(x -> pdf(x,lvl)>t, ŷ) for t in ts])
        push!(reclls, [count_true_pos(ŷ, yt, t, lvl)/count_correct(ŷ, yt, t, lvl) for t in ts])
    end
    
    roc_traces = [scatter(; x=reclls[i], y=precs[i], name=levels(yt)[i], opacity=0.65) for i in 1:length(levels(yt))]
    diagonal_trace = scatter(; x=0:0.1:1, y=1:-0.1:0, mode="lines", line=attr(dash="dash", color="black"), 
                               name="baseline")
    layout = Layout(xaxis_title="Recall", yaxis_title="Precision", 
                    title = "Multiclass ROCs", 
                    font=attr(family="Veranda", size=14,color="Black"))
    Plot([roc_traces..., diagonal_trace], layout)
end

display_multiclass_roc()

┌ Info: Training Machine{LogisticClassifier,…}.
└ @ MLJBase /home/markus/.julia/packages/MLJBase/HZmTU/src/machines.jl:403


# Further Reading

- [PlotlyJS Julia Documentation](http://juliaplots.org/PlotlyJS.jl/stable/)
    - [PlotlyJS Julia scatter](https://plotly.com/julia/line-and-scatter/)
- [Julia MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/)
    - [MLJ Evaluating_Model_Performance](https://alan-turing-institute.github.io/MLJ.jl/dev/evaluating_model_performance/)
    - [MLJ Machines](https://alan-turing-institute.github.io/MLJ.jl/dev/machines/)
    - [MLJ Linear Models](https://juliaai.github.io/MLJLinearModels.jl/dev/)
- [Julia Categorical Arrays](https://categoricalarrays.juliadata.org/v0.2/index.html)
- [Julia Docs](https://docs.julialang.org/en/v1/)
    - [Julia Random](https://docs.julialang.org/en/v1/stdlib/Random/)
- [Julia Dataframes](https://dataframes.juliadata.org/stable/)
- [Julia make_blobs dataset](https://alan-turing-institute.github.io/MLJ.jl/dev/generating_synthetic_data/#Regression-data-generated-from-noisy-linear-models)
- [Google ROC Background](https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc)