# Install Libraries
Install information can be found [here](https://satijalab.org/seurat/articles/install.html)\
Vignette [here](https://satijalab.org/seurat/articles/integration_mapping.html)

In [None]:
#Seurat parameters
reference_data = "path-to-seurat-object" # Seurat object for reference data
query_data = "path-to-seurat-object" # Seurat object for query data

genome = "hg38" # either hg38 or mm10

normalization_method = "LogNormalize"
normalization_scale_factor = 10000

variable_features_method = "vst"
variable_features_num = 2000

# Dimensional reduction to use for the weighting anchors.
weight.reduction = "pca" 
n_dims = 30 # Set of dimensions to use in the anchor weighting procedure. If NULL, the same dimensions that were used to find anchors will be used for weighting.

threads = 8
prefix = "prefix" #project name

#Papermill specific parameters
papermill = TRUE

In [None]:
# #########################
# # For test
# reference_data = "../../../ReferenceData/BrainAgingSpatialAtlas_snRNAseq.rds"
# query_data = "../../../QueryData/MouseBrain/SS-PKR-129-192-PLATE1-LEFT-HALF.rna.seurat.filtered_rds.mm10.rds"
# genome = "mm10"

In [None]:
papermill <- as.logical(papermill)

In [None]:
suppressMessages(library(Seurat))
suppressMessages(library(future))
suppressMessages(library(logr))
suppressMessages(library(dplyr))
suppressMessages(library(grid))
suppressMessages(library(gridExtra))
suppressMessages(library(ggplot2))
suppressMessages(library(patchwork))
suppressMessages(library(cowplot))
suppressMessages(library(EnsDb.Mmusculus.v79))
suppressMessages(library(EnsDb.Hsapiens.v86))

In [None]:
options("logr.notes" = FALSE)
options(future.globals.maxSize=10e9)
set.seed(1234)

In [None]:
# Function to convert gene ID to symbol
create_seurat_obj_with_gene_symbol <- function(object, genome){

    # get gene symbol
    if(genome == "hg38" | genome == "hg37"){
        gene.id <- ensembldb::select(EnsDb.Hsapiens.v86, 
                                     keys= rownames(object), 
                                     keytype = "GENEID", 
                                     columns = c("SYMBOL","GENEID"))
    
    } else if(genome == "mm10" | genome == "mm9"){
        gene.id <- ensembldb::select(EnsDb.Mmusculus.v79, 
                                     keys = rownames(object), 
                                     keytype = "GENEID", 
                                     columns = c("SYMBOL","GENEID"))
    }
    # remove genes with empty symbol
    gene.id <- subset(gene.id, gene.id$SYMBOL != "")

    # make gene symbol unique
    gene.id$Unique_SYMBOL <- make.unique(gene.id$SYMBOL, "")
    
    counts <- object@assays$RNA@counts
    counts <- counts[gene.id$GENEID, ]
    rownames(counts) <- gene.id$Unique_SYMBOL

    object <- CreateSeuratObject(counts = counts, meta.data = object@meta.data)
    
    return(object)
}

In [None]:
#Function to save plots
plot_filename = glue::glue("{prefix}.rna.cell.annotation.plots.{genome}")
dir.create(plot_filename, showWarnings=F)

printPNG <- function(name, plot, papermill, width = 22, height = 11){
    filename = glue::glue("{plot_filename}/{prefix}.rna.cell.annotation.{name}.{genome}.png")
    
    if(papermill){
        ggsave(plot = plot, filename = filename, width = width, height = height)
    }
}

#Create log file
logfile <- file.path(glue::glue("{prefix}.rna.cell.annotation.logfile.{genome}.txt"))
lf <- log_open(logfile)

In [None]:
# Read reference data
tryCatch(
    {
        log_print("# Reading reference data...")
        obj.ref <- readRDS(reference_data)
        log_print("SUCCESSFUL: Reading reference data")
    
    },
    error = function(cond) {
        log_print("ERROR: Reading reference data")
        log_print(cond)
    }
)

In [None]:
# Read query data
tryCatch(
    {
        log_print("# Reading query data...")
        obj.query <- readRDS(query_data)
        log_print("SUCCESSFUL: Reading query data")
    
    },
    error = function(cond) {
        log_print("ERROR: Reading query data")
        log_print(cond)
    }
)

In [None]:
# Convert gene ID to symbol for reference data
tryCatch(
    {
        log_print("# Converting gene id to symbol for reference data")
        obj.ref <- create_seurat_obj_with_gene_symbol(object = obj.ref, 
                                                      genome = genome)
        log_print("SUCCESSFUL: Converting gene id to symbol for reference data")

    },
    error = function(cond) {
        log_print("ERROR: Converting gene id to symbol for reference data")
        log_print(cond)
    }
)

In [None]:
# Subset reference data
tryCatch(
    {
        log_print("# Subseting reference and query data with common genes")
        
        gene.common <- intersect(rownames(obj.ref), rownames(obj.query))
        obj.ref <- subset(obj.ref, features = gene.common)
        obj.query <- subset(obj.query, features = gene.common)
        
        log_print(glue::glue("# Found {length(gene.common)} common genes between reference and query data"))
        log_print("SUCCESSFUL: Subseting reference data")
    },
     error = function(cond) {
        log_print("ERROR: Subseting reference data")
        log_print(cond)
    }
    
)

In [None]:
# Predict labels for query dataset
tryCatch(
    {
        log_print("# Predicting labels for query data")
        
        obj.ref <- obj.ref %>%
            NormalizeData(verbose = FALSE) %>%
            FindVariableFeatures(selection.method = variable_features_method, 
                                 nfeatures = variable_features_num)
        
        obj.query <- obj.query %>%
            NormalizeData(verbose = FALSE) %>%
            FindVariableFeatures(selection.method = variable_features_method,
                                 nfeatures = variable_features_num)
        
        transfer.anchors <- FindTransferAnchors(
            reference = obj.ref,
            query = obj.query,
            reduction = "cca",
            verbose = FALSE
        )
        
        predictions <- TransferData(anchorset = transfer.anchors, 
                                    refdata = obj.ref$cell_type,
                                    weight.reduction = obj.query[[weight.reduction]],
                                    dims = 1:n_dims,
                                    verbose = FALSE)
        
        obj.query <- AddMetaData(obj.query, metadata = predictions)

        write.csv(predictions, 
                  file = glue::glue("{prefix}.rna.cell.annotation.prediction.${genome}.csv"),
                  quote = FALSE)
        
        log_print("SUCCESSFUL: Predicting labels for query data")
    },
     error = function(cond) {
        log_print("ERROR: Predicting labels for query data")
        log_print(cond)
    }
    
)

In [None]:
## Plotting
tryCatch(
    {
        log_print("# Plotting predicted labels")
        
        p1 <- DimPlot(obj.query, group.by = "seurat_clusters", label = TRUE, 
                      label.size = 5, repel = TRUE)

        p2 <- DimPlot(obj.query, group.by = "predicted.id", label = TRUE, 
                      label.size = 5, repel = TRUE)
        
        p <- p1 + p2
        
        printPNG(name = "predicted.labels", plot = p, papermill = papermill, 
                 width = 15, height = 6)
        
        log_print("SUCCESSFUL: Plotting predicted labels")
        

    },
    error = function(cond) {
        log_print("ERROR: Plotting predicted labels")
        log_print(cond)
    }

)

In [None]:
## Plotting
tryCatch(
    {
        log_print("# Plotting predicted score per cluster")
        
        sel_cols <- grep("prediction.score|seurat_clusters", colnames(obj.query@meta.data), value=TRUE)
        sel_cols <- sel_cols[1:length(sel_cols) - 1]

        df <- obj.query@meta.data %>%
            subset(select = sel_cols) %>%
            tidyr::gather(key = "celltype", value = "score", -seurat_clusters)
        
        df$celltype <- stringr::str_replace_all(df$celltype, "prediction.score.", "")
        
        p <- ggplot(df, aes(x = celltype, y = score)) +
             geom_violin(aes(fill = celltype), scale = "width") +
             facet_wrap(~seurat_clusters, ncol = 4) +
             theme_cowplot() +
             xlab("") + ylab("Predictied score") +
             theme(axis.text.x = element_text(angle=60, hjust = 1),
                  legend.position = "none",
                  plot.title = element_text(hjust = 0.5)) 

        # decide figure size
        n_clusters <- length(unique(df$seurat_clusters))
        n_rows <- ceiling(n_clusters / 4)
        
        printPNG(name = "predicted.scores", plot = p, papermill = papermill, 
                 width = 3*4, height = 3*n_rows + 2)
        
        log_print("SUCCESSFUL: Plotting predicted score per cluster")
    },
    error = function(cond) {
        log_print("ERROR: Plotting predicted score per cluster")
        log_print(cond)
    }

)