In [1]:
import pickle
import polars as pl
import pandas as pd
import glob
from pathlib import Path
from typing import Dict, Set, Tuple


class ModificationAnalyzer:
    """Analyzes RNA modifications across cell lines"""
    
    # Modification type mappings
    # Note: ChEBI IDs are used in DRS data (17802=psi, 17596=inosine)
    MOD_MAPPING = {
        'm6A': ['a'],
        'm5C': ['m'],
        'psi': ['17802', 'psi'],  # ChEBI ID: 17802
        'inosine': ['17596', 'inosine', 'I']  # ChEBI ID: 17596
    }
    
    def __init__(self, base_dir: str, output_dir: str):
        """
        Initialize analyzer
        
        Parameters:
        -----------
        base_dir : str
            Base directory containing cell line subdirectories
        output_dir : str
            Directory for output files
        """
        self.base_dir = Path(base_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.cell_lines = {}
        self.drs_data = {}
        
    def load_cell_line_data(self, cell_line: str, orthogonal_subdir: str, drs_file: str):
        """
        Load validated sites and DRS data for a cell line
        
        Parameters:
        -----------
        cell_line : str
            Cell line name (e.g., 'HEK293', 'GM12878')
        orthogonal_subdir : str
            Subdirectory path for orthogonal validated data
        drs_file : str
            Path to DRS pickle file
        """
        print(f"\n{'='*70}")
        print(f"Loading data for {cell_line}")
        print('='*70)
        
        # Load orthogonal validated data
        files = glob.glob(str(self.base_dir / orthogonal_subdir / '*'))
        validated_dict = self._read_validated_data(files, cell_line)
        
        # Load DRS data
        with open(drs_file, 'rb') as f:
            drs_df = pd.read_pickle(f)
        
        if not isinstance(drs_df, pl.DataFrame):
            drs_df = pl.from_pandas(drs_df)
        
        self.cell_lines[cell_line] = validated_dict
        self.drs_data[cell_line] = drs_df
        
        print(f"Loaded {len(validated_dict)} validated datasets for {cell_line}")
        print(f"DRS data: {len(drs_df)} sites")
        
    def _read_validated_data(self, files: list, cell_line: str) -> Dict:
        """Read and organize validated modification data"""
        data_dict = {}
        
        for file in files:
            filename = Path(file).name
            lines = filename.split('_')
            name = lines[0]
            
            if len(lines) > 1 and lines[1] != 'orthogonal':
                name = f"{lines[0]}_{lines[1]}"
            
            df = pl.read_csv(
                str(file),
                schema_overrides={'chromosome': pl.Utf8}
            )
            
            data_dict[name] = df
            print(f"  Loaded {name}: {len(df)} rows")
        
        return data_dict
    
    def _determine_mod_type(self, tech_name: str) -> Tuple[str, list]:
        """Determine modification type from technology name"""
        tech_lower = tech_name.lower()
        
        if 'm6a' in tech_lower or 'glori' in tech_lower:
            return 'm6A', self.MOD_MAPPING['m6A']
        elif 'm5c' in tech_lower:
            return 'm5C', self.MOD_MAPPING['m5C']
        elif 'psi' in tech_lower or 'bid' in tech_lower or 'praise' in tech_lower:
            return 'psi', self.MOD_MAPPING['psi']
        elif 'inosine' in tech_lower:
            return 'inosine', self.MOD_MAPPING['inosine']
        return None, None
    
    def analyze_cell_line(self, cell_line: str) -> Dict:
        """Analyze validated genes for a single cell line"""
        print(f"\n{'='*70}")
        print(f"Analyzing Validated Genes for {cell_line}")
        print('='*70)
        
        drs_df = self.drs_data[cell_line]
        validated_dict = self.cell_lines[cell_line]
        
        # Store gene IDs for each modification type
        validated_genes = {mod_type: set() for mod_type in self.MOD_MAPPING.keys()}
        detailed_results = {}
        
        # Process each validated dataset
        for tech_name, validated_df in validated_dict.items():
            print(f"\n--- Processing {tech_name} ---")
            
            mod_type, mod_codes = self._determine_mod_type(tech_name)
            if not mod_type:
                print(f"  Warning: Could not determine modification type")
                continue
            
            if 'site_id' not in validated_df.columns:
                print(f"  Warning: No site_id column")
                continue
            
            validated_sites = set(validated_df['site_id'].to_list())
            
            # Filter DRS data for this modification
            # Ensure all mod_codes are strings for comparison
            mod_codes_str = [str(code) for code in mod_codes]
            drs_mod = drs_df.filter(pl.col('mod').cast(pl.Utf8).is_in(mod_codes_str))
            
            drs_mod = drs_mod.with_columns([
                (pl.col('chrom').cast(pl.Utf8) + '_' + 
                 pl.col('drs_end').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
            
            # Find matching sites
            matched_sites = drs_mod.filter(pl.col('site_id').is_in(validated_sites))
            gene_ids = set(matched_sites['gene_id'].unique().to_list())
            gene_ids.discard(None)
            
            validated_genes[mod_type].update(gene_ids)
            
            print(f"  Modification: {mod_type}")
            print(f"  Validated sites: {len(validated_sites)}")
            print(f"  Matched sites: {len(matched_sites)}")
            print(f"  Unique genes: {len(gene_ids)}")
            
            detailed_results[tech_name] = {
                'mod_type': mod_type,
                'validated_sites': len(validated_sites),
                'matched_sites': len(matched_sites),
                'unique_genes': len(gene_ids),
                'gene_ids': gene_ids
            }
        
        # Calculate modification combinations
        cross_mod = self._calculate_combinations(validated_genes)
        
        return {
            'validated_genes': validated_genes,
            'detailed_results': detailed_results,
            'cross_mod_analysis': cross_mod
        }
    
    def _calculate_combinations(self, validated_genes: Dict[str, Set]) -> Dict:
        """Calculate all modification combinations"""
        # Four-way
        all_four = (validated_genes['m6A'] & validated_genes['m5C'] & 
                   validated_genes['psi'] & validated_genes['inosine'])
        
        # Three-way combinations
        three_mods = {
            'm6a_m5c_psi': validated_genes['m6A'] & validated_genes['m5C'] & validated_genes['psi'],
            'm6a_m5c_ino': validated_genes['m6A'] & validated_genes['m5C'] & validated_genes['inosine'],
            'm6a_psi_ino': validated_genes['m6A'] & validated_genes['psi'] & validated_genes['inosine'],
            'm5c_psi_ino': validated_genes['m5C'] & validated_genes['psi'] & validated_genes['inosine']
        }
        
        # Two-way combinations
        two_mods = {
            'm6a_m5c': validated_genes['m6A'] & validated_genes['m5C'],
            'm6a_psi': validated_genes['m6A'] & validated_genes['psi'],
            'm6a_ino': validated_genes['m6A'] & validated_genes['inosine'],
            'm5c_psi': validated_genes['m5C'] & validated_genes['psi'],
            'm5c_ino': validated_genes['m5C'] & validated_genes['inosine'],
            'psi_ino': validated_genes['psi'] & validated_genes['inosine']
        }
        
        return {
            'all_four': all_four,
            'three_mods': three_mods,
            'two_mods': two_mods
        }
    
    def compare_cell_lines(self, cell_line1: str, cell_line2: str, 
                          results1: Dict, results2: Dict) -> Dict:
        """
        Compare validated genes between two cell lines
        
        Returns dict with unique and shared genes for each modification and combination
        """
        print(f"\n{'='*70}")
        print(f"Comparing {cell_line1} vs {cell_line2}")
        print('='*70)
        
        comparison = {}
        
        # Compare individual modifications
        for mod_type in self.MOD_MAPPING.keys():
            genes1 = results1['validated_genes'][mod_type]
            genes2 = results2['validated_genes'][mod_type]
            
            unique1 = genes1 - genes2
            unique2 = genes2 - genes1
            shared = genes1 & genes2
            
            comparison[mod_type] = {
                f'unique_{cell_line1}': unique1,
                f'unique_{cell_line2}': unique2,
                'shared': shared
            }
            
            print(f"\n{mod_type}:")
            print(f"  {cell_line1} only: {len(unique1)} genes")
            print(f"  {cell_line2} only: {len(unique2)} genes")
            print(f"  Shared: {len(shared)} genes")
        
        # Compare two-way combinations
        print(f"\nTwo-modification combinations:")
        for combo_name in results1['cross_mod_analysis']['two_mods'].keys():
            genes1 = results1['cross_mod_analysis']['two_mods'][combo_name]
            genes2 = results2['cross_mod_analysis']['two_mods'][combo_name]
            
            unique1 = genes1 - genes2
            unique2 = genes2 - genes1
            shared = genes1 & genes2
            
            comparison[f'combo_2way_{combo_name}'] = {
                f'unique_{cell_line1}': unique1,
                f'unique_{cell_line2}': unique2,
                'shared': shared
            }
            
            if len(unique1) > 0 or len(unique2) > 0 or len(shared) > 0:
                print(f"  {combo_name}: {cell_line1}={len(unique1)}, "
                      f"{cell_line2}={len(unique2)}, shared={len(shared)}")
        
        # Compare three-way combinations
        print(f"\nThree-modification combinations:")
        for combo_name in results1['cross_mod_analysis']['three_mods'].keys():
            genes1 = results1['cross_mod_analysis']['three_mods'][combo_name]
            genes2 = results2['cross_mod_analysis']['three_mods'][combo_name]
            
            unique1 = genes1 - genes2
            unique2 = genes2 - genes1
            shared = genes1 & genes2
            
            comparison[f'combo_3way_{combo_name}'] = {
                f'unique_{cell_line1}': unique1,
                f'unique_{cell_line2}': unique2,
                'shared': shared
            }
            
            if len(unique1) > 0 or len(unique2) > 0 or len(shared) > 0:
                print(f"  {combo_name}: {cell_line1}={len(unique1)}, "
                      f"{cell_line2}={len(unique2)}, shared={len(shared)}")
        
        # Compare four-way combination
        print(f"\nFour-modification combination:")
        genes1 = results1['cross_mod_analysis']['all_four']
        genes2 = results2['cross_mod_analysis']['all_four']
        
        unique1 = genes1 - genes2
        unique2 = genes2 - genes1
        shared = genes1 & genes2
        
        comparison['combo_4way_all_four'] = {
            f'unique_{cell_line1}': unique1,
            f'unique_{cell_line2}': unique2,
            'shared': shared
        }
        
        print(f"  all_four: {cell_line1}={len(unique1)}, "
              f"{cell_line2}={len(unique2)}, shared={len(shared)}")
        
        return comparison
    
    def save_comparison_results(self, comparison: Dict, cell_line1: str, cell_line2: str):
        """Save comparison results to CSV files"""
        print(f"\n{'='*70}")
        print("Saving Comparison Results")
        print('='*70)
        
        for key, data in comparison.items():
            for category, gene_set in data.items():
                if len(gene_set) > 0:
                    df = pl.DataFrame({'gene_id': list(gene_set)})
                    filename = f"{cell_line1}_vs_{cell_line2}_{key}_{category}.csv"
                    output_path = self.output_dir / filename
                    df.write_csv(str(output_path))
                    print(f"  Saved {filename}: {len(gene_set)} genes")
        
        # Create summary
        summary_data = []
        for key, data in comparison.items():
            for category, gene_set in data.items():
                summary_data.append({
                    'modification': key,
                    'category': category,
                    'n_genes': len(gene_set)
                })
        
        summary_df = pl.DataFrame(summary_data).sort('n_genes', descending=True)
        summary_path = self.output_dir / f"{cell_line1}_vs_{cell_line2}_comparison_summary.csv"
        summary_df.write_csv(str(summary_path))
        print(f"\nSaved comparison summary: {summary_path}")


# Example usage:
if __name__ == "__main__":
    # Initialize analyzer
    analyzer = ModificationAnalyzer(
        base_dir="/Volumes/AJS_SSD/",
        output_dir="/Volumes/AJS_SSD/comparative_analysis/"
    )
    
    # Load HEK293 data
    analyzer.load_cell_line_data(
        cell_line='HEK293',
        orthogonal_subdir='HEK293/orthogonal_validated/HEK293/',
        drs_file='/Volumes/AJS_SSD/HEK293/modkit_output/Annotated_Data/HEK293_08_01_25_09_09_25_10_06_25_GRCh38_sorted_filtered_combined_modkit_ivt_corrected_annotated_annotated_valid_kmer.pkl'
    )
    
    # Load GM12878 data
    analyzer.load_cell_line_data(
        cell_line='GM12878',
        orthogonal_subdir='HEK293/orthogonal_validated/GM12878/',
        drs_file='/Volumes/AJS_SSD/HEK293/modkit_output/Annotated_Data/08_07_24_R9RNA_GM12878_mRNA_RT_sup_8mods_polyA_annotated_valid_kmer.pkl'  # Update path
    )
    
    # Analyze each cell line
    hek293_results = analyzer.analyze_cell_line('HEK293')
    gm12878_results = analyzer.analyze_cell_line('GM12878')
    
    # Compare cell lines
    comparison = analyzer.compare_cell_lines(
        'HEK293', 'GM12878',
        hek293_results, gm12878_results
    )
    
    # Save comparison results
    analyzer.save_comparison_results(comparison, 'HEK293', 'GM12878')


Loading data for HEK293
  Loaded HEK293_m6A: 47518 rows
  Loaded HEK293_m6A: 31580 rows
  Loaded HEK293_m5C: 59 rows
  Loaded HEK293_Psi: 45 rows
  Loaded HEK293_Inosine: 1340 rows
Loaded 4 validated datasets for HEK293
DRS data: 93571 sites

Loading data for GM12878
  Loaded GM12878_m6A: 42525 rows
  Loaded GM12878_m6A: 31802 rows
  Loaded GM12878_m5C: 159 rows
  Loaded GM12878_Psi: 50 rows
  Loaded GM12878_Inosine: 2148 rows
Loaded 4 validated datasets for GM12878
DRS data: 120842 sites

Analyzing Validated Genes for HEK293

--- Processing HEK293_m6A ---
  Modification: m6A
  Validated sites: 31580
  Matched sites: 32385
  Unique genes: 5905

--- Processing HEK293_m5C ---
  Modification: m5C
  Validated sites: 59
  Matched sites: 57
  Unique genes: 55

--- Processing HEK293_Psi ---
  Modification: psi
  Validated sites: 45
  Matched sites: 47
  Unique genes: 44

--- Processing HEK293_Inosine ---
  Modification: inosine
  Validated sites: 1340
  Matched sites: 1242
  Unique genes: 26