In [None]:
import numpy as np
import pandas as pd
import anndata as ad
from matplotlib import pyplot as plt
import seaborn as sns
import requests
import sys
import patchworklib as pw
import urllib 
import os
import gzip
import shutil

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from joblib import Parallel, delayed
from itertools import combinations
from threading import Thread

In [None]:
def download_test_data() -> str: #The absolute path of the extracted file 'sample_isomatrix.txt' if the download is successful.
    """
    This function downloads a test data file from a specified URL, saves it locally, and extracts it.
    """

    # URL of the file to be downloaded
    url = "https://ftp.ncbi.nlm.nih.gov/geo/samples/GSM3748nnn/GSM3748087/suppl/GSM3748087%5F190c.isoforms.matrix.txt.gz"

    # Download the file from `url` and save it locally under `file.txt.gz`:
    urllib.request.urlretrieve(url, 'file.txt.gz')

    # Check if the file is downloaded correctly
    if os.path.exists('file.txt.gz'):
        print("File downloaded successfully")
        # Now we need to extract the file
        with gzip.open('file.txt.gz', 'rb') as f_in:
            with open('sample_isomatrix.txt', 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        print("File extracted successfully")
        return os.path.abspath('sample_isomatrix.txt')
    else:
        print("Failed to download the file")
        return None 
    
path = download_test_data()

def create_iso_adata (path: str # path to a tab separated file, sicelore output with transcript counts and columns transcriptId, geneId
):

    
    with open(path) as isoforms:
        df = pd.read_table(isoforms, sep='\t')

    counts = df.iloc[0:,2:]
    adata_iso = ad.AnnData(counts).transpose()
    adata_iso.var_names = df['transcriptId'].to_list()
    adata_iso.var = df[['transcriptId', 'geneId']]
    barcodes = {'barcodes': df.columns.values[2:]}   
    barcodes = pd.DataFrame(data=barcodes)
    adata_iso.obs = barcodes
    adata_iso.obs_names = barcodes['barcodes'].tolist()
    return(adata_iso)

adata_iso = create_iso_adata(path)

File downloaded successfully
File extracted successfully


In [None]:
cell_type = pd.read_csv('./annot_190c.csv', sep='\t')
adata_iso.obs['cell_type'] = cell_type['leiden'].tolist()

In [None]:
class AnnDataIso(ad.AnnData):
    
    def __filter_isodata (self):
    #     ##
        genes, frequencies = np.unique(self.var['geneId'], return_counts=True)
        df = pd.DataFrame({"geneId": genes, "frequencies": frequencies})
        multi_iso_genes = df[df["frequencies"]>1]['geneId'].tolist()
        filtr_adata_iso = self[:, self.var['geneId'].isin(multi_iso_genes)]
        return filtr_adata_iso
    
    
    def iso_percent(self, df, barcodes_regex="^[ACGT]+$"):
        iso_perc_df = df.__deepcopy__()
        if len(iso_perc_df.filter(regex=(barcodes_regex)).columns.to_list()) < 1:
            raise ValueError("No barcode was identified. Please check the names of the columns.")
        iso_perc_df[iso_perc_df.filter(regex=(barcodes_regex)).columns.to_list()] = iso_perc_df.filter(regex=(barcodes_regex)) / iso_perc_df.groupby(['geneId']).transform('sum').filter(regex=(barcodes_regex))
        iso_perc_df = iso_perc_df.replace(np.nan, 0.0)
        return iso_perc_df
        
    def __init__(self, anndata: ad.AnnData, cell_types: pd.DataFrame):
        self._init_as_actual(anndata.copy())
        self.gene_counts = self.var.reset_index().groupby(by='geneId').count()
        self.__filtered_anndata = self.__filter_isodata()
        self.obs['cell_type'] = cell_types
        df = self.__filtered_anndata.to_df().set_index(self.__filtered_anndata.obs['barcodes'])
        df = df.transpose()
        df[['transcriptId', 'geneId']] = self.__filtered_anndata.var
        df_m_iso = self.iso_percent(df)
        df_m_iso = df_m_iso.iloc[0:,:-2].transpose()
        self.__filtered_anndata.obsm['Iso_prct'] = df_m_iso

    def plot_isoforms_summary(self):
        ax1 = pw.Brick(figsize=(4,4))
        self._plot_switch_gen_bar(ax1)
        ax1.set_title("Multiple isoforms genes %")
        ax2 = pw.Brick(figsize=(4,4))
        self._plot_isoforms_frequencies(ax2)
        ax2.set_title("Frequency of isoforms per gene")
        ax3 = pw.Brick(figsize=(3,2))
        self._plot_genes_cell_type(ax3)
        ax3.set_title("Nb of genes per cell type")
        return (ax1|ax2|ax3).savefig()

    def _plot_genes_cell_type (self, _ax): #mdata with annotated cell types in isoform and gen anndata objects
        if _ax is None:
            df = pd.DataFrame(np.transpose(self.X), columns=self.obs['cell_type'])
            df = df.sum(axis = 0).to_frame().reset_index()
            df.columns = ['cell_type', 'n_of_genes']
            ax = sns.boxplot(x='cell_type', y='n_of_genes', data=df) 
            ax1 = sns.stripplot(x='cell_type', y='n_of_genes', data=df, color = 'black', size = 3)
            ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
            plt.show()
        else:
            df = pd.DataFrame(np.transpose(self.X), columns=self.obs['cell_type'])
            df = df.sum(axis = 0).to_frame().reset_index()
            df.columns = ['cell_type', 'n_of_genes']
            sns.boxplot(x='cell_type', y='n_of_genes', data=df, ax=_ax) 
            sns.stripplot(x='cell_type', y='n_of_genes', data=df, color = 'black', size = 3, ax=_ax)
            _ax.set_xticklabels(_ax.get_xticklabels(), rotation=90)

    def plot_genes_cell_type(self):
        self._plot_genes_cell_type(None)
    
    def _plot_isoforms_frequencies (self, _ax):
        if _ax is None:
            fig, ax = plt.subplots()
            self.gene_counts['transcriptId'].value_counts().plot(ax=ax, kind='bar', 
                                                                xlabel='number of isoforms per gene', ylabel='quantity of genes')
        else:
            self.gene_counts['transcriptId'].value_counts().plot(ax=_ax, kind='bar', 
                                                             xlabel='number of isoforms per gene', ylabel='quantity of genes')
    def plot_isoforms_frequencies(self):
        self._plot_isoforms_frequencies(None)

    def _plot_switch_gen_bar (self, _ax): 
        if _ax is None:
            iso_per_gene = self.gene_counts
            x = ['genes']
            x1 = ['transcripts']
            multiple_iso = sum(iso_per_gene['transcriptId'] > 1)
            mono_iso = sum((iso_per_gene['transcriptId'] > 1) == False)
            labels = [str(round(1000*multiple_iso/(multiple_iso+mono_iso))/10) + '%', str(round(1000*mono_iso/(multiple_iso+mono_iso))/10) + '%']
            fig, ax = plt.subplots()
            mult = ax.bar(x, multiple_iso, color = 'deepskyblue', label=labels[1]) #multiple_iso/(multiple_iso+mono_iso))
            mono = ax.bar(x, mono_iso, bottom=multiple_iso, color='sandybrown', label=labels[0]) #mona_iso/(multiple_iso+mono_iso))
            tran = ax.bar(x1, len(self.var['transcriptId']))
            ax.text(
                ax.patches[0].get_x() + ax.patches[0].get_width() / 2, ax.patches[0].get_height() / 2, labels[0], ha="center", va="center"
                )
            ax.text(
                ax.patches[1].get_x() + ax.patches[1].get_width() / 2, ax.patches[1].get_height() / 2 + ax.patches[0].get_height(), labels[1], ha="center", va="center"
                )
            plt.legend(['Multiple isoforms', 'Single isoform'])
            plt.show()
        else:
            iso_per_gene = self.gene_counts
            x = ['genes']
            x1 = ['transcripts']
            multiple_iso = sum(iso_per_gene['transcriptId'] > 1)
            mono_iso = sum((iso_per_gene['transcriptId'] > 1) == False)
            labels = [str(round(1000*multiple_iso/(multiple_iso+mono_iso))/10) + '%', str(round(1000*mono_iso/(multiple_iso+mono_iso))/10) + '%']
            mult = _ax.bar(x, multiple_iso, color = 'deepskyblue', label=labels[1]) #multiple_iso/(multiple_iso+mono_iso))
            mono = _ax.bar(x, mono_iso, bottom=multiple_iso, color='sandybrown', label=labels[0]) #mona_iso/(multiple_iso+mono_iso))
            tran = _ax.bar(x1, len(self.var['transcriptId']))
            _ax.text(
                _ax.patches[0].get_x() + _ax.patches[0].get_width() / 2, _ax.patches[0].get_height() / 2, labels[0], ha="center", va="center"
                )
            _ax.text(
                _ax.patches[1].get_x() + _ax.patches[1].get_width() / 2, _ax.patches[1].get_height() / 2 + _ax.patches[0].get_height(), labels[1], ha="center", va="center"
                )
            _ax.legend(['Multiple isoforms', 'Single isoform'])
            plt.show()
    
    def plot_switch_gen_bar (self): # takes as input mdata object
        self._plot_switch_gen_bar (None)

    def _plot_transcripts_per_cell_type(self, gene_name, _ax):
        if _ax is None:
            grouped = self.__filtered_anndata.obsm['Iso_prct']
            grouped['cell_type'] = self.obs['cell_type']
            res = grouped.groupby('cell_type').mean().transpose()
            res = res.assign(transcriptId=self.__filtered_anndata.var['transcriptId'].to_list())
            res = res.assign(geneId=self.__filtered_anndata.var['geneId'].to_list())
            res = res[res['geneId'] == gene_name].drop(['geneId'], axis=1)
            plot_data = res.set_index('transcriptId').transpose()
            plot_data.plot(kind='barh', stacked=True).legend(loc='center left',bbox_to_anchor=(1.0, 1.0))
        else:
            grouped = self.__filtered_anndata.obsm['Iso_prct']
            grouped['cell_type'] = self.obs['cell_type']
            res = grouped.groupby('cell_type').mean().transpose()
            res = res.assign(transcriptId=self.__filtered_anndata.var['transcriptId'].to_list())
            res = res.assign(geneId=self.__filtered_anndata.var['geneId'].to_list())
            res = res[res['geneId'] == gene_name].drop(['geneId'], axis=1)
            plot_data = res.set_index('transcriptId').transpose()
            plot_data.plot(kind='barh', stacked=True, ax=_ax).legend(loc='center left',bbox_to_anchor=(1.0, 1.0))
            
    def plot_transcripts_per_cell_type(self, gene_name):
        self._plot_transcripts_per_cell_type(self, gene_name, None)
    
    def _trsct_counts_cell_type (self, gene_name, _ax):
        if _ax is None:
        # create df with filtered isoform counts and labeled cell types:
            df = self.__filtered_anndata.to_df().set_index(self.__filtered_anndata.obs['cell_type'])
            df = df.transpose() 
            df[['transcriptId', 'geneId']] = self.__filtered_anndata.var
            gene_iso_count = df[df['geneId']== gene_name]
            gene_iso_count = gene_iso_count.drop('geneId', axis=1).set_index('transcriptId').transpose()
            gene_iso_count_long = gene_iso_count.reset_index().melt(id_vars='cell_type', var_name='transcriptId', value_name='count')
            g = sns.catplot(x="cell_type", y="count", col="transcriptId", aspect=1, dodge=False, kind="violin", data=gene_iso_count_long)
            # Set custom facet titles
            g.set_titles(col_template="{col_name}", size = 8)
            # Remove x ticks
            g.set_xticklabels(rotation=90)
            g.fig.suptitle(gene_name)
            plt.show()
        else:
            df = self.__filtered_anndata.to_df().set_index(self.__filtered_anndata.obs['cell_type'])
            df = df.transpose() 
            df[['transcriptId', 'geneId']] = self.__filtered_anndata.var
            gene_iso_count = df[df['geneId']== gene_name]
            gene_iso_count = gene_iso_count.drop('geneId', axis=1).set_index('transcriptId').transpose()
            gene_iso_count_long = gene_iso_count.reset_index().melt(id_vars='cell_type', var_name='transcriptId', value_name='count')
            g = sns.catplot(x="cell_type", y="count", col="transcriptId", aspect=1, dodge=False, kind="violin", data=gene_iso_count_long)
            # Set custom facet titles
            g.set_titles(col_template="{col_name}", size = 8)
            # Remove x ticks
            g.set_xticklabels(rotation=90, labels=self.obs['cell_type'].unique())
            g.fig.suptitle(gene_name)
            return g

    
    def trsct_counts_cell_type (self, gene_name):
        # create df with filtered isoform counts and labeled cell types:
        self._trsct_counts_cell_type(gene_name, None)

    def __get_coord_from_tscrpt_id(self, transcript_id):
        if '.' in transcript_id:
            transcript_id = transcript_id.split('.')[0]
        server = "https://rest.ensembl.org"
        ext = "/lookup/id/" + transcript_id + "?expand=1"
    
        r = requests.get(server+ext, headers={ "Content-Type" : "application/json"})
    
        if not r.ok:
            r.raise_for_status()
            sys.exit()
    
        decoded = r.json()
        exon_list = list(decoded['Exon'])
        exon_coord = []
        for i, e in enumerate(exon_list):
            coord = [e.get('end'), e.get('start')]
            exon_coord.append(coord)
        strand = decoded['strand']
        return(exon_coord, strand)

    def __draw_exons(self, exons, direction, color, transcript_name, offset=0, start_override=None, end_override=None, no_render=False):
        if not no_render:
            plt.axes()
            plt.xlim((-0.1, 1))
            plt.ylim((-0.3, 0.3))
            plt.margins(0.2)
            plt.axis('off')
            fig = plt.gcf()
            fig.set_size_inches(20, 2)
        height = 0.2
        plt.plot([offset + 0.1, offset + 0.1], linestyle='solid', linewidth=0.5, c='grey')
        j = 0
        k = 1
        if direction == 1:
            pos_start = exons[0][1]
            pos_end = exons[-1][0]
        else: #direction == -1
            pos_start = exons[-1][1]
            pos_end = exons[0][0]
            j = 1
            k = 0
        real_start = pos_start
        real_end = pos_end
        if start_override is not None and end_override is not None:
            pos_start = start_override
            pos_end = end_override
        total_length = pos_end - pos_start
        total_length_with_margin = 1.05 * total_length
        pos_start_with_margin = pos_start - 0.025*total_length
        for i, exon in enumerate(exons):
            rectangle = plt.Rectangle(((exon[j] - pos_start_with_margin)/total_length_with_margin,offset), (exon[k] - exon[j])/total_length_with_margin, height, fc=color,ec="black")
            plt.gca().add_patch(rectangle)
        if i > 0:
            arrow = None
            if direction < 0:
                arrow = plt.arrow(1, offset - height/4, -1, 0, width=0.0015, head_length=0.01, head_width=0.1, length_includes_head=True, overhang=1)
            else:
                arrow = plt.arrow(0, offset - height/4, 1, 0, width=0.0015, head_length=0.01, head_width=0.1, length_includes_head=True, overhang=1)
            plt.gca().add_patch(arrow)
        plt.plot(np.array([0.025 + (real_start - pos_start) / (total_length)/1.05, 0.025 + (real_start - pos_start) / (total_length)/1.05]), np.array([offset - height/4 - 0.03, offset - height/4 + 0.03]), color='black')
        plt.plot(np.array([1 - 0.025 - (pos_end - real_end) / (total_length)/1.05, 1 - 0.025 - (pos_end - real_end) / (total_length)/1.05]), np.array([offset - height/4 - 0.03, offset - height/4 + 0.03]), color='black')
        plt.text(0.025 + (real_start - pos_start) / (total_length)/1.05, offset - height/4 - 0.075, real_start, horizontalalignment='center', verticalalignment='center', fontsize=9)
        plt.text(1 - 0.025 - (pos_end - real_end) / (total_length)/1.05, offset - height/4 - 0.075, real_end, horizontalalignment='center', verticalalignment='center', fontsize=9)
        plt.text(1, offset - height, transcript_name, horizontalalignment='right', verticalalignment='top', fontsize=12)
        if not no_render:
            plt.show()
    
    def __get_transcripts_from_gene(self, gene_name):
        elems = self.__filtered_anndata.var
        return elems[elems['geneId'] == gene_name]['transcriptId'].to_list()

    def __draw_transcripts_list(self, gene_name, _ax, colors=None):
        transcripts_id = self.__get_transcripts_from_gene(gene_name)
        exons = []
        directions = []
        for tr in transcripts_id:
            t, d = self.__get_coord_from_tscrpt_id(tr)
            exons += [t]
            directions += [d]
        if colors is None:
            colors = []
            for i in range(len(exons)):
                colors.append(['lightblue', 'lightgreen', 'orange', 'yellow', 'brown'][i % 5])
        def get_limits(ex, dir):
            start = sys.maxsize
            end = -sys.maxsize
            for (e, d) in zip(ex, dir):
                if d == 1:
                    start = min(start,e[0][1])
                    end = max(end, e[-1][0])
                else:
                    start = min(start,e[-1][1])
                    end = max(end, e[0][0])
            return (start, end)
        plt.axes()
        plt.xlim((-0.1, 1.1))
        plt.ylim((0.1 - 0.5 *  len(exons), 0.3))
        plt.margins(0.2)
        plt.axis('off')
        fig = plt.gcf()
        fig.set_size_inches(20, len(exons) * 2)
        i = 0
        (start, end) = get_limits(exons, directions)
        for (ex, di, co, name) in zip(exons, directions, colors, transcripts_id):
            self.__draw_exons(ex, di, co, name, offset= -0.5 * i, start_override=start, end_override=end, no_render=True)
            i+=1
        if _ax is None:
            plt.show()
        else: 
            return plt

    def draw_transcripts_list(self, gene_name, colors=None):
        self.__draw_transcripts_list(gene_name, None, colors)
    
    def draw_gene_summary(self, gene_name):
        ax1 = pw.Brick(figsize=(12,4))
        pw.overwrite_axisgrid()
        fg = self._trsct_counts_cell_type(gene_name, ax1)
        ax4 = pw.load_seaborngrid(fg)
        ax1.set_title("Transcripts count per cell type")
        ax2 = pw.Brick(figsize=(12,4))
        self._plot_transcripts_per_cell_type(gene_name, ax2)
        ax2.set_title("Transcripts count per cell type")
        ax3 = pw.Brick(figsize=(12,4))
        self.__draw_transcripts_list(gene_name, ax3)
        ax3.set_title("Transcripts list")
        return (ax4/(ax2/ax3)).savefig()
    
    def __model(self, data, total_counts):
        alpha = pyro.param('alpha', torch.ones(data.size(1)), constraint=dist.constraints.positive)
        with pyro.plate('data_%i' % (data.size(0)), data.size(0)):
            pyro.sample('obs', dist.DirichletMultinomial(concentration=alpha, is_sparse=True, total_count=total_counts), obs=data)
    def __guide(self, data, total_counts):
        pass
    def __perform_mle(self, data, total_counts, num_steps=10000, lr=0.01, tolerance=1e-4, patience=10):
        pyro.clear_param_store()
        optimizer = Adam({"lr": lr})
        svi = SVI(self.__model, self.__guide, optimizer, loss=Trace_ELBO())
    
        best_loss = float('inf')
        patience_counter = 0
    
        for step in range(num_steps):
            loss = svi.step(data, total_counts)
    
            if loss < best_loss - tolerance:
                best_loss = loss
                patience_counter = 0
            else:
                patience_counter += 1
    
            if patience_counter >= patience:
                break
    
        alpha_est = pyro.param('alpha').detach().cpu().numpy()
        return best_loss, alpha_est
    def __LRT_test(self, data1, data2):
        total_counts1 = data1.sum(dim=-1).float()
        total_counts2 = data2.sum(dim=-1).float()
        combined_data = torch.cat([data1, data2], dim=0)
        combined_counts = torch.cat([total_counts1, total_counts2], dim=0)
    
        loss_full, alpha_full = self.__perform_mle(combined_data, combined_counts)
        loss1, alpha1 = self.__perform_mle(data1, total_counts1)
        loss2, alpha2 = self.__perform_mle(data2, total_counts2)
    
        chi2_stat = 2 * (loss_full - (loss1 + loss2))
        chi2_stat_tensor = torch.tensor(chi2_stat)
    
        if chi2_stat_tensor.item() < 0:
            return float('nan'), float('nan'), alpha_full, alpha1, alpha2
    
        p_value = 1 - torch.distributions.Chi2(df=data1.size(1)).cdf(chi2_stat_tensor).item()
        
        return chi2_stat, p_value, alpha_full, alpha1, alpha2
    
    def __compare_groups(self, group_1_label, group_2_label, cell_group_column, gene_id):
        group_1 = self[self.obs[cell_group_column] == group_1_label]
        group_2 = self[self.obs[cell_group_column] == group_2_label]
    
        data1 = torch.tensor(group_1[:, group_1.var['geneId'] == gene_id].X.toarray(), dtype=torch.float)
        data2 = torch.tensor(group_2[:, group_2.var['geneId'] == gene_id].X.toarray(), dtype=torch.float)
        
        total_counts1 = data1.sum(dim=-1)
        total_counts2 = data2.sum(dim=-1)
        non_zero_indices1 = total_counts1 > 9
        non_zero_indices2 = total_counts2 > 9
        data1 = data1[non_zero_indices1]
        data2 = data2[non_zero_indices2]
        total_counts1 = total_counts1[non_zero_indices1].float()
        total_counts2 = total_counts2[non_zero_indices2].float()
        
    
        if data1.size(0) == 0 or data2.size(0) == 0:
            return None
    
        chi2_stat, p_value, alpha_combined, alpha1, alpha2 = self.__LRT_test(data1, data2)
        if not np.isnan(chi2_stat):
            return {
                "chi2_stat": chi2_stat,
                "p_value": p_value,
                "alpha_combined": alpha_combined.tolist(),
                "alpha1": alpha1.tolist(),
                "alpha2": alpha2.tolist()
            }
        return None
    
    def __filter_genes(self, group_1_label, group_2_label, cell_group_column):
        group_1 = self[self.obs[cell_group_column] == group_1_label]
        group_2 = self[self.obs[cell_group_column] == group_2_label]
    
        gene_counts_group_1 = np.array(group_1.X.sum(axis=0)).flatten()
        gene_counts_group_2 = np.array(group_2.X.sum(axis=0)).flatten()
        total_gene_counts = gene_counts_group_1 + gene_counts_group_2
    
        valid_genes = total_gene_counts > 9
        adata = self[:, valid_genes]
    
        group_1 = adata[adata.obs[cell_group_column] == group_1_label]
        group_2 = adata[adata.obs[cell_group_column] == group_2_label]
    
        non_zero_genes_group_1 = np.array((group_1.X != 0).sum(axis=0)).flatten() > 9
        non_zero_genes_group_2 = np.array((group_2.X != 0).sum(axis=0)).flatten() > 9
        valid_genes = non_zero_genes_group_1 & non_zero_genes_group_2
        filtered_with_valid_genes = adata[:, valid_genes]
        # print(filtered_with_valid_genes.shape)
        for gene_id in np.concatenate((group_1.var['geneId'], group_2.var['geneId'])):
            percent_1 = group_1[:, group_1.var['geneId'] == gene_id]
            percent_2 = group_2[:, group_2.var['geneId'] == gene_id]
            percent_1_X = percent_1.X.astype(float)
            percent_2_X = percent_2.X.astype(float)
            percent_1_X = percent_1_X / float(sum(percent_1_X.sum(axis=0)))
            percent_2_X = percent_2_X / float(sum(percent_2_X.sum(axis=0)))
            percent_1.layers['percent'] = percent_1_X
            percent_2.layers['percent'] = percent_2_X
            diff = False
            for transcript_id in np.concatenate((group_1[:, group_1.var['geneId'] == gene_id].var['transcriptId'], group_2[:, group_2.var['geneId'] == gene_id].var['transcriptId']), axis=0):
                
                per1 = percent_1[:, percent_1.var['transcriptId'] == transcript_id].layers['percent'].sum(axis=0)
                per2 = percent_2[:, percent_2.var['transcriptId'] == transcript_id].layers['percent'].sum(axis=0)
                diff = abs(per1 - per2) > 0.1
                if diff:
                    break
            if not diff:
                filtered_with_valid_genes = filtered_with_valid_genes[:, filtered_with_valid_genes.var['geneId'] != gene_id]
            
        print(filtered_with_valid_genes.var[['geneId', 'transcriptId']])
        return AnnDataIso(filtered_with_valid_genes, self.obs['cell_type'])
    
    def main(self, cell_group_column):
        cell_types = self.obs[cell_group_column].unique()
    
    
        def process_gene(obj, group_1_label, group_2_label, cell_group_column, gene_id):
            result = obj.__compare_groups(group_1_label, group_2_label, cell_group_column, gene_id)
            if result:
                return {
                    'gene_id': gene_id,
                    'group_1': group_1_label,
                    'group_2': group_2_label,
                    'p_value': result['p_value'],
                    'chi2_stat': result['chi2_stat'],
                    'alpha_combined': result['alpha_combined'],
                    'alpha1': result['alpha1'],
                    'alpha2': result['alpha2'],
                    'transcript_ids': obj.var_names[obj.var['geneId'] == gene_id].tolist()
                }
            return None
        
        results = []
        gene_ids = self.var['geneId'].unique()
        print(f"Processing {len(gene_ids)} genes for {len(cell_types)} cell types...")
        
        def add_res(adata, group_1_label, group_2_label, cell_group_column, gene_id, pairwise_results):
            pairwise_results.append(process_gene(adata, group_1_label, group_2_label, cell_group_column, gene_id))
        
        
        for group_1_label, group_2_label in combinations(cell_types, 2):
            print(f"Comparing {group_1_label} vs {group_2_label}...")
            filtered_adata = self.__filter_genes(group_1_label, group_2_label, cell_group_column)
            pairwise_results = []
            
            threads = []
            for gene_id in gene_ids:
                    threads.append(Thread(target=add_res, daemon=True, args=(filtered_adata, group_1_label, group_2_label, cell_group_column, gene_id, pairwise_results)))
            for t in threads:
                t.start()
            for t in threads:
                t.join()
            
            pairwise_results = [res for res in pairwise_results if res is not None]
            results.extend(pairwise_results)
    
        results_df = pd.DataFrame(results)
        return results_df

In [None]:
adata_m = AnnDataIso(adata_iso, cell_type['leiden'].tolist())
results = adata_m.main("cell_type")

Processing 10587 genes for 8 cell types...
Comparing Cycling Radial glia vs Radial glia...
         geneId           transcriptId
1018      Atp5j  ENSMUST00000023608.13
1284   Pafah1b3  ENSMUST00000005583.11
2888     Atp5g1  ENSMUST00000090541.11
3240      Crmp1  ENSMUST00000031004.10
6187      Ubxn1   ENSMUST00000096255.5
7409        Set   ENSMUST00000067996.6
7421    Ngfrap1  ENSMUST00000053540.10
7639        Fau   ENSMUST00000178310.7
8920        Evl  ENSMUST00000077735.12
11932   Smarcb1   ENSMUST00000000925.9
13385    Pabpc1  ENSMUST00000001809.14
13868   Trappc4   ENSMUST00000034623.7
13957       Pkm   ENSMUST00000163694.3
15782       Pkm  ENSMUST00000034834.15
15928    Zfp428  ENSMUST00000071361.12
18453   Anapc11   ENSMUST00000026128.9
18784     Rsrp1   ENSMUST00000078084.6
20449     Ppdpf  ENSMUST00000016488.12
20701    Rbfox2   ENSMUST00000166610.7
Comparing Cycling Radial glia vs intermediate progenitor...
Empty DataFrame
Columns: [geneId, transcriptId]
Index: []
Comparing C

Exception in thread Thread-55243:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 327, in __model
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/subsample_messenger.py", line 88, in __init__
    self.size, self.subsample_size, self._indices = self._subsample(
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/subsample_messenger.py", line 126, in _subsample
    apply_stack(msg)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/runtime.py", line 221, in apply_stack
    frame._postprocess_message(msg)
  File "/home/diamant/.conda/envs/iso_sw

Comparing Cycling Radial glia vs Cajal Retzius...
        geneId           transcriptId
13        Rbm3   ENSMUST00000115616.7
1018     Atp5j  ENSMUST00000023608.13
1790    Dnaja1   ENSMUST00000030118.9
3240     Crmp1  ENSMUST00000031004.10
7321     U2af1   ENSMUST00000014684.4
7383    Rpl35a   ENSMUST00000115075.1
7421   Ngfrap1  ENSMUST00000053540.10
10788   Rpl35a  ENSMUST00000078804.11
11179    Srsf3   ENSMUST00000130216.1
13385   Pabpc1  ENSMUST00000001809.14
14420     Meg3   ENSMUST00000143836.7
15928   Zfp428  ENSMUST00000071361.12
16928     Meg3   ENSMUST00000129245.7
17804    Rpl27  ENSMUST00000077856.12
18938    Rpl27   ENSMUST00000107249.7
20449    Ppdpf  ENSMUST00000016488.12
20701   Rbfox2   ENSMUST00000166610.7
Comparing Radial glia vs intermediate progenitor...
Empty DataFrame
Columns: [geneId, transcriptId]
Index: []
Comparing Radial glia vs Imature glutamatergic...
         geneId           transcriptId
143       H2afv   ENSMUST00000109737.8
385      Mrps21   ENSMUST000

Exception in thread Thread-86236:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 471, in add_res
  File "/tmp/ipykernel_511566/2694172863.py", line 451, in process_gene
  File "/tmp/ipykernel_511566/2694172863.py", line 393, in __compare_groups
  File "/tmp/ipykernel_511566/2694172863.py", line 359, in __LRT_test
  File "/tmp/ipykernel_511566/2694172863.py", line 340, in __perform_mle
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 97, in __exit__
    return super()

Comparing Radial glia vs Mature glutamatergic...
       geneId           transcriptId
13       Rbm3   ENSMUST00000115616.7
385    Mrps21   ENSMUST00000067298.4
505      Ufc1   ENSMUST00000111302.3
969      Nnat   ENSMUST00000109526.1
1324    Eif4h   ENSMUST00000202622.3
...       ...                    ...
18647   Pnisr  ENSMUST00000029911.11
19047    Nfib   ENSMUST00000135024.1
19282  Mrpl30  ENSMUST00000027256.11
20082   Parp6  ENSMUST00000026267.15
20441  Pantr1   ENSMUST00000181725.7

[72 rows x 2 columns]


Exception in thread Thread-96306:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 471, in add_res
  File "/tmp/ipykernel_511566/2694172863.py", line 451, in process_gene
  File "/tmp/ipykernel_511566/2694172863.py", line 393, in __compare_groups
  File "/tmp/ipykernel_511566/2694172863.py", line 359, in __LRT_test
  File "/tmp/ipykernel_511566/2694172863.py", line 340, in __perform_mle
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 97, in __exit__
    return super()

Comparing Radial glia vs mature GABAergic...
         geneId           transcriptId
143       H2afv   ENSMUST00000109737.8
1284   Pafah1b3  ENSMUST00000005583.11
1324      Eif4h   ENSMUST00000202622.3
1394      Dtymk   ENSMUST00000112890.2
1858      Mtch1  ENSMUST00000095427.10
...         ...                    ...
19245    Ube2v2   ENSMUST00000115777.9
19446     Apex1  ENSMUST00000049411.11
19975      Acp1  ENSMUST00000062740.14
20441    Pantr1   ENSMUST00000181725.7
20449     Ppdpf  ENSMUST00000016488.12

[74 rows x 2 columns]


Exception in thread Thread-107666:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 471, in add_res
  File "/tmp/ipykernel_511566/2694172863.py", line 451, in process_gene
  File "/tmp/ipykernel_511566/2694172863.py", line 393, in __compare_groups
  File "/tmp/ipykernel_511566/2694172863.py", line 359, in __LRT_test
  File "/tmp/ipykernel_511566/2694172863.py", line 340, in __perform_mle
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 97, in __exit__
    return super(

  File "/tmp/ipykernel_511566/2694172863.py", line 359, in __LRT_test
  File "/tmp/ipykernel_511566/2694172863.py", line 451, in process_gene
  File "/tmp/ipykernel_511566/2694172863.py", line 340, in __perform_mle
  File "/tmp/ipykernel_511566/2694172863.py", line 393, in __compare_groups
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/elbo.py", line 236, in _get_traces
  File "/tmp/ipykernel_511566/2694172863.py", line 359, in __LRT_test
    yield self._get_trace(model, guide, args, kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/t

Comparing Radial glia vs Imature GABAergic...
          geneId           transcriptId
114       Hnrnpk   ENSMUST00000176207.7
385       Mrps21   ENSMUST00000067298.4
969         Nnat   ENSMUST00000109526.1
1073   Hnrnpa2b1   ENSMUST00000114459.7
1324       Eif4h   ENSMUST00000202622.3
...          ...                    ...
19820     Nap1l4   ENSMUST00000072727.6
19975       Acp1  ENSMUST00000062740.14
20157        Bax   ENSMUST00000033093.9
20396       Nxf1   ENSMUST00000010248.3
20441     Pantr1   ENSMUST00000181725.7

[114 rows x 2 columns]


Exception in thread Thread-116667:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 471, in add_res
  File "/tmp/ipykernel_511566/2694172863.py", line 451, in process_gene
  File "/tmp/ipykernel_511566/2694172863.py", line 393, in __compare_groups
  File "/tmp/ipykernel_511566/2694172863.py", line 359, in __LRT_test
  File "/tmp/ipykernel_511566/2694172863.py", line 340, in __perform_mle
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 97, in __exit__
    return super(

Exception in thread Thread-122331:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 327, in __model
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/subsample_messenger.py", line 88, in __init__
    self.size, self.subsample_size, self._indices = self._subsample(
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/subsample_messenger.py", line 126, in _subsample
    apply_stack(msg)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/runtime.py", line 221, in apply_stack
    frame._postprocess_message(msg)
  File "/home/diamant/.conda/envs/iso_s

Comparing Radial glia vs Cajal Retzius...
         geneId           transcriptId
13         Rbm3   ENSMUST00000115616.7
267      Cacnb3   ENSMUST00000230490.1
385      Mrps21   ENSMUST00000067298.4
1284   Pafah1b3  ENSMUST00000005583.11
1548      Cdc42   ENSMUST00000030417.9
1790     Dnaja1   ENSMUST00000030118.9
1858      Mtch1  ENSMUST00000095427.10
3154     Polr2i  ENSMUST00000019882.15
3582        Ogt   ENSMUST00000155792.1
3705     Dnaja1   ENSMUST00000164233.7
3873       Lsm6   ENSMUST00000051867.6
4539       Ly6h   ENSMUST00000127095.7
4614       Fis1   ENSMUST00000019198.6
5143      Actr3   ENSMUST00000178474.7
5563      Mtch1   ENSMUST00000118366.7
5590      Psma3   ENSMUST00000160027.7
5786      Ube2b  ENSMUST00000020657.12
6184   Pafah1b3   ENSMUST00000148150.7
6473      Psmd4   ENSMUST00000107237.7
6723     Ppp1cc  ENSMUST00000102528.10
7209     Serbp1   ENSMUST00000203077.2
7321      U2af1   ENSMUST00000014684.4
7639        Fau   ENSMUST00000178310.7
8546      Rufy3  ENSMU

Exception in thread Thread-201363:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 471, in add_res
  File "/tmp/ipykernel_511566/2694172863.py", line 451, in process_gene
  File "/tmp/ipykernel_511566/2694172863.py", line 393, in __compare_groups
  File "/tmp/ipykernel_511566/2694172863.py", line 359, in __LRT_test
  File "/tmp/ipykernel_511566/2694172863.py", line 340, in __perform_mle
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 97, in __exit__
    return super(

    model_trace, guide_trace = get_importance_trace(
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/infer/enum.py", line 75, in get_importance_trace
    model_trace.compute_log_prob()
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_struct.py", line 236, in compute_log_prob
    raise ValueError(
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_struct.py", line 230, in compute_log_prob
    log_p = site["fn"].log_prob(
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/distributions/conjugate.py", line 208, in log_prob
    self._validate_sample(value)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/torch/distributions/distribution.py", line 281, in _validate_sample
    raise ValueError('The right-most size of value must match event_shape: {} vs {}.'.
ValueError: Error while computing log_prob at site 'obs':
The right-most size of value 

Comparing Imature glutamatergic vs Imature GABAergic...
        geneId           transcriptId
114     Hnrnpk   ENSMUST00000176207.7
969       Nnat   ENSMUST00000109526.1
1198      Lsm7   ENSMUST00000035775.8
1517    Zfp428   ENSMUST00000177205.1
1548     Cdc42   ENSMUST00000030417.9
1600      Nasp   ENSMUST00000154811.7
1685    Mrps12   ENSMUST00000056078.8
1819      Nasp  ENSMUST00000030457.11
1858     Mtch1  ENSMUST00000095427.10
2527     Rps28   ENSMUST00000173844.7
2663      Myl6   ENSMUST00000218127.1
3139      Cfl2   ENSMUST00000078124.7
3240     Crmp1  ENSMUST00000031004.10
3723     Mat2a   ENSMUST00000059472.9
3911      Rer1   ENSMUST00000030914.3
4324      Rtn4   ENSMUST00000102843.9
4567    Ube2d3   ENSMUST00000197859.4
5563     Mtch1   ENSMUST00000118366.7
6167      Tecr  ENSMUST00000019382.16
6770    Mrpl51   ENSMUST00000032485.6
7639       Fau   ENSMUST00000178310.7
7912     Pdcd6   ENSMUST00000022060.6
8039     Tra2b   ENSMUST00000161286.7
8353    Hnrnpc   ENSMUST00000227

Exception in thread Thread-211950:
Traceback (most recent call last):
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/tmp/ipykernel_511566/2694172863.py", line 327, in __model
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/subsample_messenger.py", line 88, in __init__
    self.size, self.subsample_size, self._indices = self._subsample(
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/subsample_messenger.py", line 126, in _subsample
    apply_stack(msg)
  File "/home/diamant/.conda/envs/iso_swt/lib/python3.9/site-packages/pyro/poutine/runtime.py", line 221, in apply_stack
    frame._postprocess_message(msg)
  File "/home/diamant/.conda/envs/iso_s

In [None]:
results