In [2]:
# Define the packages
using Pkg
Pkg.activate(".")
Pkg.instantiate()
using CategoricalArrays
using MAT
using DataFrames
using MLJ
using LinearAlgebra
using Statistics
using Plots
using StatsPlots
using CSV

[32m[1m  Activating[22m[39m new project at `~/ETH/TNM_Final_Project/TNM-Safe/Project_8/Classification_and_Regression_DCM`


In [2]:
# sub-MOA101, sub-MOA102, sub-MOA104, sub-MOA105, sub-MOA107, sub-MOA108, sub-MOA109, sub-MOA110, sub-MOA111, sub-MOA112, sub-MOA114,
# sub-MOA115, sub-MOA116, sub-MOA118, sub-MOA121, sub-MOA122, sub-MOA123, sub-MOA124, sub-MOA126, sub-MOA127, sub-MOA128, sub-MOA130,
# sub-MOA131, sub-MOA133, sub-MOA134, sub-MOA135

# Binary treatment response categories
y_binary = ("Response","Response","Response","No Response","Response","No Response","Response","No Response","Response","Response",
    "No Response","No Response","Response","Response","No Response","No Response","No Response","No Response","Response","Response",
    "No Response", "No Response" ,"No Response", "No Response","Response", "No Response")

# Tri label response categories
y_tri = ("Partial Response","Partial Response","Remission","No Response","Partial Response","No Response","Partial Response","No Response",
    "Remission", "Remission","Partial Response","Partial Response","Partial Response","Remission","No Response","No Response",
    "Partial Response","No Response","Remission","Partial Response","Partial Response","Partial Response","No Response","No Response",
    "Remission","No Response")

# Quad label response categories
y_quad = ("Partial Response","Partial Response","Remission","Stable","Partial Response","Deterioration","Partial Response","Stable",
    "Remission","Remission","Partial Response","Partial Response","Partial Response","Remission","Stable","Deterioration", "Stable",
    "Stable","Remission","Partial Response","Partial Response","Partial Response","Stable","Stable","Remission","Stable")

# Penta label response categories
y_penta = ("Strong Response","Strong Response","Remission","Stable","Strong Response","Deterioration","Strong Response","Stable",
    "Remission","Remission","Mild Response","Mild Response","Strong Response","Remission","Stable","Deterioration","Stable","Stable",
    "Remission","Strong Response","Mild Response","Mild Response","Stable","Stable","Remission","Stable")



# Convert to CategoricalArray
y_binary = categorical(collect(y_binary))
y_tri = categorical(collect(y_tri))
y_quad = categorical(collect(y_quad))
y_penta = categorical(collect(y_penta))


26-element CategoricalArray{String,1,UInt32}:
 "Strong Response"
 "Strong Response"
 "Remission"
 "Stable"
 "Strong Response"
 "Deterioration"
 "Strong Response"
 "Stable"
 "Remission"
 "Remission"
 "Mild Response"
 "Mild Response"
 "Strong Response"
 "Remission"
 "Stable"
 "Deterioration"
 "Stable"
 "Stable"
 "Remission"
 "Strong Response"
 "Mild Response"
 "Mild Response"
 "Stable"
 "Stable"
 "Remission"
 "Stable"

In [3]:
# List of subjects that have some degree of depression
target_subjects = [
    "sub-MOA101", "sub-MOA102", "sub-MOA104", "sub-MOA105", "sub-MOA107", "sub-MOA108", "sub-MOA109", "sub-MOA110", "sub-MOA111",
    "sub-MOA112", "sub-MOA114", "sub-MOA115", "sub-MOA116", "sub-MOA118", "sub-MOA121", "sub-MOA122", "sub-MOA123","sub-MOA124", 
    "sub-MOA126", "sub-MOA127", "sub-MOA128", "sub-MOA130", "sub-MOA131", "sub-MOA133", "sub-MOA134", "sub-MOA135"]

# Base path to your subject folders
base_path = "Spectral_DCM_Collection_15x15"

# Collect valid file paths
valid_files = String[]

for subj in target_subjects
    subj_path = joinpath(base_path, subj)
    ses_path = joinpath(subj_path, "ses-b0")
    glm_path = joinpath(ses_path, "glm")
    dcm_file = joinpath(glm_path, "spDCM_rsTozzi.mat")

    if !isdir(ses_path)
        @warn "Missing session folder: $ses_path"
    elseif !isfile(dcm_file)
        @warn "Missing spDCM_rsTozzi.mat for $subj"
    else
        push!(valid_files, dcm_file)
    end
end

println("Found spDCM_rsTozzi.mat for $(length(valid_files)) out of $(length(target_subjects)) subjects.")


Found spDCM_rsTozzi.mat for 26 out of 26 subjects.


In [4]:
# Extract A matrix features as a flat 16-element vector
function extract_features(file)
    mat = matread(file)
    A = mat["params"]  # 4×4 matrix
    return vec(Matrix(A))  # Flatten to 16-element vector
end

extract_features (generic function with 1 method)

In [5]:
# Create feature dataset
X = hcat([extract_features(file) for file in valid_files]...)'

X_df = DataFrame(X, :auto)  # convert to MLJ-compatible table

# Ensure that the number of subjects match in X_df and Y labels of different modalities
@assert size(X_df, 1) == length(y_binary) "Mismatch between number of samples in X and y_binary"
@assert size(X_df, 1) == length(y_tri) "Mismatch between number of samples in X and y_tri"
@assert size(X_df, 1) == length(y_quad) "Mismatch between number of samples in X and y_quad"
@assert size(X_df, 1) == length(y_penta) "Mismatch between number of samples in X and y_penta"

In [6]:
function evaluate_logistic_model(X_df, y, nfolds::Int, label_type::String)
    # Load model
    LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels
    pipe = LogisticClassifier()
    mach = machine(pipe, X_df, y)

    # Setup evaluation
    measures = [accuracy, MulticlassFScore(), ConfusionMatrix()]
    cv = StratifiedCV(nfolds=nfolds, shuffle=true, rng=42)

    # Evaluate
    eval_result = evaluate!(
        mach,
        resampling=cv,
        measures=measures,
        operation=predict_mode,
        verbosity=0,
        check_measure=false
    )

    accs = eval_result.per_fold[1]
    f1s = eval_result.per_fold[2]
    cmats = eval_result.per_fold[3]

    macro_precisions = Float64[]
    macro_recalls = Float64[]

    for cm in cmats
        cmatrix = cm.mat
        TP = diag(cmatrix)
        FP = sum(cmatrix, dims=1)' .- TP
        FN = sum(cmatrix, dims=2) .- TP
        support = sum(cmatrix, dims=2)[:]
        total = sum(support)

        class_precisions = map((tp, fp) -> (tp + fp == 0 ? NaN : tp / (tp + fp)), TP, FP)
        class_recalls    = map((tp, fn) -> (tp + fn == 0 ? NaN : tp / (tp + fn)), TP, FN)

        macro_precision = mean(skipmissing(class_precisions))
        macro_recall    = mean(skipmissing(class_recalls))

        push!(macro_precisions, macro_precision)
        push!(macro_recalls, macro_recall)
    end

    # Collect metrics
    metrics_df = DataFrame(
        Fold = 1:nfolds,
        Accuracy = accs,
        F1_Score = f1s,
        Macro_Precision = macro_precisions,
        Macro_Recall = macro_recalls
    )

    # Add row for averages
    avg_row = DataFrame(
        Fold = ["Average"],
        Accuracy = [mean(skipmissing(accs))],
        F1_Score = [mean(skipmissing(f1s))],
        Macro_Precision = [mean(skipmissing(macro_precisions))],
        Macro_Recall = [mean(skipmissing(macro_recalls))]
    )

    metrics_table = vcat(metrics_df, avg_row)

    # Save table as csv
    csv_table_path = "tables/table_$(nfolds)_grouped_metrics_logistics_$(label_type)_15x15.csv"
    CSV.write(csv_table_path, metrics_table)

    # Prepare dataframe
    metrics = ["Accuracy", "F1 Score", "Macro Precision", "Macro Recall"]
    metric_values = [accs, f1s, macro_precisions, macro_recalls]

    plot_df = DataFrame(
        Fold = repeat(1:nfolds, outer=length(metrics)),
        Metric = repeat(metrics, inner=nfolds),
        Value = vcat(metric_values...)
    )

    # Plot
    @df plot_df groupedbar(
        string.(:Fold), :Value, group=:Metric,
        bar_position=:dodge,
        bar_width=0.2,
        xlabel="Fold", ylabel="Metric Value",
        yticks=0:0.05:1.0,
        title="Logistic Classification Fold-wise Metrics",
        legend=:outertop,
        size=(750, 500),
        guidefontsize=10,
        tickfontsize=10,
        dpi=300
    )

    savefig("figures/$(nfolds)_grouped_metrics_logistics_$(label_type)_15x15.png")
end

evaluate_logistic_model (generic function with 1 method)

In [8]:
#Evaluate logistic model classifier for 3 fold cross validation
evaluate_logistic_model(X_df, y_binary, 3, "binary")
evaluate_logistic_model(X_df, y_tri, 3, "tri")
evaluate_logistic_model(X_df, y_quad, 3, "quad")
evaluate_logistic_model(X_df, y_penta, 3, "penta")

#Evaluate logistic model classifier for 4 fold cross validation
evaluate_logistic_model(X_df, y_binary, 4, "binary")
evaluate_logistic_model(X_df, y_tri, 4, "tri")
evaluate_logistic_model(X_df, y_quad, 4, "quad")
evaluate_logistic_model(X_df, y_penta, 4, "penta")

#Evaluate logistic model classifier for 5 fold cross validation
evaluate_logistic_model(X_df, y_binary, 5, "binary")
evaluate_logistic_model(X_df, y_tri, 5, "tri")
evaluate_logistic_model(X_df, y_quad, 5, "quad")
evaluate_logistic_model(X_df, y_penta, 5, "penta")

import MLJLinearModels ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔
import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔
import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m


"/Users/keyshavmor/ETH/TNM_Final_Project/Project_8/figures/5_grouped_metrics_logistics_penta_15x15.png"

In [9]:
function evaluate_multinomial_classifier_model(X_df, y, nfolds::Int, label_type::String)
    # Load model
    MultinomialClassifier = @load MultinomialClassifier pkg=MLJLinearModels
    pipe = MultinomialClassifier()
    mach = machine(pipe, X_df, y)

    # Setup evaluation
    measures = [accuracy, MulticlassFScore(), ConfusionMatrix()]
    cv = StratifiedCV(nfolds=nfolds, shuffle=true, rng=42)

    # Evaluate
    eval_result = evaluate!(
        mach,
        resampling=cv,
        measures=measures,
        operation=predict_mode,
        verbosity=0,
        check_measure=false
    )

    accs = eval_result.per_fold[1]
    f1s = eval_result.per_fold[2]
    cmats = eval_result.per_fold[3]

    macro_precisions = Float64[]
    macro_recalls = Float64[]

    for cm in cmats
        cmatrix = cm.mat
        TP = diag(cmatrix)
        FP = sum(cmatrix, dims=1)' .- TP
        FN = sum(cmatrix, dims=2) .- TP
        support = sum(cmatrix, dims=2)[:]
        total = sum(support)

        class_precisions = map((tp, fp) -> (tp + fp == 0 ? NaN : tp / (tp + fp)), TP, FP)
        class_recalls    = map((tp, fn) -> (tp + fn == 0 ? NaN : tp / (tp + fn)), TP, FN)

        macro_precision = mean(skipmissing(class_precisions))
        macro_recall    = mean(skipmissing(class_recalls))

        push!(macro_precisions, macro_precision)
        push!(macro_recalls, macro_recall)
    end

    # Collect metrics
    metrics_df = DataFrame(
        Fold = 1:nfolds,
        Accuracy = accs,
        F1_Score = f1s,
        Macro_Precision = macro_precisions,
        Macro_Recall = macro_recalls
    )

    # Add row for averages
    avg_row = DataFrame(
        Fold = ["Average"],
        Accuracy = [mean(skipmissing(accs))],
        F1_Score = [mean(skipmissing(f1s))],
        Macro_Precision = [mean(skipmissing(macro_precisions))],
        Macro_Recall = [mean(skipmissing(macro_recalls))]
    )

    metrics_table = vcat(metrics_df, avg_row)

    # Save table as csv
    csv_table_path = "tables/table_$(nfolds)_grouped_metrics_multinomial_$(label_type)_15x15.csv"
    CSV.write(csv_table_path, metrics_table)

    # Prepare dataframe
    metrics = ["Accuracy", "F1 Score", "Macro Precision", "Macro Recall"]
    metric_values = [accs, f1s, macro_precisions, macro_recalls]

    plot_df = DataFrame(
        Fold = repeat(1:nfolds, outer=length(metrics)),
        Metric = repeat(metrics, inner=nfolds),
        Value = vcat(metric_values...)
    )

    # Plot
    @df plot_df groupedbar(
        string.(:Fold), :Value, group=:Metric,
        bar_position=:dodge,
        bar_width=0.2,
        xlabel="Fold", ylabel="Metric Value",
        yticks=0:0.05:1.0,
        title="Multinomial Classification Fold-wise Metrics",
        legend=:outertop,
        size=(750, 500),
        guidefontsize=10,
        tickfontsize=10,
        dpi=300
    )

    savefig("figures/$(nfolds)_grouped_metrics_multinomial_$(label_type)_15x15.png")
end

evaluate_multinomial_classifier_model (generic function with 1 method)

In [10]:
#Evaluate multinomial model classifier for 3 fold cross validation
evaluate_multinomial_classifier_model(X_df, y_binary, 3, "binary")
evaluate_multinomial_classifier_model(X_df, y_tri, 3, "tri")
evaluate_multinomial_classifier_model(X_df, y_quad, 3, "quad")
evaluate_multinomial_classifier_model(X_df, y_penta, 3, "penta")

#Evaluate multinomial model classifier for 4 fold cross validation
evaluate_multinomial_classifier_model(X_df, y_binary, 4, "binary")
evaluate_multinomial_classifier_model(X_df, y_tri, 4, "tri")
evaluate_multinomial_classifier_model(X_df, y_quad, 4, "quad")
evaluate_multinomial_classifier_model(X_df, y_penta, 4, "penta")

#Evaluate multinomial model classifier for 5 fold cross validation
evaluate_multinomial_classifier_model(X_df, y_binary, 5, "binary")
evaluate_multinomial_classifier_model(X_df, y_tri, 5, "tri")
evaluate_multinomial_classifier_model(X_df, y_quad, 5, "quad")
evaluate_multinomial_classifier_model(X_df, y_penta, 5, "penta")

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔
import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔
import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔
import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔
import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔
import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor 

import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m


"/Users/keyshavmor/ETH/TNM_Final_Project/Project_8/figures/5_grouped_metrics_multinomial_penta_15x15.png"

In [11]:
using MLJScikitLearnInterface
function evaluate_random_forest_model(X_df, y, nfolds::Int, label_type::String)
    # Load model
    RandomForestClassifier = @load RandomForestClassifier pkg=MLJScikitLearnInterface
    pipe = RandomForestClassifier()
    mach = machine(pipe, X_df, y)

    # Setup evaluation
    measures = [accuracy, MulticlassFScore(), ConfusionMatrix()]
    cv = StratifiedCV(nfolds=nfolds, shuffle=true, rng=42)

    # Evaluate
    eval_result = evaluate!(
        mach,
        resampling=cv,
        measures=measures,
        operation=predict_mode,
        verbosity=0,
        check_measure=false
    )

    accs = eval_result.per_fold[1]
    f1s = eval_result.per_fold[2]
    cmats = eval_result.per_fold[3]

    macro_precisions = Float64[]
    macro_recalls = Float64[]

    for cm in cmats
        cmatrix = cm.mat
        TP = diag(cmatrix)
        FP = sum(cmatrix, dims=1)' .- TP
        FN = sum(cmatrix, dims=2) .- TP
        support = sum(cmatrix, dims=2)[:]
        total = sum(support)

        class_precisions = map((tp, fp) -> (tp + fp == 0 ? NaN : tp / (tp + fp)), TP, FP)
        class_recalls    = map((tp, fn) -> (tp + fn == 0 ? NaN : tp / (tp + fn)), TP, FN)

        macro_precision = mean(skipmissing(class_precisions))
        macro_recall    = mean(skipmissing(class_recalls))

        push!(macro_precisions, macro_precision)
        push!(macro_recalls, macro_recall)
    end

    # Collect metrics
    metrics_df = DataFrame(
        Fold = 1:nfolds,
        Accuracy = accs,
        F1_Score = f1s,
        Macro_Precision = macro_precisions,
        Macro_Recall = macro_recalls
    )

    # Add row for averages
    avg_row = DataFrame(
        Fold = ["Average"],
        Accuracy = [mean(skipmissing(accs))],
        F1_Score = [mean(skipmissing(f1s))],
        Macro_Precision = [mean(skipmissing(macro_precisions))],
        Macro_Recall = [mean(skipmissing(macro_recalls))]
    )

    metrics_table = vcat(metrics_df, avg_row)

    # Save table as csv
    csv_table_path = "tables/table_$(nfolds)_grouped_metrics_random_forest_$(label_type)_15x15.csv"
    CSV.write(csv_table_path, metrics_table)

    # Prepare dataframe
    metrics = ["Accuracy", "F1 Score", "Macro Precision", "Macro Recall"]
    metric_values = [accs, f1s, macro_precisions, macro_recalls]

    plot_df = DataFrame(
        Fold = repeat(1:nfolds, outer=length(metrics)),
        Metric = repeat(metrics, inner=nfolds),
        Value = vcat(metric_values...)
    )

    # Plot
    @df plot_df groupedbar(
        string.(:Fold), :Value, group=:Metric,
        bar_position=:dodge,
        bar_width=0.2,
        xlabel="Fold", ylabel="Metric Value",
        yticks=0:0.05:1.0,
        title="Random Forest Classification Fold-wise Metrics",
        legend=:outertop,
        size=(750, 500),
        guidefontsize=10,
        tickfontsize=10,
        dpi=300
    )

    savefig("figures/$(nfolds)_grouped_metrics_random_forest_$(label_type)_15x15.png")
end

evaluate_random_forest_model (generic function with 1 method)

In [12]:
#Evaluate random forest model classifier for 3 fold cross validation
evaluate_random_forest_model(X_df, y_binary, 3, "binary")
evaluate_random_forest_model(X_df, y_tri, 3, "tri")
evaluate_random_forest_model(X_df, y_quad, 3, "quad")
evaluate_random_forest_model(X_df, y_penta, 3, "penta")

#Evaluate random forest model classifier for 4 fold cross validation
evaluate_random_forest_model(X_df, y_binary, 4, "binary")
evaluate_random_forest_model(X_df, y_tri, 4, "tri")
evaluate_random_forest_model(X_df, y_quad, 4, "quad")
evaluate_random_forest_model(X_df, y_penta, 4, "penta")

#Evaluate random forest model classifier for 5 fold cross validation
evaluate_random_forest_model(X_df, y_binary, 5, "binary")
evaluate_random_forest_model(X_df, y_tri, 5, "tri")
evaluate_random_forest_model(X_df, y_quad, 5, "quad")
evaluate_random_forest_model(X_df, y_penta, 5, "penta")

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJScikitLearnInterface ✔


"/Users/keyshavmor/ETH/TNM_Final_Project/Project_8/figures/5_grouped_metrics_random_forest_penta_15x15.png"