## Figure 3

- Panel A: Accuracy for all plates
- Panel B: PR curves for all plates
- Panel C: Confusion matrix for all plates

In [35]:
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(patchwork))
suppressPackageStartupMessages(library(arrow))
suppressPackageStartupMessages(library(RColorBrewer))

In [36]:
# R logger that matches Python ANALYSIS_TYPE style
suppressPackageStartupMessages(library(jsonlite))

# Set identifiers
RUN_ID        <- "12_06_08_07"   # or whatever run id you want
MODEL_ID      <- "ensemble"
ROLE          <- "main_figure_3"
ANALYSIS_TYPE <- "main_figure_3"      # or "figure", "generalizability", etc

log_dir <- "logs"
if (!dir.exists(log_dir)) {
  dir.create(log_dir, recursive = TRUE)
}

# Logger name and log file pattern match Python:
# logger_name: "<analysis_type>_<run_id>_<model_id>_<role>"
# log file:    "log_<analysis_type>_<run_id>_<model_id>.log"
logger_name <- sprintf(
  "%s_%s_%s_%s",
  ANALYSIS_TYPE, RUN_ID, MODEL_ID, ROLE
)
log_file <- file.path(
  log_dir,
  sprintf("log_%s_%s_%s.log", ANALYSIS_TYPE, RUN_ID, MODEL_ID)
)

log_info <- function(msg) {
  ts   <- format(Sys.time(), "%Y-%m-%dT%H:%M:%S")
  line <- sprintf("%s [%s] INFO: %s", ts, logger_name, msg)

  # write to stdout
  cat(line, "\n")

  # append to file
  cat(line, "\n", file = log_file, append = TRUE)
}

log_json_panel <- function(key, df) {
  json_txt <- jsonlite::toJSON(df, dataframe = "rows", auto_unbox = TRUE)
  log_info(paste0(key, "_json=", json_txt))
}

log_info("Initialized R logger for ensemble_main_figure_3.")
log_info(paste("Logging to file:", log_file))


2025-12-06T11:06:36 [main_figure_3_12_06_08_07_ensemble_main_figure_3] INFO: Initialized R logger for ensemble_main_figure_3. 
2025-12-06T11:06:36 [main_figure_3_12_06_08_07_ensemble_main_figure_3] INFO: Logging to file: logs/log_main_figure_3_12_06_08_07_ensemble.log 


## Set paths

In [37]:
figure_dir <- "../figures"
output_main_figure_3 <- file.path(
    figure_dir, "main_figure_3_ensemble_model_eval.png"
)

log_info(paste("Figure will be saved to", output_main_figure_3))


2025-12-06T11:06:36 [main_figure_3_12_06_08_07_ensemble_main_figure_3] INFO: Figure will be saved to ../figures/main_figure_3_ensemble_model_eval.png 


In [None]:
# ================================
# Summarize panel_exports JSON for all models
# Logs TRAIN and TEST separately to help detect overfitting
# ================================

suppressPackageStartupMessages(library(jsonlite))

# Config
PANEL_EXPORT_DIR    <- "panel_exports"
PANEL_ANALYSIS_TYPE <- "main_figure_3"          # analysis_type used when exporting panels
PANEL_RUN_ID        <- "12_06_08_07"       # run_id used in the per-model figure3 exports
PANEL_MODELS        <- c("logreg", "randomforest", "xgboost")

# Which summaries to log
LOG_CONFUSION         <- TRUE   # primary goal 1
LOG_FEATURES_SHUFFLED <- TRUE   # primary goal 3
LOG_PR_CURVES         <- FALSE  # secondary goal (off by default)
LOG_DERIV_PR          <- FALSE
LOG_COEFFICIENTS      <- FALSE
LOG_ACCURACY          <- FALSE

# Helper: read one panel export file for a given model id
read_panel_export <- function(model_id) {
  fname <- sprintf(
    "figure3_panels_%s_%s_%s.json",
    PANEL_ANALYSIS_TYPE,
    PANEL_RUN_ID,
    model_id
  )
  path <- file.path(PANEL_EXPORT_DIR, fname)

  if (!file.exists(path)) {
    log_info(
      paste(
        "panel_export_missing model_id=", model_id,
        " file=", path
      )
    )
    return(NULL)
  }

  payload <- tryCatch(
    jsonlite::fromJSON(path),
    error = function(e) {
      log_info(
        paste("panel_export_read_error model_id=", model_id,
              " file=", path,
              " msg=", conditionMessage(e))
      )
      return(NULL)
    }
  )

  if (is.null(payload$panels)) {
    log_info(
      paste(
        "panel_export_no_panels model_id=", model_id,
        " file=", path
      )
    )
    return(NULL)
  }

  log_info(
    paste("panel_export_loaded model_id=", model_id,
          " file=", path)
  )

  payload
}


summarize_confusion_panel <- function(cm_df) {
  # Internal helper to summarize one split (Train or Test)
  summarize_one_split <- function(split_name) {
    # Real (non shuffled) rows for this split and all_plates
    sub_true <- cm_df[
      cm_df$plate == "all_plates" &
        cm_df$datasplit == split_name &
        cm_df$shuffled_type == "FALSE",
      ,
      drop = FALSE
    ]

    # Shuffled rows for this split and all_plates
    sub_shuf <- cm_df[
      cm_df$plate == "all_plates" &
        cm_df$datasplit == split_name &
        cm_df$shuffled_type == "TRUE",
      ,
      drop = FALSE
    ]

    # -----------------------------
    # Real (non shuffled) summary
    # -----------------------------
    acc_true   <- NA_real_
    recalls_vec <- NULL
    total_true <- NA_real_

    if (nrow(sub_true) > 0) {
      total_true <- sum(sub_true$confusion_values)
      correct_true <- sum(
        sub_true$confusion_values[
          sub_true$true_genotype == sub_true$predicted_genotype
        ]
      )

      if (total_true > 0) {
        acc_true <- correct_true / total_true
      }

      # Recall per true class (fraction of correct within each true genotype)
      recalls_vec <- sapply(
        split(sub_true, sub_true$true_genotype),
        function(df_gt) {
          total_gt   <- sum(df_gt$confusion_values)
          correct_gt <- sum(
            df_gt$confusion_values[
              df_gt$true_genotype == df_gt$predicted_genotype
            ]
          )
          if (total_gt > 0) correct_gt / total_gt else NA_real_
        }
      )
    }

    # -----------------------------
    # Shuffled summary
    # -----------------------------
    acc_shuf   <- NA_real_
    total_shuf <- NA_real_

    if (nrow(sub_shuf) > 0) {
      total_shuf <- sum(sub_shuf$confusion_values)
      correct_shuf <- sum(
        sub_shuf$confusion_values[
          sub_shuf$true_genotype == sub_shuf$predicted_genotype
        ]
      )

      if (total_shuf > 0) {
        acc_shuf <- correct_shuf / total_shuf
      }
    }

    # ------------------------------------------------
    # Confusion counts and percentages per cell (real)
    # ------------------------------------------------
    confusion_counts_true <- NULL

    if (nrow(sub_true) > 0) {
      # Group by true -> predicted combination
      by_cell_true <- split(
        sub_true,
        paste(sub_true$true_genotype, "->", sub_true$predicted_genotype)
      )

      confusion_counts_true <- lapply(
        by_cell_true,
        function(df_cell) {
          # Sum over rows in case there are multiple
          count <- sum(as.numeric(df_cell$confusion_values))
          true_label <- as.character(df_cell$true_genotype[[1]])

          total_gt <- sum(
            sub_true$confusion_values[sub_true$true_genotype == true_label]
          )

          list(
            count        = count,
            frac_of_all  = if (!is.na(total_true) && total_true > 0)
                             count / total_true else NA_real_,
            frac_of_true = if (!is.na(total_gt) && total_gt > 0)
                             count / total_gt else NA_real_
          )
        }
      )
    }

    # ---------------------------------------------------
    # Confusion counts and percentages per cell (shuffled)
    # ---------------------------------------------------
    confusion_counts_shuf <- NULL

    if (nrow(sub_shuf) > 0) {
      by_cell_shuf <- split(
        sub_shuf,
        paste(sub_shuf$true_genotype, "->", sub_shuf$predicted_genotype)
      )

      confusion_counts_shuf <- lapply(
        by_cell_shuf,
        function(df_cell) {
          count <- sum(as.numeric(df_cell$confusion_values))
          true_label <- as.character(df_cell$true_genotype[[1]])

          total_gt_shuf <- sum(
            sub_shuf$confusion_values[sub_shuf$true_genotype == true_label]
          )

          list(
            count        = count,
            frac_of_all  = if (!is.na(total_shuf) && total_shuf > 0)
                             count / total_shuf else NA_real_,
            frac_of_true = if (!is.na(total_gt_shuf) && total_gt_shuf > 0)
                             count / total_gt_shuf else NA_real_
          )
        }
      )
    }

    # Return summary for this split
    list(
      accuracy                  = acc_true,
      accuracy_shuffled         = acc_shuf,
      recall_by_true_class      = if (!is.null(recalls_vec)) as.list(recalls_vec) else NULL,
      confusion_counts          = confusion_counts_true,
      confusion_counts_shuffled = confusion_counts_shuf
    )
  }

  # Summaries for Train and Test
  train_res <- summarize_one_split("Train")
  test_res  <- summarize_one_split("Test")

  # Pack into a single list that the logging code can consume
  list(
    train_accuracy                  = train_res$accuracy,
    train_accuracy_shuffled         = train_res$accuracy_shuffled,
    train_recall_by_true_class      = train_res$recall_by_true_class,
    train_confusion_counts          = train_res$confusion_counts,
    train_confusion_counts_shuffled = train_res$confusion_counts_shuffled,

    test_accuracy                   = test_res$accuracy,
    test_accuracy_shuffled          = test_res$accuracy_shuffled,
    test_recall_by_true_class       = test_res$recall_by_true_class,
    test_confusion_counts           = test_res$confusion_counts,
    test_confusion_counts_shuffled  = test_res$confusion_counts_shuffled
  )
}


# Helper: summarize features_shuffled panel for Train and Test, all_plates
summarize_features_shuffled_panel <- function(fs_df) {
  # Train
  sub_train <- fs_df[
    fs_df$plate == "all_plates" &
      fs_df$datasplit == "Train",
    ,
    drop = FALSE
  ]

  train_acc_true <- NA_real_
  train_acc_shuf <- NA_real_

  if (nrow(sub_train) > 0) {
    acc_train_true_vec <- sub_train$accuracy[sub_train$shuffled_type == "FALSE"]
    acc_train_shuf_vec <- sub_train$accuracy[sub_train$shuffled_type == "TRUE"]

    if (length(acc_train_true_vec) > 0) {
      train_acc_true <- acc_train_true_vec[[1]]
    }
    if (length(acc_train_shuf_vec) > 0) {
      train_acc_shuf <- acc_train_shuf_vec[[1]]
    }
  }

  # Test
  sub_test <- fs_df[
    fs_df$plate == "all_plates" &
      fs_df$datasplit == "Test",
    ,
    drop = FALSE
  ]

  test_acc_true <- NA_real_
  test_acc_shuf <- NA_real_

  if (nrow(sub_test) > 0) {
    acc_test_true_vec <- sub_test$accuracy[sub_test$shuffled_type == "FALSE"]
    acc_test_shuf_vec <- sub_test$accuracy[sub_test$shuffled_type == "TRUE"]

    if (length(acc_test_true_vec) > 0) {
      test_acc_true <- acc_test_true_vec[[1]]
    }
    if (length(acc_test_shuf_vec) > 0) {
      test_acc_shuf <- acc_test_shuf_vec[[1]]
    }
  }

  list(
    train_accuracy          = train_acc_true,
    train_accuracy_shuffled = train_acc_shuf,
    test_accuracy           = test_acc_true,
    test_accuracy_shuffled  = test_acc_shuf
  )
}

# Loop over models and log summaries
for (model_id in PANEL_MODELS) {
  payload <- read_panel_export(model_id)
  if (is.null(payload)) {
    next
  }

  panels <- payload$panels

  # 1) Confusion matrix summary (primary goal)
  if (LOG_CONFUSION && !is.null(panels$confusion_matrix)) {
    cm_summary <- summarize_confusion_panel(panels$confusion_matrix)

    if (!is.null(cm_summary)) {
      # Separate log lines for Train and Test
      train_payload <- list(
        split                    = "Train",
        accuracy                 = cm_summary$train_accuracy,
        accuracy_shuffled        = cm_summary$train_accuracy_shuffled,
        recall_by_true_class     = cm_summary$train_recall_by_true_class
      )
      test_payload <- list(
        split                    = "Test",
        accuracy                 = cm_summary$test_accuracy,
        accuracy_shuffled        = cm_summary$test_accuracy_shuffled,
        recall_by_true_class     = cm_summary$test_recall_by_true_class
      )

      msg_train <- paste0(
        "figure3_confusion_summary_", model_id, "_train=",
        jsonlite::toJSON(train_payload, auto_unbox = TRUE)
      )
      msg_test <- paste0(
        "figure3_confusion_summary_", model_id, "_test=",
        jsonlite::toJSON(test_payload, auto_unbox = TRUE)
      )

      log_info(msg_train)
      log_info(msg_test)
    } else {
      log_info(
        paste(
          "figure3_confusion_summary_no_all_plates_rows model_id=",
          model_id
        )
      )
    }
  }

  # 3) Features shuffled summary (primary goal)
  if (LOG_FEATURES_SHUFFLED && !is.null(panels$features_shuffled)) {
    fs_summary <- summarize_features_shuffled_panel(panels$features_shuffled)

    if (!is.null(fs_summary)) {
      fs_train_payload <- list(
        split             = "Train",
        accuracy          = fs_summary$train_accuracy,
        accuracy_shuffled = fs_summary$train_accuracy_shuffled
      )
      fs_test_payload <- list(
        split             = "Test",
        accuracy          = fs_summary$test_accuracy,
        accuracy_shuffled = fs_summary$test_accuracy_shuffled
      )

      msg_fs_train <- paste0(
        "figure3_features_shuffled_summary_", model_id, "_train=",
        jsonlite::toJSON(fs_train_payload, auto_unbox = TRUE)
      )
      msg_fs_test <- paste0(
        "figure3_features_shuffled_summary_", model_id, "_test=",
        jsonlite::toJSON(fs_test_payload, auto_unbox = TRUE)
      )

      log_info(msg_fs_train)
      log_info(msg_fs_test)
    } else {
      log_info(
        paste(
          "figure3_features_shuffled_summary_no_all_plates_rows model_id=",
          model_id
        )
      )
    }
  }

  # Hooks for other panels if you decide to turn them on later
  if (LOG_PR_CURVES && !is.null(panels$pr_curves)) {
    log_info(
      paste(
        "figure3_pr_curves_present model_id=",
        model_id,
        " (logging is disabled by config)"
      )
    )
  }

  if (LOG_ACCURACY && !is.null(panels$accuracy)) {
    log_info(
      paste(
        "figure3_accuracy_panel_present model_id=",
        model_id,
        " (logging is disabled by config)"
      )
    )
  }

  if (LOG_DERIV_PR && !is.null(panels$derivative_pr_curves)) {
    log_info(
      paste(
        "figure3_derivative_pr_panel_present model_id=",
        model_id,
        " (logging is disabled by config)"
      )
    )
  }

  if (LOG_COEFFICIENTS && !is.null(panels$coefficients)) {
    log_info(
      paste(
        "figure3_coefficients_panel_present model_id=",
        model_id,
        " (logging is disabled by config)"
      )
    )
  }
}

log_info("Finished summarizing panel_exports for all models in ensemble_main_figure_3.")


2025-12-06T11:06:36 [main_figure_3_12_06_08_07_ensemble_main_figure_3] INFO: panel_export_loaded model_id= logreg  file= panel_exports/figure3_panels_main_figure_3_12_06_08_07_logreg.json 


ERROR: Error in summarize_confusion_panel(panels$confusion_matrix): argument "split" is missing, with no default
