In [None]:
from swissisoform.genome import GenomeHandler
from swissisoform.visualize import GenomeVisualizer
import matplotlib.pyplot as plt

# Initialize the genome handler
genome = GenomeHandler(
   '../data/genome_data/hg38.fa',
   '../data/genome_data/hg38.ncbiRefSeq.gtf'
)

# Initialize the visualizer
visualizer = GenomeVisualizer(genome)

# Get NAXE features
naxe_features = genome.find_gene_features('NAXE')
print("NAXE features:")
print(naxe_features)

# Get gene statistics
stats = genome.get_gene_stats('NAXE')
print("\nGene statistics:")
print(stats)

# Get available transcript IDs for NAXE
transcript_info = genome.get_transcript_ids('NAXE')
print("Available transcripts:")
print(transcript_info)

# Select a transcript ID (for example, the last one)
transcript_id = transcript_info.iloc[-1]['transcript_id']

NAXE features:
                  chromosome                 source feature_type      start  \
241841   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28   transcript     150333   
241842   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28         exon     150333   
241843   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28         5UTR     150333   
241844   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28          CDS     150362   
241845   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28         exon     150658   
241846   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28          CDS     150658   
241847   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28         exon     150922   
241848   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28          CDS     150922   
241849   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28         exon     151114   
241850   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28          CDS     151114   
241851   chr1_MU273335v1_fix  ncbiRefSeq.2022-10-28         exon     151965   
241852   chr1_MU273335v1_fix  ncbiRef

In [71]:
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D

class GenomeVisualizer:
    def __init__(self, genome_handler):
        self.genome = genome_handler
        self.feature_colors = {
            'exon': '#4CAF50',      # green
            'CDS': '#2196F3',       # blue
            'UTR': '#FFA500',       # orange
            'start_codon': '#FF0000', # red
            'stop_codon': '#800080',  # purple
            'truncation': '#FF1493',  # deep pink
            'alternative_start': '#FFD700'  # yellow
        }
        self.track_height = 0.15
        self.codon_height = 0.2

    def visualize_transcript(self, gene_name, transcript_id, alt_features=None, output_file=None):
        """
        Visualize transcript with optional alternative start sites.
        
        Args:
            gene_name (str): Name of the gene
            transcript_id (str): Transcript ID to visualize
            alt_features (pd.DataFrame, optional): Alternative start features
            output_file (str, optional): Path to save the visualization
        """
        transcript_data = self.genome.get_transcript_features_with_sequence(transcript_id)
        features = transcript_data['features']
        
        if features.empty:
            raise ValueError(f"No features found for transcript {transcript_id}")

        # Get transcript bounds
        transcript_info = features[features['feature_type'] == 'transcript'].iloc[0]
        transcript_start = transcript_info['start']
        transcript_end = transcript_info['end']
        span = transcript_end - transcript_start
        
        # Create figure
        fig, ax = plt.subplots(figsize=(15, 3))
        
        # Draw base transcript line (thicker)
        ax.hlines(y=0.4, xmin=transcript_start, xmax=transcript_end,
                 color='black', linewidth=1.5)

        # Plot regular transcript features
        for _, feature in features.iterrows():
            width = feature['end'] - feature['start']
            
            if feature['feature_type'] == 'exon':
                rect = Rectangle((feature['start'], 0.325),
                               width,
                               self.track_height,
                               facecolor=self.feature_colors['exon'],
                               alpha=0.3)
                ax.add_patch(rect)
            
            elif feature['feature_type'] == 'CDS':
                rect = Rectangle((feature['start'], 0.325),
                               width,
                               self.track_height,
                               facecolor=self.feature_colors['CDS'])
                ax.add_patch(rect)
            
            elif feature['feature_type'] in ['5UTR', '3UTR']:
                rect = Rectangle((feature['start'], 0.6),
                               width,
                               self.track_height,
                               facecolor=self.feature_colors['UTR'])
                ax.add_patch(rect)
            
            elif feature['feature_type'] in ['start_codon', 'stop_codon']:
                rect = Rectangle((feature['start'], 0.3),
                               width,
                               self.codon_height,
                               facecolor=self.feature_colors[feature['feature_type']])
                ax.add_patch(rect)

        # Plot alternative start sites if provided
        if alt_features is not None and not alt_features.empty:
            for _, alt_feature in alt_features.iterrows():
                alt_start = alt_feature['start']
                
                # Draw truncation bracket
                alt_end = alt_feature['end']
                bracket_height = 0.05  # Height of the vertical parts of bracket
                bracket_y = 0.2  # Base y-position of the bracket
                
                # Draw horizontal line of bracket
                ax.hlines(y=bracket_y, 
                         xmin=alt_start-0.5,
                         xmax=alt_end+0.5,
                         color=self.feature_colors['truncation'],
                         linewidth=1.5,
                         label='Truncation')
                
                # Draw vertical parts of bracket
                ax.vlines(x=[alt_start, alt_end-0.5],
                         ymin=bracket_y,
                         ymax=bracket_y + bracket_height,
                         color=self.feature_colors['truncation'],
                         linewidth=0.5)
                
                # Draw alternative start marker (vertical yellow line)
                # Calculate vertical position for alternative start marker
                black_bar_y = 0.4  # y-position of the black bar
                alt_start_height = self.codon_height  # match codon height
                alt_start_ymin = black_bar_y - (alt_start_height / 2)  # center around black bar
                alt_start_ymax = black_bar_y + (alt_start_height / 2)
                
                ax.vlines(x=alt_end,
                         ymin=alt_start_ymin,
                         ymax=alt_start_ymax,
                         color=self.feature_colors['alternative_start'],
                         linewidth=2,
                         label='Alternative start')
                
                # Add start codon label if available (now above the line)
                if 'start_codon' in alt_feature:
                    plt.text(alt_end - span*0.01, alt_start_ymax + 0.05, 
                            alt_feature['start_codon'],
                            fontsize=8,
                            rotation=45)

        # Center title
        plt.title(f"{gene_name} - {transcript_id}", pad=20, y=1.05)

        # Customize plot
        ax.set_ylim(0, 1)
        ax.set_xlim(transcript_start - 10, transcript_end + 10)
        
        # Calculate dynamic tick interval based on span
        tick_span = transcript_end - transcript_start
        if tick_span > 100000:
            tick_interval = 10000
        elif tick_span > 50000:
            tick_interval = 5000
        elif tick_span > 10000:
            tick_interval = 1000
        else:
            tick_interval = 500

        # Calculate tick positions
        base_position = transcript_start - (transcript_start % tick_interval)
        tick_positions = range(base_position, transcript_end + tick_interval, tick_interval)
        
        plt.xticks(tick_positions, [f"{pos:,}" for pos in tick_positions])
        plt.xticks(rotation=45)
        
        # Remove y-axis
        ax.set_yticks([])

        # Add legend without frame
        legend_elements = [
            plt.Rectangle((0,0), 1, 1, facecolor=self.feature_colors['exon'], 
                         alpha=0.3, label='Exon'),
            plt.Rectangle((0,0), 1, 1, facecolor=self.feature_colors['CDS'], 
                         label='CDS'),
            plt.Rectangle((0,0), 1, 1, facecolor=self.feature_colors['UTR'], 
                         label='UTR'),
            plt.Rectangle((0,0), 1, 1, facecolor=self.feature_colors['start_codon'], 
                         label='Start codon'),
            plt.Rectangle((0,0), 1, 1, facecolor=self.feature_colors['stop_codon'], 
                         label='Stop codon')
        ]
        
        # Add alternative start elements to legend if relevant
        if alt_features is not None and not alt_features.empty:
            legend_elements.extend([
                Line2D([0], [0], color=self.feature_colors['truncation'], 
                      label='Truncation', linewidth=1.5),
                Line2D([0], [0], color=self.feature_colors['alternative_start'], 
                      label='Alternative start', linewidth=2)
            ])

        ax.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), 
                 loc='upper left', borderaxespad=0., frameon=False)

        plt.tight_layout()
        
        # Remove spines        
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)

        if output_file:
            plt.savefig(output_file, bbox_inches='tight', dpi=300,
                       facecolor='white', edgecolor='none')
            plt.close()
        else:
            plt.show()

In [72]:
import pandas as pd
from typing import Optional, Dict, List

class AlternativeIsoform:
    """
    A class to handle alternative isoform data from BED format files.
    The name field in the BED file is expected to contain gene information in the format:
    ENSG00000260916.6_CCPG1_AUG_TruncationToAnno
    """
    def __init__(self):
        """Initialize an empty isoform handler"""
        self.isoforms = pd.DataFrame()
        
    def load_bed(self, file_path: str) -> None:
        """
        Load data from BED format file.
        
        Args:
            file_path (str): Path to the BED file
        """
        # Read BED format with standard columns
        self.isoforms = pd.read_csv(file_path, sep='\t',
                                  names=['chrom', 'start', 'end', 'name', 'score', 'strand'])
        
        # Parse the name field which contains gene information
        def parse_name(name: str) -> Dict[str, str]:
            parts = name.split('_')
            return {
                'gene_id': parts[0],
                'gene_name': parts[1],
                'start_codon': parts[2],
                'isoform_type': parts[3]
            }
        
        # Parse name field and add as new columns
        info_df = pd.DataFrame(self.isoforms['name'].apply(parse_name).tolist())
        self.isoforms = pd.concat([self.isoforms, info_df], axis=1)
        
        # Add columns needed for visualization
        self.isoforms['feature_type'] = 'alternative_start'
        self.isoforms['source'] = 'truncation'

    def get_visualization_features(self, gene_name: str) -> pd.DataFrame:
        """
        Get features formatted for the GenomeVisualizer.
        
        Args:
            gene_name (str): Name of the gene to get features for
            
        Returns:
            pd.DataFrame: Features formatted for visualization
        """
        if self.isoforms.empty:
            raise ValueError("No data loaded. Please load data first.")
            
        features = self.isoforms[self.isoforms['gene_name'] == gene_name].copy()
        
        if features.empty:
            return pd.DataFrame()
            
        # Format features for visualization to match GenomeVisualizer expected format
        viz_features = pd.DataFrame({
            'chromosome': features['chrom'],
            'source': 'truncation',
            'feature_type': 'alternative_start',
            'start': features['start'],
            'end': features['end'],
            'score': features['score'],
            'strand': features['strand'],
            'frame': '.',
            'gene_id': features['gene_id'],
            'transcript_id': features['gene_id'] + '_alt',
            'gene_name': features['gene_name'],
            'start_codon': features['start_codon']
        })
        
        return viz_features

    def get_gene_list(self) -> List[str]:
        """
        Get list of all genes in the dataset.
        
        Returns:
            List[str]: List of gene names
        """
        if self.isoforms.empty:
            raise ValueError("No data loaded. Please load data first.")
            
        return sorted(self.isoforms['gene_name'].unique().tolist())

    def get_stats(self) -> Dict:
        """
        Get basic statistics about the loaded isoforms.
        
        Returns:
            Dict: Dictionary containing various statistics
        """
        if self.isoforms.empty:
            raise ValueError("No data loaded. Please load data first.")
            
        stats = {
            'total_sites': len(self.isoforms),
            'unique_genes': len(self.get_gene_list()),
            'chromosomes': sorted(self.isoforms['chrom'].unique().tolist()),
            'start_codons': sorted(self.isoforms['start_codon'].unique().tolist())
        }
        
        return stats

    def get_alternative_starts(self, gene_name: Optional[str] = None) -> pd.DataFrame:
        """
        Get alternative start sites, optionally filtered by gene name.
        
        Args:
            gene_name (str, optional): Gene name to filter by
            
        Returns:
            pd.DataFrame: Alternative start sites with positions and metadata
        """
        if self.isoforms.empty:
            raise ValueError("No data loaded. Please load data first.")
            
        if gene_name is not None:
            return self.isoforms[self.isoforms['gene_name'] == gene_name]
        return self.isoforms.copy()

In [74]:
visualizer = GenomeVisualizer(genome)

In [75]:
# Initialize handlers
alt_isoforms = AlternativeIsoform()

# Load BED file
alt_isoforms.load_bed('../data/ribosome_profiling/RiboTISHV6_Ly2024_AnnoToTruncation_exonintersect.bed')

In [76]:
alt_features = alt_isoforms.get_visualization_features('NAXE')

In [77]:
visualizer.visualize_transcript('NAXE', transcript_id, alt_features=None, output_file=f'naxe_transcript_{transcript_id}.png')

In [78]:
visualizer.visualize_transcript('NAXE', transcript_id, alt_features=alt_features, output_file=f'naxe_transcript_alt_{transcript_id}.png')