# Calculating Functions

In [None]:
# Function that finds each type of gene according to an input reference list
    # data = anndata data object (AnnData object)
    # geneType = string with the name of the gene type you are looking for (string)
    # referenceLst = list with all the gene names you are looking for in your data set(list)
def findGenes(data, geneType, prefix):
    genesFound = []
    geneCount = 0
    print(f'\n{geneType} GENES PRESENT IN SAMPLE') 
    for gene in data.var_names:
        if gene.startswith(prefix):
            print(str(gene))
            genesFound.append(str(gene))
            geneCount += 1
    print(f'{geneType} Gene Count: {geneCount}')
    return(genesFound, geneCount)

# Plotting Functions

In [None]:
def violinFilterPlot(data, parameter, filterLine, tickList, labelLst, yLabel):
    
    samples = data.obs['sample'].unique()
    plotData = []
    xPos = [num for num in range(len(samples))]

    for sample in samples:
        sampleObj = data[data.obs['sample'].isin([sample]),:]
        sampleData = list(sampleObj.obs[parameter]) 
        plotData.append(sampleData)

    logPlotData = []

    for dataLst in plotData:
        logSampleData = []
        for dataPoint in dataLst:
            if dataPoint != 0:
                logPoint = math.log(dataPoint, 10)
                if logPoint <= -2:
                    logSampleData.append(-2)
                else:
                    logSampleData.append(logPoint)
            else:
                logSampleData.append(-2)
        logPlotData.append(logSampleData)

    axisFontWeight = 'bold'
    fontSize = 20

    fig, ax = plt.subplots(figsize = (10, 10))
    
    if filterLine != None: 
        ax.axhline(y=filterLine, color='red', linestyle='--')
    
    ax.violinplot(logPlotData, xPos)
    ax.set_xlabel('Samples', fontweight = axisFontWeight, fontsize = fontSize)
    ax.set_ylabel(f'{yLabel}', fontweight = axisFontWeight, fontsize = fontSize)
    ax.set_title(f'{yLabel} vs. Samples', fontweight = axisFontWeight, fontsize = fontSize)
    ax.set_xticks(xPos, samples, rotation=90)
    if tickList != None and labelLst != None:
        ax.set_yticks(tickList, labelLst)

In [None]:
def histFilterPlot(data, parameter, filterLine, log, binNum, xLabel, yLabel, title):
    
    plotData = list(data.obs[parameter])
    
    fig, ax = plt.subplots(figsize = (10, 10))
    
    if filterLine != None: 
        ax.axvline(x=filterLine, color='red', linestyle='--')
    
    ax.hist(plotData, bins = binNum)
    
    ax.set_xlabel(xLabel)
    ax.set_ylabel(yLabel)
    if log == True:
        ax.set_yscale('log')
    ax.set_title(title)
    

In [1]:
def doubletBar(data):
    # Getting samples names
    sampleNames = list(data.obs['sample'].unique())
    xTicks = [num for num in range(len(sampleNames))]
    
    # Create a figure and axis
    fig, ax = plt.subplots(figsize = (10, 10))
    
    # Iterating through each sample
    for sample, xTick in zip(sampleNames, xTicks):
        # AnnData object of the sample alone
        sampleObj = data[data.obs['sample'].isin([sample]),:]
        
        # Getting doublet info 
        nuclei = list(sampleObj.obs['predicted_doublets']).count('False') # Value for the first segment
        doublets = list(sampleObj.obs['predicted_doublets']).count('True')  # Value for the second segment
        
        # Set the width of the bars
        width = 0.35

        # Plot the first segment
        ax.bar(xTick, nuclei, width, color='xkcd:soft blue')

        # Plot the second segment on top of the first segment
        ax.bar(xTick, doublets, width, bottom=nuclei, color='xkcd:vermillion')

    # Add labels and title
    ax.set_xlabel('Samples')
    ax.set_xticks(xTicks, sampleNames, rotation = 90)
    ax.set_ylabel('Nuclei/Doublet Count')
    ax.set_title('Nuclei/Doublet Count vs. Samples')

    # Creating Legend 
    patchA = mpatches.Patch(color='xkcd:soft blue', label='Nuclei Counts')
    patchB = mpatches.Patch(color='xkcd:vermillion', label='Doublet Counts')
    ax.legend(handles=[patchA, patchB], loc='best')

In [6]:
def lessThanThreeCellPlot(data):
    geneCellCounts = list(data.var['n_cells_by_counts'])
    
    # Create a figure and axis
    fig, ax = plt.subplots(figsize = (10, 10))
    
    logGeneCellCounts = []

    for count in geneCellCounts:
        if count == 0:
            logGeneCellCounts.append(-2)
        else:
            logCount = math.log(count, 10)
            logGeneCellCounts.append(logCount)
    
    ax.axvline(x=0.47712125472, color='red', linestyle='--') 
    
    ax.hist(logGeneCellCounts, bins = 100)
    
    ax.set_xlabel('Number of Nuclei (Log Scale Data)')
    ax.set_ylabel('Gene Counts')
    ax.set_xticks([-2, -1, 0, 1, 2, 3, 4, 5], ['"0"', '0.1', '1', '10', '100', '1000', '10000', '10000'])
    ax.set_title('Number of Nuclei a Gene is Found In vs. Gene Counts')

In [None]:
def sizeFactorsPlot(data):
    sizeFactors = list(data.obs['size_factors'])
    
    # Create a figure and axis
    fig, ax = plt.subplots(figsize = (10, 10))
    
    ax.hist(sizeFactors)
    ax.set_xlabel('Size Facotrs')
    ax.set_ylabel('Number of Nuclei')
    ax.set_title('Number of Nuclei vs. Size Factor')