# Readme

This script contains codes for figure generation described in the manuscript titled 'Benchmarking algorithms for joint integration of unpaired and paired single-cell RNA-seq and ATAC-seq data'. 

Please run steps described in evaluate_vary_situations_public.ipynb before this script. 

For scenario 1-3, one summary plot is generated for each challenge. For senario 4 and 5, multiple plots could be generated based on the number of missing cell types. 

Additional plots are generated using umap_plot_generation_public.ipynb and hpap_result_plot_public.ipynb


# Functions

In [None]:
# Run with an R kernel
library(dplyr)
library(ggplot2)
library(tidyverse)
library(ComplexHeatmap)
library(gridExtra)

plot_metric <- function(res_df, folder_path,file_name,ct_metric=c('ARI','NMI','ct_aws','clisi'),
                        batch_metric=c('b_saw','kbet','ilisi','gconn'),var='var',x_label=""){
    pdf(paste0(folder_path,file_name))
    if(x_label==""){
        x_label=var
    }
    # ct metrics
    print("===== cell type annotation metrics =====")
    ct_glist = list()
    
    
    for (col in ct_metric){
        ct_glist[[col]] = ggplot(res_df,aes_string(var,col,color="method")) + 
            geom_point() +
            stat_summary(aes(group = method),fun=mean, geom="line") +
            xlab(x_label)
    }

    gridExtra::grid.arrange(grobs=ct_glist)

    batch_glist = list()
    for (col in batch_metric){
        batch_glist[[col]] = ggplot(res_df,aes_string(var,col,color="method")) + 
            geom_point() +
            stat_summary(aes(group = method),fun=mean, geom="line") +
            xlab(x_label)
    }
    print("===== batch removal metrics =====")
    gridExtra::grid.arrange(grobs=batch_glist)
    dev.off()
    return(list(ct_glist,batch_glist))
}



plot_metric_boxplot <- function(res_df, folder_path,file_name,ct_metric=c('ARI','NMI','ct_aws','clisi'),
                        batch_metric=c('b_saw','kbet','ilisi','gconn'),var='var',title=""){
    pdf(paste0(folder_path,file_name))

    # ct metrics
    print("===== cell type annotation metrics =====")
    ct_glist = list()
    
    
    for (col in ct_metric){
        ct_glist[[col]] = ggplot(res_df,aes_string(x="method",y=col,fill="method")) + 
            geom_boxplot() +
            ggtitle(title) +
            theme(plot.title = element_text(hjust = 0.5),
                 axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))
    }

    gridExtra::grid.arrange(grobs=ct_glist)

    batch_glist = list()
    for (col in batch_metric){
        batch_glist[[col]] = ggplot(res_df,aes_string(x="method",y=col,fill="method")) + 
            geom_boxplot() +
            ggtitle(title) +
            theme(plot.title = element_text(hjust = 0.5),
                 axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))
    }
    print("===== batch removal metrics =====")
    gridExtra::grid.arrange(grobs=batch_glist)
    dev.off()
    return(list(ct_glist,batch_glist))
}



plot_result <- function(dir_path,res_folder_filter=NULL,runtime_exists=T,prediction_exists=T,metric_file="metric.csv"){
    file_paths = list.files(dir_path,metric_file,recursive = T)
    res_df <- c()
    for(i in file_paths){
        #print(i)
        m_i = read.csv(file.path(dir_path,i))
        elements = stringr::str_split(i,"/")[[1]]
        result_out = elements[2]
        method = elements[3]
        ct_ref = elements[4]
        orig = gsub("[0-9]*","",gsub("_.*","",elements[1]))
        var = gsub("[a-z]*[A-Z]*","",gsub("_.*","",elements[1]))
        rep = gsub(".*_","",elements[1])
        res_i = as.data.frame(c(m_i,c("orig"=orig,"var"=var,"rep"=rep,"method"=method,"out"=result_out,"ct_ref"=ct_ref)))
        res_df <- rbind(res_df,res_i)
    }
    res_df$var <- as.integer(res_df$var)
    
    # save intemediate result to output variable 
    res_all <- res_df
    if(runtime_exists){
        runtimes = list.files(dir_path,".*runtime.txt",recursive = T)
        res_df2 <- c()
        for(i in runtimes){
            m_i = read.table(file.path(dir_path,i))
            elements = stringr::str_split(i,"/")[[1]]
            result_out = elements[2]
            method = gsub("_.*","",elements[4])
            orig = gsub("[0-9]*","",gsub("_.*","",elements[1]))
            var = gsub("[a-z]*[A-Z]*","",gsub("_.*","",elements[1]))
            rep = gsub(".*_","",elements[1])
            res_i = as.data.frame(c(m_i,c("orig"=orig,"var"=var,"rep"=rep,"method"=method,"out"=result_out)))
            res_df2 <- rbind(res_df2,res_i)
        }

        rownames(res_df2) <- NULL
        res_df2<-data.frame(res_df2)
        res_df2$var<-as.integer(res_df2$var)
        colnames(res_df2)[1]="runtime"

        res_df2$runtime<-as.numeric(res_df2$runtime)
        
        res_all <- left_join(res_df,res_df2)
    }
    
    if(prediction_exists){
        prediciton_eval = list.files(dir_path,".*prediction_eval.txt",recursive = T)
        res_df3 <- c()
        for(i in prediciton_eval){
            m_i = read.csv(file.path(dir_path,i),header = F,sep='\t')
            vec = m_i[,2]
            names(vec) = unlist(m_i[1])
            elements = stringr::str_split(i,"/")[[1]]
            result_out = elements[2]
            method = elements[3]
            orig = gsub("[0-9]*","",gsub("_.*","",elements[1]))
            var = gsub("[a-z]*[A-Z]*","",gsub("_.*","",elements[1]))
            rep = gsub(".*_","",elements[1])
            res_i = t(data.frame(c(vec,"orig"=orig,"var"=var,"rep"=rep,"method"=method,"out"=result_out)))
            res_df3 <- rbind(res_df3,res_i)
        }

        rownames(res_df3) <- NULL
        res_df3<-data.frame(res_df3)
        res_df3$var<-as.integer(res_df3$var)
        res_df3$percent_recovered_50kb<-as.numeric(res_df3$percent_recovered_50kb)
        res_df3$f1<-as.numeric(res_df3$f1)
        res_all <- left_join(res_all,res_df3)
    }
    if(!is.null(res_folder_filter)){
        res_all = res_all %>% filter(out %in% res_folder_filter)
    }
   return(res_all)
}


# format unpaired integration method result, print out one average value despite changes in multiome cells, because they did not use the multiome cells 

format_unpaired_res <- function(res, n_times){
    output <- do.call(cbind,apply(unpaired_method_perform,2,function(x){
        y<-as.data.frame(x)
        do.call(cbind,rep(list(y),n_times))})
    )
    return(output)
}

# get observed (truth) peak-gene association accuracy result (percentage and F1)
## Assuming there is "truth_prediction_eval.txt" file under the dir_path folder. Read in results and format the table so it shows which repeat it is. 

get_pred_truth <- function(dir_path){
    prediciton_eval = list.files(dir_path,"truth_prediction_eval.txt",recursive = T)
    res_df3<-c()
    for(i in prediciton_eval){
        m_i = read.csv(file.path(dir_path,i),header = F,sep='\t')
        vec = m_i[,2]
        names(vec) = unlist(m_i[1])
        elements = stringr::str_split(i,"/")[[1]]
        result_out = elements[2]
        method = elements[3]
        orig = gsub("[0-9]*","",gsub("_.*","",elements[1]))
        var = gsub("[a-z]*[A-Z]*","",gsub("_.*","",elements[1]))
        rep = gsub(".*_","",elements[1])
        res_i = t(data.frame(c(vec,"orig"=orig,"var"=var,"rep"=rep,"method"=method,"out"=result_out)))
        res_df3 <- rbind(res_df3,res_i)
    }

    rownames(res_df3) <- NULL
    res_df3<-data.frame(res_df3)
    res_df3$var<-as.integer(res_df3$var)
    res_df3$percent_recovered_50kb<-as.numeric(res_df3$percent_recovered_50kb)
    res_df3$f1<-as.numeric(res_df3$f1)
    return(res_df3)
}


plot_rare_cell_result <- function(dir_path,res_folder_filter=NULL,runtime_exists=T,prediction_exists=T,metric_file="metric.csv"){
    file_paths = list.files(dir_path,metric_file,recursive = T)
    res_df <- c()
    for(i in file_paths){
        #print(i)
        m_i = read.table(file.path(dir_path,i),sep=',',header = T)[1,]
        elements = stringr::str_split(i,"/")[[1]]
        result_out = elements[2]
        method = elements[3]
        ct_ref = elements[4]
        elements_sep = stringr::str_split(elements[1],"_")[[1]]
        orig = elements_sep[1]
        var = gsub("[a-z]*[A-Z]*","",gsub("_.*","",elements_sep[2]))
        rep = elements_sep[3]
        res_i = as.data.frame(c(m_i,c("orig"=orig,"var"=var,"rep"=rep,"method"=method,"out"=result_out,"ct_ref"=ct_ref,out_path = elements[1])))
        res_df <- rbind(res_df,res_i)
    }
    res_df$var <- as.integer(res_df$var)
    
    # save intemediate result to output variable 
    res_all <- res_df
    if(runtime_exists){
        runtimes = list.files(dir_path,".*runtime.txt",recursive = T)
        res_df2 <- c()
        for(i in runtimes){
            m_i = read.table(file.path(dir_path,i))
            elements = stringr::str_split(i,"/")[[1]]
            method = method = gsub("_.*","",elements[4])
            result_out = elements[2]
            res_i = as.data.frame(c(m_i,c("method"=method,"out"=result_out,out_path = elements[1])))
            res_df2 <- rbind(res_df2,res_i)
        }

        rownames(res_df2) <- NULL
        res_df2<-data.frame(res_df2)
        colnames(res_df2)[1]="runtime"

        res_df2$runtime<-as.numeric(res_df2$runtime)
        
        res_all <- left_join(res_df,res_df2,by = c("method","out","out_path"))
    }
    
    if(prediction_exists){
        prediciton_eval = list.files(dir_path,".*prediction_eval.txt",recursive = T)
        res_df3 <- c()
        for(i in prediciton_eval){
            m_i = read.csv(file.path(dir_path,i),header = F,sep='\t')
            vec = m_i[,2]
            names(vec) = unlist(m_i[1])
            elements = stringr::str_split(i,"/")[[1]]
            result_out = elements[2]
            method = elements[3]
            orig = gsub("[0-9]*","",gsub("_.*","",elements[1]))
            var = gsub("[a-z]*[A-Z]*","",gsub("_.*","",elements[1]))
            rep = gsub(".*_","",elements[1])
            res_i = t(data.frame(c(vec,"orig"=orig,"var"=var,"rep"=rep,"method"=method,"out"=result_out)))
            res_df3 <- rbind(res_df3,res_i)
        }

        rownames(res_df3) <- NULL
        res_df3<-data.frame(res_df3)
        res_df3$var<-as.integer(res_df3$var)
        res_df3$percent_recovered_50kb<-as.numeric(res_df3$percent_recovered_50kb)
        res_df3$f1<-as.numeric(res_df3$f1)
        res_all <- left_join(res_all,res_df3)
    }
    if(!is.null(res_folder_filter)){
        res_all = res_all %>% filter(out %in% res_folder_filter)
    }
   return(res_all)
}


# Color pallete

In [None]:
method_names = c("MultiVI","Seurat v4", "Cobolt","Seurat v3","BindSC","FigR", "LIGER","GLUE","scMoMaT","Seurat v4 integrate")
hex <- c("#cc6677","#999933","#44aa99","#117733","#88ccee","#882255","#332288","#ddcc77","#aa4499","#555555")
names(hex)<-method_names
hex

# Scenario 1

## PBMC

### Load result

In [None]:
dir_path = "dataset/multiome_pbmc_10k/pbmc_vary_cell_test/"

res_all<- plot_result(dir_path,res_folder_filter = c("results_single_mod","results_single_same_cell_number"),runtime_exists=T,prediction_exists=T)

res_all$log2_runtime <- log2(res_all$runtime)

# add _2 to method name and add type column
unpaired_method_list<-c("seurat3","rfigr","rbindsc","rliger","glue")
res_plot <- res_all %>% 
    filter(var %in% c(1000,3000,8000)) %>% 
    mutate(method_name = ifelse(out=="results_single_same_cell_number",paste0(method,"_2"),method)) %>%
    # assign method type 
    mutate(type = ifelse(out!="results_single_same_cell_number",
                         ifelse(method%in%unpaired_method_list,"Unpaired","Multiome-guided"),
                         "Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))

#res_plot_2$metric_name <- dplyr::recode(res_plot_2$metric_name,"Gene-peak \n pair % recovered" ="Peak-gene \n pair % recovered")

# get observed accuracy
plot_ref_accuracy = TRUE
#dir_path = "dataset/multiome_pbmc_10k/multiome_ncells_pmat/"
truth_tbl <- get_pred_truth(dir_path)
truth_df <- data.frame(metric_name = factor(c("Peak-gene \n pair % recovered","F1"),
                                            levels=c("Peak-gene \n pair % recovered","F1")),
                       Z = as.numeric(round(colMeans(truth_tbl[,1:2]),digits=2)))
truth_df



### Summary plot

In [None]:
file_name <- "pbmc_single_mod_n_splitted_metric_summary19_7ct_w_observed_rev1.pdf"

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = c(1000,3000,8000) 

ylim_list = list(c(0.2,1),c(0.4,0.9),c(0.5,0.8),c(0.8,1),c(0.6,1),c(0,1),c(0.25,0.5),c(0.25,0.45),c(0,1),c(6,14))
nbreaks = list(4,5,6,4,4,5,5,4,5,8)
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=35,height=15)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

## BMMC

### Load result

In [None]:
dir_path ="dataset/bmmc/bmmc_vary_cell_test/"
res_all<- plot_result(dir_path,res_folder_filter = c("results_single_mod","results_single_same_cell_number"),runtime_exists=T,prediction_exists=T)

res_all$log2_runtime <- log2(res_all$runtime)

# add _2 to method name and add type column
unpaired_method_list<-c("seurat3","rfigr","rbindsc","rliger","glue")
res_plot <- res_all %>% 
    mutate(method_name = ifelse(out=="results_single_same_cell_number",paste0(method,"_2"),method)) %>%
    # assign method type 
    mutate(type = ifelse(out!="results_single_same_cell_number",
                         ifelse(method%in%unpaired_method_list,"Unpaired","Multiome-guided"),
                         "Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel

res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))

res_plot_2$metric_name <- dplyr::recode(res_plot_2$metric_name,"Gene-peak \n pair % recovered" ="Peak-gene \n pair % recovered")


# get observed accuracy
plot_ref_accuracy = TRUE
#dir_path = "dataset/multiome_pbmc_10k/multiome_ncells_pmat/"
truth_tbl <- get_pred_truth(dir_path)
truth_df <- data.frame(metric_name = factor(c("Peak-gene \n pair % recovered","F1"),
                                            levels=c("Peak-gene \n pair % recovered","F1")),
                       Z = as.numeric(round(colMeans(truth_tbl[,1:2]),digits=2)))
truth_df


### Summary plot

In [None]:
file_name <- "bmmc_single_mod_n_splitted_metric_summary19_21ct_w_observed_rev1.pdf"

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = c(1000,2000,4000) # unique(res_plot_3$var)
ylim_list = list(c(0,0.5),c(0.2,0.7),c(0.45,0.7),c(0.85,1),c(0.4,1),c(0,1),c(0.05,0.3),c(0.05,0.3),c(0,1),c(6,14))
nbreaks = list(5,5,5,6,6,5,5,5,5,8)
counter=1
subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                           
pdf(file.path(folder_path,file_name),width=35,height=15)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

## SHARE-seq 

### Load result

In [None]:
dir_path = "dataset/mouse_skin/multiome_ncells_pmat/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all<- plot_result(dir_path,res_folder_filter = c("results_single_mod","results_single_same_cell_number"),runtime_exists=T,prediction_exists=T)
table(res_all$method)
table(res_all$out)

res_all$log2_runtime <- log2(res_all$runtime)

# add _2 to method name and add type column
unpaired_method_list<-c("seurat3","rfigr","rbindsc","rliger","glue")
res_plot <- res_all %>% 
    mutate(method_name = ifelse(out=="results_single_same_cell_number",paste0(method,"_2"),method)) %>%
    # assign method type 
    mutate(type = ifelse(out!="results_single_same_cell_number",
                         ifelse(method%in%unpaired_method_list,"Unpaired","Multiome-guided"),
                         "Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel

res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))

res_plot_2$metric_name <- dplyr::recode(res_plot_2$metric_name,"Gene-peak \n pair % recovered" ="Peak-gene \n pair % recovered")


# get observed accuracy
plot_ref_accuracy = TRUE

truth_tbl <- get_pred_truth(dir_path)
truth_df <- data.frame(metric_name = factor(c("Peak-gene \n pair % recovered","F1"),
                                            levels=c("Peak-gene \n pair % recovered","F1")),
                       Z = as.numeric(round(colMeans(truth_tbl[,1:2]),digits=2)))
truth_df


### Summary Plot

In [None]:
file_name <- "mouse_skin_single_mod_n_splitted_metric_summary19_21ct_w_observed_rev1.pdf"

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = c(5000,10000,15000) # unique(res_plot_3$var)
ylim_list = list(c(0,0.6),c(0.2,0.7),c(0.45,0.7),c(0.85,1),c(0.4,1),c(0,1),c(0.1,0.7),c(0.05,0.5),c(0,1),c(6,15))
nbreaks = list(5,5,5,6,6,5,8,5,5,8)
counter=1
subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
        #     stat_summary(fun.y=mean, geom="point",aes(col=method_name)) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                           
pdf(file.path(folder_path,file_name),width=35,height=15)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

# Scenario 2

## PBMC 2000 cells

### Load result

In [None]:
dir_path = "dataset/multiome_pbmc_10k/nmulti2000_7ct_vdepth_test/"

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric.csv")

res_all$log2_runtime <- log2(res_all$runtime)
# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","scmomat")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all %>%  
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))




### Summary plot

In [None]:
file_name <- "pbmc_vdepth_summary19_7ct_fair_rev1.pdf"

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = unique(res_plot_3$var)

ylim_list = list(c(0.2,1),c(0.4,0.9),c(0.5,0.8),c(0.8,1),c(0.6,1),c(0,1),c(0.25,0.5),c(0.25,0.45),c(0,1),c(6,12))
nbreaks = list(4,5,6,4,4,5,5,4,5,8)
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=35,height=15)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()


## BMMC 2000 cells

### Load result

In [None]:
dir_path = "dataset/bmmc/nmulti2000_all_ct_vdepth/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3wbatch_metric.csv")
table(res_all$method)
table(res_all$out)

res_all$log2_runtime <- log2(res_all$runtime)


# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","scmomat")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


### Summary plot

In [None]:
file_name <- "bmmc_nmulti2000_vdepth_metric_summary19_21ct_fair_rev1.pdf"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))
# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = unique(res_plot_3$var)
ylim_list = list(c(0,0.5),c(0.2,0.7),c(0.45,0.7),c(0.85,1),c(0.4,1),c(0,1),c(0.05,0.3),c(0.05,0.3),c(0,1),c(6,12))
nbreaks = list(5,5,5,6,6,5,5,5,5,8)
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
        #     stat_summary(fun.y=mean, geom="point",aes(col=method_name)) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=35,height=15)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()


## BMMC 4000 cells

### Load result

In [None]:
dir_path = "/project/mingyaolpc/myylee/scmint/methods_eval/dataset/bmmc/nmulti4000_all_ct_vdepth/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric.csv")
table(res_all$method)
table(res_all$out)

res_all$log2_runtime <- log2(res_all$runtime)

# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","scmomat")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all %>% 
    #filter(var %in% c(1000,3000,5000)) %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


### Summary plot

In [None]:
file_name <- "bmmc_nmulti4000_vdepth_metric_summary19_21ct_fair_rev1.pdf"

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))
# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = unique(res_plot_3$var)
ylim_list = list(c(0,0.5),c(0.2,0.7),c(0.45,0.7),c(0.85,1),c(0.4,1),c(0,1),c(0.05,0.3),c(0.05,0.3),c(0,1),c(6,12))

nbreaks = list(5,5,5,6,6,5,5,5,5,8)
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=35,height=15) 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

## BMMC Increasing number of cells; 100% depth

### Load result

In [None]:
# read in metric in ct3_metric.csv
dir_path = "dataset/bmmc/bmmc_vcells_intervals/"

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric.csv")

res_all$log2_runtime <- log2(res_all$runtime)

# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2")
res_plot <- res_all %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_split"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

table(res_plot$method_name)

method_names = c("Seurat v3","Seurat v4")
row_order = c("seurat3_2","seurat4")
names(method_names)  = row_order

method_type = c(rep("Unpaired with \n Multiome-split",1),rep("Multiome-guided",1))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-split","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))




### Summary plot

In [None]:
file_name <- "bmmc_nmulti4000_vdepth_10interv_summary19_21ct_fair.pdf"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))

res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = unique(res_plot_3$var)

ylim_list = list(c(0,0.5),c(0.2,0.7),c(0.45,0.7),c(0.85,1),c(0.4,1),c(0,1),c(0.05,0.3),c(0.05,0.3),c(0,1),c(6,14))
nbreaks = list(5,5,5,6,6,5,5,5,5,8)
counter=1
subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values= hex,labels=names(hex)) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x){x$name}) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=35,height=15)              
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()


## BMMC Increasing depth; 4000 cells

### Load result

In [None]:
dir_path = "dataset/bmmc/bmmc_vdepth_intervals/"

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric.csv")

res_all$log2_runtime <- log2(res_all$runtime)

# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2")
res_plot <- res_all %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_split"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","percent_recovered_50kb","f1",
               "log2_runtime")

table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt",
                 "Seurat v3","bindSC","FigR", "Liger")
row_order = c("multivi","seurat4","cobolt",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",3),rep("Unpaired with \n Multiome-split",4))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Peak-gene \n pair % recovered","F1","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-split","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))




### Summary plot

In [None]:
file_name <- "bmmc_nmulti4000_vdepth_10interv_summary19_21ct_fair.pdf"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

subplot_idx = list(c("ARI","NMI"),c("Cell type ASW","cLisi"),c("Batch ASW","kBET"),
                   c("Peak-gene \n pair % recovered","F1"),c('log2(Runtime)',"Graph \n connectivity"))

res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = unique(res_plot_3$var)

ylim_list = list(c(0,0.5),c(0.2,0.7),c(0.45,0.7),c(0.85,1),c(0.4,1),c(0,1),c(0.05,0.3),c(0.05,0.3),c(0,1),c(6,14))
nbreaks = list(5,5,5,6,6,5,5,5,5,8)
counter=1
subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values= hex,labels=names(hex)) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x){x$name}) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
#pdf(file.path(folder_path,file_name),width=35,height=15) 
pdf("tmp_fig4.pdf",width=35,height=15) 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()


# Scenario 3

## Technical batch effect

### Load result

In [None]:
# read in metric in ct3wbatch_metric.csv, then ct3wbatch_metric_sample_batch.csv, and then combine
dir_path = "dataset/bmmc/multiome_ncells_pmat_batch/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3wbatch_metric.csv")
table(res_all$method)
table(res_all$out)


res_all_batch <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3wbatch_metric_sample_batch.csv")
table(res_all_batch$method)
table(res_all_batch$out)

res_all_comb <- left_join(res_all,res_all_batch[2:10])
res_all_comb$log2_runtime <- log2(res_all_comb$runtime)

table(is.na(res_all_comb$batch_b_saw))

# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","scmomat")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all_comb %>% 
    filter(var %in% c(1000,3000,5000)) %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_split"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","batch_b_saw","batch_kbet",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3wbatch_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-split",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Site ASW","Site kBET","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-split","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))



### Summary plot

In [None]:
file_name <- "bmmc_tbatch_metric_summary19_21ct_fair_rev1.pdf"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

subplot_idx = list(c("ARI","NMI"),
                   c("Cell type ASW","cLisi"),
                   c("Batch ASW","kBET"),
                   c("Site ASW","Site kBET"),
                   c("Graph \n connectivity",'log2(Runtime)'))

ylim_list = list(c(0.1,0.7),c(0.3,0.8),
                 c(0.4,0.7), c(0.8,1.0),
                 c(0.5,1),c(0.1,0.8),
                 c(0.2 ,1),c(0.1,1),
                 c(0,1),c(6,12))
nbreaks = list(6,5,
               6,5,
               5,6,
               7,8,
               5,6)

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = unique(res_plot_3$var)
subplots = list()
counter = 1
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(#strip.text.x = element_text(size = 35,angle=90,colour = "black"),
                  strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 35,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),axis.text.y = element_text(size = 30),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=30,height=10)             
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

## Biological batch effect

### Load result

In [None]:
# read in metric in ct3wbatch_metric.csv, then ct3wbatch_metric_sample_batch.csv, and then combine
dir_path = "dataset/bmmc/bmmc_biological_batch_test/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3wbatch_metric.csv")

res_all_batch <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3wbatch_metric_sample_batch.csv")

res_all_comb <- left_join(res_all,res_all_batch[2:10])
res_all_comb$log2_runtime <- log2(res_all_comb$runtime)

table(is.na(res_all_comb$batch_b_saw))

# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","scmomat")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all_comb %>% 
    filter(var %in% c(1000,3000,5000)) %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_split"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","batch_b_saw","batch_kbet",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3wbatch_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-split",5))
names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Donor ASW","Donor kBET","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-split","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))



### Summary plot

In [None]:
file_name <- "summary_bmmc_biological_batch_test_rev1.pdf"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

subplot_idx = list(c("ARI","NMI"),
                   c("Cell type ASW","cLisi"),
                   c("Batch ASW","kBET"),
                   c("Donor ASW","Donor kBET"),
                   c("Graph \n connectivity",'log2(Runtime)'))

ylim_list = list(c(0.1,0.7),c(0.3,0.8),
                 c(0.4,0.7), c(0.8,1.0),
                 c(0.5,1),c(0.1,0.8),
                 c(0.2,1),c(0.1,1),
                 c(0,1),c(6,12))
nbreaks = list(6,5,
               6,5,
               5,6,
               8,9,
               5,6)

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))
x_interval = unique(res_plot_3$var)
subplots = list()
counter = 1
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(#strip.text.x = element_text(size = 35,angle=90,colour = "black"),
                  strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 35,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),axis.text.y = element_text(size = 30),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=30,height=10)             
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()


## Complex test #1

### Load result

In [None]:
# read in metric in ct3wbatch_metric.csv, then ct3wbatch_metric_sample_batch.csv, and then combine
dir_path = "dataset/bmmc/bmmc_complex1_test/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric.csv")

res_all_batch <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric_sample_batch.csv")

res_all_comb <- left_join(res_all,res_all_batch[2:10])
res_all_comb$log2_runtime <- log2(res_all_comb$runtime)

table(is.na(res_all_comb$batch_b_saw))

# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","seurat4int","scmomat")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all_comb %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","batch_b_saw","batch_kbet",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v4 integrate","Seurat v3","bindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","bindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat4int","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")

names(method_names)  = row_order

method_type = c(rep("Multiome-guided",5),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))

names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Donor ASW","Donor kBET","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


### Summary plot

In [None]:
file_name <- "summary_bmmc_complex1_test_rev1.pdf"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)
subplot_idx = list(c("ARI","NMI"),
                   c("Cell type ASW","cLisi"),
                   c("Batch ASW","kBET"),
                   c("Donor ASW","Donor kBET"),
                   c("Graph \n connectivity",'log2(Runtime)'))

#ylim_list = list(c(0,0.7),c(0,0.8),
ylim_list = list(c(0.2,0.65),c(0.4,0.8),
                 c(0.4,0.7), c(0.7,1.0),
                 c(0.5,1),c(0,0.8),
                 c(0.4,1),c(0,0.9),
                 c(0.5,1),c(6,12))
#nbreaks = list(6,6,
nbreaks = list(6,6,
               6,6,
               5,8,
               8,9,
               5,6)

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

res_plot_3$method <- factor(res_plot_3$method,
                           levels=c("seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2",
                                   "cobolt","multivi","seurat4","seurat4int","scmomat"))

x_interval = unique(res_plot_3$var)
subplots = list()
counter = 1
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            #geom_point(size=3) +
            geom_boxplot() +
            scale_x_continuous(breaks=x_interval) +
            #scale_color_manual(values= hex,labels=names(hex)) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(#strip.text.x = element_text(size = 35,angle=90,colour = "black"),
                  strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 35,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),axis.text.y = element_text(size = 30),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=30,height=10)             
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()



## Complex test #2

### Load result

In [None]:
# read in metric in ct3wbatch_metric.csv, then ct3wbatch_metric_sample_batch.csv, and then combine
dir_path = "dataset/bmmc/bmmc_complex2_test/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric.csv")

res_all_batch <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric_sample_batch.csv")

res_all_comb <- left_join(res_all,res_all_batch[2:10])
res_all_comb$log2_runtime <- log2(res_all_comb$runtime)

table(is.na(res_all_comb$batch_b_saw))

# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","scmomat","seurat4int")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all_comb %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","batch_b_saw","batch_kbet",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v4 integrate","Seurat v3","bindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","bindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat4int","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")

names(method_names)  = row_order

method_type = c(rep("Multiome-guided",5),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))

names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Donor ASW","Donor kBET","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


### Summary plot

In [None]:
# color palette for methods 
file_name <- "summary_bmmc_complex2_test.pdf"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

subplot_idx = list(c("ARI","NMI"),
                   c("Cell type ASW","cLisi"),
                   c("Batch ASW","kBET"),
                   c("Donor ASW","Donor kBET"),
                   c("Graph \n connectivity",'log2(Runtime)'))

ylim_list = list(c(0.2,0.7),c(0.45,0.8), 
                 c(0.4,0.7), c(0.85,1.0), 
                 c(0.5,1),c(0,0.8),
                 c(0.4,1),c(0,0.9),
                 c(0.5,1),c(6,14))

nbreaks = list(6,6, 
               6,8,
               5,8,
               8,9,
               5,8)
res_plot_3 <- res_plot_2 %>% 
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

res_plot_3$method <- factor(res_plot_3$method,
                           levels=c("seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2",
                                   "cobolt","multivi","seurat4","seurat4int","scmomat"))

x_interval = unique(res_plot_3$var)
subplots = list()
counter = 1
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            #geom_point(size=3) +
            geom_boxplot() +
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(#strip.text.x = element_text(size = 35,angle=90,colour = "black"),
                  strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 35,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),axis.text.y = element_text(size = 30),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]]),
            scale_y_continuous(limits = ylim_list[[counter+1]],n.breaks = nbreaks[[counter+1]])

        )
    )
    counter = counter + 2
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x){x$name}) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=30,height=10)             
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()
                              

# Scenario 4

## PBMC

### Load result

In [None]:
dir_path = "dataset/pbmc/single_modality_fixed_missing_ct/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)


res_all<- plot_rare_cell_result(dir_path,
                      runtime_exists=T,prediction_exists=F)
res_all$log2_runtime <- log2(res_all$runtime)
res_all 

plot_ref_accuracy = FALSE

# add _2 to method name and add type column
unpaired_method_list<-c("seurat3","rfigr","rbindsc","rliger","glue")
res_plot <- res_all %>% 
    filter(var %in% c(1000,3000,6000)) %>% 
    mutate(method_name = ifelse(out=="results_single_same_cell_number",paste0(method,"_2"),method)) %>%
    # assign method type 
    mutate(type = ifelse(out!="results_single_same_cell_number",
                         ifelse(method%in%unpaired_method_list,"Unpaired","Multiome-guided"),
                         "Unpaired_multiome_splitted"),
           ct = gsub("_metric.csv","",ct_ref),
          metric=paste0(orig,"_",ct))

metric_cond = c("noMiss","rnaMissNK","atacMissNK")
cts = c("ct3_NK")
cts_name = c("NK")

metric_sel = paste0(rep(metric_cond,each=length(cts)),"_",rep(cts_name,length(metric_cond)))

table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","bindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","bindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = paste0(rep(metric_cond,each=length(cts_name)),"_",rep(cts_name,length(metric_cond)))
names(title_list)=metric_sel



res_plot_2 <- res_plot %>% 
    separate(metric, c('cond', 'fill',"ct","dataset"),sep = "_") %>%
    mutate(nTarget = TP+FN) %>%
    mutate(perAcc = TP/nTarget) %>%
    select(F1,nTarget,perAcc,log2_runtime,method_name,var,cond,ct,dataset) %>%
    gather(key = "key",
               value = "value", -method_name,-var,-cond,-ct,-dataset)%>%
    mutate(method = method_name,metric_type = paste0(cond,"_",ct)) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))



### Plots

#### RNA

In [None]:
file_name <- "pbmc_rare_cell_1_f1_try2_rna.pdf"

subplot_idx = list(c("noMiss_NK"),
                   c("atacMissNK_NK")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="scRNA") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

x_interval = c(1000,3000,6000)
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
        #     stat_summary(fun.y=mean, geom="point",aes(col=method_name)) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
    
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=15,height=10)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

#### ATAC

In [None]:
file_name <- "pbmc_rare_cell_1_f1_try2_atac.pdf"
library(gridExtra)

subplot_idx = list(c("noMiss_NK"),
                   c("rnaMissNK_NK")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="snATAC") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

x_interval = c(1000,3000,6000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
    
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=15,height=10)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

## SHARE-seq

### Load result

In [None]:
dir_path = "/dataset/mouse_skin/single_modality_fixed_missing_ct/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all<- plot_rare_cell_result(dir_path,
                      runtime_exists=T,prediction_exists=F)
plot_ref_accuracy = FALSE

res_all <- res_all %>% filter(out %in% c("results_single_mod","results_single_same_cell_number"))
res_all$log2_runtime <- log2(res_all$runtime)
res_all 

# add _2 to method name and add type column
unpaired_method_list<-c("seurat3","rfigr","rbindsc","rliger","glue")
res_plot <- res_all %>% 
    filter(var %in% c(1000,3000,6000)) %>% 
    mutate(method_name = ifelse(out=="results_single_same_cell_number",paste0(method,"_2"),method)) %>%
    # assign method type 
    mutate(type = ifelse(out!="results_single_same_cell_number",
                         ifelse(method%in%unpaired_method_list,"Unpaired","Multiome-guided"),
                         "Unpaired_multiome_splitted"),
           ct = gsub("_metric.csv","",ct_ref),
          metric=paste0(orig,"_",ct))

metric_cond = c("noMiss","rnaMissHS","atacMissHS","rnaMissEndo","atacMissEndo",
                "rnaMissTwo","atacMissTwo","eachMissOne","eachMissOneAlt")
cts = c("10k_HS","10k_Endo")
cts_name = c("HS","Endo")

metric_sel = paste0(rep(metric_cond,each=length(cts)),"_",rep(cts_name,length(metric_cond)))

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","bindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","bindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = paste0(rep(metric_cond,each=length(cts_name)),"_",rep(cts_name,length(metric_cond)))
names(title_list)=metric_sel



res_plot_2 <- res_plot %>% 
    separate(metric, c('cond', 'fill',"ct","dataset"),sep = "_") %>%
    mutate(nTarget = TP+FN) %>%
    mutate(perAcc = TP/nTarget) %>%
    select(F1,nTarget,perAcc,log2_runtime,method_name,var,cond,ct,dataset) %>%
    gather(key = "key",
               value = "value", -method_name,-var,-cond,-ct,-dataset)%>%
    mutate(method = method_name,metric_type = paste0(cond,"_",ct)) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


res_plot_2


### Plots

#### scRNA F1 HS

In [None]:
file_name <- "shareseq_mouse_skin_rare_cell_1_f1_try2_rna_HS.pdf"
subplot_idx = list(c("noMiss_HS"),
                   c("atacMissHS_HS"),
                   c("atacMissTwo_HS")
                  )
# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="scRNA") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

x_interval = c(1000,3000,6000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
    
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])
        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=20,height=10)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()
                              
   
                    

#### scRNA F1 Endo

In [None]:
file_name <- "shareseq_mouse_skin_rare_cell_1_f1_try2_rna_Endo.pdf"
subplot_idx = list(c("noMiss_Endo"),
                   c("atacMissEndo_Endo"),
                   c("atacMissTwo_Endo")
                  )
# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="scRNA") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

x_interval = c(1000,3000,6000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
    
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])
        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=20,height=10)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()
                              
   
                    

#### snATAC F1 HS

In [None]:

file_name <- "shareseq_mouse_skin_rare_cell_1_f1_try2_atac_HS.pdf"

subplot_idx = list(c("noMiss_HS"),
                   c("rnaMissHS_HS"),
                   c("rnaMissTwo_HS")
                  )


# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="snATAC") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

x_interval = c(1000,3000,6000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
        #     stat_summary(fun.y=mean, geom="point",aes(col=method_name)) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
    
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=20,height=10)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

#### snATAC F1 Endo

In [None]:

file_name <- "shareseq_mouse_skin_rare_cell_1_f1_try2_atac_Endo.pdf"

subplot_idx = list(c("noMiss_Endo"),
                   c("multiMissEndo_Endo"),
                   c("atacOnlyEndo_Endo")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="snATAC") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

x_interval = c(1000,3000,6000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=mean_f,group=method,col=method_name,shape=type)) + 
            geom_point(size=3) +
            geom_line() +
            geom_errorbar(aes(ymin=mean_f-sd_f, ymax=mean_f+sd_f),
                         position=position_dodge(0.05),alpha=0.2)+
            scale_shape_manual(values = c(15,16,17))+#c(0,1,2))+
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
    
            theme(strip.text.x = element_text(size = 30,angle=90,colour = "black"),
                  strip.text.y = element_text(size = 30,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 25,angle = 90),axis.text.y = element_text(size = 25),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(5, "lines"),panel.spacing.x = unit(2, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=20,height=10)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

# Scenario 5

## PBMC

### Load result

In [None]:
dir_path = "dataset/pbmc/multiome_fixed_missing_ct/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all<- plot_rare_cell_result(dir_path,runtime_exists=T,prediction_exists=F)
plot_ref_accuracy = FALSE

res_all <- res_all %>% filter(out %in% c("results_single_mod","results_single_same_cell_number"))
res_all$log2_runtime <- log2(res_all$runtime)
res_all 

# add _2 to method name and add type column
unpaired_method_list<-c("seurat3","rfigr","rbindsc","rliger","glue")
res_plot <- res_all %>% 
    filter(var %in% c(1000,3000,6000)) %>% 
    mutate(method_name = ifelse(out=="results_single_same_cell_number",paste0(method,"_2"),method)) %>%
    # assign method type 
    mutate(type = ifelse(out!="results_single_same_cell_number",
                         ifelse(method%in%unpaired_method_list,"Unpaired","Multiome-guided"),
                         "Unpaired_multiome_splitted"),
           ct = gsub("_metric.csv","",ct_ref),
          metric=paste0(orig,"_",ct))

metric_cond = c("noMiss","multiMissNK","rnaOnlyNK","atacOnlyNK")

cts_name = c("NK")

metric_sel = paste0(rep(metric_cond,each=length(cts_name)),"_",rep(cts_name,length(metric_cond)))

table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = paste0(rep(metric_cond,each=length(cts_name)),"_",rep(cts_name,length(metric_cond)))
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    separate(metric, c('cond', 'fill',"ct","dataset"),sep = "_") %>%
    mutate(nTarget = TP+FN) %>%
    mutate(perAcc = TP/nTarget) %>%
    select(F1,nTarget,perAcc,log2_runtime,method_name,var,cond,ct,dataset) %>%
    gather(key = "key",
               value = "value", -method_name,-var,-cond,-ct,-dataset)%>%
    mutate(method = method_name,metric_type = paste0(cond,"_",ct)) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


res_plot_2


### Plots

#### boxplot RNA

In [None]:
file_name <- "pbmc_rare_cell_2_f1_try2_rna_boxplot.pdf"
library(gridExtra)

subplot_idx = list(c("noMiss_NK"),
                   c("multiMissNK_NK"),
                   c("rnaOnlyNK_NK")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="scRNA") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

# plot only the unpaired and multiome-guided categories
res_plot_3 <- res_plot_3 %>% filter(type %in% c("Unpaired","Multiome-guided"))

x_interval = c(3000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            geom_boxplot(lwd=1) +
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values=hex[method_names]) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 50,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),
                  axis.text.y = element_text(size = 50),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=25,height=8)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

#### boxplot ATAC

In [None]:
file_name <- "pbmc_rare_cell_2_f1_try2_atac_boxplot.pdf"
library(gridExtra)

subplot_idx = list(c("noMiss_NK"),
                   c("multiMissNK_NK"),
                   c("atacOnlyNK_NK")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="snATAC") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

# plot only the unpaired and multiome-guided categories
res_plot_3 <- res_plot_3 %>% filter(type %in% c("Unpaired","Multiome-guided"))

x_interval = c(3000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            geom_boxplot(lwd=1) +
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values=hex[method_names]) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 50,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),
                  axis.text.y = element_text(size = 50),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=25,height=8)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

## SHARE-seq

### Load result

In [None]:
dir_path = "/project/mingyaolpc/myylee/scmint/methods_eval/dataset/mouse_skin/multiome_fixed_missing_ct/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)
 
res_all<- plot_rare_cell_result(dir_path,runtime_exists=T,prediction_exists=F)
plot_ref_accuracy = FALSE

res_all <- res_all %>% filter(out %in% c("results_single_mod","results_single_same_cell_number"))
res_all$log2_runtime <- log2(res_all$runtime)
res_all 

# add _2 to method name and add type column
unpaired_method_list<-c("seurat3","rfigr","rbindsc","rliger","glue")
res_plot <- res_all %>% 
    filter(var %in% c(1000,3000,6000)) %>% 
    mutate(method_name = ifelse(out=="results_single_same_cell_number",paste0(method,"_2"),method)) %>%
    # assign method type 
    mutate(type = ifelse(out!="results_single_same_cell_number",
                         ifelse(method%in%unpaired_method_list,"Unpaired","Multiome-guided"),
                         "Unpaired_multiome_splitted"),
           ct = gsub("_metric.csv","",ct_ref),
          metric=paste0(orig,"_",ct))

metric_cond = c("noMiss","multiMissHS","rnaOnlyHS","atacOnlyHS","multiOnlyHS",
                "multiMissEndo","rnaOnlyEndo","atacOnlyEndo","multiOnlyEndo")
cts = c("10k_HS","10k_Endo")
cts_name = c("HS","Endo")

metric_sel = paste0(rep(metric_cond,each=length(cts)),"_",rep(cts_name,length(metric_cond)))

table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v3","bindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","bindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")
names(method_names)  = row_order

method_type = c(rep("Multiome-guided",4),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))
names(method_type) <- row_order

title_list = paste0(rep(metric_cond,each=length(cts_name)),"_",rep(cts_name,length(metric_cond)))
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    separate(metric, c('cond', 'fill',"ct","dataset"),sep = "_") %>%
    mutate(nTarget = TP+FN) %>%
    mutate(perAcc = TP/nTarget) %>%
    select(F1,nTarget,perAcc,log2_runtime,method_name,var,cond,ct,dataset) %>%
    gather(key = "key",
               value = "value", -method_name,-var,-cond,-ct,-dataset)%>%
    mutate(method = method_name,metric_type = paste0(cond,"_",ct)) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


res_plot_2


### Plots

#### box plot - RNA HS

In [None]:
file_name <- "shareseq_mouse_skin_rare_cell_2_f1_try2_boxplot_rna_HS.pdf"

subplot_idx = list(c("noMiss_HS"),
                   c("multiMissHS_HS"),
                   c("rnaOnlyHS_HS")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="scRNA") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))


# plot only the unpaired and multiome-guided categories
res_plot_3 <- res_plot_3 %>% filter(type %in% c("Unpaired","Multiome-guided"))

x_interval = c(3000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            geom_boxplot(lwd=1) +
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values=hex[method_names]) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 50,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),
                  axis.text.y = element_text(size = 50),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                  panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=25,height=8)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

#### box plot - RNA Endo

In [None]:
file_name <- "shareseq_mouse_skin_rare_cell_2_f1_try2_boxplot_rna_Endo.pdf"
subplot_idx = list(c("noMiss_Endo"),
                   c("multiMissEndo_Endo"),
                   c("rnaOnlyEndo_Endo")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="scRNA") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    # mean after grouping by metric and method
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))


# plot only the unpaired and multiome-guided categories
res_plot_3 <- res_plot_3 %>% filter(type %in% c("Unpaired","Multiome-guided"))

x_interval = c(3000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            geom_boxplot(lwd=1) +
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values=hex[method_names]) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 50,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),
                  axis.text.y = element_text(size = 50),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                  panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=25,height=8)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

#### box plot - ATAC HS

In [None]:
file_name <- "shareseq_mouse_skin_rare_cell_2_f1_try2_boxplot_atac_HS.pdf"

subplot_idx = list(c("noMiss_HS"),
                   c("multiMissHS_HS"),
                   c("atacOnlyHS_HS")
                  )

# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="snATAC") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

# plot only the unpaired and multiome-guided categories
res_plot_3 <- res_plot_3 %>% filter(type %in% c("Unpaired","Multiome-guided"))

x_interval = c(3000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            geom_boxplot(lwd=1) +
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values=hex[method_names]) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 50,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),
                  axis.text.y = element_text(size = 50),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                  panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=25,height=8)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

#### box plot - ATAC Endo

In [None]:
file_name <- "shareseq_mouse_skin_rare_cell_2_f1_try2_boxplot_atac_Endo.pdf"

subplot_idx = list(c("noMiss_Endo"),
                   c("multiMissEndo_Endo"),
                   c("atacOnlyEndo_Endo")
                  )
# add summarized values: unpaired method - summarize across all var (despite 1000/3000/8000 multiome cells); for others, summarize across repeats

res_plot_3 <- res_plot_2 %>% 
    filter(dataset=="snATAC") %>%
    filter(key=="F1") %>%
    group_by(metric_name,var,method) %>% 
    # mean after grouping by metric, method, and var
    dplyr::mutate(mean_var = mean(value),
                  sd_var = sd(value)) %>%
    group_by(metric_name,method) %>%
    dplyr::mutate(mean_method = mean(value),
                  sd_method = sd(value)) %>%
    mutate(mean_f = ifelse(type=="Unpaired",mean_method,mean_var),
           sd_f = ifelse(type=="Unpaired",0,sd_var))

# plot only the unpaired and multiome-guided categories
res_plot_3 <- res_plot_3 %>% filter(type %in% c("Unpaired","Multiome-guided"))

x_interval = c(3000) 
ylim_list = rep(list(c(0,1)), length(unlist(subplot_idx)))
nbreaks = rep(list(8),length(unlist(subplot_idx)))
counter=1

subplots = list()
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_3 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            geom_boxplot(lwd=1) +
            scale_x_continuous(breaks=x_interval) +
            scale_color_manual(values=hex[method_names]) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            theme_light()+
            theme(strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 50,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),
                  axis.text.y = element_text(size = 50),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                  panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])

        )
    )
    counter = counter + 1
    if(plot_ref_accuracy){
        if("F1" %in% subplot_idx[[i]]){
            subplots[[i]] = subplots[[i]] +
                geom_hline(data = truth_df, aes(yintercept = Z),linetype =2,color="red")
        }
    }

}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x) x$name) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=25,height=8)                 
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

# HPAP integration

## Load results

In [None]:
dir_path =  "dataset/hpap/real_data/"
folder_path <- "figures/metric_plots/"
dir.create(folder_path)

res_all <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric.csv")

res_all_batch <- plot_result(dir_path,res_folder_filter = c("results"),runtime_exists=T,prediction_exists=T,metric_file = "ct3_metric_sample_batch.csv")

res_all_comb <- left_join(res_all,res_all_batch[2:10])
res_all_comb$log2_runtime <- log2(res_all_comb$runtime)

table(is.na(res_all_comb$batch_b_saw))


# add _2 to method name and add type column
multiome_guided_method_list<-c("multivi","seurat4","cobolt","scmomat","seurat4int")
unpaired_method_list<-c("seurat3_2","rfigr_2","rbindsc_2","rliger_2","glue_2")
res_plot <- res_all_comb %>% 
    mutate(method_name = ifelse(method %in% multiome_guided_method_list,method,paste0(method,"_2")))%>% 
    # assign method type
    mutate(type = ifelse(method_name %in% unpaired_method_list,"Multiome-guided","Unpaired_multiome_splitted"))

metric_sel = c("ARI","NMI","ct_aws","clisi","b_saw","kbet","gconn","batch_b_saw","batch_kbet",
               "log2_runtime")

res_plot = res_plot %>% filter(ct_ref =="ct3_metric.csv")
table(res_plot$method_name)

method_names = c("MultiVI","Seurat v4", "Cobolt","scMoMaT","Seurat v4 integrate","Seurat v3","BindSC","FigR", "LIGER","GLUE",
                 "Seurat v3","BindSC","FigR", "LIGER","GLUE")
row_order = c("multivi","seurat4","cobolt","scmomat","seurat4int","seurat3","rbindsc","rfigr","rliger","glue",
              "seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2")

names(method_names)  = row_order

method_type = c(rep("Multiome-guided",5),rep("Unpaired",5),rep("Unpaired with \n Multiome-splitted",5))

names(method_type) <- row_order

title_list = c("ARI","NMI","Cell type ASW","cLisi", "Batch ASW","kBET","Graph \n connectivity", 
               "Donor ASW","Donor kBET","log2(Runtime)")
names(title_list)=metric_sel


res_plot_2 <- res_plot %>% 
    select(ARI:gconn,percent_recovered_50kb:log2_runtime,method_name,var) %>%
    gather(key = "metric_type",
               value = "value", -method_name,-var,)%>%
    mutate(method = method_name) %>%
    mutate(type=factor(method_type[method],levels = c("Unpaired","Unpaired with \n Multiome-splitted","Multiome-guided"))) %>% 
    mutate(method_name= factor(method_names[method],levels=unique(method_names))) %>% 
    mutate(metric_name = factor(title_list[metric_type],levels = title_list))


## Summary plot

In [None]:
file_name <- "hpap_fair_long_scmomatCorrected.pdf"

subplot_idx = list(c("Batch ASW","kBET"),
                   c("Donor ASW","Donor kBET"),
                   c('log2(Runtime)'))
subplot_idx = as.list(unlist(subplot_idx))


ylim_list = list(c(0.5,1),c(0,0.5),c(0.5,1),c(0,0.7),c(8,16))

nbreaks = list(5,8,8,9,8)


res_plot_2$method <- factor(res_plot_2$method,
                           levels=c("seurat3_2","rbindsc_2","rfigr_2","rliger_2","glue_2",
                                   "cobolt","multivi","seurat4","seurat4int","scmomat"))

x_interval = unique(res_plot_2$var)
subplots = list()
counter = 1
for(i in 1:length(subplot_idx)){
    print(i)
    subplots[[i]] = res_plot_2 %>%  
            filter(metric_name %in% subplot_idx[[i]]) %>%  
            ggplot(.,aes(x=var,y=value,group=method,col=method_name,shape=type)) +
            geom_point(size=3) +
            scale_x_continuous(breaks=x_interval) +
            facet_grid(cols = vars(type),rows=vars(metric_name),scales = "free_y", switch="y") +
            scale_color_manual(values=hex[method_names])+
            theme_light()+
            theme(strip.text.x = element_blank(),
                  strip.text.y = element_text(size = 35,angle=90,colour = "black"),
                  axis.text.x  = element_text(size = 30,angle = 90),axis.text.y = element_text(size = 30),
                  axis.title.x = element_blank(),axis.title.y = element_blank(),
                  strip.background = element_blank(), strip.placement = "outside",
                  panel.spacing.y = unit(2, "lines"),panel.spacing.x = unit(1, "lines"),
                 panel.border=element_blank(), axis.line=element_line())+
    ggh4x::facetted_pos_scales(
        y = list(
            scale_y_continuous(limits = ylim_list[[counter]],n.breaks = nbreaks[[counter]])
        )
    )
    counter = counter + 1
}
subplots_no_bdg = lapply(subplots,function(x){x+theme(legend.position = "none")})
pcomb <- grid.arrange(grobs=subplots_no_bdg,nrow=1)


# function to extract legend from plot
get_only_legend <- function(plot) {
  plot_table <- ggplot_gtable(ggplot_build(plot))
  legend_plot <- which(sapply(plot_table$grobs, function(x){x$name}) == "guide-box")
  legend <- plot_table$grobs[[legend_plot]]
  return(legend)
}
                            
# extract legend from plot1 using above function
legend <- get_only_legend(subplots[[1]])   
                                                          
pdf(file.path(folder_path,file_name),width=30,height=8)             
grid.arrange(pcomb, legend, ncol = 2, widths = c(10, 1))
dev.off()

                              