In [None]:
import sys
import os
import random
import time
import scanpy as sc
import pandas as pd
import numpy as np
import pdb
import seaborn as sns

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D      
import matplotlib.cm as cm
from matplotlib.patches import Rectangle
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout

%matplotlib inline
import umap

sys.path.insert(1, '../scripts/')
import FuncVizBench as vb

#sys.path.insert(2, '../scripts/scVitalPackage/src')
import scVital as sv

In [None]:
%matplotlib inline

In [None]:
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['font.sans-serif'] = "Arial"
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 10
#plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.transparent'] = True
#plt.rcParams['figure.figsize'] = (3,4)

sc.set_figure_params(scanpy=True, dpi=100, dpi_save=150, fontsize=10, format='svg')
sc.settings.verbosity = 0

In [None]:

def zScores(scoreMat, cutoff = 1.5):
    sigs = []
    for i,scoreCol in enumerate(scoreMat.columns):
        scoreData = scoreMat[scoreCol]
        zscore = ((scoreData-np.mean(scoreData))/np.std(scoreData))
        sigs.append(zscore )
    
    sigScore = pd.DataFrame(sigs).T
    simple = []
    for i,cell in enumerate(sigScore.index):
        names = scoreMat.columns#np.array([str(x+1) for x in range(len(sigScore.columns))])
        
        #print((sigScore.loc[cell]> cutoff).values)
        sigNames = names[np.array((sigScore.loc[cell]> cutoff).values)]
        sigName = "out"
        if len(sigNames) > 0:
            sigName = str(names[np.argmax(sigScore.loc[cell])])[:-5]
            #sigName = "-".join(sigNames)
        #simple.append("-".join(sigNames))
        simple.append(sigName)
    return(np.array(simple))
    
def makeAdata(dirName,datasetName,paramName):
    snakeInput = {"latentfiles":[
            #f"{dirName}/{datasetName}/{paramName}/latents/normal.csv",
    		#f"{dirName}/{datasetName}/{paramName}/latents/BBKNN.csv", 
    		#f"{dirName}/{datasetName}/{paramName}/latents/Harmony.csv",
    		#f"{dirName}/{datasetName}/{paramName}/latents/scVI.csv",
    		f"{dirName}/{datasetName}/{paramName}/latents/scVital.csv",
    		#f"{dirName}/{datasetName}/{paramName}/latents/scDREAMER.csv"
    ],
    		"inAdata" : f"{dirName}/{datasetName}/{paramName}/vaeOut_EvalBench.h5ad"}
    snakeOut = {"outBench": f"{dirName}/{datasetName}/{paramName}/figures/bench.csv",
    		    "outAdata": f"{dirName}/{datasetName}/{paramName}/vaeOut_EvalBench.h5ad"}
    
    inAdataFile = snakeInput["inAdata"]
    name = inAdataFile.split("~")[1].split("/")[0]
    
    dataName = inAdataFile.split("~")[1].split("/")[0]
    batchName = inAdataFile.split("/")[5].split("_")[9].split("~")[1]
    labelName = inAdataFile.split("/")[5].split("_")[10].split("~")[1]
    dataDir = ("/").join(inAdataFile.split("/")[:-1])
    
    outBenchFile = snakeOut["outBench"]
    outAdataFile = snakeOut["outAdata"]
    
    latNameSet = set()
    clutNameSet = set()
    
    adata = sc.read_h5ad(inAdataFile)    
    
    for infoDir in os.listdir(os.path.join(dataDir)):##in directory with latents and clusters
        if(os.path.isdir(os.path.join(dataDir,infoDir))): 
            for file in os.listdir(os.path.join(dataDir,"latents")):
                if(".csv" in file):
                    name = file.split(".")[0]
                    adata.obsm[f"X_{name}"] = pd.read_table(os.path.join(dataDir,"latents",file), sep=",",header=None).to_numpy()
                    latNameSet.add(f"X_{name}")
    
            for file in os.listdir(os.path.join(dataDir,"clusters")):
                if(".csv" in file):
                    name = file.split(".")[0]
                    adata.obs[name] = pd.read_table(os.path.join(dataDir,"clusters",file), sep=",",index_col=0,dtype=str).iloc[:,0].to_numpy()
                    clutNameSet.add(f"{name}")
    
    latents = list(latNameSet)
    return(adata, batchName, labelName)

def getAllStats(dirName,datasetName,paramName):
    #allStats = pd.DataFrame(np.zeros((4,6)),index=["ARI","FM","nKbet","LSS"],columns=["scVital","normal","BBKNN","Harmony","scVI","scDREAMER"])
    latents = [lat.split(".")[0] for lat in os.listdir(f"{dirName}/{datasetName}/{paramName}/latents")]
    allStats = pd.DataFrame(np.zeros((4,len(latents))),index=["Time\n(min)  ","ARI","FM","LSS"],columns=latents)

    lss = pd.read_csv(f"{dirName}/{datasetName}/{paramName}/figures/lssAucScore.csv",index_col=0)
    for lat in latents:
        #for stat in allStats.index:
        indLat = lat
        met = pd.read_csv(f"{dirName}/{datasetName}/{paramName}/figures/metrics_{lat}.csv",index_col=0)
        if(lat=="normal"):
            allStats.drop("normal", axis=1, inplace=True)
            indLat = "No\nintegration"
        allStats.loc["ARI",indLat] = met.loc["ARI",lat]    
        allStats.loc["FM",indLat] = met.loc["FM",lat]
 
        try:
            scale = pd.read_csv(f"{dirName}/{datasetName}/{paramName}/figures/scale_{lat}.csv",index_col=0)
        except:
            scale = pd.read_csv(f"{dirName}/{datasetName}/{paramName}/figures/scale_No_Integration.csv",index_col=0)
            
        allStats.loc["Time\n(min)  ",indLat] = (scale[lat][0])/60
        
        #kbet = pd.read_csv(f"{dirName}/{datasetName}/{paramName}/figures/stats_{lat}.csv",index_col=0)
        #allStats.loc["nKbet",lat] = sum(kbet["kBet"]>0.05)/len(kbet["kBet"])
        
        allStats.loc["LSS",indLat] = lss.loc[:,f"X_{lat}"].values[0]
    
    return(allStats)

def getCmapValue(value, vals):
    maxVal = np.round(np.max(vals))
    if(maxVal > 1):
        return ((-value+maxVal)/(-np.min(vals)+maxVal))
    return value

def vizAllStats(allStats, h=2, w=4, name="", order=False, save=False):
    if(order):
        allStats = allStats[["No\nintegration","BBKNN","Harmony","scVital","scVI","scDREAMER"]]
    matrix = allStats.values
    xLabels = allStats.index
    yLabels = allStats.columns

    # Create a custom colormap for the first three rows
    RdYlGnCmap = matplotlib.colormaps['RdYlGn']

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(w,h))
    
    # Plot rectangles for each value
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            value = matrix[i, j]
            color = RdYlGnCmap(getCmapValue(value,matrix[i, :]))
            rect = Rectangle((j - 0.5, i - 0.5), 1, 1, facecolor=color, edgecolor='black')
            ax.add_patch(rect)
            textColor='black'
            if(0.212*color[0]+0.7152*color[1]+0.0722*color[3]<0.6):
                textColor='white'
            ax.text(j, i, f"{value:.2f}", ha='center', va='center', color=textColor)
    
    # Set axis labels
    ax.set_xticks(np.arange(-0.5,matrix.shape[1],0.5), [yLabels[i//2] if i%2 ==1 else "" for i,_ in enumerate(np.arange(-0.5,matrix.shape[1],0.5))])
    ax.set_yticks(np.arange(-0.5,matrix.shape[0],0.5), [xLabels[i//2] if i%2 ==1 else "" for i,_ in enumerate(np.arange(-0.5,matrix.shape[0],0.5))])
    
    ax.set_xticklabels(ax.get_xticklabels(),rotation=70)
    ax.set_yticks([], minor=True)
    ax.set_xticks([], minor=True)

    # Set plot title
    ax.set_title("Integration Statistics")
    ax.grid(False)
    #fig.colorbar(np.arange(0,1,0.2),ax=ax)
    
    # Show the plot
    plt.tight_layout()
    plt.show()
    
    if(save):
        fig.savefig(f"{save}/allStats_{name}.svg", format="svg")


def plotOneUmap(title,x, y, c, edgecolors, name="", **kwargs): 
    #linewidths=0.2, s=2, alpha=0.5, 
    fig, ax = plt.subplots()

    ax.scatter(x, y, c=c, edgecolors=edgecolors, **kwargs)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)

    plt.tight_layout()
    plt.show()

    if("/" in title):
        title = "".join(title.split("/"))

    if(save):
        fig.savefig(f"{save}/umap_{title}_{name}.png", format="png")


def makeLegend(ctVals, btVals, cellTypeColorDict, batchColorDict, name="", save=False):
    
    ctColLab = cellTypeColorDict.values()
    btColLab = batchColorDict.values()

    ctLegendEle = [Line2D([0], [0], color=ctc, marker="o", lw=0, label=ctLabel) for ctLabel,ctc in ctColLab]
    spaceLegEle = [Line2D([0], [0], marker='o', lw=0, color='white', markeredgecolor='white', label="")]
    btLegendEle = [Line2D([0], [0], color="white", marker=btShape, lw=0, markeredgecolor=btc, label=btLabel) for i,(btLabel,btc,btShape) in enumerate(btColLab)]

    ctColors = np.array([cellTypeColorDict[ct][1] for ct in ctVals])
    btColors = np.array([batchColorDict[bt][1] for bt in btVals])

    legendEle = btLegendEle + spaceLegEle + ctLegendEle

    fig, ax = plt.subplots()
    plt.legend(handles=legendEle)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    if(save):
        fig.savefig(f"{save}/legend_{name}.svg", format="svg")
    
    return(legendEle, ctColors, btColors)

def plotInteg(inUmaps, titles, ctColors, btColors, shuff):
    for i, iUmap in enumerate(inUmaps):
        plotOneUmap(titles[i], x=iUmap[shuff, 0],y=iUmap[shuff, 1], c=ctColors[shuff], edgecolors=btColors[shuff])

def plotCbar(title,name, norm, cmap, save=False):
    fig, ax = plt.subplots(1)
    fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    title = "".join(title.split("/"))
    if(save):
        fig.savefig(f"{save}/cbarleg_{title}_{name}.svg", format="svg")


def df2StackBar(clustBatch, neighborsKey, colorLabelDict, label, save=False):
    fig, ax = plt.subplots(1, figsize = (4,4))

    clustBatchCount = pd.DataFrame(clustBatch.value_counts(sort=False))
    clusters = np.unique(clustBatch[neighborsKey]).tolist()
    batches = np.unique(clustBatch[label]).tolist()
    
    counts = pd.DataFrame(np.zeros((len(clusters),len(batches))), index=clusters, columns=batches)
    
    for clust in clusters:
        for bat in batches:
            try:
                val = clustBatchCount.loc[(clust,bat)].iloc[0]
            except:
                val = 0
            counts.loc[clust,bat] = val
    
    numClust = len(clusters)#len(adata.obs[neighborsKey].cat.categories)
    rangeClusts = range(0,numClust)
    
    #pdb.set_trace()
    
    bott=np.zeros(numClust)
    for bat in counts:
        vals=counts[bat].values
        name=counts[bat].name
        lab, color = colorLabelDict[name]
        ax.bar(rangeClusts, vals, bottom=bott, label=lab, color=color)
        bott = bott+vals
    
    #pdb.set_trace()
    
    ax.set_title(f"# of Cells of each Cluster by {label}") 
    #ax.set_xlabel("Cluster")#neighborsKey
    ax.set_ylabel("# cells")
    ax.legend(loc='center right', bbox_to_anchor=(1.4,0.5))
    ax.grid(False)

    plt.show()

    if(save):
        fig.savefig(f"{save}/stackBar_{neighborsKey}_{label}.svg", format="svg")


def findPair(pairs, ctf, visited=None): #modified with copilot
    if visited is None:
        visited = set()
    
    if ctf in visited:
        return None  # or handle the circular reference case appropriately
    
    visited.add(ctf)
    
    check1 = False
    check2 = False
    for ct1,ct2 in pairs:
        if (ctf==ct1):
            if(ct2 in annoToColorDict.keys()):
                return ct2
            else:
                check2 = ct2
        elif (ctf==ct2):
            if(ct1 in annoToColorDict.keys()):
                return ct1
            else:
                check1 = ct1
    if(check1):
        return findPair(pairs, check1)
    if(check2):
        return findPair(pairs, check2)

def getOverColors(ogLabel, overlabel, pairs, colorDict):
    colorOut = ogLabel.copy()
    for i,ctf in enumerate(ogLabel):
        if (ctf not in overlabel):
            ctf = findPair(pairs, ctf)
        colorOut[i] = colorDict[ctf]
    return colorOut

def getOverColorDict(overlabel, pairs, colorLabelDict):
    for i, ctf in enumerate(overlabel):
        if (ctf not in colorLabelDict.keys()):
            newCtf = findPair(pairs, ctf)
            colorLabelDict[ctf] = (ctf,colorLabelDict[newCtf][1])
    return colorLabelDict

#written with co-pilot
def group_pairs(pairs):
    from collections import defaultdict

    # Dictionary to hold the groups
    groups = defaultdict(set)

    # Iterate through each pair
    for name1, name2 in pairs:
        # Find the groups that name1 and name2 belong to
        group1 = next((group for group in groups.values() if name1 in group), None)
        group2 = next((group for group in groups.values() if name2 in group), None)

        if group1 and group2:
            if group1 != group2:
                # Merge the two groups if they are different
                group1.update(group2)
                for name in group2:
                    groups[name] = group1
        elif group1:
            group1.add(name2)
            groups[name2] = group1
        elif group2:
            group2.add(name1)
            groups[name1] = group2
        else:
            # Create a new group if neither name is in any group
            new_group = {name1, name2}
            groups[name1] = new_group
            groups[name2] = new_group

    # Extract unique groups
    unique_groups = set(frozenset(group) for group in groups.values())

    # Convert each group to a list
    return [list(group) for group in unique_groups]

def makeGraphLSS(clustDist, batchToColorDict, annoToColorDict, overlap, pairs, shape=False, qCut = 0.28, save=False):
    fig, ax = plt.subplots()
    G = nx.Graph()
    for i in clustDist.columns:
        G.add_node(i)
    batchColors = [batchToColorDict[label.split("~")[0]] for label in clustDist.columns]
    cellTColors = getOverColors([label.split("~")[1] for label in clustDist.columns], overlap, pairs, annoToColorDict)
                                 #adata.obs.overlapLabel.cat.categories.values, adata.uns["pairs"]
    cutoff = np.quantile(clustDist.to_numpy().flatten(), qCut)
    for i in range(len(clustDist.columns)):
        for j in range(i,len(clustDist.columns)):
            if((i != j) and clustDist.iloc[i,j] < cutoff):
                G.add_edge(clustDist.columns[i], clustDist.index[j], weight=clustDist.iloc[i,j])

    mouseNodes = clustDist.columns[["human" not in l for l in clustDist.columns]]
    mCtColors = getOverColors([label.split("~")[1] for label in mouseNodes], 
                               adata.obs.overlapLabel.cat.categories.values, adata.uns["pairs"], annoToColorDict)
    
    humanNodes = clustDist.columns[["human" in l for l in clustDist.columns]]
    hCtColors = getOverColors([label.split("~")[1] for label in humanNodes], 
                               adata.obs.overlapLabel.cat.categories.values, adata.uns["pairs"], annoToColorDict)
    lw=1
    #pos = nx.spring_layout(G, seed=42)
    pos = graphviz_layout(G)#, seed=42)

    nx.draw_networkx_nodes(G,pos,nodelist=mouseNodes, edgecolors = "black", node_color=mCtColors, node_shape='s',linewidths=lw, ax=ax)
    nx.draw_networkx_nodes(G,pos,nodelist=humanNodes, edgecolors = "gray",  node_color=hCtColors, node_shape='o',linewidths=lw, ax=ax)
    nx.draw_networkx_edges(G,pos, ax=ax)
    
    ax.axis("off")

    plt.tight_layout()
    plt.show()
    if(save):
        fig.savefig(f"{save}/graphLLS_scVital.svg", format="svg")



def makeGraphLSSMulti(clustDist, batchDict, annoToColorDict, overlap, pairs, name="", prog="neato", wLab=False, qCut = 0.28, save=False):
    batchToColorDict = {lab:batchDict[lab][1] for lab in batchDict.keys()}
    batchToShapeDict = {lab:batchDict[lab][2] for lab in batchDict.keys()}
    
    fig, ax = plt.subplots()
    G = nx.Graph()
    for i in clustDist.columns:
        G.add_node(i)

    allDists = clustDist.to_numpy().flatten()
    cutoff = np.quantile(allDists[allDists>0], qCut)
    
    for i in range(len(clustDist.columns)):
        for j in range(i,len(clustDist.columns)):
            if((i != j) and clustDist.iloc[i,j] < cutoff):
                G.add_edge(clustDist.columns[i], clustDist.index[j], weight=clustDist.iloc[i,j])
    pos = graphviz_layout(G, prog=prog)#, seed=42)
    nx.draw_networkx_edges(G, pos, ax=ax)
    lw=1
    for j,bat in enumerate(list(set([cl.split("~")[0] for cl in clustDist.columns]))):
        nodes = clustDist.columns[[bat==cl.split("~")[0] for cl in clustDist.columns]]
        
        #labels = [cl.split("~")[1] for cl in nodes]label=labels,
        
        ctColors = getOverColors([label.split("~")[1] for label in nodes], overlap, pairs, annoToColorDict)
        
        nx.draw_networkx_nodes(G, pos, nodelist=nodes, node_color=ctColors, node_size=150, edgecolors = batchToColorDict[bat],  node_shape=batchToShapeDict[bat], linewidths=lw, ax=ax)#alpha=0.9
    
    #nx.draw_networkx_labels(G, pos, font_size=10)

    #mouseNodes = clustDist.columns[["human" not in l for l in clustDist.columns]]
    #mCtColors = getOverColors([label.split("~")[1] for label in mouseNodes], overlap, pairs, annoToColorDict)
    #nx.draw_networkx_nodes(G,pos,nodelist=mouseNodes, edgecolors = "black", node_color=mCtColors, node_shape='s',linewidths=lw, ax=ax)

    #humanNodes = clustDist.columns[["human" in l for l in clustDist.columns]]
    #hCtColors = getOverColors([label.split("~")[1] for label in humanNodes], overlap, pairs, annoToColorDict)
    #nx.draw_networkx_nodes(G,pos,nodelist=humanNodes, edgecolors = "gray",  node_color=hCtColors, node_shape='o',linewidths=lw, ax=ax)
    fig.suptitle(f"LSS Cut:{qCut}, Prog:{prog}")
    ax.axis("off")
    plt.tight_layout()
    if(save):
        #clustDist.to_csv(f"{save}/multiLSS_{name}scVital.csv")
        #nx.write_adjlist(G, f"{save}/multiLSS_{name}scVital.adjlist")
        #nx.write_weighted_edgelist(G, f"{save}/multiLSS_{name}scVital.weighted.edgelist")
        fig.savefig(f"{save}/multiLSS_{name}scVital.svg", format="svg")

def plotQuadUMAP(scVitalUmap, pcaUmap, ctColors, btColors, reorder, name="", save=False, **kwargs):
    if(reorder):
        newOrder = reorderColors(ctColors)
        ctColors = ctColors[newOrder]
        btColors = btColors[newOrder]
        scVitalUmap = scVitalUmap[newOrder,:]
        pcaUmap = pcaUmap[newOrder,:]
    
    fig, axs = plt.subplots(2,2, figsize=(5,5), dpi=300)
    
    axs[0,0].scatter(x=pcaUmap[:, 0],y=pcaUmap[:, 1],         c=btColors, edgecolors=btColors, **kwargs)#, linewidths=linewidths, s=s, alpha=alpha)
    axs[0,0].patch.set_facecolor("#fced62")
    axs[0,0].patch.set_alpha(0.1)

    axs[1,0].scatter(x=pcaUmap[:, 0],y=pcaUmap[:, 1],         c=ctColors, edgecolors=btColors, **kwargs)#, linewidths=linewidths, s=s, alpha=alpha)
    axs[1,0].patch.set_facecolor("#fced62")
    axs[1,0].patch.set_alpha(0.1)

    axs[0,1].scatter(x=scVitalUmap[:, 0],y=scVitalUmap[:, 1], c=btColors, edgecolors=btColors, **kwargs)#, linewidths=linewidths, s=s, alpha=alpha)
    axs[0,1].patch.set_facecolor("#6b3b8a")
    axs[0,1].patch.set_alpha(0.05)

    axs[1,1].scatter(x=scVitalUmap[:, 0],y=scVitalUmap[:, 1], c=ctColors, edgecolors=btColors, **kwargs)#, linewidths=linewidths, s=s, alpha=alpha)
    axs[1,1].patch.set_facecolor("#6b3b8a")
    axs[1,1].patch.set_alpha(0.05)

    for ax in axs.flatten():
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
    
    plt.tight_layout()
    plt.show()
    if(save):
        fig.savefig(f"{save}/quadUMAP_{name}scVital.png", format="png",dpi=300, pad_inches=0)

def plotAllUMAP(scVitalUmap, pcaUmap, ctColors, btColors, reorder, name="", save=False, **kwargs):
    runUmap = umap.UMAP(n_neighbors=50, min_dist=0.5, random_state=42)
    bbknnUmap = adata.obsm["X_BBKNN"]
    harmUmap = runUmap.fit_transform(adata.obsm["X_Harmony"])
    scviUmap = runUmap.fit_transform(adata.obsm["X_scVI"])
    scDrUmap = runUmap.fit_transform(adata.obsm["X_scDREAMER"])

    if(reorder):
        newOrder = reorderColors(ctColors)
        ctColors = ctColors[newOrder]
        btColors = btColors[newOrder]
        bbknnUmap = bbknnUmap[newOrder,:]
        pcaUmap = pcaUmap[newOrder,:]
        harmUmap = harmUmap[newOrder,:]
        scviUmap = scviUmap[newOrder,:]
        scVitalUmap = scVitalUmap[newOrder,:]
        scDrUmap = scDrUmap[newOrder,:]

    linewidths=0.2
    s=10
    alpha=0.7


    fig, axs = plt.subplots(2,3, figsize=(9,6), dpi=300)
    
    axs[0,0].scatter(x=pcaUmap[:, 0],y=pcaUmap[:, 1],          c=ctColors, edgecolors=btColors, linewidths=linewidths, s=s, alpha=alpha)
    axs[0,0].patch.set_facecolor("#fced62")
    axs[0,0].patch.set_alpha(0.1)
    axs[0,0].set_title("No Integration")
    
    axs[0,1].scatter(x=bbknnUmap[:, 0],y=bbknnUmap[:, 1],     c=ctColors, edgecolors=btColors, linewidths=linewidths, s=s, alpha=alpha)
    axs[0,1].set_title("BBKNN")
    
    axs[0,2].scatter(x=harmUmap[:, 0],y=harmUmap[:, 1],       c=ctColors, edgecolors=btColors,  linewidths=linewidths, s=s, alpha=alpha)
    axs[0,2].set_title("Harmony")
    
    axs[1,0].scatter(x=scviUmap[:, 0],y=scviUmap[:, 1],       c=ctColors, edgecolors=btColors,  linewidths=linewidths, s=s, alpha=alpha)
    axs[1,0].set_title("scVI")
    
    axs[1,1].scatter(x=scVitalUmap[:, 0],y=scVitalUmap[:, 1], c=ctColors, edgecolors=btColors,  linewidths=linewidths, s=s, alpha=alpha)
    axs[1,1].patch.set_facecolor("#6b3b8a")
    axs[1,1].patch.set_alpha(0.05)
    axs[1,1].set_title("scVital")
    
    axs[1,2].scatter(x=scDrUmap[:, 0],y=scDrUmap[:, 1],       c=ctColors, edgecolors=btColors, linewidths=linewidths, s=s, alpha=alpha)
    axs[1,2].set_title("scDREAMER")
    
    for ax in axs.flatten():
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
    
    plt.tight_layout()
    plt.show()
    if(save):
        fig.savefig(f"{save}/UMAPsAll_{name}integ.png", format="png",dpi=300, pad_inches=0)


def reorderColors(ctColors, col="#d3d3d3", seed=42):
    grayCols = []
    otherCols = []
    for i,color in enumerate(ctColors):
        if color == col:
            grayCols.append(i)
        else:
            otherCols.append(i)
    otherCols = np.array(otherCols)
    np.random.seed(seed)  
    np.random.shuffle(otherCols)
    grayorder = np.concatenate((np.array(grayCols),otherCols))
    return([int(g) for g in grayorder])

def runUmapGetColors(adata, batchName, labelName, cellTypeDict=None, batchDict=None, name="", save=False):
    runUmap = umap.UMAP(n_neighbors=50, min_dist=0.5, random_state=42)
    scVitalUmap = runUmap.fit_transform(adata.obsm["X_scVital"])
    pcaUmap = runUmap.fit_transform(adata.obsm["X_pca"])
    
    if(cellTypeDict == None):
        cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
        print(f"cellTypeDict = :{cellTypeDict}")
    if(batchDict == None):
        batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
        print(f"batchDict=:{batchDict}")

    batchToColorDict = {lab:batchDict[lab][1] for lab in batchDict.keys()}
    annoToColorDict = {lab:cellTypeDict[lab][1] for lab in cellTypeDict.keys()}
    
    legendEle, ctColors, btColors = makeLegend(adata.obs[labelName], adata.obs[batchName], cellTypeDict, batchDict, name=name, save=save)
    
    return(scVitalUmap, pcaUmap, legendEle, ctColors, btColors, batchToColorDict, annoToColorDict)

#cellStateName = paramName.split("_")[-2].split("~")[1]


def makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, vizStatsOrder=True,  name="", qCut=0.28, saveDir = False):
    if(saveDir):
        if(not os.path.isdir(saveDir)):
            os.mkdir(saveDir)
    vizAllStats(allStats, h=2, w=4, name=name, order=vizStatsOrder, save=saveDir)
    scVitalUmap, pcaUmap, legendEle, ctColors, btColors, batchToColorDict, annoToColorDict = runUmapGetColors(adata, batchName, labelName, cellTypeDict=cellTypeDict, batchDict=batchDict, name=name, save=saveDir) 
    plotQuadUMAP(scVitalUmap, pcaUmap, ctColors, btColors, reorder=True,  name = f"{name}Patient", save=saveDir, linewidths=0.2, s=6, alpha=0.7)
    latents = ["X_scVital"]#,"normal","BBKNN","Harmony"
    batches = adata.obs[batchName].cat.categories.values
    clustDist, lssAUC, totalDist, allCellTypes, ctPairs = sv.lss.calcPairsLSS(adata, latent=latents[0], batchName=batchName, cellTypeLabel=labelName)
    batchDict, annoToColorDict = sv.lss.plotGraphLSS(adata, labelName, batchName, clustDist, name=datasetName, 
             batchDict=batchDict, annoToColorDict=annoToColorDict,
             prog="neato", wLab=False, qCut = qCut, plot=True, save=saveDir)
		          

In [None]:
saveDir = "figuresPaper"

In [None]:
dirName="../../results/allParamsOUT"
datasetName="filename~PDACmouseHumanT6M"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]
adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

In [None]:
cellTypeDict={'Acinar cell': ('Acinar cell', '#1f77b4'), 
              'Basal': ('Basal', '#ff7f0e'), 
              'Classical': ('Classical', '#279e68'), 
              'Ductal cell type 1': ('Ductal cell type 1', '#d62728'), 
              'EMT': ('EMT', '#aa40fc'), 
              'Endocrine cell': ('Endocrine cell', '#8c564b'), 
              'out': ('unlabeled', '#d3d3d3')}
batchDict={'HumanAll': ('HumanT6', '#c5c7c9'), 
            'Mouse': ('Mouse', '#050505')} 
            
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, vizStatsOrder=True, saveDir = False)

## PDAC mouse and Human (Hwang) <--

In [None]:
#"PDACmouseHumanHuwangGMM":
dirName="../../results/allParamsOUT"
datasetName="filename~PDACmouseHumanHuwangGMM"
paramName = os.listdir(f"{dirName}/{datasetName}")[0]
print(paramName)
adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

cellTypeDict={'Basal': ('Basal', '#1f77b4'), 
                    'Classical': ('Classical', '#ff7f0e'), 
                    'EMT': ('Mesenchymal', '#279e68'), 
                    'out': ('unlabeled', '#d3d3d3')}
batchDict={'Mouse': ('Mouse', '#050505',"P"), 
                     'MouseVeh': ('MouseVeh', '#050505',"X"), 
                     'U1': ('U1', '#c5c7c9',"o"), 
                     'U2': ('U2', '#c5c7c9',"o"), 
                     'U3': ('U3', '#c5c7c9',"o"), 
                     'U5': ('U5', '#c5c7c9',"o"), 
                     'U6': ('U6', '#c5c7c9',"o"), 
                     'U7': ('U7', '#c5c7c9',"o"), 
                     'U9': ('U9', '#c5c7c9',"o"), 
                     'U10': ('U10', '#c5c7c9',"o"), 
                     'U13': ('U13', '#c5c7c9',"o"), 
                     'U14': ('U14', '#c5c7c9',"o"), 
                     'U16': ('U16', '#c5c7c9',"o"), 
                     'U17': ('U17', '#c5c7c9',"o"), 
                     'U18': ('U18', '#c5c7c9',"o")}

makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, vizStatsOrder=True, name="PDACHwang", qCut=0.20, saveDir = saveDir)

In [None]:
#"PDACmouseHumanU6M":
dirName="../../results/allParamsOUT"
datasetName="filename~PDACmouseHumanU6M"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]
adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)



In [None]:
cellTypeDict={'Basal': ('Basal', '#1f77b4'), 
            'Classical': ('Classical', '#ff7f0e'), 
            'EMT': ('EMT', '#279e68'),
            'out': ('unlabeled', '#d3d3d3')}
batchDict={'HwangG': ('HwangG', '#c5c7c9'), 
            'Mouse': ('Mouse', '#050505')} 

makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, saveDir = False)

## Lung injury mouse

In [None]:
#"lungLUADinjMouse":
dirName="../../results/allParamsOUT"
datasetName="filename~lungLUADinjMouse"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

In [None]:
cellTypeDict={'AT1': ('AT1', '#56B4E9'), 
              'AT1-like': ('AT1-like', '#56B4E9'), 
                                
              'AT2': ('AT2', '#0072B2'), 
              'Cycling AT2': ('AT2', '#0072B2'), 
              'AT2-like': ('AT2-like', '#0072B2'), 
              
              'HPCS': ('HPCS', '#E69F00'), 
              'DATP': ('DATP', '#F0E442'), 

              'Adv': ('Adv', '#099E00'), 
              'EMT': ('EMT', '#076301'), 
              'Endoderm-like': ('Endoderm-like', '#97FD91'), 
              'Rib': ('Rib', '#537E51'), 
              'unlabeled': ('unlabeled', '#d3d3d3'), 

              'NuEnd': ('NuEnd', '#009E73'), 
              'Cili': ('Cili', '#D55E00'), 
              'pBasal': ('pBasal', '#CC79A7')}

batchDict={'NonM': ('NonMalig', '#d4d4d4',"o"),   
           'Inj': ('Inj', '#706f6f',"s"), 
           'shRen': ('Malig', '#050505',"p")} 

makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, vizStatsOrder=True, name="MouLungLUAD", qCut=0.2, saveDir = saveDir)

## All LUAD Mouse and Human

In [None]:
#"LUADmouseHumanBKHDM":
dirName="../../results/allParamsOUT"
datasetName="filename~LUADmouseHumanBKHM"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

In [None]:
cellTypeDict={'AT1-like': ('AT1-like', '#56B4E9'), 
              'AT2-like': ('AT2-like', '#0072B2'), 
              'Adv': ('Adv', '#D55E00'), 
              'EMT': ('EMT', '#CC79A7'), 
              'Endoderm-like': ('Endoderm-like', '#E69F00'), 
              'HPCS': ('HPCS', '#009E73'), 
              'unlabeled': ('unlabeled', '#d3d3d3')}

shapes="ospPX><", 
batchDict={'mouse': ('Mouse', '#050505',"o"),
           'P2': ('P2', '#c5c7c9',"s"),
           'P14T': ('P14T', '#c5c7c9',"p"), 
           'T30': ('T30', '#c5c7c9',"P"),
           'T34': ('T34', '#c5c7c9',"X"),
           'p018': ('p018', '#c5c7c9',"^"),
           'p024': ('p024', '#c5c7c9',">"), 
           'p032': ('p032', '#c5c7c9',"<")}
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, vizStatsOrder=True, name="LUAD", qCut=0.20, saveDir = saveDir)

## Lung Human Malignat and Non-Malignant <--

In [None]:
#"luadMDA_P2T7":
dirName="../../results/allParamsOUT"
datasetName="filename~luadMDA_P2T7"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]
print(paramName)
#paramName='batchSize~512_numEpoch~50_learningRate~1e-3_inLayerDims~1024-128_lastLayer~12_inDiscLayer~6_reconCoef~5_klCoef~5e-2_discCoef~2_batchName~Tumor_cellName~cs_res~4e-1'

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

sc.pl.umap(adata, color=[batchName, labelName], show=False)

In [None]:
cellTypeDict={'AT2-like': ('AT2-like', '#0072B2'), 
              'alveoli_AT2': ('alveoli_AT2', '#11567D'), 

              'HPCS': ('HPCS', '#009E73'), 
                     
              'alveoli_AT1': ('alveoli_AT1', '#56B4E9'), 
              'alveoli_AVP': ('alveoli_AVP', '#7CC4EC'), 
              'alveoli_Bronchio': ('alveoli_Bronchio', '#93CCEC'), 
                     
              'airway_Basal': ('airway_Basal', '#F0E442'), 
              'airway_Ciliated': ('airway_Ciliated', '#CC79A7'), 
              'airway_ClubandSecretory': ('airway_ClubandSecretory', '#E69F00'), 
              'airway_Ionocyte': ('airway_Ionocyte', '#D55E00'), 
              'airway_Neuroendocrine': ('airway_Neuroendocrine', '#000000')}

batchDict={'normal': ('NonMalig', '#c5c7c9',"o"), 
           'tumor': ('Malig', '#050505',"s")}

#makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, saveDir = False)
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, vizStatsOrder=True, name="HumLungLUAD", qCut=0.35, saveDir = saveDir)

## Simulated 

In [None]:
dirName="../../results/allParamsOUT"
datasetName="filename~simb2c5o3BIG"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]
paramName

In [None]:
#"simmulated":
dirName="../../results/allParamsOUT"
datasetName="filename~simb2c5o3BIG"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]
adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

sc.pl.umap(adata, color=[batchName, labelName], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
print(f"batchDict = {batchDict}")

batchDict = {'Batch_1': ('Batch_1', 'gray', "o"), 'Batch_2': ('Batch_2', 'black',"s")}


In [None]:
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, saveDir = False)

## Juk T92

In [None]:
#"5050_JukT92":
dirName="../../results/allParamsOUT"
datasetName="filename~5050_JukT92"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)
sc.pl.umap(adata, color=[batchName, labelName], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
print(f"batchDict = {batchDict}")

In [None]:
batchDict = {'Batch_1': ('Batch_1', 'gray',"o"), 'Batch_2': ('Batch_2', 'black',"s")}
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, saveDir = False)

## PBMC

In [None]:
#"PBMCwCounts":
dirName="../../results/allParamsOUT"
datasetName="filename~PBMCwCounts"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]
adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

sc.pl.umap(adata, color=[batchName, labelName], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
print(f"batchDict = {batchDict}")

batchDict = {'0': ('5\'', 'black',"o"), '1': ('3\'', '#c5c7c9',"s")}

print(cellTypeDict)

In [None]:
np.round(np.max(allStats.loc["LSS",:]))

In [None]:
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, name="PBMC", saveDir = saveDir)

In [None]:
umapKey="sv"

sc.pp.neighbors(adata, n_pcs=adata.obsm["X_scVital"].shape[1], use_rep="X_scVital", key_added=umapKey)
allres = [0.2,0.4,0.5,1]
for res in allres:
    sc.tl.leiden(adata, resolution=res, key_added=f'res{res}', neighbors_key = umapKey)

sc.pl.umap(adata, color=[f'res{res}' for res in allres])

In [None]:

fig, ax = plt.subplots( figsize=(3,3))
x=adata.obsm["X_umap"][:,0]
y=adata.obsm["X_umap"][:,1]

obsRes = "anno"
clust2Col = dict(zip(adata.obs[obsRes].cat.categories,adata.uns[f"{obsRes}_colors"]))
c=[clust2Col[l] for l in adata.obs[obsRes]]
ax.scatter(x, y, c=c,  s=1)# edgecolors=edgecolors, linewidths=linewidths,alpha=alpha)
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xticks([])
ax.set_yticks([])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.set_title(f"True Clusters ({len(adata.obs[obsRes].cat.categories)})")

plt.tight_layout()
plt.show()

fig.savefig(f"{saveDir}/umap_Fig1TrueClust.png", format="png")

In [None]:
clus7Col = dict(zip([str(i) for i in range(7)],  ['#ff7f0e','#8c564b','#1f77b4','#279e68','#e377c2','#17becf','#d62728']))
clus9Col = dict(zip([str(i) for i in range(9)],  ["#98df8a",'#8c564b','#1f77b4','#ff7f0e','#279e68','#b5bd61','#e377c2','#d62728',"#17becf"]))
clus10Col = dict(zip([str(i) for i in range(10)],['#ff7f0e','#8c564b','#1f77b4','#279e68',"#ff9896",'#b5bd61','#e377c2','#d62728','#c5b0d5',"#17becf"]))
colDicts = [clus7Col,clus9Col,clus10Col]

fig, axs = plt.subplots(1,3, figsize=(9,3))
x=adata.obsm["X_umap"][:,0]
y=adata.obsm["X_umap"][:,1]

clust2Col = dict(zip(adata.obs[f"res{allres[-1]}"].cat.categories,adata.uns[f"res{allres[-1]}_colors"]))
for i,ax in enumerate(axs.flatten()):
    obsRes = f"res{allres[i]}"
    c=[colDicts[i][l] for l in adata.obs[obsRes]]
    ax.scatter(x, y, c=c,  s=1)# edgecolors=edgecolors, linewidths=linewidths,alpha=alpha)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.set_title(f"# Leiden clusters: {len(adata.obs[obsRes].cat.categories)}")
    
plt.tight_layout()
plt.show()

#if("/" in title):
#    title = "".join(title.split("/"))

fig.savefig(f"{saveDir}/umap_umap_Fig1manyClust.png", format="png")

In [None]:
from sklearn import metrics

In [None]:
for res in [f'res{res}' for res in allres]:
    metNames = ["LSS","FM","ARI"]
    metVals = [lssAUC, metrics.fowlkes_mallows_score(adata.obs[labelName], adata.obs[res]),metrics.adjusted_rand_score(adata.obs[labelName], adata.obs[res])]
    metClus = [f"# clusters: {len(adata.obs[res].cat.categories)}"]
    vizAllStats(pd.DataFrame(metVals, index=metNames, columns=metClus), h=2, w=1.5, name=f"Fig1met{res}", order=False, save=saveDir)

## Muscle

In [None]:
#"MusclehumanmouseNU":
dirName="../../results/allParamsOUT"
datasetName="filename~MusclehumanmouseNU"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]
adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
print(f"batchDict = {batchDict}")

batchDict = {'mouse': ('Mouse', '#050505', "o"), 'human': ('Human', '#c5c7c9', "s")}

print(cellTypeDict)

In [None]:
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, 
            name="muscle", vizStatsOrder=True, qCut=0.2, saveDir = saveDir)

## Lung

In [None]:
#"lunghumanmouse":
dirName="../../results/allParamsOUT"
datasetName="filename~lunghumanmouse"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

sc.pl.umap(adata, color=[batchName, labelName], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
print(f"batchDict = {batchDict}")

batchDict = {'mouse': ('Mouse', '#050505', "o"), 'human': ('Human', '#c5c7c9', "s")}

print(cellTypeDict)

In [None]:
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, qCut=0.2, name="lung", saveDir = saveDir)

## Liver

In [None]:
#"Liverhumanmouse":
dirName="../../results/allParamsOUT"
datasetName="filename~Liverhumanmouse"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

sc.pl.umap(adata, color=[batchName, labelName], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
print(f"batchDict = {batchDict}")

batchDict = {'mouse': ('Mouse', '#050505', "o"), 'human': ('Human', '#c5c7c9', "s")}

print(cellTypeDict)

In [None]:
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, qCut=0.25, name="liver", saveDir = saveDir)

## Pancreas

In [None]:
#"pancreashumanmouse":
dirName="../../results/allParamsOUT"
datasetName="filename~pancreashumanmouse"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

sc.pl.umap(adata, color=[batchName, labelName], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"]))))
print(f"batchDict = {batchDict}")

batchDict = {'mouse': ('Mouse', '#050505', "o"), 'human': ('Human', '#c5c7c9', "s")}

print(cellTypeDict)

In [None]:
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, name="pancreas", saveDir = saveDir)

## Bladder

In [None]:
#"bladderMouseHuman":
dirName="../../results/allParamsOUT"
datasetName="filename~bladderMouseHuman"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
allStats = getAllStats(dirName,datasetName,paramName)

sc.pl.umap(adata, color=[batchName, labelName], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"], "os"))))
print(f"batchDict = {batchDict}")

batchDict = {'mouse': ('Mouse', '#050505',"o"), 'human': ('Human', '#c5c7c9',"s")}

print(cellTypeDict)

In [None]:
makeFigures(adata, batchName, labelName, cellTypeDict, batchDict, allStats, qCut= 0.25, name="bladder", saveDir = saveDir)

## UPS

In [None]:
#"LUADmouseHumanBKHDM":
dirName="../../results/allParamsOUT"
datasetName="filename~mouseHuman_PDX_tum_labeled_G1_outer"
paramName=os.listdir(f"{dirName}/{datasetName}")[0]

adata, batchName, labelName = makeAdata(dirName,datasetName,paramName)
labelName="chemo"
allStats = getAllStats(dirName,datasetName,paramName)

In [None]:
pcaRep = "X_scVital"
umapKey="scVital"
neighborsKey="scVital"
sc.pp.neighbors(adata, n_pcs=12, use_rep=pcaRep, key_added=umapKey)
sc.tl.leiden(adata, resolution=0.7, key_added = neighborsKey, neighbors_key = umapKey)#, flavor="igraph", n_iterations=2,  directed=False)
sc.tl.umap(adata, neighbors_key = umapKey)
sc.pl.umap(adata, color=[batchName, labelName, "scVital"], show=False)

pairs = adata.uns["pairs"]
cellTypeDict = dict(zip(adata.obs[labelName].cat.categories, list(zip(adata.obs[labelName].cat.categories, adata.uns[f"{labelName}_colors"]))))
overlapDict = dict(zip(adata.obs["overlapLabel"].cat.categories, list(zip(adata.obs["overlapLabel"].cat.categories, adata.uns[f"overlapLabel_colors"]))))
gpairs = group_pairs(pairs)

for gpair in gpairs:
    for csc in gpair:
        if csc in overlapDict.keys():
            overColor = overlapDict[csc][1]
            break
    for csc in gpair:
        cellTypeDict[csc] = (csc, overColor)

batchDict = dict(zip(adata.obs[batchName].cat.categories, list(zip(adata.obs[batchName].cat.categories, adata.uns[f"{batchName}_colors"],"osp"))))
print(f"batchDict = {batchDict}")
print(f"cellTypeDict = {cellTypeDict}")

In [None]:
cellTypeDict = {'ctl': ('Vehicle', '#009E73'), 
                'wk1': ('Short Chemo', '#E69F00'), 
                'wk3/4': ("Long Chemo", '#D55E00')}

batchDict = {'Mouse': ('Mouse', '#000000', 'o'), 
             'UPS2236': ('UPS2236', '#DCDADA', 's'), 
             'MFH9': ('MFH9', '#DCDADA', 'p')}

scVitalUmap, pcaUmap, legendEle, ctColors, btColors, batchToColorDict, annoToColorDict = runUmapGetColors(adata, batchName, labelName, 
                                                                                                          cellTypeDict=cellTypeDict, 
                                                                                                          batchDict=batchDict, name="sarcChemo", save=saveDir)
#plotAllUMAP(scVitalUmap, pcaUmap, ctColors, btColors, reorder=True, name="sarcChemoAll", save=saveDir)

In [None]:
plotAllUMAP(scVitalUmap, pcaUmap, ctColors, btColors, reorder=True, name="sarcChemoAll", save=saveDir)

In [None]:
def plotDoubleUMAP(scVitalUmap, pcaUmap, ctColors, btColors, reorder, name="", save=False, **kwargs):
    if(reorder):
        newOrder = reorderColors(ctColors)
        ctColors = ctColors[newOrder]
        btColors = btColors[newOrder]
        pcaUmap = pcaUmap[newOrder,:]
        scVitalUmap = scVitalUmap[newOrder,:]

    s=1
    fig, axs = plt.subplots(1,2, figsize=(4,2), dpi=300)
        
    axs[0].scatter(x=pcaUmap[:, 0],y=pcaUmap[:, 1],         c=btColors, s=s)#, edgecolors=btColors, **kwargs)#, linewidths=linewidths, s=s, alpha=alpha)
    axs[0].patch.set_facecolor("#fced62")
    axs[0].patch.set_alpha(0.1)
    axs[0].set_title("No Integration")
    
    axs[1].scatter(x=scVitalUmap[:, 0],y=scVitalUmap[:, 1], c=btColors, s=s)#, edgecolors=btColors, **kwargs)#, linewidths=linewidths, s=s, alpha=alpha)
    axs[1].patch.set_facecolor("#6b3b8a")
    axs[1].patch.set_alpha(0.05)
    axs[1].set_title("scVital")


    for ax in axs.flatten():
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
    
    plt.tight_layout()
    plt.show()
    if(save):
        fig.savefig(f"{save}/UMAPsAll_{name}integ.png", format="png",dpi=300, pad_inches=0)


plotDoubleUMAP(scVitalUmap, pcaUmap, ctColors, btColors, True, name="", save=saveDir)

In [None]:
def plotOneUmap(title,x, y, c, edgecolors, name="", reorder=True, save=False, **kwargs): 
    if(reorder):
        newOrder = reorderColors(c)
        x=x[newOrder]
        y=y[newOrder]
        c=c[newOrder]
        edgecolors = edgecolors[newOrder]

    fig, ax = plt.subplots()

    ax.scatter(x, y, c=c, edgecolors=edgecolors, **kwargs)#, 
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.set_title(title)
    
    plt.tight_layout()
    plt.show()

    if("/" in title):
        title = "".join(title.split("/"))

    if(save):
        fig.savefig(f"{save}/umap_{title}_{name}.png", format="png")

clusterLabel="scVital"
clust2Col = {'0': '#1f77b4',
 '1': '#d62728',
 '2': '#279e68',
 '3': '#ff7f0e',
 '4': '#aa40fc',
 '5': '#8c564b',
 '6': '#e377c2',
 '7': '#b5bd61',
 '8': '#17becf',
 '9': '#aec7e8'}#dict(zip(adata.obs[clusterLabel].cat.categories,adata.uns[f"{clusterLabel}_colors"]))
c=np.array([clust2Col[l] for l in adata.obs[clusterLabel]])
x=scVitalUmap[:, 0]
y=scVitalUmap[:, 1]
plotOneUmap("UPS", x, y, c, edgecolors=btColors, name="UPSLeiden", save=saveDir, linewidths=0.2, s=10, alpha=0.7)

In [None]:
cellTypeDict = {'ctl': ('Vehicle', '#009E73'), 
                'wk1': ('Short Chemo', '#E69F00'), 
                'wk3/4': ("Long Chemo", '#D55E00')}

clustBatch = adata.obs[["scVital","chemo"]]
df2StackBar(clustBatch, "scVital", cellTypeDict, "chemo", save = False)#saveDir)

In [None]:
def rgba_to_hex(values):
    r, g, b, a = values
    
    """Converts RGBA values to a hexadecimal color code."""

    return f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}"


adatar = adata.raw.to_adata()

titles = ["BNIP3/Bnip3","SLC2A1/Slc2a1"]

for gene in titles:

    geneExpr = np.asarray(adatar[:,adatar.var_names==gene].X.todense()).flatten()
    cmap = plt.get_cmap("Reds")
    norm = plt.Normalize(min(geneExpr), max(geneExpr))
    ctPlotColors = np.array([rgba_to_hex(cmap(norm(ex))) for ex in geneExpr])
        
    plotOneUmap(gene, x=scVitalUmap[:, 0], y=scVitalUmap[:, 1], c=ctPlotColors, edgecolors=btColors, name="UPSHypox",  reorder=True, save=saveDir, linewidths=0.2, s=10)

    plotCbar(gene,"UPSHypox", norm, cmap)