# Install Libraries
Install information can be found [here](https://satijalab.org/seurat/articles/install.html)\
Vignette [here](https://satijalab.org/seurat/articles/pbmc3k_tutorial.html)

In [None]:
#Seurat parameters
reference_data = "path-to-seurat-object" # Seurat object for reference data
query_data = "path-to-seurat-object" # Seurat object for query data

genome = "genome-name" # either hg38 or mm10

normalization_method = "LogNormalize"
normalization_scale_factor = 10000

variable_features_method = "vst"
variable_features_num = 2000

weight_reduction = "pca" # Dimensional reduction to use for the weighting anchors.
n_dims = 30 # Set of dimensions to use in the anchor weighting procedure. If NULL, the same dimensions that were used to find anchors will be used for weighting.

prefix = "prefix" #project name

#Papermill specific parameters
papermill = TRUE

#jupyter notebook plot sizes
options(repr.plot.width=20, repr.plot.height=15)

In [None]:
papermill <- as.logical(papermill)

In [None]:
if (!requireNamespace("Seurat", quietly = TRUE))
    install.packages("Seurat")
if (!requireNamespace("future", quietly = TRUE))
    install.packages("future")
if (!requireNamespace("logr", quietly = TRUE))
    install.packages("logr")
if (!requireNamespace("grid", quietly = TRUE))
    install.packages("grid")
if (!requireNamespace("dplyr", quietly = TRUE))
    install.packages("dplyr")
if (!requireNamespace("gridExtra", quietly = TRUE))
    install.packages("gridExtra")
if (!requireNamespace("ggplot2", quietly = TRUE))
    install.packages("ggplot2")
if (!require("BiocManager", quietly = TRUE))
    install.packages("BiocManager")
if (!requireNamespace("EnsDb.Mmusculus.v79", quietly = TRUE))
    BiocManager::install("EnsDb.Mmusculus.v79")
if (!requireNamespace("EnsDb.Hsapiens.v86", quietly = TRUE))
    BiocManager::install("EnsDb.Hsapiens.v86")

suppressMessages(library(Seurat))
suppressMessages(library(future))
suppressMessages(library(logr))
suppressMessages(library(dplyr))
suppressMessages(library(grid))
suppressMessages(library(gridExtra))
suppressMessages(library(ggplot2))
suppressMessages(library(patchwork))
suppressMessages(library(ggsankey))
suppressMessages(library(cowplot))
suppressMessages(library(EnsDb.Mmusculus.v79))
suppressMessages(library(EnsDb.Hsapiens.v86))

future.seed=TRUE
plan("multisession", workers = threads)
options("logr.notes" = FALSE)
options(future.globals.maxSize=10e9)
set.seed(1234)

In [None]:
# Function to convert gene ID to symbol
create_seurat_obj_with_gene_symbol <- function(obj, genome){

    # get gene symbol
    if(genome == "hg38"){
        gene.id <- ensembldb::select(EnsDb.Hsapiens.v86, 
                                      keys= rownames(obj), 
                                      keytype = "GENEID", 
                                      columns = c("SYMBOL","GENEID"))
    
    } else if(genome == "mm10"){
        gene.id <- ensembldb::select(EnsDb.Mmusculus.v79, 
                                  keys= rownames(obj), 
                                  keytype = "GENEID", 
                                  columns = c("SYMBOL","GENEID"))
    }
    # remove genes with empty symbol
    gene.id <- subset(gene.id, gene.id$SYMBOL != "")

    # make gene symbol unique
    gene.id$Unique_SYMBOL <- make.unique(gene.id$SYMBOL, "")
    
    counts <- obj@assays$RNA@counts
    counts <- counts[gene.id$GENEID, ]
    rownames(counts) <- gene.id$Unique_SYMBOL

    obj <- CreateSeuratObject(counts = counts, meta.data = obj@meta.data)
    
    return(obj)
}

In [None]:
#Function to save plots
plot_filename = paste0(prefix,".rna.seurat.annotation.plots.",genome)
dir.create(plot_filename, showWarnings=F)
printPNG <- function(name, plotObject, papermill, wf=22, hf=11){
    filename = paste0(plot_filename,"/",prefix,".rna.seurat.annotation.",name,".",genome,".png")
    if(papermill){
        ggsave(plot = plotObject, filename = filename, width = wf, height = hf)
    }
}

#Create log file
logfile <- file.path(paste0(prefix,".rna.seurat.annotation.logfile.", genome, ".txt"))
lf <- log_open(logfile)

In [None]:
# Read reference data
tryCatch(
    {
        log_print("# Reading reference data...")
        obj.ref <- readRDS(reference_data)
        log_print("SUCCESSFUL: Reading reference data")
    
    },
    error = function(cond) {
        log_print("ERROR: Reading reference data")
        log_print(cond)
    }
)

In [None]:
# Read query data
tryCatch(
    {
        log_print("# Reading query data...")
        obj.query <- readRDS(query_data)
        log_print("SUCCESSFUL: Reading query data")
    
    },
    error = function(cond) {
        log_print("ERROR: Reading query data")
        log_print(cond)
    }
)

In [None]:
# Convert gene ID to symbol for reference data
tryCatch(
    {
        log_print("# Converting gene id to symbol for reference data")
        obj.ref <- create_seurat_obj_with_gene_symbol(obj = obj.ref, 
                                                      genome = genome)
        log_print("SUCCESSFUL: Converting gene id to symbol for reference data")

    },
    error = function(cond) {
        log_print("ERROR: Converting gene id to symbol for reference data")
        log_print(cond)
    }
)

In [None]:
# Subset reference data
tryCatch(
    {
        log_print("# Subseting reference data")
        
        gene.common <- intersect(rownames(obj.ref), rownames(obj.query))
        obj.ref <- subset(obj.ref, features = gene.common)
        obj.query <- subset(obj.query, features = gene.common)
        
        log_print("SUCCESSFUL: Subseting reference data")
    },
     error = function(cond) {
        log_print("ERROR: Subseting reference data")
        log_print(cond)
    }
    
)

In [None]:
# Predict labels for query dataset
tryCatch(
    {
        log_print("# Predicting labels for query data")
        
        obj.ref <- obj.ref %>%
            NormalizeData(verbose = FALSE) %>%
            FindVariableFeatures(selection.method = variable_features_method, 
                                 nfeatures = variable_features_num)
        
        obj.query <- obj.query %>%
            NormalizeData(verbose = FALSE) %>%
            FindVariableFeatures(selection.method = variable_features_method,
                                 nfeatures = variable_features_num)
        
        transfer.anchors <- FindTransferAnchors(
            reference = obj.ref,
            query = obj.query,
            reduction = "cca",
            verbose = FALSE
        )
        
        predictions <- TransferData(anchorset = transfer.anchors, 
                                    refdata = obj.ref$cell_type,
                                    weight.reduction = obj.query[[weight.reduction]],
                                    dims = 1:n_dims,
                                   verbose = FALSE)
        
        obj.query <- AddMetaData(obj.query, metadata = predictions)
        
        log_print("SUCCESSFUL: Predicting labels for query data")
    },
     error = function(cond) {
        log_print("ERROR: Predicting labels for query data")
        log_print(cond)
    }
    
)

In [None]:
## Plotting
tryCatch(
    {
        log_print("# Plotting predicted labels")
        
        p1 <- DimPlot(obj.query, group.by = "seurat_clusters", label = TRUE, 
              label.size = 5, 
              repel = TRUE)

        p2 <- DimPlot(obj.query, group.by = "predicted.id", label = TRUE, 
                      label.size = 5, 
                      repel = TRUE)
        
        p <- p1 + p2
        
        printPNG(name = "predicted.labels", plotObject = p, papermill = papermill, wf = 15, hf = 6)
        
        log_print("SUCCESSFUL: Plotting predicted labels")
        

    },
    error = function(cond) {
        log_print("ERROR: Plotting predicted labels")
        log_print(cond)
    }

)

In [None]:
# options(repr.plot.height = 6 ,repr.plot.width = 12)

# df <- obj.query@meta.data %>%
#   make_long(seurat_clusters, predicted.id)

# ggplot(df, aes(x = x, 
#                next_x = next_x, 
#                node = node, 
#                next_node = next_node,
#                fill = factor(node),
#               label = node)) +
#       geom_sankey() +
#     geom_sankey_label(size = 6, color = "white", fill = "gray40") +
#       theme_sankey(base_size = 12) +
#         xlab("") +
#     theme(legend.position = "none",
#         plot.title = element_text(hjust = .5))

In [None]:
## Plotting
options(repr.plot.height = 4, repr.plot.width = 6)

tryCatch(
    {
        log_print("# Plotting cell-type-specific score")
        
        features = colnames(predictions)[3:ncol(predictions) - 1]
        
        for(feature in features){
            p <- VlnPlot(obj.query, features = feature, pt.size=0, group.by = "seurat_clusters", y.max = 1.0) #+
#                scale_y_continuous(limits = c(0, 1), breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0))
            
            printPNG(name = feature, plotObject = p, papermill = papermill, wf = 6, hf = 4)
        }
        
        log_print("SUCCESSFUL: Plotting cell-type-specific score")
    },
    error = function(cond) {
        log_print("ERROR: Plotting cell-type-specific score")
        log_print(cond)
    }

)