In [1]:
# -----------------------------------------------------------------------------
# Full Cohort Plotting - Poster-ready Refactor (Updated)
# -----------------------------------------------------------------------------

# === Packages ================================================================
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggrepel)
  library(scales)
  library(cowplot)
})
has_patchwork <- requireNamespace("patchwork", quietly = TRUE)

# === Paths ===================================================================
proj_path     <- "/data/gusev/USERS/jpconnor/clinical_text_project/"
fig_path      <- file.path(proj_path, "figures")
results_path  <- file.path(proj_path, "data/survival_data/results/")

save_dir <- file.path(fig_path, "ASHG_figures", "full_cohort")
if (!dir.exists(save_dir)) dir.create(save_dir, recursive = TRUE)

# === Poster theme & colors ===================================================
theme_danafarber <- theme_bw(base_size = 12, base_family = "Arial") +
  theme(
    panel.grid = element_blank(),
    axis.text = element_text(color = "black", size = 10),
    axis.title = element_text(face = "bold", size = 12),
    legend.position = "top",
    legend.text = element_text(size = 10),
    legend.title = element_text(size = 11, face = "bold"),
    strip.background = element_rect(fill = "white", color = "black"),
    strip.text = element_text(face = "bold", size = 11)
  )

poster_palette <- list(
  model      = c("Type only" = "#1b9e77", "Text + Type" = "#d95f02"),
  delta      = c("TRUE" = "#1b9e77", "FALSE" = "#d95f02"),
  event_type = c("Death" = "#d95f02", "Metastases" = "#7570b3", "Phecodes" = "gray50")
)

# === Helper functions =======================================================
safe_paired_t_test <- function(x, y) {
  x <- x[!is.na(x)]
  y <- y[!is.na(y)]
  if(length(x) < 2 || length(y) < 2) return(NA_real_)
  tryCatch(t.test(x, y, paired = TRUE)$p.value, error = function(e) NA_real_)
}

p_to_sig <- function(p) {
  case_when(
    is.na(p) ~ "",
    p < 0.001 ~ "***",
    p < 0.01  ~ "**",
    p < 0.05  ~ "*",
    TRUE     ~ ""
  )
}

get_label_subset <- function(df, delta_col) {
  delta_sym <- rlang::sym(delta_col)
  top_met <- df %>% filter(event_type == "Metastases") %>% slice_max(order_by = !!delta_sym, n = 3)
  top_phe <- df %>% filter(event_type == "Phecodes")     %>% slice_max(order_by = !!delta_sym, n = 3)
  deaths  <- df %>% filter(event_type == "Death")
  bind_rows(deaths, top_met, top_phe)
}

# === Load and prepare data ===================================================
full_cohort_metrics_df <- read.csv(file.path(results_path, "full_cohort_metrics.csv"))

df <- full_cohort_metrics_df %>%
  mutate(
    is_met = stringr::str_ends(event, "M"),
    is_death = event == "death",
    event_type = case_when(
      is_death ~ "Death",
      is_met   ~ "Metastases",
      TRUE     ~ "Phecodes"
    ),
    delta_c_index = `text_plus_type_mean_c_index` - `type_mean_c_index`
  )

# c-index axis range for scatter
pad <- 0.02
c_index_range <- c(
  min(df$`type_mean_c_index`, df$`text_plus_type_mean_c_index`, na.rm = TRUE) - pad,
  max(df$`type_mean_c_index`, df$`text_plus_type_mean_c_index`, na.rm = TRUE) + pad
)

# === Summary bar plot with significance (adjusted stars) =====================
summary_long <- df %>%
  select(event_type, type_mean_c_index, text_plus_type_mean_c_index) %>%
  pivot_longer(cols = c(type_mean_c_index, text_plus_type_mean_c_index),
               names_to = "model", values_to = "c_index") %>%
  mutate(model = recode(model,
                        "type_mean_c_index" = "Type only",
                        "text_plus_type_mean_c_index" = "Text + Type"))

# Compute p-values
pval_df <- summary_long %>%
  group_by(event_type) %>%
  summarise(p_value = safe_paired_t_test(c_index[model=="Type only"], c_index[model=="Text + Type"]),
            .groups = "drop") %>%
  mutate(sig_label = p_to_sig(p_value))

# Compute mean and SE for bars
summary_stats <- summary_long %>%
  group_by(event_type, model) %>%
  summarise(
    mean_c_index = mean(c_index, na.rm = TRUE),
    se_c_index   = sd(c_index, na.rm = TRUE)/sqrt(sum(!is.na(c_index))),
    .groups = "drop"
  )

# Merge stats and p-values
summary_df <- left_join(summary_stats, pval_df, by = "event_type")

# Compute star positions with a larger buffer above bars
star_y <- summary_df %>%
  group_by(event_type) %>%
  summarise(y_star = max(mean_c_index + se_c_index), .groups = "drop") %>%
  mutate(y_star = y_star + 0.05 * diff(range(c_index_range)))  # bump from 0.02 → 0.05

# Extend y-axis to leave visible headroom for stars
# y_max <- max(star_y$y_star) + 0. * diff(range(c_index_range))
y_max <-0.85

summary_df <- left_join(summary_df, star_y, by = "event_type")

# Plot
p_summary_bar <- ggplot(summary_df, aes(x = event_type, y = mean_c_index, fill = model)) +
  geom_col(position = position_dodge(0.8), width = 0.7) +
  geom_errorbar(aes(ymin = mean_c_index - se_c_index, ymax = mean_c_index + se_c_index),
                width = 0.3, position = position_dodge(0.8), linewidth = 0.7) +
  geom_text(
    data = summary_df %>% distinct(event_type, .keep_all = TRUE),
    aes(x = event_type, y = y_star, label = sig_label),
    vjust = 0, size = 5
  ) +
  scale_fill_manual(values = poster_palette$model) +
  labs(x = NULL, y = "Mean C-Index ± SE", fill = "Model") +
    coord_cartesian(ylim = c(0.5, y_max)) +
  theme_danafarber

# Save
ggsave(file.path(save_dir, "summary_barplot_by_category.png"),
       p_summary_bar, width = 6, height = 3, dpi = 300)

# === Δ C-Index distribution ================================================
df_delta <- df %>%
  mutate(
    event_descr = event_descr %>%
      str_replace_all("(?i)met\\.", "metastasis") %>%               # e.g. "liver met." → "liver metastasis"
      str_replace_all("\\bVTE\\b", "Venous thromboembolism") %>%    # "VTE" → "Venous thromboembolism"
      str_replace_all("\\bGERD\\b", "Gastroesophageal reflux disease")  # "GERD" → "Gastroesophageal reflux disease"
  )

p_delta <- ggplot(df_delta, aes(x = reorder(event_descr, delta_c_index), y = delta_c_index,
                                fill = delta_c_index > 0)) +
  geom_col() +
  coord_flip() +
  scale_fill_manual(values = poster_palette$delta,
                    labels = c("FALSE" = "Worse", "TRUE" = "Improved")) +
  labs(x = NULL, y = "Δ C-Index", fill = "Text Effect") +
  theme_danafarber

ggsave(file.path(save_dir, "delta_c_index_distribution.png"), p_delta,
       width = 6, height = 8, dpi = 300)

SyntaxError: invalid syntax. Perhaps you forgot a comma? (806721837.py, line 7)

In [None]:
full_cohort_metrics_df

In [3]:
# -----------------------------------------------------------------------------
# Stage Subset Plotting - Poster-ready Refactor
# -----------------------------------------------------------------------------

# === Packages ================================================================
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggrepel)
  library(cowplot)
})
has_patchwork <- requireNamespace("patchwork", quietly = TRUE)

# === Paths ===================================================================
proj_path       <- "/data/gusev/USERS/jpconnor/clinical_text_project/"
fig_path        <- file.path(proj_path, "figures")
results_path    <- file.path(proj_path, "data/survival_data/results/")
save_dir_stage  <- file.path(fig_path, "ASHG_figures", "stage_subset")
if(!dir.exists(save_dir_stage)) dir.create(save_dir_stage, recursive = TRUE)

poster_palette <- list(
  model      = c("Type only" = "#1b9e77", "Text + Type" = "#d95f02"),
  delta      = c("TRUE" = "#1b9e77", "FALSE" = "#d95f02"),
  event_type = c("Death" = "#d95f02", "Metastasis" = "#7570b3", "Phecode" = "gray50"),
  comparison = c("Type → Text+Type" = "#1b9e77", "Stage → Stage+Text" = "#7570b3", "Stage-Type → Stage-Type+Text" = "#d95f02")
)

# === Helper functions =======================================================
safe_paired_t_test <- function(x, y) {
  x <- x[!is.na(x)]; y <- y[!is.na(y)]
  if(length(x) < 2 || length(y) < 2) return(NA_real_)
  tryCatch(t.test(x, y, paired = TRUE)$p.value, error = function(e) NA_real_)
}

p_to_sig <- function(p) {
  case_when(
    is.na(p) ~ "",
    p < 0.001 ~ "***",
    p < 0.01  ~ "**",
    p < 0.05  ~ "*",
    TRUE     ~ ""
  )
}

get_label_subset <- function(df, delta_col) {
  delta_sym <- rlang::sym(delta_col)
  top_met <- df %>% filter(event_type == "Metastasis") %>% slice_max(order_by = !!delta_sym, n = 3)
  top_phe <- df %>% filter(event_type == "Phecode")     %>% slice_max(order_by = !!delta_sym, n = 3)
  # deaths  <- df %>% filter(event_type == "Death")
  # bind_rows(deaths, top_met, top_phe)
    bind_rows(top_met, top_phe)
}

# === Load & prepare stage subset ============================================
stage_subset_metrics_df <- read.csv(file.path(results_path, "stage_subset_metrics.csv"))

df_stage <- stage_subset_metrics_df %>%
  mutate(
    is_met = stringr::str_ends(event, "M"),
    is_death = event == "death",
    event_type = case_when(
      is_death ~ "Death",
      is_met   ~ "Metastasis",
      TRUE     ~ "Phecode"
    ),
    delta_c_index_type       = `text_plus_type_mean_c_index` - `type_mean_c_index`,
    delta_c_index_stage      = `text_plus_stage_mean_c_index` - `stage_mean_c_index`,
    delta_c_index_stage_type = `text_plus_stage_type_mean_c_index` - `stage_type_mean_c_index`
  )

# Axis range for scatter
pad <- 0.02
c_index_range_stage <- range(
  df_stage %>% select(starts_with("type_mean"), starts_with("stage_mean"), starts_with("stage_type_mean"), starts_with("text_plus")) %>% unlist(),
  na.rm = TRUE
) + c(-pad, pad)

# === Summary bar plot ========================================================
summary_df_stage <- df_stage %>%
  select(event_type, type_mean_c_index, text_plus_type_mean_c_index) %>%
  pivot_longer(cols = c(type_mean_c_index, text_plus_type_mean_c_index),
               names_to = "model", values_to = "c_index") %>%
  mutate(model = recode(model,
                        "type_mean_c_index" = "Type only",
                        "text_plus_type_mean_c_index" = "Text + Type")) %>%
  group_by(event_type, model) %>%
  summarise(
    mean_c_index = mean(c_index, na.rm = TRUE),
    se_c_index   = sd(c_index, na.rm = TRUE)/sqrt(sum(!is.na(c_index))),
    .groups = "drop"
  )

p_summary_bar_stage <- ggplot(summary_df_stage, aes(x = event_type, y = mean_c_index, fill = model)) +
  geom_col(position = position_dodge(0.8), width = 0.7) +
  geom_errorbar(aes(ymin = mean_c_index - se_c_index, ymax = mean_c_index + se_c_index),
                position = position_dodge(0.8), width = 0.3, linewidth = 0.7) +
  scale_fill_manual(values = poster_palette$model) +
  labs(x = "Event Category", y = "Mean C-Index ± SE", fill = "Model") +
  coord_cartesian(ylim = c(0.5, 1.0)) +
  theme_danafarber

ggsave(file.path(save_dir_stage, "summary_barplot_by_category.png"), p_summary_bar_stage, width = 6, height = 4, dpi = 300)

# === Δ C-Index distribution ==================================================
p_delta_stage <- ggplot(df_stage, aes(x = reorder(event_descr, delta_c_index_type), y = delta_c_index_type,
                                      fill = delta_c_index_type > 0)) +
  geom_col() +
  coord_flip() +
  scale_fill_manual(values = poster_palette$delta,
                    labels = c("FALSE" = "Worse", "TRUE" = "Improved")) +
  labs(x = "Event", y = "Δ C-Index", fill = "Text Effect") +
  theme_danafarber

ggsave(file.path(save_dir_stage, "delta_c_index_distribution.png"), p_delta_stage, width = 8, height = 10, dpi = 300)

# === Scatterplots ============================================================
plots <- list()

# Stage vs Stage+Text
if(all(c("stage_mean_c_index","text_plus_stage_mean_c_index") %in% colnames(df_stage))){
  plots$stage <- ggplot(df_stage, aes(x = `stage_mean_c_index`, y = `text_plus_stage_mean_c_index`)) +
    geom_abline(linetype = "dashed", color = "gray60") +
    geom_point(aes(color = event_type), size = 2, alpha = 0.8) +
geom_text_repel(
  data = get_label_subset(df_stage, "delta_c_index_stage"),
  aes(label = event_descr),
  size = 3,
  max.overlaps = Inf,          # let repel fully control placement
  box.padding = 0.5,
  point.padding = 0.2,
  min.segment.length = 0,      # keep connecting lines short
  segment.size = 0.2,
  segment.alpha = 0.6,
  force = 2,                   # stronger repelling
  force_pull = 0.5
) +
    scale_color_manual(values = poster_palette$event_type) +
    labs(x = "Stage Model C-Index", y = "Stage + Text C-Index", color = "Event Type") +
    coord_equal(xlim = c_index_range_stage, ylim = c_index_range_stage) +
    theme_danafarber + theme(legend.position = "bottom")
}

# Type vs Type+Text
if(all(c("type_mean_c_index","text_plus_type_mean_c_index") %in% colnames(df_stage))){
  plots$type <- ggplot(df_stage, aes(x = `type_mean_c_index`, y = `text_plus_type_mean_c_index`)) +
    geom_abline(linetype = "dashed", color = "gray60") +
    geom_point(aes(color = event_type), size = 2, alpha = 0.8) +
geom_text_repel(
  data = get_label_subset(df_stage, "delta_c_index_type"),
  aes(label = event_descr),
  size = 3,
  max.overlaps = Inf,          # let repel fully control placement
  box.padding = 0.5,
  point.padding = 0.2,
  min.segment.length = 0,      # keep connecting lines short
  segment.size = 0.2,
  segment.alpha = 0.6,
  force = 2,                   # stronger repelling
  force_pull = 0.5
) +
    scale_color_manual(values = poster_palette$event_type) +
    labs(x = "Type Model C-Index", y = "Type + Text C-Index", color = "Event Type") +
    coord_equal(xlim = c_index_range_stage, ylim = c_index_range_stage) +
    theme_danafarber + theme(legend.position = "bottom")
}

# Stage-Type vs Stage-Type+Text
if(all(c("stage_type_mean_c_index","text_plus_stage_type_mean_c_index") %in% colnames(df_stage))){
  plots$stage_type <- ggplot(df_stage, aes(x = `stage_type_mean_c_index`, y = `text_plus_stage_type_mean_c_index`)) +
    geom_abline(linetype = "dashed", color = "gray60") +
    geom_point(aes(color = event_type), size = 2, alpha = 0.8) +
geom_text_repel(
  data = get_label_subset(df_stage, "delta_c_index_stage_type"),
  aes(label = event_descr),
  size = 3,
  max.overlaps = Inf,          # let repel fully control placement
  box.padding = 0.5,
  point.padding = 0.2,
  min.segment.length = 0,      # keep connecting lines short
  segment.size = 0.2,
  segment.alpha = 0.6,
  force = 2,                   # stronger repelling
  force_pull = 0.5
) +
    scale_color_manual(values = poster_palette$event_type) +
    labs(x = "Stage-Type Model C-Index", y = "Text + Stage-Type C-Index", color = "Event Type") +
    coord_equal(xlim = c_index_range_stage, ylim = c_index_range_stage) +
    theme_danafarber + theme(legend.position = "bottom")
}

if (length(plots) == 3 && has_patchwork) {
  combined_plot <- (plots$stage | plots$type | plots$stage_type) +
    patchwork::plot_layout(
      guides = "collect",
      widths = c(1, 1, 1),
      design = NULL
    ) &
    theme(
      legend.position = "bottom",
      plot.margin = margin(0, 0, 0, 0),         # no outer margins per subplot
      panel.spacing = unit(0, "lines"),         # no gap between panels
    )

  # Remove patchwork global padding
  combined_plot <- combined_plot +
    patchwork::plot_annotation(theme = theme(plot.margin = margin(0, 0, 0, 0)))

  # Save with no excess border around the grid
  ggsave(
    file.path(save_dir_stage, "scatter_three_panel_shared_legend.png"),
    combined_plot,
    width = 15,   # proportionally smaller for tight layout
    height = 5,
    dpi = 300,
    bg = "white"
  )
} else {
  # fallback: save individually
  map2(
    plots, names(plots),
    ~ ggsave(
      file.path(save_dir_stage, paste0("scatter_", .y, ".png")),
      .x,
      width = 6,
      height = 6,
      dpi = 300,
      bg = "white"
    )
  )
}

# === Δ C-Index density across comparisons ===================================
delta_df <- df_stage %>%
  select(event_descr, event_type,
         delta_c_index_type, delta_c_index_stage, delta_c_index_stage_type) %>%
  pivot_longer(cols = starts_with("delta_c_index"),
               names_to = "comparison",
               values_to = "delta_c_index") %>%
  mutate(comparison = recode(comparison,
                             delta_c_index_type       = "Type → Text+Type",
                             delta_c_index_stage      = "Stage → Stage+Text",
                             delta_c_index_stage_type = "Stage-Type → Stage-Type+Text"))

p_delta_density <- ggplot(delta_df, aes(x = delta_c_index, fill = comparison)) +
  geom_density(alpha = 0.5) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "gray50") +
  scale_fill_manual(values = poster_palette$comparison) +
  labs(x = "Δ C-Index (Text Model − Base Model)", y = "Density", fill = "Model Comparison") +
  theme_danafarber

ggsave(file.path(save_dir_stage, "delta_c_index_density.png"), p_delta_density, width = 8, height = 5, dpi = 300)

In [4]:
# -----------------------------------------------------------------------------
# Held-out C-Index by Cancer Type (Pan vs Within) - Dana-Farber Theme
# -----------------------------------------------------------------------------

# === Packages ================================================================
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggrepel)
  library(cowplot)
})
has_patchwork <- requireNamespace("patchwork", quietly = TRUE)

# === Paths ===================================================================
proj_path    <- "/data/gusev/USERS/jpconnor/clinical_text_project/"
fig_path     <- file.path(proj_path, "figures/ASHG_figures/pan_vs_within_type/")
results_path <- file.path(proj_path, "data/survival_data/results/pan_vs_within_cancer")

if (!dir.exists(fig_path)) dir.create(fig_path, recursive = TRUE)

# === Dana-Farber Theme & Palette ===========================================
theme_danafarber <- theme_bw(base_size = 14, base_family = "Helvetica") +
  theme(
    panel.grid = element_blank(),
    axis.text = element_text(color = "black", size = 12),
    axis.title = element_text(face = "bold", size = 14),
    legend.position = "top",
    legend.text = element_text(size = 12),
    legend.title = element_text(size = 13, face = "bold"),
    strip.background = element_rect(fill = "white", color = "black"),
    strip.text = element_text(face = "bold", size = 13)
  )

df_palette <- c("Pan-cancer" = "#d95f02",  # Dana-Farber orange
                "Within-cancer" = "#7570b3") # Dana-Farber purple

# === Load & Prepare Data ====================================================
held_out_metrics_df <- read.csv(file.path(results_path, "cindex_by_cancer_type.csv"))

colnames(held_out_metrics_df) <- c("cancer_type", "c_index_pan", "c_index_within", "delta_c_index", "num_held_out")

plot_df <- held_out_metrics_df %>%
  pivot_longer(
    cols = c("c_index_pan", "c_index_within"),
    names_to = "training_scheme",
    values_to = "c_index"
  ) %>%
  mutate(
    training_scheme = recode(
      training_scheme,
      "c_index_pan" = "Pan-cancer",
      "c_index_within" = "Within-cancer"
    ),
    cancer_type = cancer_type %>%
      str_replace_all("_", " ") %>%         # replace underscores with spaces
      str_to_lower() %>%                    # make all lowercase
      str_to_sentence(),                    # capitalize first letter of each word
    cancer_type = factor(
      cancer_type,
      levels = str_to_sentence(str_replace_all(
        held_out_metrics_df$cancer_type[order(held_out_metrics_df$c_index_within, decreasing = TRUE)],
        "_", " "
      ))
    )
  )

# === Plot ===================================================================
p_cindex_bar <- ggplot(plot_df, aes(x = cancer_type, y = c_index, fill = training_scheme)) +
  geom_col(position = position_dodge(width = 0.8), width = 0.7) +
  scale_fill_manual(values = df_palette) +
  labs(x = NULL, y = "C-Index", fill = "Training Scheme") +
  theme_danafarber +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1))

# Save plot
ggsave(file.path(fig_path, "bar_cindex_by_cancer_type_pan_vs_within.png"),
       p_cindex_bar, width = 6, height = 6, dpi = 300)


In [5]:
# -----------------------------------------------------------------------------
# Cleaned & refactored plotting script with Dana-Farber Theme
# -----------------------------------------------------------------------------

# === Packages ================================================================
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggrepel)
  library(scales)
  library(survminer)
  library(survcomp)   # for concordance.index used later
  library(cowplot)
  library(progress)
})
has_patchwork <- requireNamespace("patchwork", quietly = TRUE)

# === Paths ===================================================================
proj_path  <- "/data/gusev/USERS/jpconnor/clinical_text_project/"
fig_path   <- file.path(proj_path, "figures")
data_path  <- file.path(proj_path, "data")
surv_path  <- file.path(data_path, "survival_data")
results_path <- file.path(surv_path, "results/")

if (!dir.exists(fig_path)) dir.create(fig_path, recursive = TRUE)

# -----------------------------------------------------------------------------
# KM curves + c-index loop (example for held-out risk)
# -----------------------------------------------------------------------------
data_path2 <- '/data/gusev/USERS/jpconnor/clinical_text_project/data/survival_data/'
pred_path  <- file.path(data_path2, 'results/phecode_held_out_risk_scores/full_cohort/')
figure_path <- file.path(proj_path, "figures/ASHG_figures/full_cohort/")

tt_phecode_df <- read.csv(file.path(data_path2, 'time-to-phecode/tt_vte_plus_phecodes.csv'), check.names = FALSE)

event <- 'death'
tt_event <- paste0('tt_', event)

# Read predictions and merge with times/status
pred_file <- file.path(pred_path, event, 'held_out_risk_predictions.csv')
if (!file.exists(pred_file)) {
message("Skipping event — predictions file missing: ", pred_file)
next
}

df_pred <- read.csv(pred_file) %>%
na.omit() %>%
merge(tt_phecode_df[, c('DFCI_MRN', event, tt_event)], by = 'DFCI_MRN')

# survival time & status
time <- df_pred[[tt_event]]
status <- df_pred[[event]]

# quartiles
df_pred <- df_pred %>%
mutate(
  base_risk_quartile = factor(ntile(base_risk_score, 4),
                              labels = c('Q1 (Lowest risk)', 'Q2', 'Q3', 'Q4 (Highest risk)')),
  text_risk_quartile = factor(ntile(text_risk_score, 4),
                              labels = c('Q1 (Lowest risk)', 'Q2', 'Q3', 'Q4 (Highest risk)'))
)

# Fit KM curves
base_fit <- survfit(Surv(time, status) ~ base_risk_quartile, data = df_pred)
text_fit <- survfit(Surv(time, status) ~ text_risk_quartile, data = df_pred)

# C-index
base_cindex <- concordance.index(df_pred$base_risk_score, surv.time = time, surv.event = status)$c.index
text_cindex <- concordance.index(df_pred$text_risk_score, surv.time = time, surv.event = status)$c.index

min_surv <- min(c(base_fit$surv, text_fit$surv))
ylim_vals <- c(max(0, min_surv - 0.02), 1)

make_plot <- function(fit, data, quartile_col, title_prefix, cindex, legend_title) {
suppressWarnings(
  ggsurvplot(
    fit, data = data,
    risk.table = FALSE,
    pval = TRUE,
    conf.int = TRUE,
    palette = "Dark2",
    legend.title = legend_title,
    legend.labs = levels(data[[quartile_col]]),
    xlab = "Time (days)",
    ylab = paste0("Survival Probability (Death)"),
    ggtheme = theme_danafarber,  # <-- apply Dana-Farber theme here
    ylim = ylim_vals,
  )
)
}

base_plot <- make_plot(base_fit, df_pred, "base_risk_quartile", "Base Model", base_cindex, "Base Risk Quartile")
text_plot <- make_plot(text_fit, df_pred, "text_risk_quartile", "Text Model", text_cindex, "Type + Text Risk Quartile")

combined_plot <- suppressWarnings(arrange_ggsurvplots(list(base_plot, text_plot), ncol = 2, nrow = 1, print = FALSE))

# Save
save_ggsurv <- function(plot_obj, file_name, width = 8, height = 6) {
png(filename = file_name, width = width, height = height, units = "in", res = 300)
suppressWarnings(print(plot_obj))
dev.off()
}

save_ggsurv(base_plot, file.path(figure_path, paste0("km_", event, "_quartiles_base_model.png")))
save_ggsurv(text_plot, file.path(figure_path, paste0("km_", event, "_quartiles_text_model.png")))
save_ggsurv(combined_plot, file.path(figure_path, paste0("km_", event, "_quartiles_combined_2x2.png")), width = 16, height = 12)

In [6]:
# -----------------------------------------------------------------------------
# Cleaned & refactored plotting script with Dana-Farber Theme
# -----------------------------------------------------------------------------

# === Packages ================================================================
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggrepel)
  library(scales)
  library(survminer)
  library(survcomp)   # for concordance.index used later
  library(cowplot)
  library(progress)
})
has_patchwork <- requireNamespace("patchwork", quietly = TRUE)

# === Paths ===================================================================
proj_path  <- "/data/gusev/USERS/jpconnor/clinical_text_project/"
fig_path   <- file.path(proj_path, "figures")
data_path  <- file.path(proj_path, "data")
surv_path  <- file.path(data_path, "survival_data")
results_path <- file.path(surv_path, "results/")

if (!dir.exists(fig_path)) dir.create(fig_path, recursive = TRUE)

# -----------------------------------------------------------------------------
# KM curves + c-index loop (example for held-out risk)
# -----------------------------------------------------------------------------
data_path2 <- '/data/gusev/USERS/jpconnor/clinical_text_project/data/survival_data/'
pred_path  <- file.path(data_path2, 'results/phecode_held_out_risk_scores/full_cohort/')
figure_path <- file.path(proj_path, "figures/ASHG_figures/full_cohort/")

tt_phecode_df <- read.csv(file.path(data_path2, 'time-to-phecode/tt_vte_plus_phecodes.csv'), check.names = FALSE)

event <- '296.22'
tt_event <- paste0('tt_', event)

# Read predictions and merge with times/status
pred_file <- file.path(pred_path, event, 'held_out_risk_predictions.csv')
if (!file.exists(pred_file)) {
message("Skipping event — predictions file missing: ", pred_file)
next
}

df_pred <- read.csv(pred_file) %>%
na.omit() %>%
merge(tt_phecode_df[, c('DFCI_MRN', event, tt_event)], by = 'DFCI_MRN')

# survival time & status
time <- df_pred[[tt_event]]
status <- df_pred[[event]]

# quartiles
df_pred <- df_pred %>%
mutate(
  base_risk_quartile = factor(ntile(base_risk_score, 4),
                              labels = c('Q1 (Lowest risk)', 'Q2', 'Q3', 'Q4 (Highest risk)')),
  text_risk_quartile = factor(ntile(text_risk_score, 4),
                              labels = c('Q1 (Lowest risk)', 'Q2', 'Q3', 'Q4 (Highest risk)'))
)

# Fit KM curves
base_fit <- survfit(Surv(time, status) ~ base_risk_quartile, data = df_pred)
text_fit <- survfit(Surv(time, status) ~ text_risk_quartile, data = df_pred)

# C-index
base_cindex <- concordance.index(df_pred$base_risk_score, surv.time = time, surv.event = status)$c.index
text_cindex <- concordance.index(df_pred$text_risk_score, surv.time = time, surv.event = status)$c.index

min_surv <- min(c(base_fit$surv, text_fit$surv))
ylim_vals <- c(max(0, min_surv - 0.02), 1)

make_plot <- function(fit, data, quartile_col, title_prefix, cindex, legend_title) {
suppressWarnings(
  ggsurvplot(
    fit, data = data,
    risk.table = FALSE,
    pval = TRUE,
    conf.int = TRUE,
    palette = "Dark2",
    legend.title = legend_title,
    legend.labs = levels(data[[quartile_col]]),
    xlab = "Time (days)",
    ylab = paste0("Survival Probability (Major Depressive Disorder)"),
    ggtheme = theme_danafarber,  # <-- apply Dana-Farber theme here
    ylim = ylim_vals,
  )
)
}

base_plot <- make_plot(base_fit, df_pred, "base_risk_quartile", "Base Model", base_cindex, "Base Risk Quartile")
text_plot <- make_plot(text_fit, df_pred, "text_risk_quartile", "Text Model", text_cindex, "Type + Text Risk Quartile")

combined_plot <- suppressWarnings(arrange_ggsurvplots(list(base_plot, text_plot), ncol = 2, nrow = 1, print = FALSE))

# Save
save_ggsurv <- function(plot_obj, file_name, width = 8, height = 6) {
png(filename = file_name, width = width, height = height, units = "in", res = 300)
suppressWarnings(print(plot_obj))
dev.off()
}

save_ggsurv(base_plot, file.path(figure_path, paste0("km_", event, "_quartiles_base_model.png")))
save_ggsurv(text_plot, file.path(figure_path, paste0("km_", event, "_quartiles_text_model.png")))
save_ggsurv(combined_plot, file.path(figure_path, paste0("km_", event, "_quartiles_combined_2x2.png")), width = 16, height = 12)