In [1]:
import polars as pl
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Patch
from pathlib import Path
import pickle

plt.rcParams['pdf.fonttype'] = 42
sns.set_style("whitegrid")

In [2]:
import polars as pl
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Patch
from pathlib import Path
import pickle

plt.rcParams['pdf.fonttype'] = 42
sns.set_style("whitegrid")

def load_genes_from_comparison(comparison_dir, cell_line='HEK293', 
                                comparison_name='HEK293_vs_GM12878'):
    """
    Load gene lists from Find_Genes.ipynb comparison results
    
    Parameters:
    -----------
    comparison_dir : str
        Directory where comparison CSV files are saved
    cell_line : str
        Which cell line to use ('HEK293' or 'GM12878')
    comparison_name : str
        Prefix of comparison files (e.g., 'HEK293_vs_GM12878')
    
    Returns:
    --------
    dict: Dictionary with gene sets for each modification combination
    """
    comparison_dir = Path(comparison_dir)
    
    print("="*70)
    print(f"Loading genes from {comparison_name} comparison")
    print("="*70)
    
    gene_sets = {}
    
    # Define patterns for different combinations
    patterns = {
        # Single modifications
        'm6A': f"{comparison_name}_m6A_unique_{cell_line}.csv",
        'm5C': f"{comparison_name}_m5C_unique_{cell_line}.csv",
        'psi': f"{comparison_name}_psi_unique_{cell_line}.csv",
        'inosine': f"{comparison_name}_inosine_unique_{cell_line}.csv",
        
        # 2-way combinations
        'm6A_m5C': f"{comparison_name}_combo_2way_m6a_m5c_unique_{cell_line}.csv",
        'm6A_psi': f"{comparison_name}_combo_2way_m6a_psi_unique_{cell_line}.csv",
        'm6A_inosine': f"{comparison_name}_combo_2way_m6a_ino_unique_{cell_line}.csv",
        'm5C_psi': f"{comparison_name}_combo_2way_m5c_psi_unique_{cell_line}.csv",
        'm5C_inosine': f"{comparison_name}_combo_2way_m5c_ino_unique_{cell_line}.csv",
        'psi_inosine': f"{comparison_name}_combo_2way_psi_ino_unique_{cell_line}.csv",
        
        # 3-way combinations
        'm6A_m5C_psi': f"{comparison_name}_combo_3way_m6a_m5c_psi_unique_{cell_line}.csv",
        'm6A_m5C_inosine': f"{comparison_name}_combo_3way_m6a_m5c_ino_unique_{cell_line}.csv",
        'm6A_psi_inosine': f"{comparison_name}_combo_3way_m6a_psi_ino_unique_{cell_line}.csv",
        'm5C_psi_inosine': f"{comparison_name}_combo_3way_m5c_psi_ino_unique_{cell_line}.csv",
        
        # Also check shared genes
        'm6A_shared': f"{comparison_name}_m6A_shared.csv",
        'm5C_shared': f"{comparison_name}_m5C_shared.csv",
        'psi_shared': f"{comparison_name}_psi_shared.csv",
        'inosine_shared': f"{comparison_name}_inosine_shared.csv",
        'm6A_m5C_psi_shared': f"{comparison_name}_combo_3way_m6a_m5c_psi_shared.csv",
        'm6A_psi_inosine_shared': f"{comparison_name}_combo_3way_m6a_psi_ino_shared.csv",
    }
    
    for combo_name, filename in patterns.items():
        filepath = comparison_dir / filename
        
        if filepath.exists():
            df = pl.read_csv(str(filepath))
            genes = set(df['gene_id'].to_list())
            gene_sets[combo_name] = genes
            print(f"✓ Loaded {combo_name}: {len(genes)} genes")
        else:
            gene_sets[combo_name] = set()
    
    return gene_sets

def create_multi_mod_df_from_genes(drs_df, gene_sets, combinations_to_plot=None):
    """
    Create a multi_mod_df DataFrame from pre-identified gene sets
    
    Parameters:
    -----------
    drs_df : polars.DataFrame
        DRS data with all modification sites
    gene_sets : dict
        Dictionary of gene sets from load_genes_from_comparison
    combinations_to_plot : list, optional
        Specific combinations to include (e.g., ['m6A_m5C_psi', 'm6A_psi_inosine'])
        If None, includes all 3-way combinations
    
    Returns:
    --------
    polars.DataFrame: DataFrame with detailed gene information
    """
    print("\n" + "="*70)
    print("Creating detailed gene information from DRS data")
    print("="*70)
    
    if isinstance(drs_df, pd.DataFrame):
        drs_df = pl.from_pandas(drs_df)
    
    # Mod code mapping
    mod_codes = {
        'm6A': ['a'],
        'm5C': ['m'],
        'psi': ['17802', 'psi'],
        'inosine': ['17596', 'inosine', 'I']
    }
    
    # If no specific combinations given, find all 3-way combos
    if combinations_to_plot is None:
        combinations_to_plot = [k for k in gene_sets.keys() 
                               if k.count('_') == 2 and 'shared' not in k]
    
    gene_details_list = []
    
    for combo_name in combinations_to_plot:
        if combo_name not in gene_sets or len(gene_sets[combo_name]) == 0:
            print(f"⚠ Skipping {combo_name}: No genes found")
            continue
        
        print(f"\nProcessing {combo_name}: {len(gene_sets[combo_name])} genes")
        
        # Parse modification types from combo name
        mods = combo_name.replace('_shared', '').split('_')
        
        for gene_id in gene_sets[combo_name]:
            gene_data = drs_df.filter(pl.col('gene_id') == gene_id)
            
            if len(gene_data) == 0:
                continue
            
            gene_name = gene_data['gene_name'][0]
            chrom = gene_data['chrom'][0]
            
            # Get modification details
            mod_counts = {}
            region_details_dict = {}
            
            for mod in mods:
                if mod not in mod_codes:
                    continue
                    
                mod_sites = gene_data.filter(pl.col('mod').is_in(mod_codes[mod]))
                mod_counts[mod] = len(mod_sites)
                
                # Group by region
                for region in mod_sites['feature_type'].unique().to_list():
                    region_key = f"{mod}_{region}"
                    sites_in_region = mod_sites.filter(pl.col('feature_type') == region)
                    
                    locations = sites_in_region.select(
                        (pl.col('chrom').cast(pl.Utf8) + ":" + 
                         pl.col('drs_start').cast(pl.Utf8) + "-" + 
                         pl.col('drs_end').cast(pl.Utf8))
                    ).to_series().to_list()
                    
                    region_details_dict[region_key] = locations
            
            # Get position ranges
            min_pos = gene_data['drs_start'].min()
            max_pos = gene_data['drs_end'].max()
            
            gene_details_list.append({
                'gene_id': gene_id,
                'gene_name': gene_name,
                'chromosome': chrom,
                'start': min_pos,
                'end': max_pos,
                'span_kb': round((max_pos - min_pos) / 1000, 2),
                'combination': combo_name,
                'modifications': ','.join(mods),
                **{f'n_{mod}_sites': mod_counts.get(mod, 0) for mod in mods},
                'total_sites': sum(mod_counts.values()),
                'regions': ','.join(region_details_dict.keys()),
                'region_locations': str(region_details_dict)
            })
    
    if len(gene_details_list) > 0:
        multi_mod_df = pl.DataFrame(gene_details_list).sort('total_sites', descending=True)
        print(f"\n✓ Created DataFrame with {len(multi_mod_df)} genes")
        return multi_mod_df
    else:
        print("\n⚠ No genes found")
        return None

def create_swarm_plots_from_genes(drs_df, multi_mod_df, 
                                  output_dir="./multi_mod_plots/",
                                  by_combination=True, top_n=10):
    """
    Create swarm plots from gene lists (strand-aware)
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    print("\n" + "="*70)
    print("Creating swarm plots (strand-aware)")
    print("="*70)
    
    if isinstance(drs_df, pd.DataFrame):
        drs_df = pl.from_pandas(drs_df)
    
    mod_colors = {
        'm6A': '#4A90E2',
        'm5C': '#E94B3C',
        'psi': '#50C878',
        'inosine': '#9B59B6'
    }
    
    mod_codes = {
        'm6A': ['a'],
        'm5C': ['m'],
        'psi': ['17802', 'psi'],
        'inosine': ['17596', 'inosine', 'I']
    }
    
    if by_combination:
        # Get unique combinations
        combinations = multi_mod_df['combination'].unique().to_list()
        
        for combo_name in combinations:
            print(f"\nPlotting {combo_name}...")
            
            combo_genes = multi_mod_df.filter(
                pl.col('combination') == combo_name
            ).head(top_n)
            
            if len(combo_genes) == 0:
                continue
            
            gene_list = combo_genes['gene_name'].to_list()
            mods = combo_genes['modifications'][0].split(',')
            
            # Create figure
            fig, axes = plt.subplots(len(gene_list), 1, 
                                    figsize=(8.5, 3.5 * len(gene_list)))
            if len(gene_list) == 1:
                axes = [axes]
            
            for ax, gene_name in zip(axes, gene_list):
                gene_data = drs_df.filter(pl.col('gene_name') == gene_name)
                
                if len(gene_data) == 0:
                    continue
                
                strand = gene_data['strand'][0]
                
                # Prepare plot data
                plot_data = []
                for mod_name in mods:
                    mod_sites = gene_data.filter(pl.col('mod').is_in(mod_codes[mod_name]))
                    
                    for row in mod_sites.iter_rows(named=True):
                        plot_data.append({
                            'start': row['drs_start'],
                            'mod_type': mod_name
                        })
                
                if len(plot_data) == 0:
                    continue
                
                plot_df = pd.DataFrame(plot_data)
                
                # Create swarm plot
                sns.swarmplot(
                    data=plot_df,
                    x='start',
                    hue='mod_type',
                    palette=mod_colors,
                    ax=ax,
                    size=5,
                    alpha=0.8,
                    hue_order=mods
                )
                
                # Add strand direction
                strand_symbol = '→' if strand == '+' else '←'
                strand_text = f"5' {strand_symbol} 3'" if strand == '+' else f"3' {strand_symbol} 5'"
                
                ax.set_title(f"{gene_name} ({strand} strand) - {combo_name.replace('_', ' + ')} | {strand_text}", 
                           fontsize=16, fontweight='bold', pad=10)
                ax.set_xlabel("Genomic Position (bp)", fontsize=12)
                ax.set_ylabel("")
                ax.set_yticks([])
                
                if strand == '-':
                    ax.invert_xaxis()
                
                ax.ticklabel_format(style='plain', axis='x')
                ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))
                
                ax.legend(title='Modification', bbox_to_anchor=(1.02, 1), 
                         loc='upper left', frameon=True)
                ax.grid(axis='x', alpha=0.3, linestyle='--')
            
            plt.tight_layout()
            
            output_file = f"{output_dir}/swarm_plot_{combo_name}.pdf"
            fig.savefig(output_file, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {output_file}")
            
            plt.close()

In [3]:
def create_swarm_plots_from_genes(drs_df, multi_mod_df, 
                                  output_dir="./multi_mod_plots/",
                                  by_combination=True, top_n=10):
    """
    Create swarm plots from gene lists (strand-aware)
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    print("\n" + "="*70)
    print("Creating swarm plots (strand-aware)")
    print("="*70)
    
    if isinstance(drs_df, pd.DataFrame):
        drs_df = pl.from_pandas(drs_df)
    
    mod_colors = {
        'm6A': '#0072B2',
        'm5C': '#CC79A7',
        'psi': '#D55E00',
        'inosine': '#009E73'
    }
    
    mod_codes = {
        'm6A': ['a'],
        'm5C': ['m'],
        'psi': ['17802', 'psi'],
        'inosine': ['17596', 'inosine', 'I']
    }
    
    if by_combination:
        # Get unique combinations
        combinations = multi_mod_df['combination'].unique().to_list()
        
        for combo_name in combinations:
            print(f"\nPlotting {combo_name}...")
            
            combo_genes = multi_mod_df.filter(
                pl.col('combination') == combo_name
            ).head(top_n)
            
            if len(combo_genes) == 0:
                continue
            
            gene_list = combo_genes['gene_name'].to_list()
            mods = combo_genes['modifications'][0].split(',')
            
            # Create figure
            fig, axes = plt.subplots(len(gene_list), 1, 
                                    figsize=(8.5, 3.5 * len(gene_list)))
            if len(gene_list) == 1:
                axes = [axes]
            
            for ax, gene_name in zip(axes, gene_list):
                gene_data = drs_df.filter(pl.col('gene_name') == gene_name)
                
                if len(gene_data) == 0:
                    continue
                
                strand = gene_data['strand'][0]
                
                # Prepare plot data
                plot_data = []
                for mod_name in mods:
                    mod_sites = gene_data.filter(pl.col('mod').is_in(mod_codes[mod_name]))
                    
                    for row in mod_sites.iter_rows(named=True):
                        plot_data.append({
                            'start': row['drs_start'],
                            'mod_type': mod_name
                        })
                
                if len(plot_data) == 0:
                    continue
                
                plot_df = pd.DataFrame(plot_data)
                
                # Create swarm plot
                sns.swarmplot(
                    data=plot_df,
                    x='start',
                    hue='mod_type',
                    palette=mod_colors,
                    ax=ax,
                    size=5,
                    alpha=0.8,
                    hue_order=mods
                )
                
                # Add strand direction
                strand_symbol = '→' if strand == '+' else '←'
                strand_text = f"5' {strand_symbol} 3'" if strand == '+' else f"3' {strand_symbol} 5'"
                
                ax.set_title(f"{gene_name} ({strand} strand) - {combo_name.replace('_', ' + ')} | {strand_text}", 
                           fontsize=16, fontweight='bold', pad=10)
                ax.set_xlabel("Genomic Position (bp)", fontsize=12)
                ax.set_ylabel("")
                ax.set_yticks([])
                
                if strand == '-':
                    ax.invert_xaxis()
                
                ax.ticklabel_format(style='plain', axis='x')
                ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))
                
                ax.legend(title='Modification', bbox_to_anchor=(1.02, 1), 
                         loc='upper left', frameon=True)
                ax.grid(axis='x', alpha=0.3, linestyle='--')
            
            plt.tight_layout()
            
            output_file = f"{output_dir}/swarm_plot_{combo_name}_GM12878.pdf"
            fig.savefig(output_file, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved: {output_file}")
            
            plt.close()

In [4]:
def create_utr_maps(drs_df, multi_mod_df, 
                   output_dir="./multi_mod_plots/", top_n=10):
    """
    Create UTR maps showing gene structure and modification positions
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    print("\n" + "="*70)
    print(f"Creating UTR maps for top {top_n} genes")
    print("="*70)
    
    if isinstance(drs_df, pd.DataFrame):
        drs_df = pl.from_pandas(drs_df)
    
    top_genes = multi_mod_df.head(top_n)
    
    mod_colors = {
        'm6A': '#4A90E2',
        'm5C': '#E94B3C',
        'psi': '#50C878',
        'inosine': '#9B59B6'
    }
    
    mod_codes = {
        'm6A': ['a'],
        'm5C': ['m'],
        'psi': ['17802', 'psi'],
        'inosine': ['17596', 'inosine', 'I']
    }
    
    fig, axes = plt.subplots(len(top_genes), 1, 
                            figsize=(20, 2.5 * len(top_genes)))
    
    if len(top_genes) == 1:
        axes = [axes]
    
    for ax, row in zip(axes, top_genes.iter_rows(named=True)):
        gene_id = row['gene_id']
        gene_name = row['gene_name']
        mods_in_gene = row['modifications'].split(',')
        
        gene_features = drs_df.filter(pl.col('gene_id') == gene_id)
        
        if len(gene_features) == 0:
            continue
        
        gene_start = gene_features['feature_start'].min()
        gene_end = gene_features['feature_end'].max()
        gene_span = gene_end - gene_start
        
        # Plot gene body
        ax.add_patch(Rectangle((gene_start, 0.35), gene_span, 0.15, 
                               facecolor='lightgray', edgecolor='black', linewidth=1))
        
        # Plot regions
        regions = gene_features.select(['feature_type', 'feature_start', 'feature_end']).unique()
        
        region_colors = {
            "5' UTR": '#FFE4B5',
            "3' UTR": '#FFE4B5',
            'CDS': '#B0C4DE',
            'exon': '#B0C4DE',
            'intron': '#F0F0F0'
        }
        
        for region_row in regions.iter_rows(named=True):
            region_type = region_row['feature_type']
            r_start = region_row['feature_start']
            r_end = region_row['feature_end']
            
            color = region_colors.get(region_type, 'white')
            ax.add_patch(Rectangle((r_start, 0.37), r_end - r_start, 0.11,
                                  facecolor=color, edgecolor='black', 
                                  linewidth=0.5, alpha=0.7))
            
            mid = (r_start + r_end) / 2
            ax.text(mid, 0.25, region_type, ha='center', va='top', 
                   fontsize=8, style='italic')
        
        # Plot modification sites
        y_positions = {'m6A': 0.6, 'm5C': 0.7, 'psi': 0.8, 'inosine': 0.9}
        
        for mod_name in mods_in_gene:
            mod_sites = gene_features.filter(pl.col('mod').is_in(mod_codes[mod_name]))
            
            if len(mod_sites) > 0:
                positions = mod_sites['drs_start'].to_list()
                y_pos = y_positions[mod_name]
                
                for pos in positions:
                    ax.plot([pos, pos], [0.5, y_pos], 
                           color=mod_colors[mod_name], linewidth=2, alpha=0.8)
                    ax.scatter(pos, y_pos, color=mod_colors[mod_name], 
                             s=100, zorder=5, edgecolor='black', linewidth=0.5)
        
        ax.set_xlim(gene_start - gene_span*0.05, gene_end + gene_span*0.05)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_title(f"{gene_name} - {' + '.join(mods_in_gene)}", 
                    fontsize=14, fontweight='bold')
        ax.set_xlabel("Genomic Position (bp)", fontsize=10)
        
        ax.ticklabel_format(style='plain', axis='x')
        ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))
        
        legend_elements = [Patch(facecolor=mod_colors[m], label=m) for m in mods_in_gene]
        ax.legend(handles=legend_elements, loc='upper right', frameon=True)
        
        print(f"  ✓ Created map for {gene_name}")
    
    plt.tight_layout()
    
    output_file = f"{output_dir}/utr_maps_multi_mod.pdf"
    fig.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\n✓ Saved: {output_file}")
    
    plt.close()

In [6]:
# Load DRS data
print("="*70)
print("LOADING DATA")
print("="*70)

drs_file = '/Volumes/AJS_SSD/HEK293/modkit_output/Annotated_Data/HEK293_08_01_25_09_09_25_10_06_25_GRCh38_sorted_filtered_combined_modkit_ivt_corrected_annotated_annotated_valid_kmer.pkl'
with open(drs_file, 'rb') as f:
    drs_df = pickle.load(f)

if isinstance(drs_df, pd.DataFrame):
    drs_df = pl.from_pandas(drs_df)

print(f"✓ Loaded DRS data: {len(drs_df):,} total sites")

# FILTER TO ONLY VALIDATED SITES
print("\n" + "="*70)
print("FILTERING TO VALIDATED SITES ONLY")
print("="*70)

validated_dir = Path('/Volumes/AJS_SSD/HEK293/orthogonal_validated/GM12878/')
validated_files = list(validated_dir.glob('*.csv'))

validated_site_ids = set()
for file in validated_files:
    val_df = pl.read_csv(str(file))
    if 'site_id' in val_df.columns:
        validated_site_ids.update(val_df['site_id'].to_list())
        print(f"  ✓ Loaded {len(val_df):,} validated sites from {file.name}")

print(f"\nTotal unique validated sites: {len(validated_site_ids):,}")

# Create site_id column if it doesn't exist
if 'site_id' not in drs_df.columns:
    drs_df = drs_df.with_columns([
        (pl.col('chrom').cast(pl.Utf8) + '_' + 
         pl.col('drs_end').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
    ])
    print("  ✓ Created site_id column in DRS data")

# Filter DRS data to ONLY validated sites
drs_df = drs_df.filter(pl.col('site_id').is_in(list(validated_site_ids)))
print(f"\n✓ Filtered DRS data: {len(drs_df):,} validated sites (down from original)")

# Load genes from Find_Genes.ipynb output
comparison_dir = '/Volumes/AJS_SSD/comparative_analysis'
gene_sets = load_genes_from_comparison(
    comparison_dir=comparison_dir,
    cell_line='GM12878',
    comparison_name='HEK293_vs_GM12878'
)

# Create multi_mod_df from the loaded genes
combinations_to_plot = [
    # 'm6A_psi_inosine',
    'm6A_m5C_psi',
    'm6A_m5C_inosine',
    'm6A_psi_inosine_shared'
    
    
    # Add more as needed
]

multi_mod_df = create_multi_mod_df_from_genes(
    drs_df=drs_df,
    gene_sets=gene_sets,
    combinations_to_plot=combinations_to_plot
)

# Save the detailed gene info and create plots
if multi_mod_df is not None:
    output_dir = '/Volumes/AJS_SSD/HEK293/multi_mod_plots_from_comparison/'
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    multi_mod_df.write_csv(f"{output_dir}/multi_mod_genes_detailed.csv")
    print(f"\n✓ Saved detailed gene info to {output_dir}/multi_mod_genes_detailed.csv")
    
    # Create swarm plots
    create_swarm_plots_from_genes(
        drs_df=drs_df,
        multi_mod_df=multi_mod_df,
        output_dir=output_dir,
        by_combination=True,
        top_n=10
    )
    
    print("\n" + "="*70)
    print("COMPLETE!")
    print("="*70)
    print(f"\nAll plots saved to: {output_dir}")
else:
    print("\n⚠ No genes found to plot")

LOADING DATA
✓ Loaded DRS data: 93,571 total sites

FILTERING TO VALIDATED SITES ONLY
  ✓ Loaded 42,525 validated sites from GM12878_m6A_GLORI1_validated.csv
  ✓ Loaded 31,802 validated sites from GM12878_m6A_GLORI2_validated.csv
  ✓ Loaded 159 validated sites from GM12878_m5C_validated.csv
  ✓ Loaded 50 validated sites from GM12878_Psi_BIDseq_validated.csv
  ✓ Loaded 2,148 validated sites from GM12878_Inosine_validated.csv

Total unique validated sites: 47,054
  ✓ Created site_id column in DRS data

✓ Filtered DRS data: 27,239 validated sites (down from original)
Loading genes from HEK293_vs_GM12878 comparison
✓ Loaded m6A: 2288 genes
✓ Loaded m5C: 129 genes
✓ Loaded psi: 20 genes
✓ Loaded inosine: 165 genes
✓ Loaded m6A_m5C: 98 genes
✓ Loaded m6A_psi: 14 genes
✓ Loaded m6A_inosine: 113 genes
✓ Loaded m5C_psi: 2 genes
✓ Loaded m5C_inosine: 8 genes
✓ Loaded m6A_m5C_psi: 1 genes
✓ Loaded m6A_m5C_inosine: 7 genes
✓ Loaded m6A_shared: 4509 genes
✓ Loaded m5C_shared: 28 genes
✓ Loaded psi_