# Install Libraries
Install information can be found [here](https://satijalab.org/seurat/articles/install.html)\
Vignette [here](https://satijalab.org/seurat/articles/integration_mapping.html)

In [None]:
# Parameters for input
reference_data_id = "id-of-reference-data"
query_data = "path-to-seurat-object" # Seurat object for query data

# Parameters for label prediction
reference_label = "cell_type_annot"

# Parameters for reference genome
genome = "mm10" # either hg38 or mm10
gene_id_to_symbol = TRUE

prefix = "prefix" #project name

#Papermill specific parameters
papermill = TRUE

In [None]:
source("/usr/local/bin/cell_annotation_helper_functions.R")

In [None]:
# #########################
# For test
# source("/data/pinello/PROJECTS/2023_02_SHARE_Pipeline/epi-SHARE-seq-pipeline/src/R/cell_annotation_helper_functions.R")
# reference_data <- "reference.h5ad"
# query_data <- "../../../QueryData/MouseBrain/SS-PKR-129-192-PLATE1-LEFT-HALF.rna.seurat.filtered_rds.mm10.rds"
# genome <- "mm10"

In [None]:
papermill <- as.logical(papermill)

In [None]:
suppressMessages(library(Seurat))
suppressMessages(library(anndata))
suppressMessages(library(reticulate))
suppressMessages(library(Matrix))
suppressMessages(library(future))
suppressMessages(library(logr))
suppressMessages(library(dplyr))
suppressMessages(library(grid))
suppressMessages(library(gridExtra))
suppressMessages(library(ggplot2))
suppressMessages(library(patchwork))
suppressMessages(library(cowplot))
suppressMessages(library(EnsDb.Mmusculus.v79))
suppressMessages(library(EnsDb.Hsapiens.v86))
suppressMessages(library(cellxgene.census))

In [None]:
options("logr.notes" = FALSE)
options(future.globals.maxSize=10e9)
set.seed(1234)

In [None]:
#Function to save plots
plot_filename = glue::glue("{prefix}.rna.cell.annotation.plots.{genome}")
dir.create(plot_filename, showWarnings=F)

printPNG <- function(name, plot, papermill, width = 22, height = 11){
    filename = glue::glue("{plot_filename}/{prefix}.rna.cell.annotation.{name}.{genome}.png")
    
    if(papermill){
        ggsave(plot = plot, filename = filename, width = width, height = height)
    }
}

#Create log file
logfile <- file.path(glue::glue("{prefix}.rna.cell.annotation.logfile.{genome}.txt"))
lf <- log_open(logfile)

In [None]:
# Download reference data based on ID
tryCatch(
    {
        log_print("# Downloading reference data...")
        cellxgene.census::download_source_h5ad(reference_data_id, file = "reference.h5ad")
        log_print("SUCCESSFUL: downloading reference data")
    
    },
    error = function(cond){
        log_print("ERROR: downloading reference data")
        log_print(cond)
    
    }
)

In [None]:
# Loading H5AD file
tryCatch(
    {
        log_print("# Loading reference data...")
        adata <- read_h5ad("reference.h5ad")
        
        counts <- t(as.matrix(adata$raw$X))
        colnames(counts) <- adata$obs_names
        rownames(counts) <- adata$var_names
        
        metadata <- as.data.frame(adata$obs)

        obj.ref <- CreateSeuratObject(counts = counts, assay = "RNA")
        obj.ref <- AddMetaData(obj.ref, metadata)
        
        rm(counts)
        gc()
        
    },
    error = function(cond){
        log_print("ERROR: loading reference data")
        log_print(cond)
    
    }


)

In [None]:
# Read query data
tryCatch(
    {
        log_print("# Reading query data...")
        obj.query <- readRDS(query_data)
        log_print("SUCCESSFUL: Reading query data")
    
    },
    error = function(cond) {
        log_print("ERROR: Reading query data")
        log_print(cond)
    }
)

In [None]:
# Convert gene ID to symbol for reference data
if(gene_id_to_symbol){
    tryCatch(
    {
        log_print("# Converting gene id to symbol for reference data")
        
        if(genome == "hg38" | genome == "hg37"){
            gene.id <- ensembldb::select(EnsDb.Hsapiens.v86, 
                                         keys= rownames(obj.ref), 
                                         keytype = "GENEID", 
                                         columns = c("SYMBOL","GENEID"))
    
        } else if(genome == "mm10" | genome == "mm9"){
            gene.id <- ensembldb::select(EnsDb.Mmusculus.v79, 
                                         keys = rownames(obj.ref), 
                                         keytype = "GENEID", 
                                         columns = c("SYMBOL","GENEID"))
        }
        
        # remove genes with empty symbol
        gene.id <- subset(gene.id, gene.id$SYMBOL != "")

        # make gene symbol unique
        gene.id$Unique_SYMBOL <- make.unique(gene.id$SYMBOL, "")

        counts <- obj.ref@assays$RNA@counts
        colnames(counts) <- colnames(obj.ref)
        rownames(counts) <- rownames(obj.ref)
        
        counts <- counts[gene.id$GENEID, ]
        rownames(counts) <- gene.id$Unique_SYMBOL

        obj.ref <- CreateSeuratObject(counts = counts, 
                                      meta.data = obj.ref@meta.data)
        
        log_print("SUCCESSFUL: Converting gene id to symbol for reference data")

    },
    error = function(cond) {
        log_print("ERROR: Converting gene id to symbol for reference data")
        log_print(cond)
    }
)
}

In [None]:
# Subset reference data
tryCatch(
    {
        log_print("# Subseting reference and query data with common genes")
        
        gene.common <- intersect(rownames(obj.ref), rownames(obj.query))
        
        counts <- obj.ref@assays$RNA@counts[gene.common, ]
        obj.ref <- CreateSeuratObject(counts = counts, 
                                      assay = "RNA",
                                      meta.data = obj.ref@meta.data)
        
        counts <- obj.query@assays$RNA@counts[gene.common, ]
        obj.query <- CreateSeuratObject(counts = counts, 
                                        assay = "RNA",
                                        meta.data = obj.query@meta.data)
        
        
        obj.ref <- obj.ref %>%
            NormalizeData(verbose = FALSE) %>%
            FindVariableFeatures() %>%
            ScaleData() %>%
            RunPCA(verbose=FALSE) %>%
            RunUMAP(verbose=FALSE, dims=1:30)
        
        obj.query <- obj.query %>%
            NormalizeData(verbose = FALSE) %>%
            FindVariableFeatures() %>%
            ScaleData() %>%
            RunPCA(verbose=FALSE) %>%
            RunUMAP(verbose=FALSE, dims=1:30)
        
        
        log_print(glue::glue("# Found {length(gene.common)} common genes between reference and query data"))
        log_print("SUCCESSFUL: Subseting reference and query data with common genes")
    },
     error = function(cond) {
        log_print("ERROR: Subseting reference and query data with common genes")
        log_print(cond)
    }
    
)

In [None]:
# Predict labels for query dataset
tryCatch(
    {
        log_print("# Predicting labels for query data")
        
        transfer.anchors <- FindTransferAnchors(
            reference = obj.ref,
            query = obj.query,
            reduction = "cca",
            verbose = TRUE
        )
        
        predictions <- TransferData(anchorset = transfer.anchors, 
                                    refdata = obj.ref[[reference_label]][, 1],
                                    weight.reduction = obj.query[["pca"]],
                                    dims = 1:30,
                                    verbose = TRUE)
        
        obj.query <- AddMetaData(obj.query, metadata = predictions)

        write.csv(predictions, 
                  file = glue::glue("{prefix}.rna.cell.annotation.prediction.{genome}.csv"),
                  quote = FALSE)
        
        log_print("SUCCESSFUL: Predicting labels for query data")
    },
     error = function(cond) {
        log_print("ERROR: Predicting labels for query data")
        log_print(cond)
    }
    
)

In [None]:
## Plotting
tryCatch(
    {
        log_print("# Plotting predicted labels")
        
        p <- DimPlot(obj.query, group.by = "predicted.id", label = TRUE, 
                      label.size = 5, repel = TRUE, reduction = "umap")
        
        printPNG(name = "predicted.labels", plot = p, papermill = papermill, 
                 width = 6, height = 6)
        
        log_print("SUCCESSFUL: Plotting predicted labels")
        

    },
    error = function(cond) {
        log_print("ERROR: Plotting predicted labels")
        log_print(cond)
    }

)

In [None]:
# ## Plotting
# tryCatch(
#     {
#         log_print("# Plotting predicted labels")
        
#         p1 <- DimPlot(obj.query, group.by = "seurat_clusters", label = TRUE, 
#                       label.size = 5, repel = TRUE)

#         p2 <- DimPlot(obj.query, group.by = "predicted.id", label = TRUE, 
#                       label.size = 5, repel = TRUE)
        
#         p <- p1 + p2
        
#         printPNG(name = "predicted.labels", plot = p, papermill = papermill, 
#                  width = 15, height = 6)
        
#         log_print("SUCCESSFUL: Plotting predicted labels")
        

#     },
#     error = function(cond) {
#         log_print("ERROR: Plotting predicted labels")
#         log_print(cond)
#     }

# )

In [None]:
# ## Plotting
# tryCatch(
#     {
#         log_print("# Plotting predicted score per cluster")
        
#         sel_cols <- grep("prediction.score|seurat_clusters", 
#                          colnames(obj.query@meta.data), value=TRUE)
#         sel_cols <- sel_cols[1:length(sel_cols) - 1]

#         df <- obj.query@meta.data %>%
#             subset(select = sel_cols) %>%
#             tidyr::gather(key = "celltype", value = "score", -seurat_clusters)
        
#         df$celltype <- stringr::str_replace_all(df$celltype, "prediction.score.", "")
        
#         p <- ggplot(df, aes(x = celltype, y = score)) +
#              geom_violin(aes(fill = celltype), scale = "width") +
#              facet_wrap(~seurat_clusters, ncol = 4) +
#              theme_cowplot() +
#              xlab("") + ylab("Predictied score") +
#              theme(axis.text.x = element_text(angle=60, hjust = 1),
#                   legend.position = "none",
#                   plot.title = element_text(hjust = 0.5)) 

#         # decide figure size
#         n_clusters <- length(unique(df$seurat_clusters))
#         n_rows <- ceiling(n_clusters / 4)
        
#         printPNG(name = "predicted.scores", plot = p, papermill = papermill, 
#                  width = 3*4, height = 3*n_rows + 2)
        
#         log_print("SUCCESSFUL: Plotting predicted score per cluster")
#     },
#     error = function(cond) {
#         log_print("ERROR: Plotting predicted score per cluster")
#         log_print(cond)
#     }

# )