In [1]:
using Pkg
using CategoricalArrays
using MAT
using DataFrames
using MLJ

In [2]:
# sub-MOA101, sub-MOA102, sub-MOA104, sub-MOA105, sub-MOA107, sub-MOA108, sub-MOA109, sub-MOA110, sub-MOA111, sub-MOA112, sub-MOA114,
# sub-MOA115, sub-MOA116, sub-MOA118, sub-MOA121, sub-MOA122, sub-MOA123, sub-MOA124, sub-MOA126, sub-MOA127, sub-MOA128, sub-MOA130,
# sub-MOA131, sub-MOA133, sub-MOA134, sub-MOA135

# Treatment response categories
y = ("Response", "Response", "Response", "Stable", "Response", "Stable", "Response", "Stable", "Remission", "Remission", "Stable", 
    "Stable", "Response", "Remission", "Stable", "Stable", "Stable", "Stable", "Response", "Response", "Stable", "Stable", "Stable", 
    "Stable", "Response", "Stable")

# Convert to CategoricalArray
y_cat = categorical(collect(y))

println("Ground truth targets as categories:", y_cat)

Ground truth targets as categories:CategoricalValue{String, UInt32}["Response", "Response", "Response", "Stable", "Response", "Stable", "Response", "Stable", "Remission", "Remission", "Stable", "Stable", "Response", "Remission", "Stable", "Stable", "Stable", "Stable", "Response", "Response", "Stable", "Stable", "Stable", "Stable", "Response", "Stable"]


In [3]:
# List of subjects that have some degree of depression
target_subjects = [
    "sub-MOA101", "sub-MOA102", "sub-MOA104", "sub-MOA105", "sub-MOA107", "sub-MOA108", "sub-MOA109", "sub-MOA110", "sub-MOA111",
    "sub-MOA112", "sub-MOA114", "sub-MOA115", "sub-MOA116", "sub-MOA118", "sub-MOA121", "sub-MOA122", "sub-MOA123","sub-MOA124", 
    "sub-MOA126", "sub-MOA127", "sub-MOA128", "sub-MOA130", "sub-MOA131", "sub-MOA133", "sub-MOA134", "sub-MOA135"]

# Base path to your subject folders
base_path = "Spectral_DCM_Collection_DMN"

# Collect valid file paths
valid_files = String[]

for subj in target_subjects
    subj_path = joinpath(base_path, subj)
    ses_path = joinpath(subj_path, "ses-b0")
    glm_path = joinpath(ses_path, "glm")
    dcm_file = joinpath(glm_path, "spDCM_DMN.mat")

    if !isdir(ses_path)
        @warn "Missing session folder: $ses_path"
    elseif !isfile(dcm_file)
        @warn "Missing spDCM_DMN.mat for $subj"
    else
        push!(valid_files, dcm_file)
    end
end

println("✅ Found Spectral_DCM.mat for $(length(valid_files)) out of $(length(target_subjects)) subjects.")


✅ Found Spectral_DCM.mat for 26 out of 26 subjects.


In [4]:
# Extract A matrix features as a flat 16-element vector
function extract_features(file)
    mat = matread(file)
    A = mat["params"]  # 4×4 matrix
    return vec(Matrix(A))  # Flatten to 16-element vector
end

extract_features (generic function with 1 method)

In [5]:
# Create feature dataset
X = hcat([extract_features(file) for file in valid_files]...)'

X_df = DataFrame(X, :auto)  # convert to MLJ-compatible table

@show size(X_df)
@show length(y_cat)

@assert size(X_df, 1) == length(y_cat) "Mismatch between number of samples in X and y"

size(X_df) = (26, 16)
length(y_cat) = 26


In [23]:
#Declare logistic regression model
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels
# Build pipeline
pipe = LogisticClassifier()

# Machine
mach = machine(pipe, X_df, y_cat)

# Compatible measures
measures = [accuracy, MulticlassFScore(), ConfusionMatrix()]

# Cross-validation setup
cv = StratifiedCV(nfolds=3, shuffle=true, rng=42)

# Evaluate model
eval_result = evaluate!(
    mach,
    resampling=cv,
    measures=measures,
    operation=predict_mode,
    verbosity=1,
    check_measure=false  # Force bypassing check for unsupported measures
)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFor silent loading, specify `verbosity=0`. 


import MLJLinearModels ✔


[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m
[33m[1m└ [22m[39m[90m@ Optim ~/.julia/packages/Optim/8dE7C/src/types.jl:120[39m


PerformanceEvaluation object with these fields:
  model, measure, operation,
  measurement, per_fold, per_observation,
  fitted_params_per_fold, report_per_fold,
  train_test_rows, resampling, repeats
Extract:
┌───┬──────────────────────────────┬──────────────┬─────────────────────────────
│[22m   [0m│[22m measure                      [0m│[22m operation    [0m│[22m measurement               [0m ⋯
├───┼──────────────────────────────┼──────────────┼─────────────────────────────
│ A │ Accuracy()                   │ predict_mode │ 0.538                      ⋯
│ B │ MulticlassFScore(            │ predict_mode │ 0.519                      ⋯
│   │   beta = 1.0,                │              │                            ⋯
│   │   average = MacroAvg(),      │              │                            ⋯
│   │   return_type = LittleDict,  │              │                            ⋯
│   │   levels = nothing,          │              │                            ⋯
│   │   perm = nothing,  

In [26]:
using LinearAlgebra
using Statistics
using Plots
using StatsPlots

# Extract per-fold metrics
accs = eval_result.per_fold[1]  # accuracy (macro)
f1s  = eval_result.per_fold[2]  # F1 score (macro)
cmats = eval_result.per_fold[3] # confusion matrices

nfolds = length(cmats)

# Initialize metric storage
macro_precisions = Float64[]
macro_recalls = Float64[]
weighted_precisions = Float64[]
weighted_recalls = Float64[]



for cm in cmats
    cmatrix = cm.mat
    
    TP = diag(cmatrix)
    FP = sum(cmatrix, dims=1)' .- TP
    FN = sum(cmatrix, dims=2) .- TP

    println("True Positives: \n",TP)
    println("False Positives: \n",FP)
    println("False Negatives: \n",FN)
    
    support = sum(cmatrix, dims=2)[:]
    total = sum(support)

    # Avoid division by zero
    class_precisions = map((tp, fp) -> (tp + fp == 0 ? NaN : tp / (tp + fp)), TP, FP)
    class_recalls    = map((tp, fn) -> (tp + fn == 0 ? NaN : tp / (tp + fn)), TP, FN)

    # Macro average (ignores class imbalance)
    macro_precision = mean(skipmissing(class_precisions))
    macro_recall    = mean(skipmissing(class_recalls))

    push!(macro_precisions, macro_precision)
    push!(macro_recalls, macro_recall)
end

# Print for verification
println("Accuracies:", accs)
println("Macro Precisions:", macro_precisions)
println("Macro Recalls:", macro_recalls)
println("F1 Scores:", f1s)

# Assuming these have been computed earlier:
# accs, f1s, macro_precisions, macro_recalls
nfolds = 3

# Define metrics
metrics = ["Accuracy", "F1 Score", "Macro Precision", "Macro Recall"]
metric_values = [accs, f1s, macro_precisions, macro_recalls]

# Prepare DataFrame
plot_df = DataFrame(
    Fold = repeat(1:nfolds, outer=length(metrics)),
    Metric = repeat(metrics, inner=nfolds),
    Value = vcat(metric_values...)
)

# Grouped bar plot with folds labeled as 1, 2, 3
@df plot_df groupedbar(
    string.(:Fold), :Value, group=:Metric,
    bar_position=:dodge,
    bar_width=0.2,
    xlabel="Fold", ylabel="Metric Value",
    title="Fold-wise Accuracy, F1, Precision, Recall",
    legend=:topright,
    size=(1000, 600), dpi=300
)

savefig("fold_grouped-metrics_logistic.png")

True Positives: 
[0, 18, 27]
False Positives: 
[9; 9; 18;;]
False Negatives: 
[9; 18; 9;;]
True Positives: 
[9, 18, 0]
False Positives: 
[0; 9; 45;;]
False Negatives: 
[9; 36; 9;;]
True Positives: 
[8, 16, 24]
False Positives: 
[0; 8; 8;;]
False Negatives: 
[0; 8; 8;;]
Accuracies:[0.5555555555555556, 0.3333333333333333, 0.75]
Macro Precisions:[0.4222222222222222, 0.5555555555555555, 0.8055555555555555]
Macro Recalls:[0.4166666666666667, 0.27777777777777773, 0.8055555555555555]
F1 Scores:[0.4126984126984127, 0.37037037037037035, 0.8055555555555555]


"/Users/keyshavmor/ETH/TNM_Final_Project/Project_8/fold_grouped-metrics_logistic.png"