<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

In [None]:
class GeneReports(TranscriptPlots, AnnDataIso):
    def __init__(self, anndata: ad.AnnData, cell_types: pd.DataFrame, palette='ghibli', gtf_file=None, reference_fasta=None):
        AnnDataIso.__init__(self,anndata, cell_types, palette)
        TranscriptPlots.__init__(self,gtf_file, reference_fasta)

    # Plotting a summary of isoform-related metrics.
    def plot_isoforms_summary(self):
        """
        Generates a combined plot summarizing:
        1. Percentage of genes with multiple isoforms.
        2. Frequency distribution of isoforms per gene.
        3. Number of genes expressed per cell type.

        Returns:
        - Combined plot saved as an image.
        """
        ax1 = pw.Brick(figsize=(4, 4))
        self._plot_switch_gen_bar(ax1)
        ax1.set_title("Multiple isoforms genes %")
        ax2 = pw.Brick(figsize=(4, 4))
        self._plot_isoforms_frequencies(ax2)
        ax2.set_title("Frequency of isoforms per gene")
        ax3 = pw.Brick(figsize=(3, 2))
        self._plot_genes_cell_type(ax3)
        ax3.set_title("Nb of genes per cell type")
        return (ax1 | ax2 | ax3).savefig()

    # Private method to create a boxplot of genes expressed per cell type.
    def _plot_genes_cell_type(self, _ax):
        """
        Plots the number of genes expressed per cell type as a boxplot.

        Parameters:
        - _ax (Axes): Matplotlib axes to render the plot.
        """
        if _ax is None:
            df = pd.DataFrame(np.transpose(self.X), columns=self.obs['cell_type'])
            df = df.sum(axis=0).to_frame().reset_index()
            df.columns = ['cell_type', 'n_of_genes']
            ax = sns.boxplot(x='cell_type', y='n_of_genes', data=df)
            ax1 = sns.stripplot(x='cell_type', y='n_of_genes', data=df, color='black', size=3)
            ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
            plt.show()
        else:
            df = pd.DataFrame(np.transpose(self.X), columns=self.obs['cell_type'])
            df = df.sum(axis=0).to_frame().reset_index()
            df.columns = ['cell_type', 'n_of_genes']
            sns.boxplot(x='cell_type', y='n_of_genes', data=df, ax=_ax)
            sns.stripplot(x='cell_type', y='n_of_genes', data=df, color='black', size=3, ax=_ax)
            _ax.set_xticklabels(_ax.get_xticklabels(), rotation=90)

    def plot_genes_cell_type(self):
        """
        Public method to plot the number of genes expressed per cell type.
        """
        self._plot_genes_cell_type(None)

    # Method to visualize isoform frequency distribution.
    def _plot_isoforms_frequencies(self, _ax):
        """
        Plots the frequency distribution of isoforms per gene.

        Parameters:
        - _ax (Axes): Matplotlib axes to render the plot.
        """
        if _ax is None:
            fig, ax = plt.subplots()
            self.gene_counts['transcriptId'].value_counts().plot(
                ax=ax, kind='bar', xlabel='number of isoforms per gene', ylabel='quantity of genes'
            )
        else:
            self.gene_counts['transcriptId'].value_counts().plot(
                ax=_ax, kind='bar', xlabel='number of isoforms per gene', ylabel='quantity of genes'
            )

    def plot_isoforms_frequencies(self):
        """
        Public method to plot the frequency distribution of isoforms per gene.
        """
        self._plot_isoforms_frequencies(None)

    def plot_switch_gen_bar (self, _ax):
        self._plot_switch_gen_bar(None)

    def _plot_switch_gen_bar (self, _ax):
        if _ax is None:
            iso_per_gene = self.gene_counts
            x = ['genes']
            x1 = ['transcripts']
            multiple_iso = sum(iso_per_gene['transcriptId'] > 1)
            mono_iso = sum((iso_per_gene['transcriptId'] > 1) == False)
            labels = [str(round(1000*multiple_iso/(multiple_iso+mono_iso))/10) + '%', str(round(1000*mono_iso/(multiple_iso+mono_iso))/10) + '%']
            fig, ax = plt.subplots()
            mult = ax.bar(x, multiple_iso, label=labels[1], color=self.colors[0]) #multiple_iso/(multiple_iso+mono_iso))
            mono = ax.bar(x, mono_iso, bottom=multiple_iso, label=labels[0], color=self.colors[0]) #mona_iso/(multiple_iso+mono_iso))
            tran = ax.bar(x1, len(self.var['transcriptId']), color=self.colors[2])
            ax.text(
                ax.patches[0].get_x() + ax.patches[0].get_width() / 2, ax.patches[0].get_height() / 2, labels[0], ha="center", va="center"
                )
            ax.text(
                ax.patches[1].get_x() + ax.patches[1].get_width() / 2, ax.patches[1].get_height() / 2 + ax.patches[0].get_height(), labels[1], ha="center", va="center"
                )
            plt.legend(['Multiple isoforms', 'Single isoform'])
            plt.show()
        else:
            iso_per_gene = self.gene_counts
            x = ['genes']
            x1 = ['transcripts']
            multiple_iso = sum(iso_per_gene['transcriptId'] > 1)
            mono_iso = sum((iso_per_gene['transcriptId'] > 1) == False)
            labels = [str(round(1000*multiple_iso/(multiple_iso+mono_iso))/10) + '%', str(round(1000*mono_iso/(multiple_iso+mono_iso))/10) + '%']
            mult = _ax.bar(x, multiple_iso, label=labels[1], color=self.colors[0]) #multiple_iso/(multiple_iso+mono_iso))
            mono = _ax.bar(x, mono_iso, bottom=multiple_iso, label=labels[0], color=self.colors[1]) #mona_iso/(multiple_iso+mono_iso))
            tran = _ax.bar(x1, len(self.var['transcriptId']), color=self.colors[2])
            _ax.text(
                _ax.patches[0].get_x() + _ax.patches[0].get_width() / 2, _ax.patches[0].get_height() / 2, labels[0], ha="center", va="center"
                )
            _ax.text(
                _ax.patches[1].get_x() + _ax.patches[1].get_width() / 2, _ax.patches[1].get_height() / 2 + _ax.patches[0].get_height(), labels[1], ha="center", va="center"
                )
            _ax.legend(['Multiple isoforms', 'Single isoform'])
            plt.show()

    def _trsct_counts_cell_type (self, gene_name, trs_to_show, _ax):
        if _ax is None:
        # create df with filtered isoform counts and labeled cell types:
            df = self._filtered_anndata.to_df().set_index(self._filtered_anndata.obs['cell_type'])
            df = df.transpose()
            df[['transcriptId', 'geneId']] = self._filtered_anndata.var[['transcriptId', 'geneId']]
            gene_iso_count = df[df['geneId'] == gene_name]
            gene_iso_count = gene_iso_count.drop('geneId', axis=1).set_index('transcriptId').transpose()
            gene_iso_count_long = gene_iso_count.reset_index().melt(id_vars='cell_type', var_name='transcriptId', value_name='count')
            if trs_to_show != []:
                gene_iso_count_long = gene_iso_count_long[gene_iso_count_long['transcriptId'].isin(trs_to_show)]
            g = sns.catplot(x="cell_type", y="count", col="transcriptId", aspect=1, dodge=False, kind="violin", data=gene_iso_count_long, palette=self.colors)
            # Set custom facet titles
            g.set_titles(col_template="{col_name}", size = 8)
            # Remove x ticks
            g.set_xticklabels(rotation=90)
            g.fig.suptitle(gene_name)
            plt.show()
        else:
            df = self._filtered_anndata.to_df ().set_index(self._filtered_anndata.obs['cell_type'])
            df = df.transpose()
            df[['transcriptId', 'geneId']] = self._filtered_anndata.var[['transcriptId', 'geneId']]
            gene_iso_count = df[df['geneId']== gene_name]
            gene_iso_count = gene_iso_count.drop('geneId', axis=1).set_index('transcriptId').transpose()
            gene_iso_count_long = gene_iso_count.reset_index().melt(id_vars='cell_type', var_name='transcriptId', value_name='count')
            if trs_to_show != []:
                gene_iso_count_long = gene_iso_count_long[gene_iso_count_long['transcriptId'].isin(trs_to_show)]
            g = sns.catplot(x="cell_type", y="count", col="transcriptId", aspect=1, dodge=False, kind="violin", data=gene_iso_count_long, palette=self.colors)
            # Set custom facet titles
            g.set_titles(col_template="{col_name}", size = 8)
            # Remove x ticks
            g.set_xticklabels(rotation=90, labels=self.obs['cell_type'].unique())
            g.fig.suptitle(gene_name)
            return g

    def trsct_counts_cell_type (self, gene_name, trs_to_show = []):
        # create df with filtered isoform counts and labeled cell types:
        self._trsct_counts_cell_type(gene_name, trs_to_show, None)

    def _plot_transcripts_per_cell_type(self, gene_name, trs_to_show, _ax):
        if trs_to_show == []:
            transcripts_id = self._get_transcripts_from_gene(gene_name)
        else:
            transcripts_id = trs_to_show
        if _ax is None:
            fig, ax = plt.subplots()
            grouped = self._filtered_anndata.obsm['Iso_prct']
            grouped['cell_type'] = self.obs['cell_type']
            res = grouped.groupby('cell_type').mean().transpose()
            res = res.assign(transcriptId=self._filtered_anndata.var['transcriptId'].to_list())
            res = res.assign(geneId=self._filtered_anndata.var['geneId'].to_list())
            res = res[res['geneId'] == gene_name].drop(['geneId'], axis=1)
            res = res[res['transcriptId'].isin(transcripts_id)]
            plot_data = res.set_index('transcriptId').transpose()
            plot_data.plot(kind='barh', ax=ax, stacked=True, color=self.colors).legend(loc='center left',bbox_to_anchor=(1.0, 1.0))
            plt.legend(self.get_transcripts_common_names(transcripts_id), loc="upper left", bbox_to_anchor=(1, 1))
            #plt.legend(self.get_transcripts_common_names(trs_to_show))
            plt.ylabel('Cell type')
        else:
            grouped = self._filtered_anndata.obsm['Iso_prct']
            grouped['cell_type'] = self.obs['cell_type']
            res = grouped.groupby('cell_type').mean().transpose()
            res = res.assign(transcriptId=self._filtered_anndata.var['transcriptId'].to_list())
            res = res.assign(geneId=self._filtered_anndata.var['geneId'].to_list())
            res = res[res['geneId'] == gene_name].drop(['geneId'], axis=1)
            res = res[res['transcriptId'].isin(transcripts_id)]
            plot_data = res.set_index('transcriptId').transpose()
            plot_data.plot(kind='barh', stacked=True, ax=_ax, color=self.colors).legend(loc='center left',bbox_to_anchor=(1.0, 1.0))

    def plot_transcripts_per_cell_type(self, gene_name, trs_to_show = []):
        self._plot_transcripts_per_cell_type(gene_name, trs_to_show, None)

    def _get_transcripts_from_gene(self, gene_name):
        elems = self._filtered_anndata.var
        return elems[elems['geneId'] == gene_name]['transcriptId'].to_list()

    def _draw_transcripts_list_from_gene(self, gene_name, trs_to_show, _ax, colors=None):
        if trs_to_show == []:
            transcripts_id = self._get_transcripts_from_gene(gene_name)
        else:
            transcripts_id = trs_to_show
        exons = []
        directions = []
        for tr in transcripts_id:
            t, d = self._get_coord_from_tscrpt_id(tr)
            exons += [t]
            directions += [d]
        if colors is None:
            colors = []
            for i in range(len(exons)):
                colors.append(self.colors[i % len(self.colors)])
        def get_limits(ex, dir):
            start = sys.maxsize
            end = -sys.maxsize
            for (e, d) in zip(ex, dir):
                if d == 1:
                    start = min(start,e[0][1])
                    end = max(end, e[-1][0])
                else:
                    start = min(start,e[-1][1])
                    end = max(end, e[0][0])
            return (start, end)
        plt.axes()
        plt.xlim((-0.1, 1.1))
        plt.ylim((0.1 - 0.5 *  len(exons), 0.3))
        plt.margins(0.2)
        plt.axis('off')
        fig = plt.gcf()
        fig.set_size_inches(20, len(exons) * 2)
        i = 0
        (start, end) = get_limits(exons, directions)
        for (ex, di, co, name) in zip(exons, directions, colors, transcripts_id):
            self._draw_transcript(ex, di, co, name, offset= -0.5 * i, start_override=start, end_override=end, no_render=True)
            i+=1
        if _ax is None:
            plt.show()
        else:
            return plt

    def draw_transcripts_list_from_gene(self, gene_name, colors=None):
        self._draw_transcripts_list_from_gene(gene_name, [], None, colors)

    def draw_gene_summary(self, gene_name, trs_to_show = []):
        ax1 = pw.Brick(figsize=(12,4))
        pw.overwrite_axisgrid()
        fg = self._trsct_counts_cell_type(gene_name, trs_to_show, ax1)
        ax4 = pw.load_seaborngrid(fg)
        ax1.set_title("Transcripts count per cell type")
        ax2 = pw.Brick(figsize=(12,4))
        self._plot_transcripts_per_cell_type(gene_name, trs_to_show, ax2)
        ax2.set_title("Transcripts count per cell type")
        ax3 = pw.Brick(figsize=(12,4))
        self._draw_transcripts_list_from_gene(gene_name, trs_to_show, ax3)
        ax3.set_title("Transcripts list")
        return (ax4/(ax2/ax3)).savefig()

    # More detailed functions omitted for brevity.