# Comparison Notebook for HEK293 DRS Data and Orthogonal Chemistries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import pandas as pd
from pathlib import Path
from glob import glob
import seaborn as sns
from matplotlib_venn import venn2, venn3
from collections import defaultdict
import gzip
import pickle
from functools import reduce
from typing import Any, Union, Dict

## Get orthogonal datasets
#### Links to papers used for chemistry
- m6A: https://www.nature.com/articles/s41596-023-00937-1
- Psi: https://www.nature.com/articles/s41592-024-02439-8
- m5C: https://academic.oup.com/nar/article/50/D1/D196/6431823
- Ino: https://www.nature.com/articles/s41467-022-28841-4#Sec39

### Find all files

In [3]:
orthog_dir = Path("/Volumes/AJS_SSD/HEK293/orthogonal_datasets/")

# Build a complete file index
mod_files = defaultdict(lambda: defaultdict(list))

for mod_type in ['m6A', 'm5C', 'psi', 'inosine', '2OMe']:
    mod_dir = orthog_dir / mod_type
    if not mod_dir.exists():
        continue
    
    # Get files directly in the modification directory
    root_files = [f for f in mod_dir.iterdir() if f.is_file()]
    if root_files:
        mod_files[mod_type]['root'] = root_files
    
    # Get files from subdirectories (like GLORI_1.0 and GLORI_2.0+)
    subdirs = [d for d in mod_dir.iterdir() if d.is_dir() and not d.name.startswith('.')]
    for subdir in subdirs:
        subdir_files = [f for f in subdir.rglob("*") if f.is_file()]
        mod_files[mod_type][subdir.name] = subdir_files

# Display what we have
for mod_type, locations in mod_files.items():
    print(f"\n{mod_type}:")
    for location, files in locations.items():
        print(f"  {location}: {len(files)} files")
        for f in files[:2]:  # Show first 2 files
            print(f"    - {f.name}")
        if len(files) > 2:
            print(f"    ... and {len(files)-2} more")


m6A:
  GLORI_1.0: 3 files
    - 41587_2022_1487_MOESM3_ESM(1).xlsx
    - ._.ipynb_checkpoints
    ... and 1 more
  GLORI_2.0+: 2 files
    - 41592_2025_2680_MOESM3_ESM.xlsb
    - ._41592_2025_2680_MOESM3_ESM.xlsb

m5C:
  root: 1 files
    - GSE225614_HEK293T-WT_sites.tsv.gz

psi:
  PRAISE: 1 files
    - 41589_2015_BFnchembio1836_MOESM158_ESM.xlsx
  BID-Seq: 3 files
    - GSE179798_HEK293T_mRNA_WT_BID-seq.xlsx
    - bid_seq_sites.tsv
    ... and 1 more

inosine:
  root: 12 files
    - GSM2325076_RHH1392.editLevel.txt
    - GSM2325073_RHH1389.editLevel.txt
    ... and 10 more

2OMe:
  root: 4 files
    - NIHMS922173-supplement-Supplementarty_Data.xlsx
    - GSE90164_HEKmRNA.Nm.genome.bed.txt
    ... and 2 more


### Orthogonal Data Loader

In [4]:
# Define an updated type hint to reflect that a pickle can be any object
LoadResult = Union[pd.DataFrame, dict[str, pd.DataFrame], Any]

class OrthogonalDataloader:
    """
    A reusable dataloader that can load various file types, including gzipped
    pickle files (.pkl.gz). For multi-sheet Excel files, it returns a 
    dictionary of DataFrames. For pickle files, it returns the unpickled object.
    """
    def __init__(self, file_path: Path | str) -> None:
        self.file_path: Path = Path(file_path)
        if not self.file_path.is_file():
            raise FileNotFoundError(f"Error: The file was not found at '{self.file_path}'")

    def load_data(self, **kwargs: Any) -> LoadResult:
        """
        Loads data from the file.
        - Returns a dict of DataFrames for multi-sheet Excel files.
        - Returns the original object for pickle files.
        - Returns a single DataFrame for all other types.
        """
        suffixes: list[str] = self.file_path.suffixes
        compression: str | None = 'gzip' if '.gz' in suffixes else None
        
        print(f" Loading '{self.file_path.name}'...")
        try:
            # NEW: Handle pickle files (.pkl or .pkl.gz)
            if '.pkl' in suffixes:
                if compression == 'gzip':
                    with gzip.open(self.file_path, 'rb') as f:
                        return pickle.load(f, **kwargs)
                else:
                    with open(self.file_path, 'rb') as f:
                        return pickle.load(f, **kwargs)
            
            # Handle Excel files, defaulting to load all sheets
            elif '.xlsb' in suffixes or '.xlsx' in suffixes:
                if 'sheet_name' not in kwargs:
                    kwargs['sheet_name'] = None
                engine = 'pyxlsb' if '.xlsb' in suffixes else None
                return pd.read_excel(self.file_path, engine=engine, **kwargs)

            # Handle CSV files (.csv)
            elif '.csv' in suffixes:
                return pd.read_csv(self.file_path, compression=compression, **kwargs)

            # Handle TSV and TXT files, assuming they are tab-separated
            elif '.tsv' in suffixes or '.txt' in suffixes:
                if 'sep' not in kwargs:
                    kwargs['sep'] = '\t'
                return pd.read_csv(self.file_path, compression=compression, **kwargs)

            # If no supported extension is found
            else:
                raise ValueError(f"Unsupported file type: '{''.join(suffixes)}'")
        
        except Exception as e:
            print(f" Failed to load file {self.file_path.name}. Error: {e}")
            return None

### Load all the Orthogonal Data

In [5]:
# Assume 'mod_files' and the UniversalDataloader class are defined

# This dictionary will store the loaded DataFrames
# The structure might be slightly different now for Excel files
loaded_data = defaultdict(lambda: defaultdict(dict))

print("\n--- Starting Data Loading Process ---")

for mod_type, locations in mod_files.items():
    for location, files in locations.items():
        for file_path in files:
            loader = OrthogonalDataloader(file_path)
            data = loader.load_data() # Load the data
            
            if isinstance(data, pd.DataFrame) and not data.empty:
                # It's a single DataFrame (CSV, TXT, etc.)
                loaded_data[mod_type][location][file_path.name] = data
                print(f" Successfully loaded '{file_path.name}'")

            elif isinstance(data, dict):
                # It's a dictionary of DataFrames from an Excel file
                for sheet_name, df in data.items():
                    if not df.empty:
                        # Create a unique name for each sheet
                        unique_key = f"{file_path.name} | {sheet_name}"
                        loaded_data[mod_type][location][unique_key] = df
                        print(f" Successfully loaded sheet '{sheet_name}' from '{file_path.name}'")



--- Starting Data Loading Process ---
 Loading '41587_2022_1487_MOESM3_ESM(1).xlsx'...
 Successfully loaded sheet 'Sheet1' from '41587_2022_1487_MOESM3_ESM(1).xlsx'
 Loading '._.ipynb_checkpoints'...
 Failed to load file ._.ipynb_checkpoints. Error: Unsupported file type: '.ipynb_checkpoints'
 Loading '._41587_2022_1487_MOESM3_ESM(1).xlsx'...
 Failed to load file ._41587_2022_1487_MOESM3_ESM(1).xlsx. Error: Excel file format cannot be determined, you must specify an engine manually.
 Loading '41592_2025_2680_MOESM3_ESM.xlsb'...
 Successfully loaded sheet '50ng_mRNA_input' from '41592_2025_2680_MOESM3_ESM.xlsb'
 Successfully loaded sheet '10ng_mRNA_input' from '41592_2025_2680_MOESM3_ESM.xlsb'
 Successfully loaded sheet '2ng_mRNA_input' from '41592_2025_2680_MOESM3_ESM.xlsb'
 Loading '._41592_2025_2680_MOESM3_ESM.xlsb'...
 Failed to load file ._41592_2025_2680_MOESM3_ESM.xlsb. Error: File is not a zip file
 Loading 'GSE225614_HEK293T-WT_sites.tsv.gz'...
 Successfully loaded 'GSE225614_

## Process Orthogonal m6A

### GLORI_1.0

In [10]:
glori_1_updated['41587_2022_1487_MOESM3_ESM(1).xlsx | Sheet1']

Unnamed: 0,Chr,Sites,Strand,Gene,AGCov_rep1,AGCov_rep2,m6A_level_rep1,m6A_level_rep2,Cluster_info
0,chr10,47499,-,TUBB8,76,65,0.30263,0.24615,Non-cluster
1,chr10,3835789,+,ELSE,33,31,0.69697,0.58065,Non-cluster
2,chr10,3899385,+,ELSE,60,57,0.91667,0.89474,Non-cluster
3,chr10,5357066,+,ELSE,65,53,0.98462,0.94340,Non-cluster
4,chr10,7596572,+,ELSE,28,30,0.39286,0.23333,Non-cluster
...,...,...,...,...,...,...,...,...,...
170235,chrY,25307890,+,RPL41P6,27,33,0.70370,0.63636,Non-cluster
170236,chrY,25307971,+,RPL41P6,87,92,0.19540,0.18478,Non-cluster
170237,chrY,25308031,+,RPL41P6,36,38,0.66667,0.57895,Non-cluster
170238,chrY,25393262,-,LINC00265-3P,29,19,0.27586,0.52632,Non-cluster


In [7]:
glori_1_updated = loaded_data['m6A']['GLORI_1.0']
# glori_1 = loaded_data['m6A']['GLORI_1.0']['GSM7438377_HEK293T-GLORI-rep1.FDR.tsv.gz']
# glori_1_1 = loaded_data['m6A']['GLORI_1.0']['GSM7438378_HEK293T-GLORI-rep2.FDR.tsv.gz']

In [11]:
def process_new_glori1(glori1_df):
    """
    Process the new GLORI-1 format with 2 replicates
    
    Parameters:
    -----------
    glori1_df : DataFrame
        New GLORI-1 data with columns: Chr, Sites, Strand, Gene, 
        AGCov_rep1, AGCov_rep2, m6A_level_rep1, m6A_level_rep2, Cluster_info
    
    Returns:
    --------
    DataFrame with averaged m6A levels
    """
    import pandas as pd
    
    # Create a copy to avoid modifying original
    df = glori1_df.copy()
    
    # m6A levels are already in decimal format (0.30263 = 30.263%)
    # Convert to percentage for consistency with other datasets
    df['m6A_level_rep1_pct'] = df['m6A_level_rep1'] * 100
    df['m6A_level_rep2_pct'] = df['m6A_level_rep2'] * 100
    
    # Calculate mean m6A level
    df['m6A_level_mean'] = (df['m6A_level_rep1_pct'] + df['m6A_level_rep2_pct']) / 2
    
    # Calculate mean coverage
    df['AGCov_mean'] = (df['AGCov_rep1'] + df['AGCov_rep2']) / 2
    
    print(f"Processed new GLORI-1 data:")
    print(f"  Total sites: {len(df):,}")
    print(f"  Mean m6A level: {df['m6A_level_mean'].mean():.2f}%")
    print(f"  Mean coverage: {df['AGCov_mean'].mean():.1f}")
    
    return df

# Usage:
# Load your new GLORI-1 file

new_glori1_raw =loaded_data['m6A']['GLORI_1.0']['41587_2022_1487_MOESM3_ESM(1).xlsx | Sheet1']

# Process it
new_glori1 = process_new_glori1(new_glori1_raw)

Processed new GLORI-1 data:
  Total sites: 170,240
  Mean m6A level: 45.52%
  Mean coverage: 205.0


### Old Glori 1.0

In [6]:
import pandas as pd
import numpy as np

def combine_replicates(
    df1: pd.DataFrame, 
    df2: pd.DataFrame,
    rep1_name: str = 'rep1',
    rep2_name: str = 'rep2'
) -> pd.DataFrame:
    """
    Combines two technical replicate DataFrames by finding common sites
    and averaging their quantitative values.

    Args:
        df1: DataFrame for the first replicate.
        df2: DataFrame for the second replicate.
        rep1_name: Suffix for replicate 1's columns.
        rep2_name: Suffix for replicate 2's columns.

    Returns:
        A new DataFrame with combined and averaged data for common sites.
    """
    
    # --- Step 1: Merge to find common sites ---
    # An 'inner' merge keeps only rows where the keys ('Chr', 'Sites', 'Strand')
    # exist in BOTH DataFrames.
    merged_df = pd.merge(
        df1,
        df2,
        on=['Chr', 'Sites', 'Strand'],
        how='inner',
        suffixes=(f'_{rep1_name}', f'_{rep2_name}')
    )

    # --- Step 2: Calculate mean values for key columns ---
    # List of columns to average. Add or remove as needed.
    cols_to_average = ['AGcov', 'Acov', 'Ratio', 'P_adjust']

    for col in cols_to_average:
        rep1_col = f'{col}_{rep1_name}'
        rep2_col = f'{col}_{rep2_name}'
        
        # Check if both replicate columns exist before averaging
        if rep1_col in merged_df.columns and rep2_col in merged_df.columns:
            merged_df[f'{col}_mean'] = merged_df[[rep1_col, rep2_col]].mean(axis=1)

    # --- Step 3: (Optional) Filter for sites significant in both replicates ---
    # This is a common and robust step in bioinformatics.
    # You can uncomment this block and adjust the p-value threshold as needed.
    # p_adj_threshold = 0.05
    # significant_sites = merged_df[
    #     (merged_df[f'P_adjust_{rep1_name}'] < p_adj_threshold) &
    #     (merged_df[f'P_adjust_{rep2_name}'] < p_adj_threshold)
    # ]
    # return significant_sites

    return merged_df


In [7]:
# Assuming 'replicate1_df' and 'replicate2_df' are your loaded DataFrames

# Combine the replicates
combined_glori_1 = combine_replicates(glori_1, glori_1_1)

# Display the first few rows of the combined data
print("Combined DataFrame head:")
print(combined_glori_1.shape)

# Display the new averaged columns
print("\nColumns in the new DataFrame:")
print(combined_glori_1.columns)

Combined DataFrame head:
(60462, 23)

Columns in the new DataFrame:
Index(['Chr', 'Sites', 'Strand', 'Gene_rep1', 'CR_rep1', 'AGcov_rep1',
       'Acov_rep1', 'Genecov_rep1', 'Ratio_rep1', 'Pvalue_rep1',
       'P_adjust_rep1', 'Gene_rep2', 'CR_rep2', 'AGcov_rep2', 'Acov_rep2',
       'Genecov_rep2', 'Ratio_rep2', 'Pvalue_rep2', 'P_adjust_rep2',
       'AGcov_mean', 'Acov_mean', 'Ratio_mean', 'P_adjust_mean'],
      dtype='object')


#### Plotting GLORI_1 Overlap

In [8]:
import pandas as pd
from matplotlib_venn import venn2, venn2_circles
from matplotlib import pyplot as plt
from typing import Set

def compare_dataframes_venn(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    df1_name: str = "DataFrame 1",
    df2_name: str = "DataFrame 2",
    title: str = "Genomic Site Overlap",
    key_cols_df1: tuple[str, str, str] = ('Chr', 'Sites', 'Strand'),
    key_cols_df2: tuple[str, str, str] = ('Chr', 'Sites', 'Strand')
) -> None:
    """
    Generates a 2D Venn diagram to compare common genomic sites
    between two DataFrames.
    """
    print(f"Preparing data for Venn diagram: {df1_name} vs {df2_name}")

    # Prepare df1
    df1_coords = df1[list(key_cols_df1)].copy()
    df1_coords.columns = ['chrom', 'position', 'strand']
    df1_coords['position'] = df1_coords['position'].astype(int)
    df1_sites: Set[str] = set(df1_coords.apply(lambda r: f"{r['chrom']}_{r['position']}_{r['strand']}", axis=1))

    # Prepare df2
    df2_coords = df2[list(key_cols_df2)].copy()
    df2_coords.columns = ['chrom', 'position', 'strand']
    df2_coords['position'] = df2_coords['position'].astype(int)
    df2_sites: Set[str] = set(df2_coords.apply(lambda r: f"{r['chrom']}_{r['position']}_{r['strand']}", axis=1))

    print(f"Found {len(df1_sites)} unique sites in {df1_name}.")
    print(f"Found {len(df2_sites)} unique sites in {df2_name}.")

    # Generate the Venn Diagram
    plt.figure(figsize=(8, 8))
    venn2([df1_sites, df2_sites], (df1_name, df2_name))
    venn2_circles([df1_sites, df2_sites], linestyle="solid", color="black", linewidth=1.0)
    plt.title(title)
    plt.savefig("/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6A/GLORI_1.0/GLORI_1.0_Overlap_Between_Replicates.pdf")
    plt.show()

    # Print summary
    overlap_count = len(df1_sites.intersection(df2_sites))
    print(f"\n--- Venn Diagram Summary ---")
    print(f"Sites unique to {df1_name}: {len(df1_sites) - overlap_count}")
    print(f"Sites unique to {df2_name}: {len(df2_sites) - overlap_count}")
    print(f"Sites common to both (overlap): {overlap_count}\n")

In [9]:
# Assuming 'replicate1_df' and 'replicate2_df' are your loaded DataFrames

# Generate the Venn diagram to compare the two replicates
compare_dataframes_venn(
    df1=glori_1,
    df2=glori_1_1,
    df1_name="Replicate 1",
    df2_name="Replicate 2",
    title="Overlap of Detected Sites Between GLORI-Seq 1.0 Technical Replicates",
    # The key columns are the same for both dataframes
    key_cols_df1=('Chr', 'Sites', 'Strand'),
    key_cols_df2=('Chr', 'Sites', 'Strand')
)

Preparing data for Venn diagram: Replicate 1 vs Replicate 2
Found 72449 unique sites in Replicate 1.
Found 81554 unique sites in Replicate 2.

--- Venn Diagram Summary ---
Sites unique to Replicate 1: 11987
Sites unique to Replicate 2: 21092
Sites common to both (overlap): 60462



  plt.show()


### GLORI_2.0

In [14]:
glori2_2 = loaded_data['m6A']['GLORI_2.0+']['41592_2025_2680_MOESM3_ESM.xlsb | 2ng_mRNA_input']
glori2_10 = loaded_data['m6A']['GLORI_2.0+']['41592_2025_2680_MOESM3_ESM.xlsb | 10ng_mRNA_input']
glori2_50 = loaded_data['m6A']['GLORI_2.0+']['41592_2025_2680_MOESM3_ESM.xlsb | 50ng_mRNA_input']

In [15]:
import pandas as pd
from functools import reduce

def combine_three_replicates(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    df3: pd.DataFrame,
    df1_name: str = '2ng',
    df2_name: str = '10ng',
    df3_name: str = '50ng'
) -> pd.DataFrame:
    """
    Merges three replicate DataFrames on common sites and calculates
    the mean of their m6A levels.

    Args:
        df1, df2, df3: The three replicate DataFrames to combine.
        df1_name, df2_name, df3_name: Names for each replicate condition.

    Returns:
        A single DataFrame with data for sites common to all three replicates.
    """
    # Put dataframes and their names into lists for easier processing
    dfs = [df1, df2, df3]
    names = [df1_name, df2_name, df3_name]
    
    # Define the columns that uniquely identify a site
    merge_keys = ['Chr', 'Site', 'Strand']

    # Prepare each DataFrame by adding a suffix to its data columns
    prepared_dfs = []
    for df, name in zip(dfs, names):
        df_copy = df.copy()
        # Rename 'Site' to 'position' for consistency if needed, but 'Site' is fine
        df_copy.set_index(merge_keys, inplace=True)
        df_copy = df_copy.add_suffix(f'_{name}')
        df_copy.reset_index(inplace=True)
        prepared_dfs.append(df_copy)

    # Use reduce to iteratively merge all DataFrames in the list
    merged_df = reduce(lambda left, right: pd.merge(left, right, on=merge_keys, how='inner'), prepared_dfs)

    # Calculate the mean of the m6A levels from all three replicates
    m6a_cols = [col for col in merged_df.columns if 'm6A_level' in col]
    merged_df['m6A_level_mean'] = merged_df[m6a_cols].mean(axis=1)
    
    return merged_df

# --- How to Use ---
# Assuming your dataframes are named glori2_2, glori2_10, and glori2_50
combined_glori_2 = combine_three_replicates(glori2_2, glori2_10, glori2_50)



In [12]:
from matplotlib_venn import venn3, venn3_circles
from matplotlib import pyplot as plt
from typing import Set

def compare_three_dataframes_venn(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    df3: pd.DataFrame,
    set_names: tuple[str, str, str] = ('Set 1', 'Set 2', 'Set 3'),
    title: str = "Genomic Site Overlap"
) -> None:
    """
    Generates a 3-way Venn diagram to compare common genomic sites
    between three DataFrames.
    """
    key_cols = ('Chr', 'Site', 'Strand')
    dfs = [df1, df2, df3]
    all_sites = []

    print("Preparing data for 3-way Venn diagram...")
    for i, df in enumerate(dfs):
        coords = df[list(key_cols)].copy()
        coords.columns = ['chrom', 'position', 'strand']
        coords['position'] = coords['position'].astype(int)
        
        sites: Set[str] = set(coords.apply(lambda r: f"{r['chrom']}_{r['position']}_{r['strand']}", axis=1))
        all_sites.append(sites)
        print(f"Found {len(sites)} unique sites in {set_names[i]}.")

    # Generate the Venn Diagram
    plt.figure(figsize=(10, 10))
    venn3(all_sites, set_names)
    venn3_circles(all_sites, linestyle="solid", color="black", linewidth=1.0)
    plt.title(title)
    plt.savefig("/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6A/GLORI_2.0/GLORI_2.0_Overlap_Between_Different_Concentrations.pdf")
    plt.show()

# --- How to Use ---
# Call the function with your three replicate DataFrames
compare_three_dataframes_venn(
    glori2_2,
    glori2_10,
    glori2_50,
    set_names=('2ng Input', '10ng Input', '50ng Input'),
    title="Overlap of Detected m6A Sites Across Different Input Concentrations"
)

Preparing data for 3-way Venn diagram...
Found 74254 unique sites in 2ng Input.
Found 101613 unique sites in 10ng Input.
Found 99707 unique sites in 50ng Input.


  plt.show()


### Inner GLORI Comparison

In [13]:
import pandas as pd
from matplotlib_venn import venn2, venn2_circles
from matplotlib import pyplot as plt
from typing import Set

def compare_dataframes_venn(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    df1_name: str = "DataFrame 1",
    df2_name: str = "DataFrame 2",
    title: str = "Genomic Site Overlap",
    key_cols_df1: tuple[str, str, str] = ('Chr', 'Sites', 'Strand'),
    key_cols_df2: tuple[str, str, str] = ('Chr', 'Sites', 'Strand')
) -> None:
    """
    Generates a 2D Venn diagram to compare common genomic sites
    between two DataFrames.
    """
    print(f"Preparing data for Venn diagram: {df1_name} vs {df2_name}")

    # Prepare df1
    df1_coords = df1[list(key_cols_df1)].copy()
    df1_coords.columns = ['chrom', 'position', 'strand']
    df1_coords['position'] = df1_coords['position'].astype(int)
    df1_sites: Set[str] = set(df1_coords.apply(lambda r: f"{r['chrom']}_{r['position']}_{r['strand']}", axis=1))

    # Prepare df2
    df2_coords = df2[list(key_cols_df2)].copy()
    df2_coords.columns = ['chrom', 'position', 'strand']
    df2_coords['position'] = df2_coords['position'].astype(int)
    df2_sites: Set[str] = set(df2_coords.apply(lambda r: f"{r['chrom']}_{r['position']}_{r['strand']}", axis=1))

    print(f"Found {len(df1_sites)} unique sites in {df1_name}.")
    print(f"Found {len(df2_sites)} unique sites in {df2_name}.")

    # Generate the Venn Diagram
    plt.figure(figsize=(8, 8))
    venn2([df1_sites, df2_sites], (df1_name, df2_name))
    venn2_circles([df1_sites, df2_sites], linestyle="solid", color="black", linewidth=1.0)
    plt.title(title)
    plt.savefig("/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6A/GLORI_Comparison_Between_Chemistries_GLORI2_10ng.pdf")
    plt.show()

    # Print summary
    overlap_count = len(df1_sites.intersection(df2_sites))
    print(f"\n--- Venn Diagram Summary ---")
    print(f"Sites unique to {df1_name}: {len(df1_sites) - overlap_count}")
    print(f"Sites unique to {df2_name}: {len(df2_sites) - overlap_count}")
    print(f"Sites common to both (overlap): {overlap_count}\n")

# Generate the Venn diagram to compare the two GLORI versions
compare_dataframes_venn(
    df1=combined_glori_1,
    df2=combined_glori_2,
    df1_name="GLORI 1.0 Combined",
    df2_name="GLORI 2.0 Combined",
    title="Comparison of Sites Detected by GLORI 1.0 vs GLORI 2.0 10 ng set",
    # IMPORTANT: Specify the correct position column name for each DataFrame
    key_cols_df1=('Chr', 'Sites', 'Strand'), # From your GLORI 1.0 data
    key_cols_df2=('Chr', 'Site', 'Strand')   # From your GLORI 2.0+ data
)

Preparing data for Venn diagram: GLORI 1.0 Combined vs GLORI 2.0 Combined
Found 60462 unique sites in GLORI 1.0 Combined.
Found 65687 unique sites in GLORI 2.0 Combined.

--- Venn Diagram Summary ---
Sites unique to GLORI 1.0 Combined: 22423
Sites unique to GLORI 2.0 Combined: 27648
Sites common to both (overlap): 38039



  plt.show()


## Process Orthogonal m5C

In [16]:
m5c = loaded_data['m5C']['root']['GSE225614_HEK293T-WT_sites.tsv.gz']

In [18]:
# Invert the boolean condition INSIDE the brackets
m5c_orthogonal_df = m5c[~(m5c['gene_type'] == 'rRNA')
    & ~(m5c['gene_type'] == 'tRNA')
    ]


## Process Orthogonal Inosine

In [19]:
ino_1 = loaded_data['inosine']['root']['Data_S2_A-to-I_sites_identified_by_slic-seq.xlsx | HEK293T-rep1']
ino_2 = loaded_data['inosine']['root']['Data_S2_A-to-I_sites_identified_by_slic-seq.xlsx | HEK293T-rep2']
ino_3 = loaded_data['inosine']['root']['Data_S2_A-to-I_sites_identified_by_slic-seq.xlsx | HEK293T-rep3']


In [20]:
import pandas as pd

# Define the Location values to keep
# Include intergenic, exonic, and any UTR variants
keep_locations = ['intergenic', 'exonic', 'UTR3', 'UTR5', 'UTR5;UTR3']

# Filter each dataframe
ino_1_filtered = ino_1[ino_1['Location'].isin(keep_locations)].copy()
ino_2_filtered = ino_2[ino_2['Location'].isin(keep_locations)].copy()
ino_3_filtered = ino_3[ino_3['Location'].isin(keep_locations)].copy()

# Add a column to track which replicate each row came from
ino_1_filtered['replicate'] = 'ino_1'
ino_2_filtered['replicate'] = 'ino_2'
ino_3_filtered['replicate'] = 'ino_3'

# Combine all three dataframes vertically
combined_ino = pd.concat([ino_1_filtered, ino_2_filtered, ino_3_filtered], 
                        ignore_index=True)


## 2' OMethyl

In [21]:
import pandas as pd

# Load 2'OMe data from Tang et al. paper (Table S2)
# This loads from HEK293T cell line data
df = loaded_data['2OMe']['root']['1-s2.0-S2667237524000365-mmc3.xlsx | HEK293T'].copy()

# Set first row as column names (header is in row 0)
df.columns = df.iloc[0]
df = df[1:]  # Drop the first row
df = df.reset_index(drop=True)  # Reset index

# Select only the columns you need
df_condensed = df[['Chr', 'Position', 'Strand', 'Nm', 'ID']].copy()

# Rename 'ID' to 'Gene' for clarity
df_condensed = df_condensed.rename(columns={'ID': 'Gene'})

# Convert Position to numeric (it's probably a string after the header fix)
df_condensed['Position'] = pd.to_numeric(df_condensed['Position'])

# Split into four dataframes based on the Nm (nucleotide) column
OMe_A = df_condensed[df_condensed['Nm'] == 'A'].copy()
OMe_C = df_condensed[df_condensed['Nm'] == 'C'].copy()
OMe_G = df_condensed[df_condensed['Nm'] == 'G'].copy()
OMe_U = df_condensed[df_condensed['Nm'] == 'U'].copy()

# Check the results
print(f"Total rows: {len(df_condensed)}")
print(f"A modifications: {len(OMe_A)}")
print(f"C modifications: {len(OMe_C)}")
print(f"G modifications: {len(OMe_G)}")
print(f"U modifications: {len(OMe_U)}")

Total rows: 2059
A modifications: 314
C modifications: 650
G modifications: 645
U modifications: 450


## Psi

In [22]:
praise_df = loaded_data['psi']['PRAISE']['41589_2015_BFnchembio1836_MOESM158_ESM.xlsx | ψ sites in human']

In [23]:
loaded_bid = loaded_data['psi']['BID-Seq']['GSE179798_HEK293T_mRNA_WT_BID-seq.xlsx | Sheet1']
loaded_bid.columns = loaded_bid.iloc[2]  # Set 3rd row as column names
loaded_bid = loaded_bid[3:].reset_index(drop=True)  # Add drop=True
print(loaded_bid.keys())
# Correct rename syntax - need parentheses and 'columns=' parameter
bid_seq_df = loaded_bid

Index(['chr', 'pos', 'name', 'refseq', 'seg', 'strand', 'Deletion_rep1',
       'Deletion_rep2', 'Deletion_rep3', 'Deletion_Ave', 'Motif_1', 'Motif_2',
       'Frac_rep1 %', 'Frac_rep2 %', 'Frac_rep3 %', 'Frac_Ave %',
       'Deletion count_rep1', 'Deletion count_rep2', 'Deletion count_rep3'],
      dtype='object', name=2)


In [24]:
import pandas as pd
import gzip
from collections import defaultdict

def parse_gencode_gtf(gtf_path):
    """
    Parse GENCODE GTF and build mappings:
    - transcript_id -> exon structure
    - gene_name -> best transcript_id (canonical or longest)
    """
    transcript_exons = defaultdict(list)
    gene_transcripts = defaultdict(list)
    transcript_info = {}
    
    opener = gzip.open if gtf_path.endswith('.gz') else open
    
    with opener(gtf_path, 'rt') as f:
        for line in f:
            if line.startswith('#'):
                continue
            
            fields = line.strip().split('\t')
            if len(fields) < 9:
                continue
            
            chrom, source, feature, start, end, score, strand, frame, attributes = fields
            
            # Parse attributes
            attr_dict = {}
            for attr in attributes.split(';'):
                attr = attr.strip()
                if not attr:
                    continue
                parts = attr.split(' ', 1)
                if len(parts) == 2:
                    key, val = parts
                    attr_dict[key] = val.strip('"')
            
            transcript_id = attr_dict.get('transcript_id')
            gene_name = attr_dict.get('gene_name')
            
            if feature == 'transcript':
                # Store transcript metadata
                transcript_type = attr_dict.get('transcript_type', '')
                tags = attr_dict.get('tag', '')
                transcript_info[transcript_id] = {
                    'gene_name': gene_name,
                    'chrom': chrom,
                    'strand': strand,
                    'transcript_type': transcript_type,
                    'is_canonical': 'Ensembl_canonical' in tags or 'MANE_Select' in tags,
                    'tags': tags
                }
                if gene_name:
                    gene_transcripts[gene_name].append(transcript_id)
            
            elif feature == 'exon':
                exon_number = int(attr_dict.get('exon_number', 0))
                
                if transcript_id:
                    transcript_exons[transcript_id].append({
                        'chrom': chrom,
                        'start': int(start),  # GTF is 1-based
                        'end': int(end),
                        'strand': strand,
                        'exon_number': exon_number
                    })
    
    # Sort exons by exon number for each transcript
    for transcript_id in transcript_exons:
        transcript_exons[transcript_id].sort(key=lambda x: x['exon_number'])
    
    # For each gene, select the best transcript
    gene_to_transcript = {}
    for gene_name, transcripts in gene_transcripts.items():
        # Prefer canonical transcript
        canonical = [t for t in transcripts if transcript_info.get(t, {}).get('is_canonical', False)]
        if canonical:
            best_transcript = canonical[0]
        else:
            # Otherwise, pick the longest transcript with protein_coding priority
            def transcript_length(tid):
                exons = transcript_exons.get(tid, [])
                return sum(e['end'] - e['start'] + 1 for e in exons)
            
            # Prioritize protein_coding transcripts
            protein_coding = [t for t in transcripts 
                            if transcript_info.get(t, {}).get('transcript_type') == 'protein_coding']
            if protein_coding:
                best_transcript = max(protein_coding, key=transcript_length)
            else:
                best_transcript = max(transcripts, key=transcript_length)
        
        gene_to_transcript[gene_name] = best_transcript
    
    return dict(transcript_exons), gene_to_transcript, transcript_info


def transcript_to_genomic(transcript_id, position, transcript_exons):
    """
    Convert transcript position to genomic position
    
    Args:
        transcript_id: Transcript identifier
        position: Position in transcript space (1-based)
        transcript_exons: Dict mapping transcript_id to list of exons
    
    Returns:
        tuple: (chromosome, genomic_position, strand) or (None, None, None) if not found
    """
    if transcript_id not in transcript_exons:
        return None, None, None
    
    exons = transcript_exons[transcript_id]
    if not exons:
        return None, None, None
    
    chrom = exons[0]['chrom']
    strand = exons[0]['strand']
    
    # For negative strand, reverse the exon order for transcript coordinate system
    if strand == '-':
        exons = list(reversed(exons))
    
    # Walk through exons to find the genomic position
    cumulative_length = 0
    
    for exon in exons:
        exon_length = exon['end'] - exon['start'] + 1
        
        if cumulative_length + exon_length >= position:
            # Position falls in this exon
            offset_in_exon = position - cumulative_length - 1  # 0-based offset
            
            if strand == '+':
                genomic_pos = exon['start'] + offset_in_exon
            else:
                genomic_pos = exon['end'] - offset_in_exon
            
            return chrom, genomic_pos, strand
        
        cumulative_length += exon_length
    
    # Position beyond transcript length
    return chrom, None, strand


def convert_dataframe(df, gtf_path):
    """
    Add genomic coordinates to the dataframe using gene names
    
    Args:
        df: DataFrame with 'gene' and 'Postion' columns
        gtf_path: Path to GENCODE GTF file
    
    Returns:
        DataFrame with added columns: transcript_id, chromosome, genomic_position, strand
    """
    print("Parsing GENCODE GTF file...")
    transcript_exons, gene_to_transcript, transcript_info = parse_gencode_gtf(gtf_path)
    print(f"Loaded {len(transcript_exons)} transcripts")
    print(f"Mapped {len(gene_to_transcript)} genes to transcripts")
    
    # Add new columns
    results = []
    not_found = []
    
    for idx, row in df.iterrows():
        gene_name = row['gene']
        position = row['Postion']  # Note: typo in column name
        
        # Get best transcript for this gene
        transcript_id = gene_to_transcript.get(gene_name)
        
        if transcript_id:
            chrom, genomic_pos, strand = transcript_to_genomic(
                transcript_id, position, transcript_exons
            )
        else:
            transcript_id = None
            chrom, genomic_pos, strand = None, None, None
            not_found.append(gene_name)
        
        results.append({
            'transcript_id': transcript_id,
            'chromosome': chrom,
            'genomic_position': genomic_pos,
            'strand': strand
        })
    
    # Add results to dataframe
    result_df = df.copy()
    for col in ['transcript_id', 'chromosome', 'genomic_position', 'strand']:
        result_df[col] = [r[col] for r in results]
    
    # Report unmapped genes
    if not_found:
        unique_not_found = set(not_found)
        print(f"\nWarning: {len(not_found)} positions could not be mapped")
        print(f"Unique genes not found: {len(unique_not_found)}")
        print(f"Examples: {list(unique_not_found)[:10]}")
    
    return result_df


# Example usage:
praise_with_genomic = convert_dataframe(praise_df, '/Volumes/AJS_SSD/HEK293/gencode_annotations/gencode.v47.annotation.gtf')
# df_with_genomic.to_csv('data_with_genomic_coords.csv', index=False)

# For quick testing:
print(praise_with_genomic[['Accession Number', 'Postion', 'chromosome', 'genomic_position', 'strand']].head())

# Remove rows where genomic_position is None/NaN
praise_filtered = praise_with_genomic[praise_with_genomic['genomic_position'].notna()]

# Or if you want to be explicit about both chromosome and genomic_position
praise_filtered = praise_with_genomic[
    (praise_with_genomic['chromosome'].notna()) & 
    (praise_with_genomic['genomic_position'].notna())
]

print(f"Original rows: {len(praise_with_genomic)}")
print(f"Filtered rows: {len(praise_filtered)}")
print(f"Removed: {len(praise_with_genomic) - len(praise_filtered)}")

Parsing GENCODE GTF file...
Loaded 385659 transcripts
Mapped 77114 genes to transcripts

Unique genes not found: 96
Examples: ['LOC101926955', 'LOC101928062', 'C9orf3', 'chr1.trna26-AsnGTT', 'GBA', 'UFD1L', 'LOC100287072', 'LOC100507002', 'C12orf4', 'chr1.trna111-HisGTG']
  Accession Number  Postion chromosome  genomic_position strand
0        NM_000041     1035      chr19        44909262.0      +
1        NM_000081     8851       chr1       235792037.0      -
2        NM_000100      315      chr21        43774445.0      -
3        NM_000112     4152       chr5       149983498.0      +
4        NM_000116      638       None               NaN   None
Original rows: 2084
Filtered rows: 1900
Removed: 184


In [25]:
# Filter out tRNA, rRNA, and positions with no genomic coordinates
praise_filtered = praise_with_genomic[
    (praise_with_genomic['chromosome'].notna()) & 
    (praise_with_genomic['genomic_position'].notna()) &
    (~praise_with_genomic['gene'].str.contains('trna', case=False, na=False)) &
    (~praise_with_genomic['gene'].str.contains('rrna', case=False, na=False))
]

print(f"Original rows: {len(praise_with_genomic)}")
print(f"Filtered rows: {len(praise_filtered)}")
print(f"Removed: {len(praise_with_genomic) - len(praise_filtered)}")

# Check what was removed
removed = praise_with_genomic[~praise_with_genomic.index.isin(praise_filtered.index)]
# print("\nRemoved gene types:")
# print(removed['gene'].value_counts().head(20))

Original rows: 2084
Filtered rows: 1900
Removed: 184


## Read Dorado Paths

In [26]:
import polars as pl
from pathlib import Path
from collections import defaultdict

# Updated UniversalDataloader class for Parquet
class UniversalDataloader:
    def __init__(self, filepath):
        self.filepath = Path(filepath)
    
    def load_data(self):
        """Load data from parquet file"""
        try:
            data = pl.read_parquet(self.filepath)
            return data
        except Exception as e:
            print(f"  Failed to load file {self.filepath.name}. Error: {e}")
            return None

# This structure maps a mod_type to a dictionary of filenames and their data
dorado_mods_dict = defaultdict(dict)

GM12878_path = Path('/Volumes/AJS_SSD/GM12878/mod_specific_dataframes/')
HEK293_path = Path('/Volumes/AJS_SSD/HEK293/modkit_output/mod_specific_dataframes')

# Define mappings for mod keys (same for both cell lines)
mod_mapping = {
    'a': 'm6a',
    '17802': 'psi',
    '17596': 'inosine',
    'm': 'm5c',
    '19227': '2OMeU',
    '19228': '2OMeC',
    '19229': '2OMeG',  
    '69426': '2OMeA'
}

print("Starting file loading process...")

# Process both directories
for cell_line, path in [('GM12878', GM12878_path), ('HEK293', HEK293_path)]:
    if not path.exists():
        print(f"Path {path} does not exist, skipping {cell_line}")
        continue
        
    dorado_mod_files = [f for f in path.iterdir() if f.is_file() and f.suffix == '.parquet']
    print(f"\nProcessing {cell_line} files from {path}")
    print(f"Found {len(dorado_mod_files)} files")
    
    for file in dorado_mod_files:
        # Skip hidden files
        if file.name.startswith('.'):
            print(f"  Skipping hidden file: {file.name}")
            continue
            
        try:
            # Both cell lines use the same naming pattern: 'filtered_X_dataframe.parquet'
            if 'filtered_' in file.name:
                mod_code = file.name.split('filtered_')[1].split('_')[0]
                mod_key = mod_mapping.get(mod_code, mod_code)
            else:
                print(f"  Skipping {file.name} - doesn't match expected pattern")
                continue
            
            # Add cell line to the key for clarity
            full_key = f"{cell_line}_{mod_key}"
            
            # Use the loader to get the data
            print(f"  Loading '{file.name}'...")
            loader = UniversalDataloader(file)
            data_object = loader.load_data()
            
            # Store the loaded object if it's not None
            if data_object is not None:
                dorado_mods_dict[full_key][file.name] = data_object
                print(f"  ✓ Loaded '{file.name}' as {full_key}")
                
        except IndexError:
            print(f"  Could not extract key from filename: {file.name}")
        except Exception as e:
            print(f"  An unexpected error occurred with {file.name}: {e}")

print("\nLoading complete!")

# Display summary
print("\nSummary of loaded data:")
for mod_key, files in sorted(dorado_mods_dict.items()):
    print(f"  {mod_key}: {len(files)} files")

Starting file loading process...

Processing GM12878 files from /Volumes/AJS_SSD/GM12878/mod_specific_dataframes
Found 0 files

Processing HEK293 files from /Volumes/AJS_SSD/HEK293/modkit_output/mod_specific_dataframes
Found 16 files
  Loading 'filtered_a_dataframe.parquet'...
  ✓ Loaded 'filtered_a_dataframe.parquet' as HEK293_m6a
  Skipping hidden file: ._filtered_a_dataframe.parquet
  Loading 'filtered_17802_dataframe.parquet'...
  ✓ Loaded 'filtered_17802_dataframe.parquet' as HEK293_psi
  Skipping hidden file: ._filtered_17802_dataframe.parquet
  Loading 'filtered_17596_dataframe.parquet'...
  ✓ Loaded 'filtered_17596_dataframe.parquet' as HEK293_inosine
  Skipping hidden file: ._filtered_17596_dataframe.parquet
  Loading 'filtered_m_dataframe.parquet'...
  ✓ Loaded 'filtered_m_dataframe.parquet' as HEK293_m5c
  Skipping hidden file: ._filtered_m_dataframe.parquet
  Loading 'filtered_19227_dataframe.parquet'...
  ✓ Loaded 'filtered_19227_dataframe.parquet' as HEK293_2OMeU
  Skippi

# HEK293 Plotting

## Colors

In [37]:
# ==========================================
# ENHANCED COLOR SCHEME WITH OVERLAP COLORS
# ==========================================

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import LogNorm
from matplotlib_venn import venn2, venn3
import numpy as np

# Base colors for each technology
MODIFICATION_COLORS = {
    # ============ m6A Technologies ============
    'm6a': {
        'DRS_HEK293': '#1f77b4',       # Blue
        'DRS_GM12878': '#aec7e8',      # Light blue
        'GLORI1': '#ff7f0e',           # Orange
        'GLORI2': '#d62728',           # Red
        'GLORI_combined': '#9467bd',   # Purple (for intersection)
        # Overlap colors for 3-way diagrams
        'HEK_GLORI1': '#bcbd22',       # Olive (HEK293 ∩ GLORI1)
        'HEK_GLORI2': '#17becf',       # Cyan (HEK293 ∩ GLORI2)
        'GLORI1_GLORI2': '#e377c2',    # Pink (GLORI1 ∩ GLORI2)
        'ALL_THREE': '#7f7f7f',        # Gray (all three)
    },
    
    # ============ m5C Technologies ============
    'm5c': {
        'DRS_HEK293': '#2ca02c',       # Green
        'DRS_GM12878': '#98df8a',      # Light green
        'Orthogonal': '#e377c2',       # Pink
        # Overlap colors
        'HEK_Orth': '#8c564b',
        'GM_Orth': '#c49c94',
        'HEK_GM': '#bcbd22',
        'ALL_THREE': '#7f7f7f',
    },
    
    # ============ Pseudouridine (Ψ) ============
    'psi': {
        'DRS_HEK293': '#9467bd',       # Purple
        'DRS_GM12878': '#c5b0d5',      # Light purple
        'BID-seq': '#8c564b',          # Brown
        'PRAISE': '#ff7f0e',           # Orange
        'Combined': '#d62728',         # Red
        # Overlap colors
        'HEK_BID': '#bcbd22',
        'HEK_PRAISE': '#17becf',
        'BID_PRAISE': '#e377c2',
        'ALL_THREE': '#7f7f7f',
    },
    
    # ============ Inosine ============
    'inosine': {
        'DRS_HEK293': '#17becf',       # Cyan
        'DRS_GM12878': '#9edae5',      # Light cyan
        'Orthogonal': '#ff7f0e',       # Orange
        # Overlap colors
        'HEK_Orth': '#8c564b',
        'GM_Orth': '#c49c94',
        'HEK_GM': '#bcbd22',
        'ALL_THREE': '#7f7f7f',
    },
    
    # ============ 2'-O-Methylation ============
    '2ome': {
        'DRS_HEK293_A': '#e377c2',     # Pink
        'DRS_HEK293_C': '#f7b6d2',     # Light pink
        'DRS_HEK293_G': '#7f7f7f',     # Gray
        'DRS_HEK293_U': '#c7c7c7',     # Light gray
        'DRS_GM12878_A': '#1f77b4',    # Blue
        'DRS_GM12878_C': '#aec7e8',    # Light blue
        'DRS_GM12878_G': '#ff7f0e',    # Orange
        'DRS_GM12878_U': '#ffbb78',    # Light orange
        'Orthogonal_A': '#2ca02c',     # Green
        'Orthogonal_C': '#98df8a',     # Light green
        'Orthogonal_G': '#d62728',     # Red
        'Orthogonal_U': '#ff9896',     # Light red
    }
}

ALPHA = 0.7

HEATMAP_CMAPS = {
    'm6a_glori1': 'YlOrRd',
    'm6a_glori2': 'YlGnBu',
    'm6a_combined': 'RdPu',
    'm5c': 'Greens',
    'psi': 'Purples',
    'inosine': 'Blues',
    '2ome': 'Oranges',
}

print("✓ Enhanced color scheme loaded!")

✓ Enhanced color scheme loaded!


In [38]:
import polars as pl

def get_drs_sites_colored(mod_dict, cell_line, mod):
    """Extract DRS sites with proper chromosome formatting"""
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        df = list(mod_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        
        sample_chr = df_filtered['Chromosome'][0] if len(df_filtered) > 0 else None
        if sample_chr and not str(sample_chr).startswith('chr'):
            sites = set(('chr' + df_filtered['Chromosome'].cast(pl.Utf8) + '_' + 
                        df_filtered['End'].cast(pl.Int64).cast(pl.Utf8)).to_list())
        else:
            sites = set((df_filtered['Chromosome'].cast(pl.Utf8) + '_' + 
                        df_filtered['End'].cast(pl.Int64).cast(pl.Utf8)).to_list())
        
        print(f"{key}: {len(sites)} sites")
        return sites
    return set()

def get_drs_values_colored(mod_dict, cell_line, mod):
    """Extract DRS sites WITH modification values"""
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        df = list(mod_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        
        sample_chr = df_filtered['Chromosome'][0] if len(df_filtered) > 0 else None
        if sample_chr and not str(sample_chr).startswith('chr'):
            df_filtered = df_filtered.with_columns([
                ('chr' + pl.col('Chromosome').cast(pl.Utf8) + '_' + 
                 pl.col('End').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        else:
            df_filtered = df_filtered.with_columns([
                (pl.col('Chromosome').cast(pl.Utf8) + '_' + 
                 pl.col('End').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        
        print(f"{key}: {len(df_filtered)} sites")
        return df_filtered.select(['site_id', 'Adjusted_Mod_Proportion'])
    return None

def process_orthogonal_smart_colored(df, chr_col, pos_col, label="Orthogonal"):
    """Process orthogonal data with consistent chromosome/position formatting"""
    if df is None:
        return set()
    
    if isinstance(df, pl.DataFrame):
        sample_chr = df[chr_col][0] if len(df) > 0 else None
        if sample_chr and not str(sample_chr).startswith('chr'):
            sites = set(('chr' + df[chr_col].cast(pl.Utf8) + '_' + 
                        df[pos_col].cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).to_list())
        else:
            sites = set((df[chr_col].cast(pl.Utf8) + '_' + 
                        df[pos_col].cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).to_list())
    else:
        sample_chr = str(df[chr_col].iloc[0]) if len(df) > 0 else None
        pos_int = df[pos_col].astype(float).astype(int).astype(str)
        if sample_chr and sample_chr.startswith('chr'):
            sites = set(df[chr_col].astype(str) + '_' + pos_int)
        else:
            sites = set('chr' + df[chr_col].astype(str) + '_' + pos_int)
    
    print(f"{label}: {len(sites)} sites")
    return sites

def process_orthogonal_values_colored(df, chr_col, pos_col, value_col, label="Orthogonal"):
    """Extract orthogonal sites WITH modification values"""
    if isinstance(df, pl.DataFrame):
        sample_chr = df[chr_col][0] if len(df) > 0 else None
        if sample_chr and not str(sample_chr).startswith('chr'):
            df_processed = df.with_columns([
                ('chr' + pl.col(chr_col).cast(pl.Utf8) + '_' + 
                 pl.col(pos_col).cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        else:
            df_processed = df.with_columns([
                (pl.col(chr_col).cast(pl.Utf8) + '_' + 
                 pl.col(pos_col).cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
    else:
        sample_chr = str(df[chr_col].iloc[0])
        pos_int = df[pos_col].astype(float).astype(int).astype(str)
        if sample_chr.startswith('chr'):
            site_ids = df[chr_col].astype(str) + '_' + pos_int
        else:
            site_ids = 'chr' + df[chr_col].astype(str) + '_' + pos_int
        
        df_processed = pl.DataFrame({
            'site_id': site_ids.tolist(),
            value_col: df[value_col].tolist()
        })
    
    result = df_processed.select(['site_id', value_col])
    print(f"{label}: {len(result)} sites")
    return result

print("✓ Utility functions loaded successfully!")

✓ Utility functions loaded successfully!


## m6A Plotting

In [40]:
def plot_m6a_venns_new_glori1(dorado_mods_dict, new_glori1, combined_glori_2_df, 
                              mode='HEK293',
                              glori_combine_mode='intersection',  # NEW: 'intersection' or 'union'
                              output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_venns_new.pdf'):
    """
    Create m6A Venn diagrams with NEW GLORI-1 data and fixed 3-way colors
    
    Parameters:
    -----------
    mode : str - 'HEK293', 'GM12878', or 'both'
    glori_combine_mode : str - 'intersection' (GLORI1 ∩ GLORI2) or 'union' (GLORI1 ∪ GLORI2)
    """
    print("\n" + "="*60)
    print(f"=== m6A Venn Diagrams (NEW GLORI-1) - Mode: {mode} ===")
    print(f"=== GLORI Combine Mode: {glori_combine_mode.upper()} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['m6a']
    
    # Process GLORI data
    glori1_sites = process_orthogonal_smart_colored(new_glori1, 'Chr', 'Sites', 'GLORI-1 (New)')
    glori2_sites = process_orthogonal_smart_colored(combined_glori_2_df, 'Chr', 'Site', 'GLORI-2')
    
    # Choose combination method
    if glori_combine_mode == 'intersection':
        glori_combined = glori1_sites & glori2_sites  # Intersection
        combine_label = 'GLORI-1 ∩ GLORI-2'
    else:
        glori_combined = glori1_sites | glori2_sites  # Union
        combine_label = 'GLORI-1 ∪ GLORI-2'
    
    print(f"{combine_label}: {len(glori_combined)} sites")
    print(f"  GLORI-1 only: {len(glori1_sites - glori2_sites)}")
    print(f"  GLORI-2 only: {len(glori2_sites - glori1_sites)}")
    print(f"  GLORI-1 ∩ GLORI-2: {len(glori1_sites & glori2_sites)}")
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'm6a')
        
        hek_vs_glori1 = hek293_sites & glori1_sites
        hek_vs_glori2 = hek293_sites & glori2_sites
        hek_vs_combined = hek293_sites & glori_combined
        
        print(f"\nHEK293 overlaps:")
        print(f"  vs GLORI-1: {len(hek_vs_glori1)} ({100*len(hek_vs_glori1)/len(hek293_sites):.1f}%)")
        print(f"  vs GLORI-2: {len(hek_vs_glori2)} ({100*len(hek_vs_glori2)/len(hek293_sites):.1f}%)")
        print(f"  vs Combined: {len(hek_vs_combined)} ({100*len(hek_vs_combined)/len(hek293_sites):.1f}%)")
        
        fig = plt.figure(figsize=(12, 4))
        
        # 1. HEK293 vs Combined GLORI
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([hek293_sites, glori_combined],
              set_labels=('HEK293 DRS', combine_label),
              set_colors=(colors['DRS_HEK293'], colors['GLORI_combined']), 
              alpha=ALPHA)
        ax1.set_title('m6A: HEK293 vs GLORI', fontweight='bold', fontsize=14)
        
        # 2. Three-way comparison - FIXED COLOR ASSIGNMENT
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([hek293_sites, glori1_sites, glori2_sites],
                     set_labels=('HEK293', 'GLORI-1', 'GLORI-2'))
        
        # Assign colors to each region explicitly
        # Region IDs: '100'=A only, '010'=B only, '001'=C only,
        #             '110'=A∩B, '101'=A∩C, '011'=B∩C, '111'=A∩B∩C
        patch_colors = {
            '100': colors['DRS_HEK293'],      # HEK293 only
            '010': colors['GLORI1'],          # GLORI-1 only
            '001': colors['GLORI2'],          # GLORI-2 only
            '110': colors['HEK_GLORI1'],      # HEK293 ∩ GLORI-1
            '101': colors['HEK_GLORI2'],      # HEK293 ∩ GLORI-2
            '011': colors['GLORI1_GLORI2'],   # GLORI-1 ∩ GLORI-2
            '111': colors['ALL_THREE']         # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax2.set_title('m6A: Three-way', fontweight='bold', fontsize=14)
        
        # 3. GLORI-1 vs GLORI-2
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=(colors['GLORI1'], colors['GLORI2']), 
              alpha=ALPHA)
        ax3.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)
        
        fig.suptitle(f'm6A Sites - HEK293 (GLORI {glori_combine_mode.title()})', 
                     fontsize=16, fontweight='bold')
    
    elif mode == 'GM12878':
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'm6a')
        
        gm_vs_glori1 = gm12878_sites & glori1_sites
        gm_vs_glori2 = gm12878_sites & glori2_sites
        gm_vs_combined = gm12878_sites & glori_combined
        
        print(f"\nGM12878 overlaps:")
        print(f"  vs GLORI-1: {len(gm_vs_glori1)} ({100*len(gm_vs_glori1)/len(gm12878_sites):.1f}%)")
        print(f"  vs GLORI-2: {len(gm_vs_glori2)} ({100*len(gm_vs_glori2)/len(gm12878_sites):.1f}%)")
        print(f"  vs Combined: {len(gm_vs_combined)} ({100*len(gm_vs_combined)/len(gm12878_sites):.1f}%)")
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([gm12878_sites, glori_combined],
              set_labels=('GM12878 DRS', combine_label),
              set_colors=(colors['DRS_GM12878'], colors['GLORI_combined']), 
              alpha=ALPHA)
        ax1.set_title('m6A: GM12878 vs GLORI', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([gm12878_sites, glori1_sites, glori2_sites],
                     set_labels=('GM12878', 'GLORI-1', 'GLORI-2'))
        
        # Fixed color assignment
        patch_colors = {
            '100': colors['DRS_GM12878'],     # GM12878 only
            '010': colors['GLORI1'],          # GLORI-1 only
            '001': colors['GLORI2'],          # GLORI-2 only
            '110': colors['HEK_GLORI1'],      # GM12878 ∩ GLORI-1
            '101': colors['HEK_GLORI2'],      # GM12878 ∩ GLORI-2
            '011': colors['GLORI1_GLORI2'],   # GLORI-1 ∩ GLORI-2
            '111': colors['ALL_THREE']         # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax2.set_title('m6A: Three-way', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=(colors['GLORI1'], colors['GLORI2']), 
              alpha=ALPHA)
        ax3.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)
        
        fig.suptitle(f'm6A Sites - GM12878 (GLORI {glori_combine_mode.title()})', 
                     fontsize=16, fontweight='bold')
    
    elif mode == 'both':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'm6a')
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'm6a')
        
        fig = plt.figure(figsize=(14, 8))
        
        # 1. Three-way: Both cells vs GLORI combined
        ax1 = plt.subplot(2, 3, 1)
        plt.sca(ax1)
        venn = venn3([hek293_sites, gm12878_sites, glori_combined],
                     set_labels=('HEK293', 'GM12878', 'GLORI'))
        
        patch_colors = {
            '100': colors['DRS_HEK293'],
            '010': colors['DRS_GM12878'],
            '001': colors['GLORI_combined'],
            '110': '#bcbd22',  # HEK ∩ GM
            '101': '#17becf',  # HEK ∩ GLORI
            '011': '#e377c2',  # GM ∩ GLORI
            '111': '#7f7f7f'   # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax1.set_title(f'm6A: Both Cells vs GLORI\n({glori_combine_mode})', 
                     fontweight='bold', fontsize=12)
        
        # 2. HEK293 vs GLORI
        ax2 = plt.subplot(2, 3, 2)
        plt.sca(ax2)
        venn2([hek293_sites, glori_combined],
              set_labels=('HEK293', combine_label),
              set_colors=(colors['DRS_HEK293'], colors['GLORI_combined']), 
              alpha=ALPHA)
        ax2.set_title('m6A: HEK293 vs GLORI', fontweight='bold', fontsize=12)
        
        # 3. GM12878 vs GLORI
        ax3 = plt.subplot(2, 3, 3)
        plt.sca(ax3)
        venn2([gm12878_sites, glori_combined],
              set_labels=('GM12878', combine_label),
              set_colors=(colors['DRS_GM12878'], colors['GLORI_combined']), 
              alpha=ALPHA)
        ax3.set_title('m6A: GM12878 vs GLORI', fontweight='bold', fontsize=12)
        
        # 4. Cell line comparison
        ax4 = plt.subplot(2, 3, 4)
        plt.sca(ax4)
        venn2([hek293_sites, gm12878_sites],
              set_labels=('HEK293', 'GM12878'),
              set_colors=(colors['DRS_HEK293'], colors['DRS_GM12878']), 
              alpha=ALPHA)
        ax4.set_title('m6A: HEK293 vs GM12878', fontweight='bold', fontsize=12)
        
        # 5. GLORI methods
        ax5 = plt.subplot(2, 3, 5)
        plt.sca(ax5)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=(colors['GLORI1'], colors['GLORI2']), 
              alpha=ALPHA)
        ax5.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)
        
        # 6. HEK293 three-way - FIXED COLORS
        ax6 = plt.subplot(2, 3, 6)
        plt.sca(ax6)
        venn = venn3([hek293_sites, glori1_sites, glori2_sites],
                     set_labels=('HEK293', 'GLORI-1', 'GLORI-2'))
        
        patch_colors = {
            '100': colors['DRS_HEK293'],      # HEK293 only
            '010': colors['GLORI1'],          # GLORI-1 only
            '001': colors['GLORI2'],          # GLORI-2 only
            '110': colors['HEK_GLORI1'],      # HEK293 ∩ GLORI-1
            '101': colors['HEK_GLORI2'],      # HEK293 ∩ GLORI-2
            '011': colors['GLORI1_GLORI2'],   # GLORI-1 ∩ GLORI-2
            '111': colors['ALL_THREE']         # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax6.set_title('m6A: HEK293 All Methods', fontweight='bold', fontsize=12)
        
        fig.suptitle(f'm6A Sites - Complete (GLORI {glori_combine_mode.title()})', 
                     fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"\n✓ Saved to {output_path}")

# # Usage examples:
# # For INTERSECTION (only sites in BOTH GLORI methods):
# plot_m6a_venns_new_glori1(dorado_mods_dict, new_glori1, combined_glori_2, 
#                           mode='HEK293', glori_combine_mode='intersection',
#                           output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_venns_intersection.pdf')

# For UNION (sites in EITHER GLORI method):
plot_m6a_venns_new_glori1(dorado_mods_dict, new_glori1, combined_glori_2, 
                          mode='HEK293', glori_combine_mode='union',
                          output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_venns_union.pdf')


=== m6A Venn Diagrams (NEW GLORI-1) - Mode: HEK293 ===
=== GLORI Combine Mode: UNION ===
GLORI-1 (New): 170240 sites
GLORI-2: 65687 sites
GLORI-1 ∪ GLORI-2: 176758 sites
  GLORI-1 only: 111071
  GLORI-2 only: 6518
  GLORI-1 ∩ GLORI-2: 59169
HEK293_m6a: 67517 sites

HEK293 overlaps:
  vs GLORI-1: 47518 (70.4%)
  vs GLORI-2: 31580 (46.8%)
  vs Combined: 49935 (74.0%)

✓ Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_venns_union.pdf


  plt.show()


In [33]:
def plot_single_heatmap_colored(ax, drs_df, ortho_df, drs_col, ortho_col, title, colormap):
    """Helper function to plot a single heatmap with custom colormap"""
    merged = drs_df.join(ortho_df, on='site_id', how='inner')
    
    if len(merged) == 0:
        ax.text(0.5, 0.5, 'No overlapping sites', ha='center', va='center', fontsize=12)
        ax.set_title(title, fontweight='bold', fontsize=11)
        return
    
    print(f"\n{title}: {len(merged)} overlapping sites")
    
    drs_values = merged[drs_col].to_numpy()
    ortho_values = merged[ortho_col].to_numpy()
    
    bandwidth_2d = 5
    bins = np.arange(0, 100 + bandwidth_2d, bandwidth_2d)
    hist, xedges, yedges = np.histogram2d(drs_values, ortho_values, bins=bins)
    
    im = ax.imshow(hist.T,
                   norm=LogNorm(vmin=1, vmax=10**3),
                   origin='lower',
                   extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
                   aspect='auto',
                   cmap=colormap,
                   interpolation='nearest')
    
    cbar = plt.colorbar(im, ax=ax, ticks=np.logspace(0, 3, 4))
    cbar.set_ticklabels(['$10^0$', '$10^1$', '$10^2$', '$10^3$'])
    cbar.set_label('Site count', fontsize=10)
    
    correlation = np.corrcoef(drs_values, ortho_values)[0, 1]
    
    ax.set_xlabel('DRS Mod %', fontsize=11)
    ax.set_ylabel('Orthogonal Mod %', fontsize=11)
    ax.set_title(title, fontweight='bold', fontsize=12)
    
    ax.plot([0, 100], [0, 100], 'k--', linewidth=1.5)
    
    stats_text = f'n = {len(merged):,}\nr = {correlation:.3f}'
    ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, 
            fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    print(f"  Mean DRS: {drs_values.mean():.2f}%")
    print(f"  Mean Orthogonal: {ortho_values.mean():.2f}%")
    print(f"  Correlation: {correlation:.3f}")

def plot_m6a_heatmap_new_glori1(dorado_mods_dict, new_glori1, combined_glori_2_df, mode='HEK293',
                                output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_heatmap_new.pdf'):
    """
    Create m6A heatmaps with NEW GLORI-1 format
    """
    print("\n" + "="*60)
    print(f"=== m6A Heatmaps (NEW GLORI-1) - Mode: {mode} ===")
    print("="*60)
    
    # Get GLORI values - NEW GLORI-1 already has m6A_level_mean in percentage
    glori1_df = process_orthogonal_values_colored(new_glori1, 'Chr', 'Sites', 'm6A_level_mean', 'GLORI-1 (New)')
    glori2_df = process_orthogonal_values_colored(combined_glori_2_df, 'Chr', 'Site', 'm6A_level_mean', 'GLORI-2')
    
    if mode == 'HEK293':
        drs_df = get_drs_values_colored(dorado_mods_dict, 'HEK293', 'm6a')
        
        fig = plt.figure(figsize=(10, 4))
        
        # Plot 1: HEK293 vs GLORI-1
        ax1 = plt.subplot(1, 2, 1)
        plot_single_heatmap_colored(ax1, drs_df, glori1_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'HEK293 vs GLORI-1 (New)', 
                                    HEATMAP_CMAPS['m6a_glori1'])
        
        # Plot 2: HEK293 vs GLORI-2
        ax2 = plt.subplot(1, 2, 2)
        plot_single_heatmap_colored(ax2, drs_df, glori2_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'HEK293 vs GLORI-2', 
                                    HEATMAP_CMAPS['m6a_glori2'])
        
        fig.suptitle('m6A Modification Levels - HEK293 (New GLORI-1)', fontsize=16, fontweight='bold')
    
    elif mode == 'GM12878':
        drs_df = get_drs_values_colored(dorado_mods_dict, 'GM12878', 'm6a')
        
        fig = plt.figure(figsize=(10, 4))
        
        ax1 = plt.subplot(1, 2, 1)
        plot_single_heatmap_colored(ax1, drs_df, glori1_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'GM12878 vs GLORI-1 (New)', 
                                    HEATMAP_CMAPS['m6a_glori1'])
        
        ax2 = plt.subplot(1, 2, 2)
        plot_single_heatmap_colored(ax2, drs_df, glori2_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'GM12878 vs GLORI-2', 
                                    HEATMAP_CMAPS['m6a_glori2'])
        
        fig.suptitle('m6A Modification Levels - GM12878 (New GLORI-1)', fontsize=16, fontweight='bold')
    
    elif mode == 'both':
        hek_drs = get_drs_values_colored(dorado_mods_dict, 'HEK293', 'm6a')
        gm_drs = get_drs_values_colored(dorado_mods_dict, 'GM12878', 'm6a')
        
        fig = plt.figure(figsize=(15, 8))
        
        # Row 1: HEK293
        ax1 = plt.subplot(2, 2, 1)
        plot_single_heatmap_colored(ax1, hek_drs, glori1_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'HEK293 vs GLORI-1', 
                                    HEATMAP_CMAPS['m6a_glori1'])
        
        ax2 = plt.subplot(2, 2, 2)
        plot_single_heatmap_colored(ax2, hek_drs, glori2_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'HEK293 vs GLORI-2', 
                                    HEATMAP_CMAPS['m6a_glori2'])
        
        # Row 2: GM12878
        ax3 = plt.subplot(2, 2, 3)
        plot_single_heatmap_colored(ax3, gm_drs, glori1_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'GM12878 vs GLORI-1', 
                                    HEATMAP_CMAPS['m6a_glori1'])
        
        ax4 = plt.subplot(2, 2, 4)
        plot_single_heatmap_colored(ax4, gm_drs, glori2_df, 'Adjusted_Mod_Proportion', 
                                    'm6A_level_mean', 'GM12878 vs GLORI-2', 
                                    HEATMAP_CMAPS['m6a_glori2'])
        
        fig.suptitle('m6A Modification Levels - Both Cell Lines (New GLORI-1)', fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=False, bbox_inches='tight')
    plt.show()
    print(f"\n✓ Saved to {output_path}")

# Usage:
plot_m6a_heatmap_new_glori1(dorado_mods_dict, new_glori1, combined_glori_2, mode='HEK293')


=== m6A Heatmaps (NEW GLORI-1) - Mode: HEK293 ===
GLORI-1 (New): 170240 sites
GLORI-2: 65687 sites
HEK293_m6a: 67517 sites

HEK293 vs GLORI-1 (New): 47518 overlapping sites
  Mean DRS: 54.28%
  Mean Orthogonal: 55.12%
  Correlation: 0.914

HEK293 vs GLORI-2: 31580 overlapping sites
  Mean DRS: 56.43%
  Mean Orthogonal: 52.97%
  Correlation: 0.928

✓ Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_heatmap_new.pdf


  plt.show()


### Inner GLORI Hist

In [30]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import LogNorm
import polars as pl
import numpy as np

# Style settings
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['figure.dpi'] = 600
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

def process_orthogonal_for_comparison(df, chr_col, pos_col, value_col, label="Orthogonal", scale_by_100=False):
    """Extract sites with their modification values for comparison"""
    if isinstance(df, pl.DataFrame):
        # Check if chromosomes need 'chr' prefix
        sample_chr = df[chr_col][0]
        if sample_chr is not None and not str(sample_chr).startswith('chr'):
            # Add 'chr' prefix to chromosomes
            df_processed = df.with_columns([
                ('chr' + pl.col(chr_col).cast(pl.Utf8) + '_' + pl.col(pos_col).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        else:
            # Use chromosomes as-is
            df_processed = df.with_columns([
                (pl.col(chr_col).cast(pl.Utf8) + '_' + pl.col(pos_col).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
    else:
        # pandas DataFrame
        sample_chr = str(df[chr_col].iloc[0])
        pos_int = df[pos_col].astype(int).astype(str)
        if sample_chr.startswith('chr'):
            site_ids = df[chr_col].astype(str) + '_' + pos_int
        else:
            site_ids = 'chr' + df[chr_col].astype(str) + '_' + pos_int
        
        df_processed = pl.DataFrame({
            'site_id': site_ids.tolist(),
            value_col: df[value_col].tolist()
        })
    
    result = df_processed.select(['site_id', value_col])
    
    # Scale by 100 if needed (for ratio to percentage conversion)
    if scale_by_100:
        result = result.with_columns([
            (pl.col(value_col) * 100).alias(value_col)
        ])
    
    print(f"{label}: {len(result)} sites")
    return result

def plot_glori_comparison(combined_glori_1, combined_glori_2_df, 
                         output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/glori1_vs_glori2_heatmap.pdf',
                         bandwidth=5, debug=False):
    """
    Create a 2D histogram comparing GLORI-1 and GLORI-2 modification levels
    
    Parameters:
    -----------
    combined_glori_1 : DataFrame
        GLORI-1 data with columns 'Chr', 'Sites', 'Ratio_mean'
    combined_glori_2_df : DataFrame
        GLORI-2 data with columns 'Chr', 'Site', 'm6A_level_mean'
    output_path : str
        Path to save the PDF output
    bandwidth : int
        Bin width for the 2D histogram (default 5%)
    debug : bool
        If True, print additional debug information
    """
    
    print("\n" + "="*60)
    print("=== GLORI-1 vs GLORI-2 Comparison ===")
    print("="*60)
    
    # Process both datasets
    glori1_df = process_orthogonal_for_comparison(combined_glori_1, 'Chr', 'Sites', 
                                                  'Ratio_mean', 'GLORI-1', scale_by_100=True)
    glori2_df = process_orthogonal_for_comparison(combined_glori_2_df, 'Chr', 'Site', 
                                                  'm6A_level_mean', 'GLORI-2')
    
    # Perform inner join to get only overlapping sites
    merged = glori1_df.join(glori2_df, on='site_id', how='inner')
    
    print(f"\nOverlapping sites: {len(merged)}")
    
    if len(merged) == 0:
        print("No overlapping sites found!")
        return
    
    # Convert to numpy arrays
    glori1_values = merged['Ratio_mean'].to_numpy()
    glori2_values = merged['m6A_level_mean'].to_numpy()
    
    if debug:
        print(f"\nGLORI-1 range: {glori1_values.min():.2f}% - {glori1_values.max():.2f}%")
        print(f"GLORI-2 range: {glori2_values.min():.2f}% - {glori2_values.max():.2f}%")
        print(f"GLORI-1 mean: {glori1_values.mean():.2f}% ± {glori1_values.std():.2f}%")
        print(f"GLORI-2 mean: {glori2_values.mean():.2f}% ± {glori2_values.std():.2f}%")
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), facecolor='white')
    
    # ===== First plot: 2D Histogram =====
    bins = np.arange(0, 100 + bandwidth, bandwidth)
    hist, xedges, yedges = np.histogram2d(glori1_values, glori2_values, bins=bins)
    
    im = ax1.imshow(hist.T,
                    norm=LogNorm(vmin=1, vmax=hist.max()),
                    origin='lower',
                    extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
                    aspect='auto',
                    cmap='magma',
                    interpolation='nearest')
    
    # Add colorbar
    cbar1 = plt.colorbar(im, ax=ax1)
    cbar1.set_label('Site count (log scale)', fontsize=10)
    
    # Calculate correlation
    correlation = np.corrcoef(glori1_values, glori2_values)[0, 1]
    
    # Add diagonal reference line
    ax1.plot([0, 100], [0, 100], 'k--', linewidth=1.5, alpha=0.5)
    
    # Labels and title
    ax1.set_xlabel('GLORI-1 Modification %', fontsize=11)
    ax1.set_ylabel('GLORI-2 Modification %', fontsize=11)
    ax1.set_title('GLORI-1 vs GLORI-2: 2D Histogram', fontweight='bold', fontsize=12)
    ax1.set_xlim(0, 100)
    ax1.set_ylim(0, 100)
    
    # Add statistics box
    stats_text = f'n = {len(merged):,}\nr = {correlation:.3f}\n'
    stats_text += f'GLORI-1: {glori1_values.mean():.1f}%±{glori1_values.std():.1f}%\n'
    stats_text += f'GLORI-2: {glori2_values.mean():.1f}%±{glori2_values.std():.1f}%'
    ax1.text(0.05, 0.95, stats_text, transform=ax1.transAxes,
            fontsize=9, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # ===== Second plot: Hexbin for better density visualization =====
    hexbin = ax2.hexbin(glori1_values, glori2_values, gridsize=50, 
                        cmap='magma', mincnt=1, xscale='linear', yscale='linear')
    
    # Add colorbar
    cbar2 = plt.colorbar(hexbin, ax=ax2)
    cbar2.set_label('Site count', fontsize=10)
    
    # Add diagonal reference line
    ax2.plot([0, 100], [0, 100], 'k--', linewidth=1.5, alpha=0.5)
    
    # Labels and title
    ax2.set_xlabel('GLORI-1 Modification %', fontsize=11)
    ax2.set_ylabel('GLORI-2 Modification %', fontsize=11)
    ax2.set_title('GLORI-1 vs GLORI-2: Hexbin Density', fontweight='bold', fontsize=12)
    ax2.set_xlim(0, 100)
    ax2.set_ylim(0, 100)
    
    # Add difference statistics
    differences = glori1_values - glori2_values
    mean_diff = differences.mean()
    std_diff = differences.std()
    
    diff_text = f'Mean diff: {mean_diff:.1f}%±{std_diff:.1f}%\n'
    diff_text += f'Median diff: {np.median(differences):.1f}%\n'
    diff_text += f'MAE: {np.mean(np.abs(differences)):.1f}%'
    ax2.text(0.05, 0.95, diff_text, transform=ax2.transAxes,
            fontsize=9, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Overall title
    fig.suptitle('GLORI-1 vs GLORI-2 m6A Modification Comparison', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=False, bbox_inches='tight')
    plt.show()
    
    print(f"\n=== Summary Statistics ===")
    print(f"Overlapping sites: {len(merged):,}")
    print(f"Correlation: {correlation:.3f}")
    print(f"Mean difference (GLORI-1 - GLORI-2): {mean_diff:.2f}% ± {std_diff:.2f}%")
    print(f"Median difference: {np.median(differences):.2f}%")
    print(f"Mean Absolute Error: {np.mean(np.abs(differences)):.2f}%")
    print(f"\nSaved to {output_path}")

# Usage:
plot_glori_comparison(combined_glori_1, combined_glori_2, debug=True)


=== GLORI-1 vs GLORI-2 Comparison ===
GLORI-1: 60462 sites
GLORI-2: 65687 sites

Overlapping sites: 38039

GLORI-1 range: 10.16% - 100.00%
GLORI-2 range: 10.93% - 99.54%
GLORI-1 mean: 55.78% ± 25.42%
GLORI-2 mean: 50.15% ± 24.72%


1 extra bytes in post.stringData array
'created' timestamp seems very low; regarding as unix timestamp



=== Summary Statistics ===
Overlapping sites: 38,039
Correlation: 0.939
Mean difference (GLORI-1 - GLORI-2): 5.63% ± 8.77%
Median difference: 4.87%
Mean Absolute Error: 7.95%

Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/glori1_vs_glori2_heatmap.pdf


  plt.show()


In [14]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib_venn import venn2, venn3
import polars as pl

# Style settings
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['figure.dpi'] = 600
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

def get_drs_sites(mod_dict, cell_line, mod):
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        df = list(mod_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        sites = set((df_filtered['Chromosome'].cast(pl.Utf8) + '_' + df_filtered['End'].cast(pl.Utf8)).to_list())
        print(f"{key}: {len(sites)} sites")
        return sites
    return set()

def process_orthogonal_smart(df, chr_col, pos_col, label="Orthogonal"):
    if isinstance(df, pl.DataFrame):
        sites = set((df[chr_col].cast(pl.Utf8) + '_' + df[pos_col].cast(pl.Int64).cast(pl.Utf8)).to_list())
    else:
        sample_chr = str(df[chr_col].iloc[0])
        pos_int = df[pos_col].astype(int).astype(str)
        if sample_chr.startswith('chr'):
            sites = set(df[chr_col].astype(str) + '_' + pos_int)
        else:
            sites = set('chr' + df[chr_col].astype(str) + '_' + pos_int)
    print(f"{label}: {len(sites)} sites")
    return sites

def plot_m6a_complete(dorado_mods_dict, combined_glori_1, combined_glori_2_df, mode='both',
                      output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_venns.pdf'):
    """
    Complete m6A Venn diagram suite
    
    Parameters:
    -----------
    mode : str
        'GM12878' - GM12878 only
        'HEK293' - HEK293 only  
        'both' - Both cell lines with all comparisons
    """
    print("\n" + "="*60)
    print(f"=== Processing m6A - Mode: {mode} ===")
    print("="*60)
    
    glori1_sites = process_orthogonal_smart(combined_glori_1, 'Chr', 'Sites', 'GLORI-1')
    glori2_sites = process_orthogonal_smart(combined_glori_2_df, 'Chr', 'Site', 'GLORI-2')
    glori_combined = glori1_sites | glori2_sites
    
    print(f"GLORI combined: {len(glori_combined)} sites")
    
    if mode == 'both':
        # Both cell lines - 3x2 grid
        gm12878_sites = get_drs_sites(dorado_mods_dict, 'GM12878', 'm6a')
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'm6a')
        
        gm_vs_combined = gm12878_sites & glori_combined
        hek_vs_combined = hek293_sites & glori_combined
        gm_vs_hek = gm12878_sites & hek293_sites
        gm_vs_glori1 = gm12878_sites & glori1_sites
        gm_vs_glori2 = gm12878_sites & glori2_sites
        hek_vs_glori1 = hek293_sites & glori1_sites
        hek_vs_glori2 = hek293_sites & glori2_sites
        
        print(f"\nOverlaps:")
        print(f"  GM12878 vs GLORI combined: {len(gm_vs_combined)}")
        print(f"  HEK293 vs GLORI combined: {len(hek_vs_combined)}")
        print(f"  GM12878 vs HEK293: {len(gm_vs_hek)}")
        
        fig = plt.figure(figsize=(14, 8))
        
        # 1. Three-way: GM12878, HEK293, Combined GLORI
        ax1 = plt.subplot(2, 3, 1)
        plt.sca(ax1)
        venn = venn3([gm12878_sites, hek293_sites, glori_combined],
                     set_labels=('GM12878', 'HEK293', 'GLORI'))
        for patch, color in zip(venn.patches, ['lightgreen', 'skyblue', 'salmon', 'purple', 'gold', 'orange', 'gray']):
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        ax1.set_title('m6A: Both Cells vs GLORI', fontweight='bold', fontsize=12)
        
        # 2. GM12878 vs GLORI
        ax2 = plt.subplot(2, 3, 2)
        plt.sca(ax2)
        venn2([gm12878_sites, glori_combined],
              set_labels=('GM12878', 'GLORI'),
              set_colors=('lightgreen', 'salmon'), alpha=0.7)
        ax2.set_title('m6A: GM12878 vs GLORI', fontweight='bold', fontsize=12)
        
        # 3. HEK293 vs GLORI
        ax3 = plt.subplot(2, 3, 3)
        plt.sca(ax3)
        venn2([hek293_sites, glori_combined],
              set_labels=('HEK293', 'GLORI'),
              set_colors=('skyblue', 'salmon'), alpha=0.7)
        ax3.set_title('m6A: HEK293 vs GLORI', fontweight='bold', fontsize=12)
        
        # 4. GM12878 vs HEK293
        ax4 = plt.subplot(2, 3, 4)
        plt.sca(ax4)
        venn2([gm12878_sites, hek293_sites],
              set_labels=('GM12878', 'HEK293'),
              set_colors=('lightgreen', 'skyblue'), alpha=0.7)
        ax4.set_title('m6A: GM12878 vs HEK293', fontweight='bold', fontsize=12)
        
        # 5. GLORI-1 vs GLORI-2
        ax5 = plt.subplot(2, 3, 5)
        plt.sca(ax5)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=('salmon', 'gold'), alpha=0.7)
        ax5.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=12)
        
        # 6. GM12878 vs GLORI-1 vs GLORI-2
        ax6 = plt.subplot(2, 3, 6)
        plt.sca(ax6)
        venn = venn3([gm12878_sites, glori1_sites, glori2_sites],
                     set_labels=('GM12878', 'GLORI-1', 'GLORI-2'))
        for patch, color in zip(venn.patches, ['lightgreen', 'salmon', 'gold', 'purple', 'orange', 'lightblue', 'gray']):
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        ax6.set_title('m6A: GM12878 vs GLORI Methods', fontweight='bold', fontsize=12)
        
        fig.suptitle('m6A Sites - Complete', fontsize=16, fontweight='bold')
    
    elif mode == 'GM12878':
        gm12878_sites = get_drs_sites(dorado_mods_dict, 'GM12878', 'm6a')
        
        gm_vs_combined = gm12878_sites & glori_combined
        gm_vs_glori1 = gm12878_sites & glori1_sites
        gm_vs_glori2 = gm12878_sites & glori2_sites
        
        print(f"\nGM12878 overlaps:")
        print(f"  vs GLORI-1: {len(gm_vs_glori1)}")
        print(f"  vs GLORI-2: {len(gm_vs_glori2)}")
        print(f"  vs Combined: {len(gm_vs_combined)}")
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([gm12878_sites, glori_combined],
              set_labels=('GM12878', 'GLORI'),
              set_colors=('lightgreen', 'salmon'), alpha=0.7)
        ax1.set_title('m6A: GM12878 vs GLORI', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([gm12878_sites, glori1_sites, glori2_sites],
                     set_labels=('GM12878', 'GLORI-1', 'GLORI-2'))
        for patch, color in zip(venn.patches, ['lightgreen', 'salmon', 'gold', 'purple', 'orange', 'lightblue', 'gray']):
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        ax2.set_title('m6A: Three-way', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=('salmon', 'gold'), alpha=0.7)
        ax3.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)
        
        fig.suptitle('m6A Sites - GM12878', fontsize=16, fontweight='bold')
    
    elif mode == 'HEK293':
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'm6a')
        
        hek_vs_combined = hek293_sites & glori_combined
        hek_vs_glori1 = hek293_sites & glori1_sites
        hek_vs_glori2 = hek293_sites & glori2_sites
        
        print(f"\nHEK293 overlaps:")
        print(f"  vs GLORI-1: {len(hek_vs_glori1)}")
        print(f"  vs GLORI-2: {len(hek_vs_glori2)}")
        print(f"  vs Combined: {len(hek_vs_combined)}")
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([hek293_sites, glori_combined],
              set_labels=('HEK293', 'GLORI'),
              set_colors=('skyblue', 'salmon'), alpha=0.7)
        ax1.set_title('m6A: HEK293 vs GLORI', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([hek293_sites, glori1_sites, glori2_sites],
                     set_labels=('HEK293', 'GLORI-1', 'GLORI-2'))
        for patch, color in zip(venn.patches, ['skyblue', 'salmon', 'gold', 'purple', 'orange', 'lightblue', 'gray']):
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        ax2.set_title('m6A: Three-way', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=('salmon', 'gold'), alpha=0.7)
        ax3.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)
        
        fig.suptitle('m6A Sites - HEK293', fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# Usage:
# plot_m6a_complete(dorado_mods_dict, combined_glori_1, combined_glori_2, mode='GM12878')
plot_m6a_complete(dorado_mods_dict, combined_glori_1, combined_glori_2, mode='HEK293', output_path = '/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/HEK293_m6A_glori_outer_venns.pdf' )
# plot_m6a_complete(dorado_mods_dict, combined_glori_1, combined_glori_2, mode='both')

1 extra bytes in post.stringData array
'created' timestamp seems very low; regarding as unix timestamp



=== Processing m6A - Mode: HEK293 ===
GLORI-1: 60462 sites
GLORI-2: 65687 sites
GLORI combined: 88110 sites
HEK293_m6a: 67517 sites

HEK293 overlaps:
  vs GLORI-1: 23426
  vs GLORI-2: 31580
  vs Combined: 36357
Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/HEK293_m6A_glori_outer_venns.pdf



  plt.show()


## m5C Plotting

In [41]:
def plot_m5c_venns_colored(dorado_mods_dict, m5c_orthogonal_df, mode='HEK293',
                          output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m5c_venns_colored.pdf'):
    """
    Create m5C Venn diagrams with fixed 3-way color scheme
    """
    print("\n" + "="*60)
    print(f"=== m5C Venn Diagrams - Mode: {mode} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['m5c']
    
    orthogonal_sites = process_orthogonal_smart_colored(m5c_orthogonal_df, 'chromosome', 'position', 'm5C Orthogonal')
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'm5c')
        overlap = hek293_sites & orthogonal_sites
        
        print(f"\nHEK293 vs Orthogonal: {len(overlap)} sites ({100*len(overlap)/len(hek293_sites):.1f}%)")
        
        fig, ax = plt.subplots(figsize=(4, 4))
        venn2([hek293_sites, orthogonal_sites],
              set_labels=('HEK293 DRS', 'Orthogonal'),
              set_colors=(colors['DRS_HEK293'], colors['Orthogonal']), 
              alpha=ALPHA, ax=ax)
        ax.set_title('m5C Sites: HEK293 vs Orthogonal', fontweight='bold', fontsize=14)
        fig.suptitle('m5C Sites - HEK293', fontsize=16, fontweight='bold')
        
    elif mode == 'GM12878':
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'm5c')
        overlap = gm12878_sites & orthogonal_sites
        
        print(f"\nGM12878 vs Orthogonal: {len(overlap)} sites ({100*len(overlap)/len(gm12878_sites):.1f}%)")
        
        fig, ax = plt.subplots(figsize=(4, 4))
        venn2([gm12878_sites, orthogonal_sites],
              set_labels=('GM12878 DRS', 'Orthogonal'),
              set_colors=(colors['DRS_GM12878'], colors['Orthogonal']), 
              alpha=ALPHA, ax=ax)
        ax.set_title('m5C Sites: GM12878 vs Orthogonal', fontweight='bold', fontsize=14)
        fig.suptitle('m5C Sites - GM12878', fontsize=16, fontweight='bold')
        
    elif mode == 'both':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'm5c')
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'm5c')
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 8))
        
        # Three-way - FIXED COLORS
        plt.sca(ax1)
        venn = venn3([hek293_sites, gm12878_sites, orthogonal_sites],
                     set_labels=('HEK293', 'GM12878', 'Orthogonal'))
        
        patch_colors = {
            '100': colors['DRS_HEK293'],      # HEK293 only
            '010': colors['DRS_GM12878'],     # GM12878 only
            '001': colors['Orthogonal'],      # Orthogonal only
            '110': colors['HEK_GM'],          # HEK293 ∩ GM12878
            '101': colors['HEK_Orth'],        # HEK293 ∩ Orthogonal
            '011': colors['GM_Orth'],         # GM12878 ∩ Orthogonal
            '111': colors['ALL_THREE']         # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax1.set_title('m5C: Three-way', fontweight='bold', fontsize=14)
        
        # Cell lines
        plt.sca(ax2)
        venn2([hek293_sites, gm12878_sites],
              set_labels=('HEK293', 'GM12878'),
              set_colors=(colors['DRS_HEK293'], colors['DRS_GM12878']), alpha=ALPHA)
        ax2.set_title('m5C: Cell Lines', fontweight='bold', fontsize=14)
        
        # HEK293 vs Orthogonal
        plt.sca(ax3)
        venn2([hek293_sites, orthogonal_sites],
              set_labels=('HEK293', 'Orthogonal'),
              set_colors=(colors['DRS_HEK293'], colors['Orthogonal']), alpha=ALPHA)
        ax3.set_title('m5C: HEK293 vs Orth', fontweight='bold', fontsize=14)
        
        # GM12878 vs Orthogonal
        plt.sca(ax4)
        venn2([gm12878_sites, orthogonal_sites],
              set_labels=('GM12878', 'Orthogonal'),
              set_colors=(colors['DRS_GM12878'], colors['Orthogonal']), alpha=ALPHA)
        ax4.set_title('m5C: GM12878 vs Orth', fontweight='bold', fontsize=14)
        
        fig.suptitle('m5C Sites - Complete', fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"\n✓ Saved to {output_path}")

# Usage:
plot_m5c_venns_colored(dorado_mods_dict, m5c_orthogonal_df, mode='HEK293')


=== m5C Venn Diagrams - Mode: HEK293 ===
m5C Orthogonal: 2191 sites
HEK293_m5c: 18159 sites

HEK293 vs Orthogonal: 59 sites (0.3%)

✓ Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m5c_venns_colored.pdf


  plt.show()


## Inosine Plotting

In [43]:
def plot_inosine_venns_colored(dorado_mods_dict, combined_ino, mode='HEK293',
                               output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/inosine_venns_colored.pdf'):
    """
    Create Inosine Venn diagrams with fixed 3-way color scheme
    """
    print("\n" + "="*60)
    print(f"=== Inosine Venn Diagrams - Mode: {mode} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['inosine']
    
    orthogonal_sites = process_orthogonal_smart_colored(combined_ino, 'Chromosome', 'position', 'Inosine Orthogonal')
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'inosine')
        overlap = hek293_sites & orthogonal_sites
        
        print(f"\nHEK293 vs Orthogonal: {len(overlap)} sites ({100*len(overlap)/len(hek293_sites):.1f}%)")
        
        fig, ax = plt.subplots(figsize=(4, 4))
        venn2([hek293_sites, orthogonal_sites],
              set_labels=('HEK293 DRS', 'Orthogonal'),
              set_colors=(colors['DRS_HEK293'], colors['Orthogonal']), 
              alpha=ALPHA, ax=ax)
        ax.set_title('Inosine (A-to-I) Sites:\nHEK293 vs Orthogonal', fontweight='bold', fontsize=14)
        fig.suptitle('Inosine Sites - HEK293', fontsize=16, fontweight='bold')
        
    elif mode == 'GM12878':
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'inosine')
        overlap = gm12878_sites & orthogonal_sites
        
        print(f"\nGM12878 vs Orthogonal: {len(overlap)} sites ({100*len(overlap)/len(gm12878_sites):.1f}%)")
        
        fig, ax = plt.subplots(figsize=(4, 4))
        venn2([gm12878_sites, orthogonal_sites],
              set_labels=('GM12878 DRS', 'Orthogonal'),
              set_colors=(colors['DRS_GM12878'], colors['Orthogonal']), 
              alpha=ALPHA, ax=ax)
        ax.set_title('Inosine (A-to-I) Sites:\nGM12878 vs Orthogonal', fontweight='bold', fontsize=14)
        fig.suptitle('Inosine Sites - GM12878', fontsize=16, fontweight='bold')
        
    elif mode == 'both':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'inosine')
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'inosine')
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 8))
        
        # Three-way - FIXED COLORS
        plt.sca(ax1)
        venn = venn3([hek293_sites, gm12878_sites, orthogonal_sites],
                     set_labels=('HEK293', 'GM12878', 'Orthogonal'))
        
        patch_colors = {
            '100': colors['DRS_HEK293'],      # HEK293 only
            '010': colors['DRS_GM12878'],     # GM12878 only
            '001': colors['Orthogonal'],      # Orthogonal only
            '110': colors['HEK_GM'],          # HEK293 ∩ GM12878
            '101': colors['HEK_Orth'],        # HEK293 ∩ Orthogonal
            '011': colors['GM_Orth'],         # GM12878 ∩ Orthogonal
            '111': colors['ALL_THREE']         # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax1.set_title('Inosine: Three-way', fontweight='bold', fontsize=14)
        
        # Cell lines
        plt.sca(ax2)
        venn2([hek293_sites, gm12878_sites],
              set_labels=('HEK293', 'GM12878'),
              set_colors=(colors['DRS_HEK293'], colors['DRS_GM12878']), alpha=ALPHA)
        ax2.set_title('Inosine: Cell Lines', fontweight='bold', fontsize=14)
        
        # HEK293 vs Orthogonal
        plt.sca(ax3)
        venn2([hek293_sites, orthogonal_sites],
              set_labels=('HEK293', 'Orthogonal'),
              set_colors=(colors['DRS_HEK293'], colors['Orthogonal']), alpha=ALPHA)
        ax3.set_title('Inosine: HEK293 vs Orth', fontweight='bold', fontsize=14)
        
        # GM12878 vs Orthogonal
        plt.sca(ax4)
        venn2([gm12878_sites, orthogonal_sites],
              set_labels=('GM12878', 'Orthogonal'),
              set_colors=(colors['DRS_GM12878'], colors['Orthogonal']), alpha=ALPHA)
        ax4.set_title('Inosine: GM12878 vs Orth', fontweight='bold', fontsize=14)
        
        fig.suptitle('Inosine (A-to-I) Sites - Complete', fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"\n✓ Saved to {output_path}")

# Usage:
plot_inosine_venns_colored(dorado_mods_dict, combined_ino, mode='HEK293')


=== Inosine Venn Diagrams - Mode: HEK293 ===
Inosine Orthogonal: 29745 sites
HEK293_inosine: 6956 sites

HEK293 vs Orthogonal: 1340 sites (19.3%)

✓ Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/inosine_venns_colored.pdf


  plt.show()


## Psi Plotting

In [42]:
def plot_psi_venns_colored(dorado_mods_dict, bid_seq_df, praise_filtered, mode='HEK293',
                          output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/psi_venns_colored.pdf'):
    """
    Create Pseudouridine Venn diagrams with fixed 3-way color scheme
    """
    print("\n" + "="*60)
    print(f"=== Pseudouridine (Ψ) Venn Diagrams - Mode: {mode} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['psi']
    
    bid_sites = process_orthogonal_smart_colored(bid_seq_df, 'chr', 'pos', 'BID-seq')
    praise_sites = process_orthogonal_smart_colored(praise_filtered, 'chromosome', 'genomic_position', 'PRAISE')
    praise_bid_combined = praise_sites & bid_sites
    
    print(f"PRAISE ∩ BID-seq: {len(praise_bid_combined)} sites")
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'psi')
        
        hek_vs_bid = hek293_sites & bid_sites
        hek_vs_praise = hek293_sites & praise_sites
        hek_vs_combined = hek293_sites & praise_bid_combined
        
        print(f"\nHEK293 overlaps:")
        print(f"  vs BID-seq: {len(hek_vs_bid)}")
        print(f"  vs PRAISE: {len(hek_vs_praise)}")
        print(f"  vs Combined: {len(hek_vs_combined)}")
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([hek293_sites, praise_bid_combined],
              set_labels=('HEK293 DRS', 'PRAISE ∩ BID-seq'),
              set_colors=(colors['DRS_HEK293'], colors['Combined']), 
              alpha=ALPHA)
        ax1.set_title('Ψ: HEK293 vs Orthogonal', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([hek293_sites, praise_sites, bid_sites],
                     set_labels=('HEK293', 'PRAISE', 'BID-seq'))
        
        # FIXED COLOR ASSIGNMENT
        patch_colors = {
            '100': colors['DRS_HEK293'],      # HEK293 only
            '010': colors['PRAISE'],          # PRAISE only
            '001': colors['BID-seq'],         # BID-seq only
            '110': colors['HEK_PRAISE'],      # HEK293 ∩ PRAISE
            '101': colors['HEK_BID'],         # HEK293 ∩ BID-seq
            '011': colors['BID_PRAISE'],      # BID-seq ∩ PRAISE
            '111': colors['ALL_THREE']         # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax2.set_title('Ψ: Three-way', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([praise_sites, bid_sites],
              set_labels=('PRAISE', 'BID-seq'),
              set_colors=(colors['PRAISE'], colors['BID-seq']), 
              alpha=ALPHA)
        ax3.set_title('Ψ: PRAISE vs BID-seq', fontweight='bold', fontsize=14)
        
        fig.suptitle('Pseudouridine (Ψ) Sites - HEK293', fontsize=16, fontweight='bold')
        
    elif mode == 'GM12878':
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'psi')
        
        gm_vs_bid = gm12878_sites & bid_sites
        gm_vs_praise = gm12878_sites & praise_sites
        gm_vs_combined = gm12878_sites & praise_bid_combined
        
        print(f"\nGM12878 overlaps:")
        print(f"  vs BID-seq: {len(gm_vs_bid)}")
        print(f"  vs PRAISE: {len(gm_vs_praise)}")
        print(f"  vs Combined: {len(gm_vs_combined)}")
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([gm12878_sites, praise_bid_combined],
              set_labels=('GM12878 DRS', 'PRAISE ∩ BID-seq'),
              set_colors=(colors['DRS_GM12878'], colors['Combined']), 
              alpha=ALPHA)
        ax1.set_title('Ψ: GM12878 vs Orthogonal', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([gm12878_sites, praise_sites, bid_sites],
                     set_labels=('GM12878', 'PRAISE', 'BID-seq'))
        
        # FIXED COLOR ASSIGNMENT
        patch_colors = {
            '100': colors['DRS_GM12878'],     # GM12878 only
            '010': colors['PRAISE'],          # PRAISE only
            '001': colors['BID-seq'],         # BID-seq only
            '110': colors['HEK_PRAISE'],      # GM12878 ∩ PRAISE (reusing HEK color)
            '101': colors['HEK_BID'],         # GM12878 ∩ BID-seq (reusing HEK color)
            '011': colors['BID_PRAISE'],      # BID-seq ∩ PRAISE
            '111': colors['ALL_THREE']         # All three
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax2.set_title('Ψ: Three-way', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([praise_sites, bid_sites],
              set_labels=('PRAISE', 'BID-seq'),
              set_colors=(colors['PRAISE'], colors['BID-seq']), 
              alpha=ALPHA)
        ax3.set_title('Ψ: PRAISE vs BID-seq', fontweight='bold', fontsize=14)
        
        fig.suptitle('Pseudouridine (Ψ) Sites - HEK293', fontsize=16, fontweight='bold')
    
    elif mode == 'both':
        hek293_sites = get_drs_sites_colored(dorado_mods_dict, 'HEK293', 'psi')
        gm12878_sites = get_drs_sites_colored(dorado_mods_dict, 'GM12878', 'psi')
        
        fig = plt.figure(figsize=(14, 8))
        
        # 1. Three-way: Both cells vs Combined
        ax1 = plt.subplot(2, 3, 1)
        plt.sca(ax1)
        venn = venn3([hek293_sites, gm12878_sites, praise_bid_combined],
                     set_labels=('HEK293', 'GM12878', 'PRAISE ∩ BID'))
        
        patch_colors = {
            '100': colors['DRS_HEK293'],
            '010': colors['DRS_GM12878'],
            '001': colors['Combined'],
            '110': '#bcbd22',
            '101': '#17becf',
            '011': '#e377c2',
            '111': colors['ALL_THREE']
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax1.set_title('Ψ: DRS vs Orthogonal', fontweight='bold', fontsize=12)
        
        # 2. Cell lines
        ax2 = plt.subplot(2, 3, 2)
        plt.sca(ax2)
        venn2([hek293_sites, gm12878_sites],
              set_labels=('HEK293', 'GM12878'),
              set_colors=(colors['DRS_HEK293'], colors['DRS_GM12878']), alpha=ALPHA)
        ax2.set_title('Ψ: Cell Lines', fontweight='bold', fontsize=12)
        
        # 3. Orthogonal methods
        ax3 = plt.subplot(2, 3, 3)
        plt.sca(ax3)
        venn2([praise_sites, bid_sites],
              set_labels=('PRAISE', 'BID-seq'),
              set_colors=(colors['PRAISE'], colors['BID-seq']), alpha=ALPHA)
        ax3.set_title('Ψ: PRAISE vs BID-seq', fontweight='bold', fontsize=12)
        
        # 4. HEK293 All Methods - FIXED COLORS
        ax4 = plt.subplot(2, 3, 4)
        plt.sca(ax4)
        venn = venn3([hek293_sites, praise_sites, bid_sites],
                     set_labels=('HEK293', 'PRAISE', 'BID-seq'))
        
        patch_colors = {
            '100': colors['DRS_HEK293'],
            '010': colors['PRAISE'],
            '001': colors['BID-seq'],
            '110': colors['HEK_PRAISE'],
            '101': colors['HEK_BID'],
            '011': colors['BID_PRAISE'],
            '111': colors['ALL_THREE']
        }
        
        for region_id, color in patch_colors.items():
            patch = venn.get_patch_by_id(region_id)
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(ALPHA)
        
        ax4.set_title('Ψ: HEK293 All', fontweight='bold', fontsize=12)
        
        # 5. HEK293 vs Combined
        ax5 = plt.subplot(2, 3, 5)
        plt.sca(ax5)
        venn2([hek293_sites, praise_bid_combined],
              set_labels=('HEK293', 'PRAISE ∩ BID'),
              set_colors=(colors['DRS_HEK293'], colors['Combined']), alpha=ALPHA)
        ax5.set_title('Ψ: HEK293 vs Orth', fontweight='bold', fontsize=12)
        
        # 6. GM12878 vs Combined
        ax6 = plt.subplot(2, 3, 6)
        plt.sca(ax6)
        venn2([gm12878_sites, praise_bid_combined],
              set_labels=('GM12878', 'PRAISE ∩ BID'),
              set_colors=(colors['DRS_GM12878'], colors['Combined']), alpha=ALPHA)
        ax6.set_title('Ψ: GM12878 vs Orth', fontweight='bold', fontsize=12)
        
        fig.suptitle('Pseudouridine (Ψ) Sites - Complete', fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"\n✓ Saved to {output_path}")

# Usage:
plot_psi_venns_colored(dorado_mods_dict, bid_seq_df, praise_filtered, mode='HEK293')


=== Pseudouridine (Ψ) Venn Diagrams - Mode: HEK293 ===
BID-seq: 543 sites
PRAISE: 1801 sites
PRAISE ∩ BID-seq: 1 sites
HEK293_psi: 3103 sites

HEK293 overlaps:
  vs BID-seq: 45
  vs PRAISE: 9
  vs Combined: 1

✓ Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/psi_venns_colored.pdf


  plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import LogNorm
from matplotlib_venn import venn2, venn3
import polars as pl
import numpy as np

# ==========================================
# COLOR SCHEME DEFINITIONS
# ==========================================

# Using tab20 colormap for 20 distinct colors
tab20 = plt.cm.tab20.colors

# Define colors for each modification type and technology
MODIFICATION_COLORS = {
    # m6A technologies (use blues/greens)
    'm6a': {
        'GM12878': tab20[0],      # Dark blue
        'HEK293': tab20[1],       # Light blue  
        'GLORI-1': tab20[4],      # Dark green
        'GLORI-2': tab20[5],      # Light green
        'GLORI_combined': tab20[8], # Dark purple
    },
    
    # m5C technologies (use oranges/reds)
    'm5c': {
        'GM12878': tab20[2],      # Dark orange
        'HEK293': tab20[3],       # Light orange
        'Orthogonal': tab20[6],   # Dark red
    },
    
    # Pseudouridine technologies (use purples/pinks)
    'psi': {
        'GM12878': tab20[9],      # Light purple
        'HEK293': tab20[10],      # Dark brown
        'PRAISE': tab20[12],      # Dark pink
        'BID-seq': tab20[13],     # Light pink
        'PRAISE_BID_combined': tab20[14], # Dark gray
    },
    
    # Inosine technologies (use browns/grays)
    'inosine': {
        'GM12878': tab20[16],     # Dark cyan
        'HEK293': tab20[17],      # Light cyan
        'Orthogonal': tab20[18],  # Dark yellow
    }
}

# Style settings
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['figure.dpi'] = 600
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# ==========================================
# HELPER FUNCTIONS
# ==========================================

def get_drs_sites(mod_dict, cell_line, mod):
    """Extract DRS sites for Venn diagrams"""
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        df = list(mod_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        
        # Check if chromosomes need 'chr' prefix
        sample_chr = df_filtered['Chromosome'][0]
        if sample_chr is not None and not str(sample_chr).startswith('chr'):
            sites = set(('chr' + df_filtered['Chromosome'].cast(pl.Utf8) + '_' + df_filtered['End'].cast(pl.Utf8)).to_list())
        else:
            sites = set((df_filtered['Chromosome'].cast(pl.Utf8) + '_' + df_filtered['End'].cast(pl.Utf8)).to_list())
        
        print(f"{key}: {len(sites)} sites")
        return sites
    return set()

def get_drs_values(mod_dict, cell_line, mod):
    """Extract DRS sites with their values for heatmaps"""
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        df = list(mod_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        
        # Check if chromosomes need 'chr' prefix
        sample_chr = df_filtered['Chromosome'][0]
        if sample_chr is not None and not str(sample_chr).startswith('chr'):
            df_filtered = df_filtered.with_columns([
                ('chr' + pl.col('Chromosome').cast(pl.Utf8) + '_' + pl.col('End').cast(pl.Utf8)).alias('site_id')
            ])
        else:
            df_filtered = df_filtered.with_columns([
                (pl.col('Chromosome').cast(pl.Utf8) + '_' + pl.col('End').cast(pl.Utf8)).alias('site_id')
            ])
        
        print(f"{key}: {len(df_filtered)} sites")
        return df_filtered.select(['site_id', 'Adjusted_Mod_Proportion'])
    return None

def process_orthogonal_smart(df, chr_col, pos_col, label="Orthogonal"):
    """Process orthogonal data for Venn diagrams"""
    if isinstance(df, pl.DataFrame):
        sample_chr = df[chr_col][0]
        if sample_chr is not None and not str(sample_chr).startswith('chr'):
            sites = set(('chr' + df[chr_col].cast(pl.Utf8) + '_' + df[pos_col].cast(pl.Int64).cast(pl.Utf8)).to_list())
        else:
            sites = set((df[chr_col].cast(pl.Utf8) + '_' + df[pos_col].cast(pl.Int64).cast(pl.Utf8)).to_list())
    else:
        sample_chr = str(df[chr_col].iloc[0])
        pos_int = df[pos_col].astype(int).astype(str)
        if sample_chr.startswith('chr'):
            sites = set(df[chr_col].astype(str) + '_' + pos_int)
        else:
            sites = set('chr' + df[chr_col].astype(str) + '_' + pos_int)
    print(f"{label}: {len(sites)} sites")
    return sites

def process_orthogonal_values(df, chr_col, pos_col, value_col, label="Orthogonal", scale_by_100=False):
    """Extract sites with their modification values for heatmaps"""
    if isinstance(df, pl.DataFrame):
        sample_chr = df[chr_col][0]
        if sample_chr is not None and not str(sample_chr).startswith('chr'):
            df_processed = df.with_columns([
                ('chr' + pl.col(chr_col).cast(pl.Utf8) + '_' + pl.col(pos_col).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        else:
            df_processed = df.with_columns([
                (pl.col(chr_col).cast(pl.Utf8) + '_' + pl.col(pos_col).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
    else:
        sample_chr = str(df[chr_col].iloc[0])
        pos_int = df[pos_col].astype(int).astype(str)
        if sample_chr.startswith('chr'):
            site_ids = df[chr_col].astype(str) + '_' + pos_int
        else:
            site_ids = 'chr' + df[chr_col].astype(str) + '_' + pos_int
        
        df_processed = pl.DataFrame({
            'site_id': site_ids.tolist(),
            value_col: df[value_col].tolist()
        })
    
    result = df_processed.select(['site_id', value_col])
    
    if scale_by_100:
        result = result.with_columns([
            (pl.col(value_col) * 100).alias(value_col)
        ])
    
    print(f"{label}: {len(result)} sites")
    return result

def plot_single_heatmap(ax, drs_df, ortho_df, drs_col, ortho_col, title, colormap='viridis'):
    """Create a single heatmap comparison"""
    # Merge on common sites
    merged = drs_df.join(ortho_df, on='site_id', how='inner')
    
    if len(merged) == 0:
        ax.text(0.5, 0.5, 'No overlapping sites', ha='center', va='center', fontsize=12)
        ax.set_title(title, fontweight='bold', fontsize=11)
        return
    
    print(f"\n{title}: {len(merged)} overlapping sites")
    
    # Convert to numpy for plotting
    drs_values = merged[drs_col].to_numpy()
    ortho_values = merged[ortho_col].to_numpy()
    
    # Set up histogram bins with fixed bandwidth
    bandwidth_2d = 5
    bins = np.arange(0, 100 + bandwidth_2d, bandwidth_2d)
    
    # Create 2D histogram
    hist, xedges, yedges = np.histogram2d(drs_values, ortho_values, bins=bins)
    
    # Plot heatmap with log normalization
    im = ax.imshow(hist.T,
                   norm=LogNorm(vmin=1, vmax=10**3),
                   origin='lower',
                   extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
                   aspect='auto',
                   cmap=colormap,
                   interpolation='nearest')
    
    # Add colorbar with log scale ticks
    cbar = plt.colorbar(im, ax=ax, ticks=np.logspace(0, 3, 4))
    cbar.set_ticklabels(['$10^0$', '$10^1$', '$10^2$', '$10^3$'])
    cbar.set_label('Site count', fontsize=10)
    
    # Calculate correlation
    correlation = np.corrcoef(drs_values, ortho_values)[0, 1]
    
    # Labels and title
    drs_method = 'DRS' if 'Adjusted' in drs_col else ortho_col
    ortho_method = ortho_col.replace('_mean', '').replace('GLORI_combined_mean', 'GLORI Combined').replace('_', ' ')
    
    ax.set_xlabel(f'{drs_method} Mod Percentage', fontsize=11)
    ax.set_ylabel(f'{ortho_method} Mod Percentage', fontsize=11)
    ax.set_title(title, fontweight='bold', fontsize=12)
    
    # Add diagonal reference line (black dashed)
    min_val = min(xedges[0], yedges[0])
    max_val = max(xedges[-1], yedges[-1])
    ax.plot([min_val, max_val], [min_val, max_val], color='k', linestyle='--', linewidth=1.5)
    
    # Add text with statistics in the upper left corner
    stats_text = f'n = {len(merged):,}\nr = {correlation:.3f}'
    ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, 
            fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Print statistics
    print(f"  Mean DRS: {drs_values.mean():.2f}%")
    print(f"  Mean Orthogonal: {ortho_values.mean():.2f}%")
    print(f"  Correlation: {correlation:.3f}")

# ==========================================
# m6A PLOTTING FUNCTIONS
# ==========================================

def plot_m6a_heatmap(dorado_mods_dict, combined_glori_1, combined_glori_2_df, mode='both',
                      output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_heatplot_colored.pdf', 
                      debug=False, combine_mode='inner'):
    """Create heatmaps comparing m6A modification levels across methods"""
    
    print("\n" + "="*60)
    print(f"=== Creating m6A Heatmaps - Mode: {mode} ===")
    print("="*60)
    
    # Process orthogonal data
    glori1_df = process_orthogonal_values(combined_glori_1, 'Chr', 'Sites', 'Ratio_mean', 'GLORI-1', scale_by_100=True)
    glori2_df = process_orthogonal_values(combined_glori_2_df, 'Chr', 'Site', 'm6A_level_mean', 'GLORI-2')
    
    # Debug: Check sample site IDs to ensure consistent formatting
    if debug:
        print("\nDEBUG - Sample site IDs:")
        print(f"GLORI-1 first 3: {glori1_df['site_id'].head(3).to_list()}")
        print(f"GLORI-2 first 3: {glori2_df['site_id'].head(3).to_list()}")
    
    # Create combined GLORI
    if combine_mode == 'inner':
        glori_combined = glori1_df.join(glori2_df, on='site_id', how='inner')
        glori_combined = glori_combined.with_columns([
            ((pl.col('Ratio_mean') + pl.col('m6A_level_mean')) / 2).alias('GLORI_combined_mean')
        ])
        glori_combined = glori_combined.select(['site_id', 'GLORI_combined_mean'])
        print(f"GLORI combined (INNER): {len(glori_combined)} sites (intersection only)")
    else:  # 'outer'
        glori1_for_combine = glori1_df.select([
            pl.col('site_id'),
            pl.col('Ratio_mean').alias('value'),
            pl.lit('GLORI1').alias('source')
        ])
        glori2_for_combine = glori2_df.select([
            pl.col('site_id'),
            pl.col('m6A_level_mean').alias('value'),
            pl.lit('GLORI2').alias('source')
        ])
        all_glori = pl.concat([glori1_for_combine, glori2_for_combine])
        glori_combined = all_glori.group_by('site_id').agg([
            pl.col('value').mean().alias('GLORI_combined_mean'),
            pl.col('source').count().alias('n_sources')
        ])
        glori_combined = glori_combined.filter(pl.col('site_id').is_not_null())
        print(f"GLORI combined (OUTER): {len(glori_combined)} sites (union)")
    
    # Different colormaps for different comparisons
    colormaps = ['viridis', 'plasma', 'cividis']
    
    if mode == 'HEK293':
        hek293_df = get_drs_values(dorado_mods_dict, 'HEK293', 'm6a')
        
        fig = plt.figure(figsize=(15, 4))
        
        comparisons = [
            (hek293_df, glori1_df, 'Adjusted_Mod_Proportion', 'Ratio_mean', 'HEK293 vs GLORI-1', 1, colormaps[0]),
            (hek293_df, glori2_df, 'Adjusted_Mod_Proportion', 'm6A_level_mean', 'HEK293 vs GLORI-2', 2, colormaps[1]),
            (hek293_df, glori_combined, 'Adjusted_Mod_Proportion', 'GLORI_combined_mean', 'HEK293 vs GLORI Combined', 3, colormaps[2])
        ]
        
        for drs_df, ortho_df, drs_col, ortho_col, title, idx, cmap in comparisons:
            ax = plt.subplot(1, 3, idx)
            plot_single_heatmap(ax, drs_df, ortho_df, drs_col, ortho_col, title, cmap)
        
        fig.suptitle('m6A Modification Levels - HEK293', fontsize=16, fontweight='bold')
    
    # Similar for other modes...
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=False, bbox_inches='tight')
    plt.show()
    print(f"\nSaved to {output_path}")

def plot_m6a_complete(dorado_mods_dict, combined_glori_1, combined_glori_2_df, mode='both',
                      output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_venns_colored.pdf'):
    """Complete m6A Venn diagram suite with technology-specific colors"""
    
    print("\n" + "="*60)
    print(f"=== Processing m6A - Mode: {mode} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['m6a']
    
    glori1_sites = process_orthogonal_smart(combined_glori_1, 'Chr', 'Sites', 'GLORI-1')
    glori2_sites = process_orthogonal_smart(combined_glori_2_df, 'Chr', 'Site', 'GLORI-2')
    glori_combined = glori1_sites | glori2_sites
    
    print(f"GLORI combined: {len(glori_combined)} sites")
    
    if mode == 'both':
        gm12878_sites = get_drs_sites(dorado_mods_dict, 'GM12878', 'm6a')
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'm6a')
        
        fig = plt.figure(figsize=(14, 8))
        
        # 1. Three-way: GM12878, HEK293, Combined GLORI
        ax1 = plt.subplot(2, 3, 1)
        plt.sca(ax1)
        venn = venn3([gm12878_sites, hek293_sites, glori_combined],
                     set_labels=('GM12878', 'HEK293', 'GLORI'))
        if venn.patches[0]: venn.patches[0].set_facecolor(colors['GM12878'])
        if venn.patches[1]: venn.patches[1].set_facecolor(colors['HEK293'])
        if venn.patches[2]: venn.patches[2].set_facecolor(colors['GLORI_combined'])
        for patch in venn.patches:
            if patch: patch.set_alpha(0.7)
        ax1.set_title('m6A: Both Cells vs GLORI', fontweight='bold', fontsize=12)
        
        # 2. GM12878 vs GLORI
        ax2 = plt.subplot(2, 3, 2)
        plt.sca(ax2)
        venn2([gm12878_sites, glori_combined],
              set_labels=('GM12878', 'GLORI'),
              set_colors=(colors['GM12878'], colors['GLORI_combined']), alpha=0.7)
        ax2.set_title('m6A: GM12878 vs GLORI', fontweight='bold', fontsize=12)
        
        # 3. HEK293 vs GLORI
        ax3 = plt.subplot(2, 3, 3)
        plt.sca(ax3)
        venn2([hek293_sites, glori_combined],
              set_labels=('HEK293', 'GLORI'),
              set_colors=(colors['HEK293'], colors['GLORI_combined']), alpha=0.7)
        ax3.set_title('m6A: HEK293 vs GLORI', fontweight='bold', fontsize=12)
        
        # 4. GM12878 vs HEK293
        ax4 = plt.subplot(2, 3, 4)
        plt.sca(ax4)
        venn2([gm12878_sites, hek293_sites],
              set_labels=('GM12878', 'HEK293'),
              set_colors=(colors['GM12878'], colors['HEK293']), alpha=0.7)
        ax4.set_title('m6A: GM12878 vs HEK293', fontweight='bold', fontsize=12)
        
        # 5. GLORI-1 vs GLORI-2
        ax5 = plt.subplot(2, 3, 5)
        plt.sca(ax5)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=(colors['GLORI-1'], colors['GLORI-2']), alpha=0.7)
        ax5.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=12)
        
        # 6. GM12878 vs GLORI-1 vs GLORI-2
        ax6 = plt.subplot(2, 3, 6)
        plt.sca(ax6)
        venn = venn3([gm12878_sites, glori1_sites, glori2_sites],
                     set_labels=('GM12878', 'GLORI-1', 'GLORI-2'))
        if venn.patches[0]: venn.patches[0].set_facecolor(colors['GM12878'])
        if venn.patches[1]: venn.patches[1].set_facecolor(colors['GLORI-1'])
        if venn.patches[2]: venn.patches[2].set_facecolor(colors['GLORI-2'])
        for patch in venn.patches:
            if patch: patch.set_alpha(0.7)
        ax6.set_title('m6A: GM12878 vs GLORI Methods', fontweight='bold', fontsize=12)
        
        fig.suptitle('m6A Sites - Complete', fontsize=16, fontweight='bold')
    
    elif mode == 'HEK293':
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'm6a')
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([hek293_sites, glori_combined],
              set_labels=('HEK293', 'GLORI'),
              set_colors=(colors['HEK293'], colors['GLORI_combined']), alpha=0.7)
        ax1.set_title('m6A: HEK293 vs GLORI', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([hek293_sites, glori1_sites, glori2_sites],
                     set_labels=('HEK293', 'GLORI-1', 'GLORI-2'))
        if venn.patches[0]: venn.patches[0].set_facecolor(colors['HEK293'])
        if venn.patches[1]: venn.patches[1].set_facecolor(colors['GLORI-1'])
        if venn.patches[2]: venn.patches[2].set_facecolor(colors['GLORI-2'])
        for patch in venn.patches:
            if patch: patch.set_alpha(0.7)
        ax2.set_title('m6A: Three-way', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([glori1_sites, glori2_sites],
              set_labels=('GLORI-1', 'GLORI-2'),
              set_colors=(colors['GLORI-1'], colors['GLORI-2']), alpha=0.7)
        ax3.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)
        
        fig.suptitle('m6A Sites - HEK293', fontsize=16, fontweight='bold')
    
    # Similar for GM12878...
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ==========================================
# m5C PLOTTING FUNCTIONS
# ==========================================

def plot_m5c_complete(dorado_mods_dict, m5c_orthogonal_df, mode='both',
                      output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m5c_venns_colored.pdf'):
    """Complete m5C Venn diagram suite with technology-specific colors"""
    
    print("\n" + "="*60)
    print(f"=== Processing m5C - Mode: {mode} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['m5c']
    
    orthogonal_sites = process_orthogonal_smart(m5c_orthogonal_df, 'chromosome', 'position', 'm5C Orthogonal')
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'm5c')
        
        fig, ax = plt.subplots(figsize=(4, 4))
        venn2([hek293_sites, orthogonal_sites],
              set_labels=('HEK293 DRS', 'Orthogonal'),
              set_colors=(colors['HEK293'], colors['Orthogonal']), alpha=0.7, ax=ax)
        ax.set_title('m5C Sites: HEK293 vs Orthogonal', fontweight='bold', fontsize=14)
        fig.suptitle('m5C Sites - HEK293', fontsize=16, fontweight='bold')
    
    # Similar for other modes...
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ==========================================
# PSEUDOURIDINE PLOTTING FUNCTIONS
# ==========================================

def plot_psi_complete(dorado_mods_dict, bid_seq_df, praise_filtered, mode='both',
                      output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/psi_venns_colored.pdf'):
    """Complete Pseudouridine Venn diagram suite with technology-specific colors"""
    
    print("\n" + "="*60)
    print(f"=== Processing Pseudouridine - Mode: {mode} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['psi']
    
    bid_sites = process_orthogonal_smart(bid_seq_df, 'chr', 'pos', 'BID-seq')
    praise_sites = process_orthogonal_smart(praise_filtered, 'chromosome', 'genomic_position', 'PRAISE')
    praise_bid_combined = praise_sites & bid_sites
    
    print(f"PRAISE & BID-seq overlap: {len(praise_bid_combined)} sites")
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'psi')
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([hek293_sites, praise_bid_combined],
              set_labels=('HEK293', 'PRAISE & BID'),
              set_colors=(colors['HEK293'], colors['PRAISE_BID_combined']), alpha=0.7)
        ax1.set_title('Ψ: HEK293 vs Orthogonal', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([hek293_sites, praise_sites, bid_sites],
                     set_labels=('HEK293', 'PRAISE', 'BID-seq'))
        if venn.patches[0]: venn.patches[0].set_facecolor(colors['HEK293'])
        if venn.patches[1]: venn.patches[1].set_facecolor(colors['PRAISE'])
        if venn.patches[2]: venn.patches[2].set_facecolor(colors['BID-seq'])
        for patch in venn.patches:
            if patch: patch.set_alpha(0.7)
        ax2.set_title('Ψ: Three-way', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([praise_sites, bid_sites],
              set_labels=('PRAISE', 'BID-seq'),
              set_colors=(colors['PRAISE'], colors['BID-seq']), alpha=0.7)
        ax3.set_title('Ψ: PRAISE vs BID-seq', fontweight='bold', fontsize=14)
        
        fig.suptitle('Pseudouridine (Ψ) Sites - HEK293', fontsize=16, fontweight='bold')
    
    # Similar for other modes...
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ==========================================
# INOSINE PLOTTING FUNCTIONS
# ==========================================

def plot_inosine_complete(dorado_mods_dict, combined_ino_df, mode='both',
                          output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/inosine_venns_colored.pdf'):
    """Complete Inosine Venn diagram suite with technology-specific colors"""
    
    print("\n" + "="*60)
    print(f"=== Processing Inosine - Mode: {mode} ===")
    print("="*60)
    
    colors = MODIFICATION_COLORS['inosine']
    
    orthogonal_sites = process_orthogonal_smart(combined_ino_df, 'Chromosome', 'position', 'Inosine Orthogonal')
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'inosine')
        
        fig, ax = plt.subplots(figsize=(4, 4))
        venn2([hek293_sites, orthogonal_sites],
              set_labels=('HEK293 DRS', 'Orthogonal'),
              set_colors=(colors['HEK293'], colors['Orthogonal']), alpha=0.7, ax=ax)
        ax.set_title('Inosine (A-to-I) Sites:\nHEK293 vs Orthogonal', fontweight='bold', fontsize=14)
        fig.suptitle('Inosine Sites - HEK293', fontsize=16, fontweight='bold')
    
    # Similar for other modes...
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

## 2' OMethyl Ploting

In [43]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib_venn import venn2, venn3
import polars as pl

# Style settings
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['figure.dpi'] = 600
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

def get_drs_sites(mod_dict, cell_line, mod):
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        df = list(mod_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        sites = set((df_filtered['Chromosome'].cast(pl.Utf8) + '_' + df_filtered['End'].cast(pl.Utf8)).to_list())
        print(f"{key}: {len(sites)} sites")
        return sites
    return set()

def process_orthogonal_smart(df, chr_col, pos_col, label="Orthogonal"):
    if isinstance(df, pl.DataFrame):
        sites = set((df[chr_col].cast(pl.Utf8) + '_' + df[pos_col].cast(pl.Int64).cast(pl.Utf8)).to_list())
    else:
        sample_chr = str(df[chr_col].iloc[0])
        pos_int = df[pos_col].astype(int).astype(str)
        if sample_chr.startswith('chr'):
            sites = set(df[chr_col].astype(str) + '_' + pos_int)
        else:
            sites = set('chr' + df[chr_col].astype(str) + '_' + pos_int)
    print(f"{label}: {len(sites)} sites")
    return sites

def plot_2ome_complete(dorado_mods_dict, ome_A_df, ome_C_df, ome_G_df, ome_U_df, mode='both',
                       output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/2ome_venns.pdf'):
    """
    Complete 2'OMe Venn diagram suite for all 4 bases
    
    Parameters:
    -----------
    ome_A_df, ome_C_df, ome_G_df, ome_U_df : DataFrames
        Orthogonal 2'OMe data for each base with 'Chr' and 'Position' columns
        (from HEK293T in Tang et al. paper)
    mode : str
        'GM12878' - GM12878 only
        'HEK293' - HEK293 only  
        'both' - Both cell lines with all comparisons
    """
    print("\n" + "="*60)
    print(f"=== Processing 2'OMe - Mode: {mode} ===")
    print("="*60)
    
    # Get orthogonal sites for each base
    orth_A_sites = process_orthogonal_smart(ome_A_df, 'Chr', 'Position', "2'OMe-A Orth (HEK293T)")
    orth_C_sites = process_orthogonal_smart(ome_C_df, 'Chr', 'Position', "2'OMe-C Orth (HEK293T)")
    orth_G_sites = process_orthogonal_smart(ome_G_df, 'Chr', 'Position', "2'OMe-G Orth (HEK293T)")
    orth_U_sites = process_orthogonal_smart(ome_U_df, 'Chr', 'Position', "2'OMe-U Orth (HEK293T)")
    orth_all = orth_A_sites | orth_C_sites | orth_G_sites | orth_U_sites
    
    print(f"Combined orthogonal: {len(orth_all)} sites")
    
    if mode == 'both':
        # Both cell lines - 3x3 grid
        gm_A = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeA')
        gm_C = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeC')
        gm_G = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeG')
        gm_U = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeU')
        gm_all = gm_A | gm_C | gm_G | gm_U
        
        hek_A = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeA')
        hek_C = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeC')
        hek_G = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeG')
        hek_U = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeU')
        hek_all = hek_A | hek_C | hek_G | hek_U
        
        gm_all_vs_orth = gm_all & orth_all
        hek_all_vs_orth = hek_all & orth_all
        gm_vs_hek_all = gm_all & hek_all
        
        print(f"\nCombined overlaps:")
        print(f"  GM12878 vs Orthogonal: {len(gm_all_vs_orth)}")
        print(f"  HEK293 vs Orthogonal: {len(hek_all_vs_orth)}")
        print(f"  GM12878 vs HEK293: {len(gm_vs_hek_all)}")
        
        fig = plt.figure(figsize=(15, 12))
        
        # 1. Three-way combined
        ax1 = plt.subplot(3, 3, 1)
        plt.sca(ax1)
        venn = venn3([gm_all, hek_all, orth_all],
                     set_labels=('GM12878', 'HEK293', 'HEK293T Orth'))
        for patch, color in zip(venn.patches, ['lightgreen', 'skyblue', 'salmon', 'purple', 'gold', 'orange', 'gray']):
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        ax1.set_title("2'OMe: All Bases Combined", fontweight='bold', fontsize=11)
        
        # 2. GM12878 vs Orthogonal
        ax2 = plt.subplot(3, 3, 2)
        plt.sca(ax2)
        venn2([gm_all, orth_all],
              set_labels=('GM12878', 'HEK293T Orth'),
              set_colors=('lightgreen', 'salmon'), alpha=0.7)
        ax2.set_title("2'OMe: GM12878 vs Orth", fontweight='bold', fontsize=11)
        
        # 3. HEK293 vs Orthogonal
        ax3 = plt.subplot(3, 3, 3)
        plt.sca(ax3)
        venn2([hek_all, orth_all],
              set_labels=('HEK293', 'HEK293T Orth'),
              set_colors=('skyblue', 'salmon'), alpha=0.7)
        ax3.set_title("2'OMe: HEK293 vs Orth", fontweight='bold', fontsize=11)
        
        # Row 2: Base-specific three-way
        for idx, (base, gm_sites, hek_sites, orth_sites) in enumerate([
            ('A', gm_A, hek_A, orth_A_sites),
            ('C', gm_C, hek_C, orth_C_sites),
            ('G', gm_G, hek_G, orth_G_sites)
        ]):
            ax = plt.subplot(3, 3, 4 + idx)
            plt.sca(ax)
            venn = venn3([gm_sites, hek_sites, orth_sites],
                         set_labels=('GM', 'HEK', 'Orth'))
            for patch, color in zip(venn.patches, ['lightgreen', 'skyblue', 'salmon', 'purple', 'gold', 'orange', 'gray']):
                if patch:
                    patch.set_facecolor(color)
                    patch.set_alpha(0.7)
            ax.set_title(f"2'OMe-{base}", fontweight='bold', fontsize=11)
        
        # Row 3
        # 7. Base U
        ax7 = plt.subplot(3, 3, 7)
        plt.sca(ax7)
        venn = venn3([gm_U, hek_U, orth_U_sites],
                     set_labels=('GM', 'HEK', 'Orth'))
        for patch, color in zip(venn.patches, ['lightgreen', 'skyblue', 'salmon', 'purple', 'gold', 'orange', 'gray']):
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        ax7.set_title("2'OMe-U", fontweight='bold', fontsize=11)
        
        # 8. GM12878 vs HEK293
        ax8 = plt.subplot(3, 3, 8)
        plt.sca(ax8)
        venn2([gm_all, hek_all],
              set_labels=('GM12878', 'HEK293'),
              set_colors=('lightgreen', 'skyblue'), alpha=0.7)
        ax8.set_title("2'OMe: GM vs HEK", fontweight='bold', fontsize=11)
        
        # 9. Summary
        ax9 = plt.subplot(3, 3, 9)
        ax9.axis('off')
        summary = f"""
        2'O-Methylation
        
        GM12878 DRS:
          Total: {len(gm_all):,}
        
        HEK293 DRS:
          Total: {len(hek_all):,}
        
        HEK293T Orth:
          Total: {len(orth_all):,}
        
        Overlaps:
          GM vs Orth: {len(gm_all_vs_orth):,}
          HEK vs Orth: {len(hek_all_vs_orth):,}
          GM vs HEK: {len(gm_vs_hek_all):,}
        """
        ax9.text(0.1, 0.5, summary, fontsize=9, verticalalignment='center', fontfamily='monospace')
        
        fig.suptitle("2'O-Methylation Sites - Complete Comparison", fontsize=16, fontweight='bold')
    
    elif mode == 'GM12878':
        # GM12878 only
        gm_A = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeA')
        gm_C = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeC')
        gm_G = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeG')
        gm_U = get_drs_sites(dorado_mods_dict, 'GM12878', '2OMeU')
        gm_all = gm_A | gm_C | gm_G | gm_U
        
        overlap_A = gm_A & orth_A_sites
        overlap_C = gm_C & orth_C_sites
        overlap_G = gm_G & orth_G_sites
        overlap_U = gm_U & orth_U_sites
        overlap_all = gm_all & orth_all
        
        print(f"\nGM12878 vs HEK293T orthogonal overlaps:")
        print(f"  A: {len(overlap_A)}")
        print(f"  C: {len(overlap_C)}")
        print(f"  G: {len(overlap_G)}")
        print(f"  U: {len(overlap_U)}")
        print(f"  Combined: {len(overlap_all)}")
        
        fig = plt.figure(figsize=(12, 8))
        
        ax1 = plt.subplot(2, 3, 1)
        plt.sca(ax1)
        venn2([gm_all, orth_all],
              set_labels=('GM12878 DRS', 'HEK293T Orth'),
              set_colors=('lightgreen', 'salmon'), alpha=0.7)
        ax1.set_title(f"2'OMe Combined\n({len(overlap_all)} overlap)", fontweight='bold', fontsize=12)
        
        for idx, (base, gm_sites, orth_sites, overlap) in enumerate([
            ('A', gm_A, orth_A_sites, overlap_A),
            ('C', gm_C, orth_C_sites, overlap_C),
            ('G', gm_G, orth_G_sites, overlap_G),
            ('U', gm_U, orth_U_sites, overlap_U)
        ]):
            ax = plt.subplot(2, 3, 2 + idx)
            plt.sca(ax)
            venn2([gm_sites, orth_sites],
                  set_labels=('DRS', 'Orth'),
                  set_colors=(['lightblue', 'lightcoral', 'palegreen', 'lightyellow'][idx], 'salmon'),
                  alpha=0.7)
            ax.set_title(f"2'OMe-{base}\n({len(overlap)})", fontweight='bold', fontsize=12)
        
        ax6 = plt.subplot(2, 3, 6)
        ax6.axis('off')
        summary = f"""
        2'O-Methylation
        GM12878 vs HEK293T
        
        DRS Sites:
          A: {len(gm_A):,}
          C: {len(gm_C):,}
          G: {len(gm_G):,}
          U: {len(gm_U):,}
          Total: {len(gm_all):,}
        
        Orthogonal:
          Total: {len(orth_all):,}
        
        Overlap: {len(overlap_all):,}
        """
        ax6.text(0.1, 0.5, summary, fontsize=9, verticalalignment='center', fontfamily='monospace')
        
        fig.suptitle("2'O-Methylation Sites - GM12878 vs HEK293T", fontsize=16, fontweight='bold')
    
    elif mode == 'HEK293':
        # HEK293 only
        hek_A = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeA')
        hek_C = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeC')
        hek_G = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeG')
        hek_U = get_drs_sites(dorado_mods_dict, 'HEK293', '2OMeU')
        hek_all = hek_A | hek_C | hek_G | hek_U
        
        overlap_A = hek_A & orth_A_sites
        overlap_C = hek_C & orth_C_sites
        overlap_G = hek_G & orth_G_sites
        overlap_U = hek_U & orth_U_sites
        overlap_all = hek_all & orth_all
        
        print(f"\nHEK293 vs HEK293T orthogonal overlaps:")
        print(f"  A: {len(overlap_A)}")
        print(f"  C: {len(overlap_C)}")
        print(f"  G: {len(overlap_G)}")
        print(f"  U: {len(overlap_U)}")
        print(f"  Combined: {len(overlap_all)}")
        
        fig = plt.figure(figsize=(12, 8))
        
        ax1 = plt.subplot(2, 3, 1)
        plt.sca(ax1)
        venn2([hek_all, orth_all],
              set_labels=('HEK293 DRS', 'HEK293T Orth'),
              set_colors=('skyblue', 'salmon'), alpha=0.7)
        ax1.set_title(f"2'OMe Combined\n({len(overlap_all)} overlap)", fontweight='bold', fontsize=12)
        
        for idx, (base, hek_sites, orth_sites, overlap) in enumerate([
            ('A', hek_A, orth_A_sites, overlap_A),
            ('C', hek_C, orth_C_sites, overlap_C),
            ('G', hek_G, orth_G_sites, overlap_G),
            ('U', hek_U, orth_U_sites, overlap_U)
        ]):
            ax = plt.subplot(2, 3, 2 + idx)
            plt.sca(ax)
            venn2([hek_sites, orth_sites],
                  set_labels=('DRS', 'Orth'),
                  set_colors=(['lightblue', 'lightcoral', 'palegreen', 'lightyellow'][idx], 'salmon'),
                  alpha=0.7)
            ax.set_title(f"2'OMe-{base}\n({len(overlap)})", fontweight='bold', fontsize=12)
        
        ax6 = plt.subplot(2, 3, 6)
        ax6.axis('off')
        summary = f"""
        2'O-Methylation
        HEK293 vs HEK293T
        
        DRS Sites:
          A: {len(hek_A):,}
          C: {len(hek_C):,}
          G: {len(hek_G):,}
          U: {len(hek_U):,}
          Total: {len(hek_all):,}
        
        Orthogonal:
          Total: {len(orth_all):,}
        
        Overlap: {len(overlap_all):,}
        """
        ax6.text(0.1, 0.5, summary, fontsize=9, verticalalignment='center', fontfamily='monospace')
        
        fig.suptitle("2'O-Methylation Sites - HEK293 vs HEK293T", fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# Usage:
# plot_2ome_complete(dorado_mods_dict, OMe_A, OMe_C, OMe_G, OMe_U, mode='GM12878')
plot_2ome_complete(dorado_mods_dict, OMe_A, OMe_C, OMe_G, OMe_U, mode='HEK293')
# plot_2ome_complete(dorado_mods_dict, OMe_A, OMe_C, OMe_G, OMe_U, mode='both')

1 extra bytes in post.stringData array
'created' timestamp seems very low; regarding as unix timestamp



=== Processing 2'OMe - Mode: HEK293 ===
2'OMe-A Orth (HEK293T): 314 sites
2'OMe-C Orth (HEK293T): 650 sites
2'OMe-G Orth (HEK293T): 645 sites
2'OMe-U Orth (HEK293T): 450 sites
Combined orthogonal: 2059 sites
HEK293_2OMeA: 203 sites
HEK293_2OMeC: 1231 sites
HEK293_2OMeG: 587 sites
HEK293_2OMeU: 2194 sites

HEK293 vs HEK293T orthogonal overlaps:
  A: 0
  C: 0
  G: 0
  U: 0
  Combined: 0
Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/2ome_venns.pdf



  plt.show()


## RMSE

In [47]:
# ============================================================
# RMSE CALCULATION - ONLY ON MATCHED SITES (FIXED)
# ============================================================

import numpy as np
from sklearn.metrics import mean_squared_error
import pandas as pd
import polars as pl

def calculate_rmse_comparison(drs_df, ortho_df, drs_col, ortho_col, 
                               chr_col_drs, pos_col_drs,
                               chr_col_ortho, pos_col_ortho,
                               comparison_name="Comparison",
                               scale_ortho_by_100=False,
                               show_examples=True):
    """
    Calculate RMSE between DRS and orthogonal modification percentages
    *** ONLY for sites present in BOTH datasets ***
    
    Parameters:
    -----------
    show_examples : bool
        If True, show example paired values to verify calculation
    """
    
    print(f"\n{'='*70}")
    print(f"RMSE Calculation: {comparison_name}")
    print(f"{'='*70}")
    
    # ============================================================
    # STEP 1: Process DRS data
    # ============================================================
    print("\n  Step 1: Processing DRS data...")
    drs_processed = drs_df.with_columns([
        (pl.col(chr_col_drs).cast(pl.Utf8) + '_' + 
         pl.col(pos_col_drs).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
    ]).select(['site_id', drs_col])
    
    print(f"    Total DRS sites: {len(drs_processed):,}")
    
    # ============================================================
    # STEP 2: Process orthogonal data
    # ============================================================
    print("\n  Step 2: Processing Orthogonal data...")
    if isinstance(ortho_df, pl.DataFrame):
        sample_chr = ortho_df[chr_col_ortho][0] if len(ortho_df) > 0 else None
        if sample_chr and not str(sample_chr).startswith('chr'):
            ortho_processed = ortho_df.with_columns([
                ('chr' + pl.col(chr_col_ortho).cast(pl.Utf8) + '_' + 
                 pl.col(pos_col_ortho).cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        else:
            ortho_processed = ortho_df.with_columns([
                (pl.col(chr_col_ortho).cast(pl.Utf8) + '_' + 
                 pl.col(pos_col_ortho).cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        ortho_processed = ortho_processed.select(['site_id', ortho_col])
    else:
        # Pandas DataFrame
        sample_chr = str(ortho_df[chr_col_ortho].iloc[0]) if len(ortho_df) > 0 else None
        
        # CRITICAL FIX: Drop NaN positions before processing
        ortho_clean = ortho_df.dropna(subset=[pos_col_ortho])
        
        pos_int = ortho_clean[pos_col_ortho].astype(float).astype(int).astype(str)
        if sample_chr and sample_chr.startswith('chr'):
            site_ids = ortho_clean[chr_col_ortho].astype(str) + '_' + pos_int
        else:
            site_ids = 'chr' + ortho_clean[chr_col_ortho].astype(str) + '_' + pos_int
        
        ortho_processed = pl.DataFrame({
            'site_id': site_ids.tolist(),
            ortho_col: ortho_clean[ortho_col].tolist()
        })
    
    # Scale if needed
    if scale_ortho_by_100:
        ortho_processed = ortho_processed.with_columns([
            (pl.col(ortho_col) * 100).alias(ortho_col)
        ])
        print(f"    ✓ Scaled by 100 (ratio → percentage)")
    
    print(f"    Total Orthogonal sites: {len(ortho_processed):,}")
    
    # ============================================================
    # STEP 3: INNER JOIN - MATCHED SITES ONLY
    # ============================================================
    print("\n  Step 3: Finding matched sites (INNER JOIN)...")
    merged = drs_processed.join(ortho_processed, on='site_id', how='inner')
    
    print(f"\n  ✓ MATCHED SITES: {len(merged):,}")
    print(f"    DRS-only:  {len(drs_processed) - len(merged):,}")
    print(f"    Orth-only: {len(ortho_processed) - len(merged):,}")
    
    if len(merged) == 0:
        print("\n  ⚠️  No overlapping sites!")
        return None
    
    # ============================================================
    # STEP 4: Calculate metrics
    # ============================================================
    print("\n  Step 4: Calculating metrics...")
    
    drs_values = merged[drs_col].to_numpy()
    ortho_values = merged[ortho_col].to_numpy()
    
    # CRITICAL: Remove any NaN values
    valid_mask = ~(np.isnan(drs_values) | np.isnan(ortho_values))
    if not valid_mask.all():
        print(f"    ⚠️  Removing {(~valid_mask).sum()} sites with NaN values")
        drs_values = drs_values[valid_mask]
        ortho_values = ortho_values[valid_mask]
    
    if len(drs_values) == 0:
        print("\n  ⚠️  No valid paired values after removing NaNs!")
        return None
    
    # Calculate metrics
    rmse = np.sqrt(mean_squared_error(drs_values, ortho_values))
    mae = np.mean(np.abs(drs_values - ortho_values))
    correlation = np.corrcoef(drs_values, ortho_values)[0, 1]
    mean_diff = np.mean(drs_values - ortho_values)
    std_diff = np.std(drs_values - ortho_values)
    median_diff = np.median(drs_values - ortho_values)
    
    # ============================================================
    # SHOW EXAMPLE PAIRED VALUES FOR VERIFICATION
    # ============================================================
    if show_examples:
        print(f"\n  📋 Example paired values (first 10 matched sites):")
        print(f"  {'DRS %':>8} {'Orth %':>8} {'Diff':>8}")
        print(f"  {'-'*8} {'-'*8} {'-'*8}")
        for i in range(min(10, len(drs_values))):
            diff = drs_values[i] - ortho_values[i]
            print(f"  {drs_values[i]:8.2f} {ortho_values[i]:8.2f} {diff:+8.2f}")
        
        # Show largest disagreements
        abs_diffs = np.abs(drs_values - ortho_values)
        largest_diff_indices = np.argsort(abs_diffs)[-5:][::-1]
        print(f"\n  📋 Largest disagreements (top 5):")
        print(f"  {'DRS %':>8} {'Orth %':>8} {'Diff':>8}")
        print(f"  {'-'*8} {'-'*8} {'-'*8}")
        for i in largest_diff_indices:
            diff = drs_values[i] - ortho_values[i]
            print(f"  {drs_values[i]:8.2f} {ortho_values[i]:8.2f} {diff:+8.2f}")
    
    # ============================================================
    # PRINT RESULTS
    # ============================================================
    print(f"\n  {'─'*70}")
    print(f"  RESULTS (based on {len(drs_values):,} matched sites)")
    print(f"  {'─'*70}")
    
    print(f"\n  DRS ({drs_col}):")
    print(f"    Mean: {drs_values.mean():.2f}%")
    print(f"    Std:  {drs_values.std():.2f}%")
    print(f"    Range: [{drs_values.min():.2f}, {drs_values.max():.2f}]")
    
    print(f"\n  Orthogonal ({ortho_col}):")
    print(f"    Mean: {ortho_values.mean():.2f}%")
    print(f"    Std:  {ortho_values.std():.2f}%")
    print(f"    Range: [{ortho_values.min():.2f}, {ortho_values.max():.2f}]")
    
    print(f"\n  📊 AGREEMENT METRICS:")
    print(f"    RMSE:        {rmse:.3f}%")
    print(f"    MAE:         {mae:.3f}%")
    print(f"    Correlation: {correlation:.3f}")
    print(f"    Mean Diff:   {mean_diff:+.3f}% ± {std_diff:.3f}%")
    print(f"    Median Diff: {median_diff:+.3f}%")
    
    # Sanity check explanation
    print(f"\n  ℹ️  RMSE Interpretation:")
    print(f"    • RMSE of {rmse:.2f}% means the average error is ~{rmse:.2f} percentage points")
    print(f"    • For comparison: if DRS=50% and Orth=60%, error = 10%")
    print(f"    • RMSE is in the same units as your measurements (percentage points)")
    
    return {
        'comparison': comparison_name,
        'n_sites': len(drs_values),
        'n_drs_only': len(drs_processed) - len(drs_values),
        'n_ortho_only': len(ortho_processed) - len(drs_values),
        'rmse': rmse,
        'mae': mae,
        'correlation': correlation,
        'mean_diff': mean_diff,
        'std_diff': std_diff,
        'median_diff': median_diff,
        'drs_mean': drs_values.mean(),
        'drs_std': drs_values.std(),
        'ortho_mean': ortho_values.mean(),
        'ortho_std': ortho_values.std()
    }

# ============================================================
# RUN ALL RMSE CALCULATIONS
# ============================================================

results_summary = {}

print("\n" + "="*80)
print("CALCULATING RMSE - ONLY ON MATCHED SITES")
print("="*80)

# ============================================================
# m6A COMPARISONS
# ============================================================

print("\n" + "🔴"*40)
print("m6A RMSE CALCULATIONS")
print("🔴"*40)

# Filter DRS data once
hek293_m6a_filtered = hek293_m6a_drs.filter(
    (pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20)
)

# HEK293 vs GLORI-1 (NEW)
results_summary['HEK293_vs_GLORI1_new'] = calculate_rmse_comparison(
    drs_df=hek293_m6a_filtered,
    ortho_df=new_glori1,
    drs_col='Adjusted_Mod_Proportion',
    ortho_col='m6A_level_mean',
    chr_col_drs='Chromosome',
    pos_col_drs='End',
    chr_col_ortho='Chr',
    pos_col_ortho='Sites',
    comparison_name="HEK293 DRS vs GLORI-1 (New)",
    scale_ortho_by_100=False,
    show_examples=True
)

# HEK293 vs GLORI-2
results_summary['HEK293_vs_GLORI2'] = calculate_rmse_comparison(
    drs_df=hek293_m6a_filtered,
    ortho_df=combined_glori_2,
    drs_col='Adjusted_Mod_Proportion',
    ortho_col='m6A_level_mean',
    chr_col_drs='Chromosome',
    pos_col_drs='End',
    chr_col_ortho='Chr',
    pos_col_ortho='Site',
    comparison_name="HEK293 DRS vs GLORI-2",
    scale_ortho_by_100=False,
    show_examples=True
)

# ============================================================
# NEW: HEK293 vs GLORI-1 ∩ GLORI-2 (INTERSECTION ONLY)
# ============================================================

print("\n" + "🔴"*40)
print("m6A - GLORI INTERSECTION COMPARISON")
print("🔴"*40)

# Create GLORI intersection dataset
# First, convert both to Polars for consistent handling
glori1_pl = pl.from_pandas(new_glori1)
glori2_pl = pl.from_pandas(combined_glori_2)

# Add site IDs to both
glori1_with_id = glori1_pl.with_columns([
    (pl.col('Chr').cast(pl.Utf8) + '_' + pl.col('Sites').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
])

glori2_with_id = glori2_pl.with_columns([
    (pl.col('Chr').cast(pl.Utf8) + '_' + pl.col('Site').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
])

# Inner join to get only sites in BOTH GLORI methods
glori_intersection = glori1_with_id.join(
    glori2_with_id.select(['site_id', pl.col('m6A_level_mean').alias('m6A_level_mean_g2')]), 
    on='site_id', 
    how='inner'
)

# Average the two GLORI values for sites in intersection
glori_intersection = glori_intersection.with_columns([
    ((pl.col('m6A_level_mean') + pl.col('m6A_level_mean_g2')) / 2).alias('m6A_level_mean_combined')
])

print(f"\n  GLORI-1 ∩ GLORI-2: {len(glori_intersection):,} sites")

# Now compare DRS vs GLORI intersection
results_summary['HEK293_vs_GLORI_intersection'] = calculate_rmse_comparison(
    drs_df=hek293_m6a_filtered,
    ortho_df=glori_intersection,
    drs_col='Adjusted_Mod_Proportion',
    ortho_col='m6A_level_mean_combined',
    chr_col_drs='Chromosome',
    pos_col_drs='End',
    chr_col_ortho='Chr',
    pos_col_ortho='Sites',
    comparison_name="HEK293 DRS vs GLORI-1∩GLORI-2",
    scale_ortho_by_100=False,
    show_examples=True
)

# ============================================================
# m5C COMPARISONS
# ============================================================

print("\n" + "🟢"*40)
print("m5C RMSE CALCULATIONS")
print("🟢"*40)

hek293_m5c_filtered = hek293_m5c_drs.filter(
    (pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20)
)

results_summary['HEK293_vs_m5C_orth'] = calculate_rmse_comparison(
    drs_df=hek293_m5c_filtered,
    ortho_df=m5c_orthogonal_df,
    drs_col='Adjusted_Mod_Proportion',
    ortho_col='ratio',
    chr_col_drs='Chromosome',
    pos_col_drs='End',
    chr_col_ortho='chromosome',
    pos_col_ortho='position',
    comparison_name="HEK293 DRS vs m5C Orthogonal",
    scale_ortho_by_100=True,
    show_examples=True
)

# ============================================================
# PSEUDOURIDINE (Ψ) COMPARISONS
# ============================================================

print("\n" + "🟣"*40)
print("PSEUDOURIDINE (Ψ) RMSE CALCULATIONS")
print("🟣"*40)

hek293_psi_filtered = hek293_psi_drs.filter(
    (pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20)
)

# BID-seq - Convert to numeric and drop NaNs
bid_seq_numeric = bid_seq_df.copy()
bid_seq_numeric['Frac_Ave %'] = pd.to_numeric(bid_seq_numeric['Frac_Ave %'], errors='coerce')
bid_seq_numeric['pos'] = pd.to_numeric(bid_seq_numeric['pos'], errors='coerce')
bid_seq_numeric = bid_seq_numeric.dropna(subset=['Frac_Ave %', 'pos'])

results_summary['HEK293_vs_BIDseq'] = calculate_rmse_comparison(
    drs_df=hek293_psi_filtered,
    ortho_df=bid_seq_numeric,
    drs_col='Adjusted_Mod_Proportion',
    ortho_col='Frac_Ave %',
    chr_col_drs='Chromosome',
    pos_col_drs='End',
    chr_col_ortho='chr',
    pos_col_ortho='pos',
    comparison_name="HEK293 DRS vs BID-seq",
    scale_ortho_by_100=False,
    show_examples=True
)

# PRAISE - Calculate average and REMOVE NaNs
praise_with_avg = praise_filtered.copy()

# CRITICAL FIX: Drop rows with NaN genomic_position BEFORE calculation
praise_with_avg = praise_with_avg.dropna(subset=['genomic_position'])

praise_with_avg['stop_rate_avg'] = (
    praise_with_avg['rep1-difference of stop rate'] + 
    praise_with_avg['rep2-difference of stop rate'] + 
    praise_with_avg['rep3-difference of stop rate']
) / 3
praise_with_avg['stop_rate_avg_pct'] = praise_with_avg['stop_rate_avg'] * 100

print(f"\n  PRAISE sites after removing NaNs: {len(praise_with_avg):,}")

results_summary['HEK293_vs_PRAISE'] = calculate_rmse_comparison(
    drs_df=hek293_psi_filtered,
    ortho_df=praise_with_avg,
    drs_col='Adjusted_Mod_Proportion',
    ortho_col='stop_rate_avg_pct',
    chr_col_drs='Chromosome',
    pos_col_drs='End',
    chr_col_ortho='chromosome',
    pos_col_ortho='genomic_position',
    comparison_name="HEK293 DRS vs PRAISE",
    scale_ortho_by_100=False,
    show_examples=True
)

# ============================================================
# NO PERCENTAGE DATA
# ============================================================

print("\n" + "🟡"*40)
print("INOSINE - NO PERCENTAGE DATA")
print("🟡"*40)
print("  ⚠️  Only coverage data, no modification %")

print("\n" + "🟤"*40)
print("2'-O-METHYLATION - NO PERCENTAGE DATA")
print("🟤"*40)
print("  ⚠️  Only presence/absence, no modification %")

# ============================================================
# SUMMARY TABLE
# ============================================================

print("\n" + "="*80)
print("RMSE SUMMARY TABLE")
print("="*80)

if results_summary:
    summary_df = pd.DataFrame(results_summary).T
    summary_df = summary_df.round(3)
    
    col_order = ['comparison', 'n_sites', 'n_drs_only', 'n_ortho_only',
                 'correlation', 'rmse', 'mae', 
                 'mean_diff', 'std_diff', 'median_diff', 
                 'drs_mean', 'drs_std', 'ortho_mean', 'ortho_std']
    summary_df = summary_df[col_order]
    
    # Add match percentages
    summary_df['pct_drs_matched'] = (summary_df['n_sites'] / 
                                     (summary_df['n_sites'] + summary_df['n_drs_only']) * 100).round(1)
    summary_df['pct_ortho_matched'] = (summary_df['n_sites'] / 
                                       (summary_df['n_sites'] + summary_df['n_ortho_only']) * 100).round(1)
    
    print("\n" + summary_df.to_string())
    
    # Save
    summary_path = '/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/RMSE_summary.csv'
    summary_df.to_csv(summary_path, index=False)
    print(f"\n✓ Saved to {summary_path}")
    
    # ============================================================
    # INTERPRETATION
    # ============================================================
    print("\n" + "="*80)
    print("INTERPRETATION")
    print("="*80)
    
    for idx, row in summary_df.iterrows():
        print(f"\n{row['comparison']}:")
        print(f"  Sites: {int(row['n_sites']):,} matched ({row['pct_drs_matched']:.1f}% of DRS)")
        print(f"  Correlation: {row['correlation']:.3f}")
        print(f"  RMSE: {row['rmse']:.2f}% (avg error per site)")
        print(f"  MAE:  {row['mae']:.2f}%")
        
        # Context for RMSE
        if row['rmse'] < 10:
            rmse_interp = "Excellent agreement (< 10% error)"
        elif row['rmse'] < 15:
            rmse_interp = "Good agreement (10-15% error)"
        elif row['rmse'] < 20:
            rmse_interp = "Moderate agreement (15-20% error)"
        else:
            rmse_interp = "Weaker agreement (> 20% error)"
        
        print(f"  → {rmse_interp}")

print("\n" + "="*80)
print("✓ RMSE ANALYSIS COMPLETE")
print("="*80)
print("\nNOTE: All RMSE values calculated ONLY on sites present in BOTH datasets")
print("      Site matching based on exact chromosome + position coordinates")


CALCULATING RMSE - ONLY ON MATCHED SITES

🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴
m6A RMSE CALCULATIONS
🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴🔴

RMSE Calculation: HEK293 DRS vs GLORI-1 (New)

  Step 1: Processing DRS data...
    Total DRS sites: 67,517

  Step 2: Processing Orthogonal data...
    Total Orthogonal sites: 170,240

  Step 3: Finding matched sites (INNER JOIN)...

  ✓ MATCHED SITES: 47,518
    DRS-only:  19,999
    Orth-only: 122,722

  Step 4: Calculating metrics...

  📋 Example paired values (first 10 matched sites):
     DRS %   Orth %     Diff
  -------- -------- --------
     95.76    89.44    +6.31
     46.62    63.61   -16.99
     52.31    55.26    -2.95
     24.22    21.16    +3.07
     56.67    71.68   -15.01
     47.20    48.62    -1.42
     48.23    67.73   -19.50
     45.40    65.22   -19.81
     52.30    26.76   +25.54
     76.09    68.35    +7.73

  📋 Largest disagreements (top 5):
     DRS %   Orth %     Diff
  -------- -------- --------
     99.88    17.

TypeError: Expected numeric dtype, got object instead.

## Save Orthogonally Validated Positions

In [29]:
import polars as pl
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib_venn import venn2, venn3

# ============================================
# PART 1: SAVE VALIDATED SITES WITH FIXED POSITION HANDLING
# ============================================

def save_validated_sites(
    dorado_mods_dict,
    combined_glori_1, 
    combined_glori_2_df, 
    m5c_orthogonal_df,
    bid_seq_df,
    praise_filtered,
    combined_ino,
    cell_line='HEK293',  # or 'GM12878'
    output_dir="/Volumes/AJS_SSD/HEK293/validated_sites/",
    debug=False
):
    """
    Save sites that are validated by both DRS and orthogonal methods
    with proper chromosome naming and position type consistency
    """
    # Create output directory if it doesn't exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    print("="*60)
    print(f"Saving Validated Sites (DRS ∩ Orthogonal) for {cell_line}")
    print("="*60)
    
    # Helper function to get DRS sites with values
    def get_drs_data(mod_type):
        key = f"{cell_line}_{mod_type}"
        if key in dorado_mods_dict:
            df = list(dorado_mods_dict[key].values())[0]
            if 'Adjusted_Mod_Proportion' in df.columns:
                df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
            else:
                df_filtered = df.filter(pl.col('Score') >= 20)
            
            # Check if chromosomes need 'chr' prefix
            sample_chr = df_filtered['Chromosome'][0] if len(df_filtered) > 0 else None
            
            if debug:
                print(f"  DRS {mod_type} sample chromosome: {sample_chr}")
            
            # Ensure positions are integers
            if sample_chr and not str(sample_chr).startswith('chr'):
                df_filtered = df_filtered.with_columns([
                    ('chr' + pl.col('Chromosome').cast(pl.Utf8) + '_' + 
                     pl.col('End').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
                ])
            else:
                df_filtered = df_filtered.with_columns([
                    (pl.col('Chromosome').cast(pl.Utf8) + '_' + 
                     pl.col('End').cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
                ])
            
            if debug and len(df_filtered) > 0:
                print(f"  DRS {mod_type} sample site IDs: {df_filtered['site_id'].head(3).to_list()}")
            
            return df_filtered
        return None
    
    # Helper function to process orthogonal data with consistent chromosome naming and integer positions
    def process_orthogonal(df, chr_col, pos_col, label=""):
        if df is None:
            return None
            
        if not isinstance(df, pl.DataFrame):
            df = pl.from_pandas(df)
        
        # Check chromosome format
        sample_chr = df[chr_col][0] if len(df) > 0 else None
        
        if debug:
            print(f"  {label} sample chromosome: {sample_chr}")
        
        # CRITICAL: Cast positions to Int64 to handle floats
        # Ensure consistent chromosome naming and integer positions
        if sample_chr and not str(sample_chr).startswith('chr'):
            df = df.with_columns([
                ('chr' + pl.col(chr_col).cast(pl.Utf8) + '_' + 
                 pl.col(pos_col).cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        else:
            df = df.with_columns([
                (pl.col(chr_col).cast(pl.Utf8) + '_' + 
                 pl.col(pos_col).cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).alias('site_id')
            ])
        
        if debug and len(df) > 0:
            print(f"  {label} sample site IDs: {df['site_id'].head(3).to_list()}")
        
        return df
    
    # 1. m6A VALIDATION (DRS ∩ GLORI)
    print("\n--- m6A Processing ---")
    drs_m6a = get_drs_data('m6a')
    if drs_m6a is not None and (combined_glori_1 is not None or combined_glori_2_df is not None):
        # Process GLORI-1
        glori1_processed = process_orthogonal(combined_glori_1, 'Chr', 'Sites', 'GLORI-1') if combined_glori_1 is not None else None
        # Process GLORI-2
        glori2_processed = process_orthogonal(combined_glori_2_df, 'Chr', 'Site', 'GLORI-2') if combined_glori_2_df is not None else None
        
        # Get DRS sites
        drs_sites = set(drs_m6a['site_id'].to_list())
        
        # 1a. SAVE GLORI-1 VALIDATED SITES SEPARATELY
        if glori1_processed is not None:
            glori1_sites = set(glori1_processed['site_id'].to_list())
            validated_glori1_sites = drs_sites & glori1_sites
            
            # Create GLORI-1 validated dataframe
            validated_glori1 = drs_m6a.filter(pl.col('site_id').is_in(validated_glori1_sites))
            glori1_values = glori1_processed.select(['site_id', 'Ratio_mean']).rename({'Ratio_mean': 'GLORI1_value'})
            validated_glori1 = validated_glori1.join(glori1_values, on='site_id', how='left', coalesce=True)
            
            # Save GLORI-1 validated sites
            output_path = f"{output_dir}/{cell_line}_m6A_GLORI1_validated_sites.csv"
            validated_glori1.write_csv(output_path)
            print(f"✓ Saved m6A-GLORI1 validated sites: {len(validated_glori1)} sites")
        
        # 1b. SAVE GLORI-2 VALIDATED SITES SEPARATELY
        if glori2_processed is not None:
            glori2_sites = set(glori2_processed['site_id'].to_list())
            validated_glori2_sites = drs_sites & glori2_sites
            
            # Create GLORI-2 validated dataframe
            validated_glori2 = drs_m6a.filter(pl.col('site_id').is_in(validated_glori2_sites))
            glori2_values = glori2_processed.select(['site_id', 'm6A_level_mean']).rename({'m6A_level_mean': 'GLORI2_value'})
            validated_glori2 = validated_glori2.join(glori2_values, on='site_id', how='left', coalesce=True)
            
            # Save GLORI-2 validated sites
            output_path = f"{output_dir}/{cell_line}_m6A_GLORI2_validated_sites.csv"
            validated_glori2.write_csv(output_path)
            print(f"✓ Saved m6A-GLORI2 validated sites: {len(validated_glori2)} sites")
        
        # 1c. SAVE COMBINED GLORI (UNION) VALIDATED SITES
        glori_sites = set()
        if glori1_processed is not None:
            glori_sites.update(glori1_processed['site_id'].to_list())
        if glori2_processed is not None:
            glori_sites.update(glori2_processed['site_id'].to_list())
        
        validated_m6a_sites = drs_sites & glori_sites
        validated_m6a = drs_m6a.filter(pl.col('site_id').is_in(validated_m6a_sites))
        
        if glori1_processed is not None:
            glori1_values = glori1_processed.select(['site_id', 'Ratio_mean']).rename({'Ratio_mean': 'GLORI1_value'})
            validated_m6a = validated_m6a.join(glori1_values, on='site_id', how='left', coalesce=True)
        
        if glori2_processed is not None:
            glori2_values = glori2_processed.select(['site_id', 'm6A_level_mean']).rename({'m6A_level_mean': 'GLORI2_value'})
            validated_m6a = validated_m6a.join(glori2_values, on='site_id', how='left', coalesce=True)
        
        output_path = f"{output_dir}/{cell_line}_m6A_GLORI_combined_validated_sites.csv"
        validated_m6a.write_csv(output_path)
        print(f"✓ Saved m6A-GLORI combined validated sites: {len(validated_m6a)} sites")
    
    # 2. m5C VALIDATION (DRS ∩ Orthogonal)
    print("\n--- m5C Processing ---")
    drs_m5c = get_drs_data('m5c')
    if drs_m5c is not None and m5c_orthogonal_df is not None:
        m5c_processed = process_orthogonal(m5c_orthogonal_df, 'chromosome', 'position', 'm5C-orthogonal')
        
        drs_sites = set(drs_m5c['site_id'].to_list())
        ortho_sites = set(m5c_processed['site_id'].to_list())
        
        validated_m5c_sites = drs_sites & ortho_sites
        validated_m5c = drs_m5c.filter(pl.col('site_id').is_in(validated_m5c_sites))
        
        if 'modification_level' in m5c_processed.columns:
            ortho_values = m5c_processed.select(['site_id', 'modification_level']).rename({'modification_level': 'orthogonal_value'})
            validated_m5c = validated_m5c.join(ortho_values, on='site_id', how='left', coalesce=True)
        
        output_path = f"{output_dir}/{cell_line}_m5C_validated_sites.csv"
        validated_m5c.write_csv(output_path)
        print(f"✓ Saved m5C validated sites: {len(validated_m5c)} sites")
    
    # 3. PSEUDOURIDINE VALIDATION (DRS ∩ (PRAISE ∪ BID-seq))
    print("\n--- Pseudouridine Processing ---")
    drs_psi = get_drs_data('psi')
    if drs_psi is not None and (bid_seq_df is not None or praise_filtered is not None):
        # Process orthogonal methods with proper position handling
        bid_processed = process_orthogonal(bid_seq_df, 'chr', 'pos', 'BID-seq') if bid_seq_df is not None else None
        praise_processed = process_orthogonal(praise_filtered, 'chromosome', 'genomic_position', 'PRAISE') if praise_filtered is not None else None
        
        drs_sites = set(drs_psi['site_id'].to_list())
        
        # 3a. SAVE BID-seq VALIDATED SITES SEPARATELY
        if bid_processed is not None:
            bid_sites = set(bid_processed['site_id'].to_list())
            validated_bid_sites = drs_sites & bid_sites
            
            validated_bid = drs_psi.filter(pl.col('site_id').is_in(validated_bid_sites))
            
            output_path = f"{output_dir}/{cell_line}_psi_BIDseq_validated_sites.csv"
            validated_bid.write_csv(output_path)
            print(f"✓ Saved Ψ-BIDseq validated sites: {len(validated_bid)} sites")
        
        # 3b. SAVE PRAISE VALIDATED SITES SEPARATELY
        if praise_processed is not None:
            praise_sites = set(praise_processed['site_id'].to_list())
            validated_praise_sites = drs_sites & praise_sites
            
            validated_praise = drs_psi.filter(pl.col('site_id').is_in(validated_praise_sites))
            
            output_path = f"{output_dir}/{cell_line}_psi_PRAISE_validated_sites.csv"
            validated_praise.write_csv(output_path)
            print(f"✓ Saved Ψ-PRAISE validated sites: {len(validated_praise)} sites")
        
        # 3c. SAVE COMBINED (UNION) VALIDATED SITES
        ortho_sites = set()
        if bid_processed is not None:
            ortho_sites.update(bid_processed['site_id'].to_list())
        if praise_processed is not None:
            ortho_sites.update(praise_processed['site_id'].to_list())
        
        validated_psi_sites = drs_sites & ortho_sites
        validated_psi = drs_psi.filter(pl.col('site_id').is_in(validated_psi_sites))
        
        # Add source information
        bid_sites = set(bid_processed['site_id'].to_list()) if bid_processed is not None else set()
        praise_sites = set(praise_processed['site_id'].to_list()) if praise_processed is not None else set()
        
        sources = []
        for site in validated_psi['site_id'].to_list():
            source_list = []
            if site in bid_sites:
                source_list.append('BID-seq')
            if site in praise_sites:
                source_list.append('PRAISE')
            sources.append(','.join(source_list))
        
        validated_psi = validated_psi.with_columns([
            pl.Series('orthogonal_source', sources)
        ])
        
        output_path = f"{output_dir}/{cell_line}_psi_combined_validated_sites.csv"
        validated_psi.write_csv(output_path)
        print(f"✓ Saved Ψ combined validated sites: {len(validated_psi)} sites")
    
    # 4. INOSINE VALIDATION (DRS ∩ Orthogonal)
    print("\n--- Inosine Processing ---")
    drs_inosine = get_drs_data('inosine')
    if drs_inosine is not None and combined_ino is not None:
        ino_processed = process_orthogonal(combined_ino, 'Chromosome', 'position', 'Inosine-orthogonal')
        
        drs_sites = set(drs_inosine['site_id'].to_list())
        ortho_sites = set(ino_processed['site_id'].to_list())
        
        validated_ino_sites = drs_sites & ortho_sites
        validated_inosine = drs_inosine.filter(pl.col('site_id').is_in(validated_ino_sites))
        
        if 'editing_level' in ino_processed.columns:
            ortho_values = ino_processed.select(['site_id', 'editing_level']).rename({'editing_level': 'orthogonal_value'})
            validated_inosine = validated_inosine.join(ortho_values, on='site_id', how='left', coalesce=True)
        
        output_path = f"{output_dir}/{cell_line}_inosine_validated_sites.csv"
        validated_inosine.write_csv(output_path)
        print(f"✓ Saved Inosine validated sites: {len(validated_inosine)} sites")
    
    # Print summary statistics
    print("\n" + "="*60)
    print("Validation Summary:")
    print("="*60)
    
    if drs_m6a is not None:
        total_drs_m6a = len(drs_m6a)
        validated_m6a_count = len(validated_m6a) if 'validated_m6a' in locals() else 0
        validated_glori1_count = len(validated_glori1) if 'validated_glori1' in locals() else 0
        validated_glori2_count = len(validated_glori2) if 'validated_glori2' in locals() else 0
        
        print(f"\nm6A Total DRS sites: {total_drs_m6a}")
        if 'validated_glori1' in locals():
            print(f"  GLORI-1 validated: {validated_glori1_count} ({100*validated_glori1_count/total_drs_m6a:.1f}%)")
        if 'validated_glori2' in locals():
            print(f"  GLORI-2 validated: {validated_glori2_count} ({100*validated_glori2_count/total_drs_m6a:.1f}%)")
        print(f"  Combined GLORI validated: {validated_m6a_count} ({100*validated_m6a_count/total_drs_m6a:.1f}%)")
    
    if drs_m5c is not None:
        total_drs_m5c = len(drs_m5c)
        validated_m5c_count = len(validated_m5c) if 'validated_m5c' in locals() else 0
        print(f"\nm5C Total DRS sites: {total_drs_m5c}")
        print(f"  Validated: {validated_m5c_count} ({100*validated_m5c_count/total_drs_m5c:.1f}%)")
    
    if drs_psi is not None:
        total_drs_psi = len(drs_psi)
        validated_psi_count = len(validated_psi) if 'validated_psi' in locals() else 0
        validated_bid_count = len(validated_bid) if 'validated_bid' in locals() else 0
        validated_praise_count = len(validated_praise) if 'validated_praise' in locals() else 0
        
        print(f"\nΨ Total DRS sites: {total_drs_psi}")
        if 'validated_bid' in locals():
            print(f"  BID-seq validated: {validated_bid_count} ({100*validated_bid_count/total_drs_psi:.1f}%)")
        if 'validated_praise' in locals():
            print(f"  PRAISE validated: {validated_praise_count} ({100*validated_praise_count/total_drs_psi:.1f}%)")
        print(f"  Combined validated: {validated_psi_count} ({100*validated_psi_count/total_drs_psi:.1f}%)")
    
    if drs_inosine is not None:
        total_drs_inosine = len(drs_inosine)
        validated_ino_count = len(validated_inosine) if 'validated_inosine' in locals() else 0
        print(f"\nInosine Total DRS sites: {total_drs_inosine}")
        print(f"  Validated: {validated_ino_count} ({100*validated_ino_count/total_drs_inosine:.1f}%)")
    
    return True


# ============================================
# PART 2: PLOTTING FUNCTIONS WITH FIXED POSITION HANDLING
# ============================================

# Style settings
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['figure.dpi'] = 600
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

def get_drs_sites(mod_dict, cell_line, mod, debug=False):
    """Get DRS sites with proper chromosome naming and integer positions"""
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        df = list(mod_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        
        # Check if chromosomes need 'chr' prefix
        sample_chr = df_filtered['Chromosome'][0] if len(df_filtered) > 0 else None
        
        if debug:
            print(f"  {key} sample chromosome: {sample_chr}")
        
        # Ensure positions are integers
        if sample_chr and not str(sample_chr).startswith('chr'):
            sites = set(('chr' + df_filtered['Chromosome'].cast(pl.Utf8) + '_' + 
                        df_filtered['End'].cast(pl.Int64).cast(pl.Utf8)).to_list())
        else:
            sites = set((df_filtered['Chromosome'].cast(pl.Utf8) + '_' + 
                        df_filtered['End'].cast(pl.Int64).cast(pl.Utf8)).to_list())
        
        print(f"{key}: {len(sites)} sites")
        
        if debug and len(sites) > 0:
            print(f"  Sample site IDs: {list(sites)[:3]}")
        
        return sites
    return set()

def process_orthogonal_smart(df, chr_col, pos_col, label="Orthogonal", debug=False):
    """Process orthogonal data with consistent chromosome naming and integer positions"""
    if df is None:
        return set()
    
    if isinstance(df, pl.DataFrame):
        sample_chr = df[chr_col][0] if len(df) > 0 else None
        
        if debug:
            print(f"  {label} sample chromosome: {sample_chr}")
            # Check if positions are floats
            sample_pos = df[pos_col][0] if len(df) > 0 else None
            print(f"  {label} sample position: {sample_pos} (type: {type(sample_pos).__name__})")
        
        # CRITICAL: Handle float positions by casting to int
        if sample_chr is not None and not str(sample_chr).startswith('chr'):
            sites = set(('chr' + df[chr_col].cast(pl.Utf8) + '_' + 
                        df[pos_col].cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).to_list())
        else:
            sites = set((df[chr_col].cast(pl.Utf8) + '_' + 
                        df[pos_col].cast(pl.Float64).cast(pl.Int64).cast(pl.Utf8)).to_list())
    else:
        # Handle pandas DataFrames
        sample_chr = str(df[chr_col].iloc[0]) if len(df) > 0 else None
        
        if debug:
            print(f"  {label} sample chromosome: {sample_chr}")
            sample_pos = df[pos_col].iloc[0] if len(df) > 0 else None
            print(f"  {label} sample position: {sample_pos} (type: {type(sample_pos).__name__})")
        
        # Cast to int to handle floats
        pos_int = df[pos_col].astype(float).astype(int).astype(str)
        if sample_chr and sample_chr.startswith('chr'):
            sites = set(df[chr_col].astype(str) + '_' + pos_int)
        else:
            sites = set('chr' + df[chr_col].astype(str) + '_' + pos_int)
    
    print(f"{label}: {len(sites)} sites")
    
    if debug and len(sites) > 0:
        print(f"  Sample site IDs: {list(sites)[:3]}")
    
    return sites

# Example plotting function with fixed position handling
def plot_psi_complete(dorado_mods_dict, bid_seq_df, praise_filtered, mode='both',
                      output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/psi_venns.pdf',
                      debug=False):
    """Complete Pseudouridine Venn diagram suite with fixed position handling"""
    
    print("\n" + "="*60)
    print(f"=== Processing Pseudouridine - Mode: {mode} ===")
    print("="*60)
    
    # Process orthogonal data with proper position handling
    bid_sites = process_orthogonal_smart(bid_seq_df, 'chr', 'pos', 'BID-seq', debug)
    praise_sites = process_orthogonal_smart(praise_filtered, 'chromosome', 'genomic_position', 'PRAISE', debug)
    
    # Calculate overlaps
    praise_bid_overlap = praise_sites & bid_sites
    praise_or_bid = praise_sites | bid_sites
    
    print(f"\nOrthogonal method comparison:")
    print(f"  PRAISE & BID-seq intersection: {len(praise_bid_overlap)} sites")
    print(f"  PRAISE only: {len(praise_sites - bid_sites)} sites")
    print(f"  BID-seq only: {len(bid_sites - praise_sites)} sites")
    print(f"  PRAISE | BID-seq union: {len(praise_or_bid)} sites")
    
    if mode == 'HEK293':
        hek293_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'psi', debug)
        
        # Calculate all overlaps
        hek_vs_bid = hek293_sites & bid_sites
        hek_vs_praise = hek293_sites & praise_sites
        hek_vs_either = hek293_sites & praise_or_bid
        hek_vs_both = hek293_sites & praise_bid_overlap
        
        print(f"\nHEK293 DRS overlaps:")
        print(f"  vs BID-seq: {len(hek_vs_bid)} sites ({100*len(hek_vs_bid)/len(hek293_sites) if len(hek293_sites) > 0 else 0:.1f}% of DRS)")
        print(f"  vs PRAISE: {len(hek_vs_praise)} sites ({100*len(hek_vs_praise)/len(hek293_sites) if len(hek293_sites) > 0 else 0:.1f}% of DRS)")
        print(f"  vs either method: {len(hek_vs_either)} sites ({100*len(hek_vs_either)/len(hek293_sites) if len(hek293_sites) > 0 else 0:.1f}% of DRS)")
        print(f"  vs both methods: {len(hek_vs_both)} sites ({100*len(hek_vs_both)/len(hek293_sites) if len(hek293_sites) > 0 else 0:.1f}% of DRS)")
        
        fig = plt.figure(figsize=(12, 4))
        
        ax1 = plt.subplot(1, 3, 1)
        plt.sca(ax1)
        venn2([hek293_sites, praise_or_bid],
              set_labels=('HEK293', 'PRAISE ∪ BID-seq'),
              set_colors=('skyblue', 'salmon'), alpha=0.7)
        ax1.set_title('Ψ: HEK293 vs Orthogonal (Union)', fontweight='bold', fontsize=14)
        
        ax2 = plt.subplot(1, 3, 2)
        plt.sca(ax2)
        venn = venn3([hek293_sites, praise_sites, bid_sites],
                     set_labels=('HEK293', 'PRAISE', 'BID-seq'))
        for patch, color in zip(venn.patches, ['skyblue', 'salmon', 'gold', 'purple', 'orange', 'lightblue', 'gray']):
            if patch:
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        ax2.set_title('Ψ: Three-way Comparison', fontweight='bold', fontsize=14)
        
        ax3 = plt.subplot(1, 3, 3)
        plt.sca(ax3)
        venn2([praise_sites, bid_sites],
              set_labels=('PRAISE', 'BID-seq'),
              set_colors=('salmon', 'gold'), alpha=0.7)
        ax3.set_title('Ψ: PRAISE vs BID-seq', fontweight='bold', fontsize=14)
        
        fig.suptitle('Pseudouridine (Ψ) Sites - HEK293', fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"\nSaved to {output_path}")

# Usage:
save_validated_sites(
    dorado_mods_dict,
    combined_glori_1, 
    combined_glori_2, 
    m5c_orthogonal_df,
    bid_seq_df,
    praise_filtered,
    combined_ino,
    cell_line='HEK293',
    output_dir="/Volumes/AJS_SSD/HEK293/orthogonal_validated/",
    debug=True
)

Saving Validated Sites (DRS ∩ Orthogonal) for HEK293

--- m6A Processing ---
  DRS m6a sample chromosome: chr1
  DRS m6a sample site IDs: ['chr1_14415', 'chr1_14517', 'chr1_14638']
  GLORI-1 sample chromosome: chr10
  GLORI-1 sample site IDs: ['chr10_3777586', 'chr10_3779328', 'chr10_3781837']
  GLORI-2 sample chromosome: chr10
  GLORI-2 sample site IDs: ['chr10_100114427', 'chr10_100150475', 'chr10_100150834']
✓ Saved m6A-GLORI1 validated sites: 23426 sites
✓ Saved m6A-GLORI2 validated sites: 31580 sites
✓ Saved m6A-GLORI combined validated sites: 36357 sites

--- m5C Processing ---
  DRS m5c sample chromosome: chr1
  DRS m5c sample site IDs: ['chr1_14516', 'chr1_14953', 'chr1_15091']
  m5C-orthogonal sample chromosome: 1
  m5C-orthogonal sample site IDs: ['chr1_918816', 'chr1_941220', 'chr1_944475']
✓ Saved m5C validated sites: 59 sites

--- Pseudouridine Processing ---
  DRS psi sample chromosome: chr1
  DRS psi sample site IDs: ['chr1_186579', 'chr1_630836', 'chr1_939369']
  BID-se

True

# Give them hell

In [28]:
dorado_mods_dict['HEK293_inosine']

{'filtered_17596_dataframe.parquet': shape: (8_235_544, 19)
 ┌────────────┬──────────┬──────────┬───────┬───┬────────┬────────┬───────────┬─────────────────────┐
 │ Chromosome ┆ Start    ┆ End      ┆ Call  ┆ … ┆ N_fail ┆ N_diff ┆ N_no_call ┆ Adjusted_Mod_Propor │
 │ ---        ┆ ---      ┆ ---      ┆ ---   ┆   ┆ ---    ┆ ---    ┆ ---       ┆ tion                │
 │ str        ┆ i64      ┆ i64      ┆ str   ┆   ┆ i64    ┆ i64    ┆ i64       ┆ ---                 │
 │            ┆          ┆          ┆       ┆   ┆        ┆        ┆           ┆ f64                 │
 ╞════════════╪══════════╪══════════╪═══════╪═══╪════════╪════════╪═══════════╪═════════════════════╡
 │ chr1       ┆ 14378    ┆ 14379    ┆ 17596 ┆ … ┆ 1      ┆ 0      ┆ 0         ┆ 0.0                 │
 │ chr1       ┆ 14383    ┆ 14384    ┆ 17596 ┆ … ┆ 3      ┆ 0      ┆ 0         ┆ 0.0                 │
 │ chr1       ┆ 14385    ┆ 14386    ┆ 17596 ┆ … ┆ 5      ┆ 0      ┆ 0         ┆ 0.0                 │
 │ chr1       ┆ 14386 

In [29]:
combined_ino.head()

Unnamed: 0,Gene ID,Chromosome,position,strand,coverage,truncated reads,Gene symbol,Location,repeatfamily,replicate
0,ENSG00000284733.1,chr1,492159,-,63,51,"OR4F29(dist=40462),RF00026(dist=24217)",intergenic,./.,ino_1
1,ENSG00000284662.1,chr1,727131,-,29,16,"OR4F16(dist=40458),RNU6-1199P(dist=31102)",intergenic,./.,ino_1
2,ENSG00000284662.1,chr1,727162,-,332,292,"OR4F16(dist=40489),RNU6-1199P(dist=31071)",intergenic,./.,ino_1
3,ENSG00000284662.1,chr1,727708,-,33,19,"OR4F16(dist=41035),RNU6-1199P(dist=30525)",intergenic,./.,ino_1
4,ENSG00000284662.1,chr1,753268,-,6,6,"OR4F16(dist=66595),RNU6-1199P(dist=4965)",intergenic,SINE/Alu,ino_1


In [28]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib_venn import venn2
import polars as pl
import pandas as pd

# ============ HELPER FUNCTIONS ============

def get_gm12878_sites(mod_dict, mod):
    """Extract GM12878 dataframe from dictionary and create site set - POLARS VERSION"""
    key = f"GM12878_{mod}"
    if key in mod_dict:
        # Get the first (and likely only) dataframe for this key
        df = list(mod_dict[key].values())[0]
        
        # POLARS FILTERING - use .filter() not []
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter(
                (pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20)
            )
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        
        # Create sites using Polars - convert to string and concatenate
        sites = set(
            (df_filtered['Chromosome'].cast(pl.Utf8) + '_' + df_filtered['End'].cast(pl.Utf8)).to_list()
        )
        print(f"{key}: {len(sites)} sites after filtering")
        return sites
    else:
        print(f"Warning: {key} not found in dictionary")
        return set()

def process_orthogonal_data(df, chr_col, pos_col, label="Orthogonal"):
    """Create genomic site identifiers from orthogonal datasets - handles both Pandas and Polars"""
    # Better detection: check if it's a Polars DataFrame specifically
    if isinstance(df, pl.DataFrame):  # Polars
        sites = set(
            (df[chr_col].cast(pl.Utf8) + '_' + df[pos_col].cast(pl.Utf8)).to_list()
        )
    else:  # Pandas
        sites = set(df[chr_col].astype(str) + '_' + df[pos_col].astype(str))
    
    print(f"{label}: {len(sites)} sites")
    return sites

# ============ INOSINE PLOTTING FUNCTION ============

def plot_gm12878_inosine(dorado_mods_dict, combined_ino_df,
                         output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_inosine_venn.pdf'):
    """
    Create Venn diagram for GM12878 Inosine vs orthogonal validation
    
    Parameters:
    -----------
    dorado_mods_dict : dict
        Dictionary containing DRS data
    combined_ino_df : DataFrame (Pandas)
        Combined orthogonal inosine data with 'Chromosome' and 'position' columns
    output_path : str
        Path to save the output PDF
    """
    print("\n" + "="*60)
    print("=== Processing GM12878 Inosine ===")
    print("="*60)
    
    gm12878_sites = get_gm12878_sites(dorado_mods_dict, 'inosine')
    orthogonal_sites = process_orthogonal_data(combined_ino_df, 'Chromosome', 'position', 'Inosine Orthogonal')
    
    # Calculate overlap
    overlap = gm12878_sites & orthogonal_sites
    print(f"Overlap: {len(overlap)} sites")
    if len(gm12878_sites) > 0 and len(orthogonal_sites) > 0:
        print(f"Overlap %: {100*len(overlap)/len(gm12878_sites):.1f}% of GM12878, {100*len(overlap)/len(orthogonal_sites):.1f}% of Orthogonal")
    
    # Create figure
    fig, ax = plt.subplots(figsize=(4, 4))
    
    venn = venn2([gm12878_sites, orthogonal_sites],
                 set_labels=('GM12878 DRS', 'Orthogonal'),
                 set_colors=('lightgreen', 'salmon'), 
                 alpha=0.7,
                 ax=ax)
    
    ax.set_title('Inosine (A-to-I) Sites:\nGM12878 vs Orthogonal', fontweight='bold', fontsize=14)
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ============ USAGE ============
plot_gm12878_inosine(dorado_mods_dict, combined_ino)


=== Processing GM12878 Inosine ===
GM12878_inosine: 12986 sites after filtering
Inosine Orthogonal: 29745 sites
Overlap: 2148 sites
Overlap %: 16.5% of GM12878, 7.2% of Orthogonal
Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_inosine_venn.pdf



  plt.show()


In [None]:
dorado_mods_dict['GM12878_2OMeA']

In [29]:
def plot_gm12878_2ome_exact_matches_only(dorado_mods_dict, ome_A_df, ome_C_df, ome_G_df, ome_U_df,
                      output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_2ome_venns.pdf'):
    """
    Create Venn diagrams for GM12878 2'OMe - EXACT MATCHES ONLY
    NOTE: 2'OMe shows very low overlap between GM12878 and HEK293T compared to other modifications
    """
    print("\n" + "="*60)
    print("=== Processing GM12878 2'O-Methylation (EXACT MATCHES) ===")
    print("="*60)
    print("NOTE: Using exact coordinate matches only")
    print("      2'OMe shows lower conservation across cell lines")
    print("="*60)
    
    # Get DRS sites for each base (using Start to match 1-based coords)
    def get_sites_using_start(mod_type):
        key = f"GM12878_{mod_type}"
        df = list(dorado_mods_dict[key].values())[0]
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter((pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20))
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        # Use Start (0-based) + 1 = 1-based to match orthogonal
        sites = set((df_filtered['Chromosome'].cast(pl.Utf8) + '_' + 
                    (df_filtered['Start'] + 1).cast(pl.Utf8)).to_list())
        print(f"GM12878_{mod_type}: {len(sites)} sites")
        return sites
    
    gm12878_A_sites = get_sites_using_start('2OMeA')
    gm12878_C_sites = get_sites_using_start('2OMeC')
    gm12878_G_sites = get_sites_using_start('2OMeG')
    gm12878_U_sites = get_sites_using_start('2OMeU')
    
    # Get orthogonal sites (add chr prefix)
    orth_A_sites = set('chr' + ome_A_df['Chr'].astype(str) + '_' + ome_A_df['Position'].astype(str))
    orth_C_sites = set('chr' + ome_C_df['Chr'].astype(str) + '_' + ome_C_df['Position'].astype(str))
    orth_G_sites = set('chr' + ome_G_df['Chr'].astype(str) + '_' + ome_G_df['Position'].astype(str))
    orth_U_sites = set('chr' + ome_U_df['Chr'].astype(str) + '_' + ome_U_df['Position'].astype(str))
    
    print(f"2'OMe-A Orthogonal (HEK293T): {len(orth_A_sites)} sites")
    print(f"2'OMe-C Orthogonal (HEK293T): {len(orth_C_sites)} sites")
    print(f"2'OMe-G Orthogonal (HEK293T): {len(orth_G_sites)} sites")
    print(f"2'OMe-U Orthogonal (HEK293T): {len(orth_U_sites)} sites")
    
    # Calculate overlaps
    overlap_A = gm12878_A_sites & orth_A_sites
    overlap_C = gm12878_C_sites & orth_C_sites
    overlap_G = gm12878_G_sites & orth_G_sites
    overlap_U = gm12878_U_sites & orth_U_sites
    
    print(f"\nExact coordinate overlaps:")
    print(f"  A: {len(overlap_A)} sites")
    print(f"  C: {len(overlap_C)} sites")
    print(f"  G: {len(overlap_G)} sites")
    print(f"  U: {len(overlap_U)} sites")
    print(f"  Total: {len(overlap_A) + len(overlap_C) + len(overlap_G) + len(overlap_U)} sites")
    
    # Combined
    gm12878_all = gm12878_A_sites | gm12878_C_sites | gm12878_G_sites | gm12878_U_sites
    orth_all = orth_A_sites | orth_C_sites | orth_G_sites | orth_U_sites
    overlap_all = gm12878_all & orth_all
    
    print(f"\nCombined: {len(overlap_all)} overlapping sites")
    
    # Create figure
    fig = plt.figure(figsize=(12, 8))
    
    # 1. Combined
    ax1 = plt.subplot(2, 3, 1)
    plt.sca(ax1)
    venn2([gm12878_all, orth_all],
          set_labels=('GM12878 DRS', 'HEK293T Orthogonal'),
          set_colors=('lightgreen', 'salmon'), alpha=0.7)
    ax1.set_title(f"2'OMe Combined\n({len(overlap_all)} exact matches)", fontweight='bold', fontsize=12)
    
    # Individual bases
    ax2 = plt.subplot(2, 3, 2)
    plt.sca(ax2)
    venn2([gm12878_A_sites, orth_A_sites], set_labels=('DRS', 'Orth'),
          set_colors=('lightblue', 'lightsalmon'), alpha=0.7)
    ax2.set_title(f"2'OMe-A ({len(overlap_A)})", fontweight='bold', fontsize=12)
    
    ax3 = plt.subplot(2, 3, 3)
    plt.sca(ax3)
    venn2([gm12878_C_sites, orth_C_sites], set_labels=('DRS', 'Orth'),
          set_colors=('lightcoral', 'gold'), alpha=0.7)
    ax3.set_title(f"2'OMe-C ({len(overlap_C)})", fontweight='bold', fontsize=12)
    
    ax4 = plt.subplot(2, 3, 4)
    plt.sca(ax4)
    venn2([gm12878_G_sites, orth_G_sites], set_labels=('DRS', 'Orth'),
          set_colors=('palegreen', 'plum'), alpha=0.7)
    ax4.set_title(f"2'OMe-G ({len(overlap_G)})", fontweight='bold', fontsize=12)
    
    ax5 = plt.subplot(2, 3, 5)
    plt.sca(ax5)
    venn2([gm12878_U_sites, orth_U_sites], set_labels=('DRS', 'Orth'),
          set_colors=('lightyellow', 'lightpink'), alpha=0.7)
    ax5.set_title(f"2'OMe-U ({len(overlap_U)})", fontweight='bold', fontsize=12)
    
    # Summary
    ax6 = plt.subplot(2, 3, 6)
    ax6.axis('off')
    summary_text = f"""
    2'O-Methylation Summary
    (GM12878 vs HEK293T)
    
    DRS Sites:
      A: {len(gm12878_A_sites):,}
      C: {len(gm12878_C_sites):,}
      G: {len(gm12878_G_sites):,}
      U: {len(gm12878_U_sites):,}
      Total: {len(gm12878_all):,}
    
    Orthogonal Sites:
      A: {len(orth_A_sites):,}
      C: {len(orth_C_sites):,}
      G: {len(orth_G_sites):,}
      U: {len(orth_U_sites):,}
      Total: {len(orth_all):,}
    
    Exact Matches: {len(overlap_all):,}
    Within ±1kb: ~69 sites*
    
    *2'OMe shows lower cell-line
    conservation than other mods
    """
    ax6.text(0.05, 0.5, summary_text, fontsize=9, verticalalignment='center', fontfamily='monospace')
    
    fig.suptitle("2'O-Methylation - GM12878 DRS vs HEK293T Orthogonal\n(Exact coordinate matches only)", 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"\nSaved to {output_path}")

# Run it
plot_gm12878_2ome_exact_matches_only(dorado_mods_dict, OMe_A, OMe_C, OMe_G, OMe_U)

NameError: name 'OMe_A' is not defined

In [30]:
# Pick a few sample sites and manually check them
sample_sites = OMe_A.head(5)

print("Check these positions in UCSC Genome Browser:")
print("="*70)
for idx, row in sample_sites.iterrows():
    chr_num = row['Chr']
    pos = row['Position']
    gene = row['Gene']
    
    print(f"\nGene: {gene}")
    print(f"  hg38: https://genome.ucsc.edu/cgi-bin/hgTracks?db=hg38&position=chr{chr_num}:{pos-100}-{pos+100}")
    print(f"  hg19: https://genome.ucsc.edu/cgi-bin/hgTracks?db=hg19&position=chr{chr_num}:{pos-100}-{pos+100}")
    print(f"  → Does this position fall INSIDE the gene {gene} in hg38 or hg19?")

NameError: name 'OMe_A' is not defined

In [31]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib_venn import venn2, venn3
import pandas as pd

# Apply Genometech Lab style settings
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['font.size'] = 8
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 16
mpl.rcParams['axes.titleweight'] = 'bold'
mpl.rcParams['figure.titlesize'] = 16
mpl.rcParams['figure.titleweight'] = 'bold'
mpl.rcParams['figure.figsize'] = [3.25, 2.25]
mpl.rcParams['figure.dpi'] = 600
mpl.rcParams['savefig.dpi'] = 600
mpl.rcParams['savefig.transparent'] = True
mpl.rcParams['savefig.bbox'] = 'tight'
mpl.rcParams['legend.fontsize'] = 8
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.8
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12

# ============ UTILITY FUNCTIONS ============

def get_drs_sites(mod_dict, cell_line, mod):
    """Extract dataframe from dictionary and create site set"""
    key = f"{cell_line}_{mod}"
    if key in mod_dict:
        # Get the first (and likely only) dataframe for this key
        df = list(mod_dict[key].values())[0]
        
        # Filter and create sites
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df[(df['Adjusted_Mod_Proportion'] >= 20) & (df['Score'] >= 20)]
        else:
            df_filtered = df[df['Score'] >= 20]
        
        sites = set(df_filtered['Chromosome'] + '_' + df_filtered['End'].astype(str))
        print(f"{key}: {len(sites)} sites after filtering")
        return sites
    else:
        print(f"Warning: {key} not found in dictionary")
        return set()

def process_orthogonal_data(df, chr_col, pos_col):
    """Create genomic site identifiers from orthogonal datasets"""
    sites = set(df[chr_col] + '_' + df[pos_col].astype(str))
    return sites

# ============ M5C PLOTTING FUNCTION ============

def plot_m5c_venns(dorado_mods_dict, m5c_orthogonal_df, output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m5c_venns.pdf'):
    """
    Create comprehensive Venn diagram plots for m5C modification sites
    
    Parameters:
    -----------
    dorado_mods_dict : dict
        Dictionary containing DRS data for both cell lines
    m5c_orthogonal_df : DataFrame
        Orthogonal validation data with 'chromosome' and 'position' columns
    output_path : str
        Path to save the output PDF
    """
    print("\nProcessing m5C datasets...")
    hek293_m5c_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'm5c')
    gm12878_m5c_sites = get_drs_sites(dorado_mods_dict, 'GM12878', 'm5c')
    m5c_orthogonal_sites = process_orthogonal_data(m5c_orthogonal_df, 'chromosome', 'position')

    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(6.5, 6))

    # Three-way comparison
    plt.sca(ax1)
    venn = venn3([hek293_m5c_sites, gm12878_m5c_sites, m5c_orthogonal_sites],
          set_labels=('HEK293 DRS', 'GM12878 DRS', 'Orthogonal'))
    for patch, color in zip(venn.patches, ['skyblue', 'lightgreen', 'salmon', 'purple', 'gold', 'orange', 'gray']):
        if patch:
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
    plt.text(-0.6, 0.6, f"Total: {len(hek293_m5c_sites)}", fontsize=8, color='blue', ha='center')
    plt.text(0.6, 0.5, f"Total: {len(gm12878_m5c_sites)}", fontsize=8, color='green', ha='center')
    plt.text(0.0, -0.85, f"Total: {len(m5c_orthogonal_sites)}", fontsize=8, color='purple', ha='center')
    ax1.set_title('m5C: Three-way Comparison', fontweight='bold', fontsize=14)

    # Cell line comparison
    plt.sca(ax2)
    venn2([hek293_m5c_sites, gm12878_m5c_sites],
          set_labels=('HEK293', 'GM12878'),
          set_colors=('skyblue', 'lightgreen'), alpha=0.7)
    ax2.set_title('m5C: Cell Line Comparison', fontweight='bold', fontsize=14)

    # HEK293 vs Orthogonal
    plt.sca(ax3)
    venn2([hek293_m5c_sites, m5c_orthogonal_sites],
          set_labels=('HEK293 DRS', 'Orthogonal'),
          set_colors=('skyblue', 'salmon'), alpha=0.7)
    ax3.set_title('m5C: HEK293 vs Orthogonal', fontweight='bold', fontsize=14)

    # GM12878 vs Orthogonal
    plt.sca(ax4)
    venn2([gm12878_m5c_sites, m5c_orthogonal_sites],
          set_labels=('GM12878 DRS', 'Orthogonal'),
          set_colors=('lightgreen', 'salmon'), alpha=0.7)
    ax4.set_title('m5C: GM12878 vs Orthogonal', fontweight='bold', fontsize=14)

    fig.suptitle('m5C Sites', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved m5C Venn diagrams to {output_path}")

# ============ PSEUDOURIDINE PLOTTING FUNCTION ============

def plot_psi_venns(dorado_mods_dict, bid_seq_df, praise_filtered, output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/psi_venns.pdf'):
    """
    Create comprehensive Venn diagram plots for Pseudouridine (Ψ) modification sites
    
    Parameters:
    -----------
    dorado_mods_dict : dict
        Dictionary containing DRS data for both cell lines
    bid_seq_df : DataFrame
        BID-seq orthogonal data with 'chr' and 'pos' columns
    praise_filtered : DataFrame
        PRAISE orthogonal data with 'chromosome' and 'genomic_position' columns
    output_path : str
        Path to save the output PDF
    """
    print("\nProcessing Psi datasets...")
    hek293_psi_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'psi')
    gm12878_psi_sites = get_drs_sites(dorado_mods_dict, 'GM12878', 'psi')
    bid_sites = process_orthogonal_data(bid_seq_df, 'chr', 'pos')
    praise_sites = process_orthogonal_data(praise_filtered, 'chromosome', 'genomic_position')
    praise_bid_combined = praise_sites & bid_sites

    fig = plt.figure(figsize=(9.75, 6))

    # Three-way comparison (DRS vs combined orthogonal)
    ax1 = plt.subplot(2, 3, 1)
    plt.sca(ax1)
    venn = venn3([hek293_psi_sites, gm12878_psi_sites, praise_bid_combined],
          set_labels=('HEK293 DRS', 'GM12878 DRS', 'PRAISE & BID-seq'))
    for patch, color in zip(venn.patches, ['skyblue', 'lightgreen', 'salmon', 'purple', 'gold', 'orange', 'gray']):
        if patch:
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
    plt.text(-0.6, 0.6, f"Total: {len(hek293_psi_sites)}", fontsize=8, color='blue', ha='center')
    plt.text(0.6, 0.5, f"Total: {len(gm12878_psi_sites)}", fontsize=8, color='green', ha='center')
    plt.text(0.0, -0.85, f"Total: {len(praise_bid_combined)}", fontsize=8, color='purple', ha='center')
    ax1.set_title('Ψ: DRS vs Orthogonal', fontweight='bold', fontsize=14)

    # Cell line comparison
    ax2 = plt.subplot(2, 3, 2)
    plt.sca(ax2)
    venn2([hek293_psi_sites, gm12878_psi_sites],
          set_labels=('HEK293', 'GM12878'),
          set_colors=('skyblue', 'lightgreen'), alpha=0.7)
    ax2.set_title('Ψ: Cell Line Comparison', fontweight='bold', fontsize=14)

    # Orthogonal methods comparison
    ax3 = plt.subplot(2, 3, 3)
    plt.sca(ax3)
    venn2([praise_sites, bid_sites],
          set_labels=('PRAISE', 'BID-seq'),
          set_colors=('salmon', 'gold'), alpha=0.7)
    ax3.set_title('Ψ: PRAISE vs BID-seq', fontweight='bold', fontsize=14)

    # HEK293 All Methods
    ax4 = plt.subplot(2, 3, 4)
    plt.sca(ax4)
    venn = venn3([hek293_psi_sites, praise_sites, bid_sites],
          set_labels=('HEK293 DRS', 'PRAISE', 'BID-seq'))
    for patch, color in zip(venn.patches, ['skyblue', 'salmon', 'gold', 'purple', 'orange', 'lightgreen', 'gray']):
        if patch:
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
    ax4.set_title('Ψ: HEK293 All Methods', fontweight='bold', fontsize=14)

    # HEK293 vs Combined Orthogonal
    ax5 = plt.subplot(2, 3, 5)
    plt.sca(ax5)
    venn2([hek293_psi_sites, praise_bid_combined],
          set_labels=('HEK293 DRS', 'PRAISE & BID-seq'),
          set_colors=('skyblue', 'salmon'), alpha=0.7)
    ax5.set_title('Ψ: HEK293 vs Orthogonal', fontweight='bold', fontsize=14)

    # GM12878 vs Combined Orthogonal
    ax6 = plt.subplot(2, 3, 6)
    plt.sca(ax6)
    venn2([gm12878_psi_sites, praise_bid_combined],
          set_labels=('GM12878 DRS', 'PRAISE & BID-seq'),
          set_colors=('lightgreen', 'salmon'), alpha=0.7)
    ax6.set_title('Ψ: GM12878 vs Orthogonal', fontweight='bold', fontsize=14)

    fig.suptitle('Pseudouridine (Ψ) Sites', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved Pseudouridine Venn diagrams to {output_path}")

# ============ M6A PLOTTING FUNCTION ============

def plot_m6a_venns(dorado_mods_dict, combined_glori_1, combined_glori_2_df, output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/m6a_venns.pdf'):
    """
    Create comprehensive Venn diagram plots for m6A modification sites
    
    Parameters:
    -----------
    dorado_mods_dict : dict
        Dictionary containing DRS data for both cell lines
    combined_glori_1 : DataFrame
        GLORI-1 orthogonal data with 'Chr' and 'Sites' columns
    combined_glori_2_df : DataFrame
        GLORI-2 orthogonal data with 'Chr' and 'Site' columns
    output_path : str
        Path to save the output PDF
    """
    print("\nProcessing m6A datasets...")
    hek293_m6a_sites = get_drs_sites(dorado_mods_dict, 'HEK293', 'm6a')
    gm12878_m6a_sites = get_drs_sites(dorado_mods_dict, 'GM12878', 'm6a')
    glori1_sites = process_orthogonal_data(combined_glori_1, 'Chr', 'Sites')
    glori2_sites = process_orthogonal_data(combined_glori_2_df, 'Chr', 'Site')
    glori_combined = glori1_sites & glori2_sites

    fig = plt.figure(figsize=(9.75, 6))

    # Three-way comparison
    ax1 = plt.subplot(2, 3, 1)
    plt.sca(ax1)
    venn = venn3([hek293_m6a_sites, gm12878_m6a_sites, glori_combined],
          set_labels=('HEK293 DRS', 'GM12878 DRS', 'GLORI'))
    for patch, color in zip(venn.patches, ['skyblue', 'lightgreen', 'salmon', 'purple', 'gold', 'orange', 'gray']):
        if patch:
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
    plt.text(-0.6, 0.6, f"Total: {len(hek293_m6a_sites)}", fontsize=8, color='blue', ha='center')
    plt.text(0.6, 0.5, f"Total: {len(gm12878_m6a_sites)}", fontsize=8, color='green', ha='center')
    plt.text(0.0, -0.85, f"Total: {len(glori_combined)}", fontsize=8, color='purple', ha='center')
    ax1.set_title('m6A: DRS vs GLORI', fontweight='bold', fontsize=14)

    # Cell line comparison
    ax2 = plt.subplot(2, 3, 2)
    plt.sca(ax2)
    venn2([hek293_m6a_sites, gm12878_m6a_sites],
          set_labels=('HEK293', 'GM12878'),
          set_colors=('skyblue', 'lightgreen'), alpha=0.7)
    ax2.set_title('m6A: Cell Line Comparison', fontweight='bold', fontsize=14)

    # GLORI comparison
    ax3 = plt.subplot(2, 3, 3)
    plt.sca(ax3)
    venn2([glori1_sites, glori2_sites],
          set_labels=('GLORI-1', 'GLORI-2'),
          set_colors=('salmon', 'gold'), alpha=0.7)
    ax3.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)

    # HEK293 All Methods
    ax4 = plt.subplot(2, 3, 4)
    plt.sca(ax4)
    venn = venn3([hek293_m6a_sites, glori1_sites, glori2_sites],
          set_labels=('HEK293 DRS', 'GLORI-1', 'GLORI-2'))
    for patch, color in zip(venn.patches, ['skyblue', 'salmon', 'gold', 'purple', 'orange', 'lightgreen', 'gray']):
        if patch:
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
    ax4.set_title('m6A: HEK293 All Methods', fontweight='bold', fontsize=14)

    # HEK293 vs GLORI Combined
    ax5 = plt.subplot(2, 3, 5)
    plt.sca(ax5)
    venn2([hek293_m6a_sites, glori_combined],
          set_labels=('HEK293 DRS', 'GLORI'),
          set_colors=('skyblue', 'salmon'), alpha=0.7)
    ax5.set_title('m6A: HEK293 vs GLORI', fontweight='bold', fontsize=14)

    # GM12878 vs GLORI Combined
    ax6 = plt.subplot(2, 3, 6)
    plt.sca(ax6)
    venn2([gm12878_m6a_sites, glori_combined],
          set_labels=('GM12878 DRS', 'GLORI'),
          set_colors=('lightgreen', 'salmon'), alpha=0.7)
    ax6.set_title('m6A: GM12878 vs GLORI', fontweight='bold', fontsize=14)

    fig.suptitle('m6A Sites', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved m6A Venn diagrams to {output_path}")

# ============ USAGE EXAMPLE ============
# Simply call the functions with your data:

# Plot m5C
plot_m5c_venns(dorado_mods_dict, m5c_orthogonal_df)

# Plot Pseudouridine
plot_psi_venns(dorado_mods_dict, bid_seq_df, praise_filtered)

# Plot m6A
plot_m6a_venns(dorado_mods_dict, combined_glori_1, combined_glori_2_df)


Processing m5C datasets...


ValueError: expected 19 values when selecting columns by boolean mask, got 7161645

In [35]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib_venn import venn2, venn3
import pandas as pd
import polars as pl

# ============ STYLE SETTINGS ============
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['font.size'] = 8
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 16
mpl.rcParams['axes.titleweight'] = 'bold'
mpl.rcParams['figure.titlesize'] = 16
mpl.rcParams['figure.titleweight'] = 'bold'
mpl.rcParams['figure.figsize'] = [3.25, 2.25]
mpl.rcParams['figure.dpi'] = 600
mpl.rcParams['savefig.dpi'] = 600
mpl.rcParams['savefig.transparent'] = True
mpl.rcParams['savefig.bbox'] = 'tight'
mpl.rcParams['legend.fontsize'] = 8
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.8
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12

# ============ UTILITY FUNCTIONS ============

def get_gm12878_sites(mod_dict, mod):
    """Extract GM12878 dataframe from dictionary and create site set - POLARS VERSION"""
    key = f"GM12878_{mod}"
    if key in mod_dict:
        # Get the first (and likely only) dataframe for this key
        df = list(mod_dict[key].values())[0]
        
        # POLARS FILTERING - use .filter() not []
        if 'Adjusted_Mod_Proportion' in df.columns:
            df_filtered = df.filter(
                (pl.col('Adjusted_Mod_Proportion') >= 20) & (pl.col('Score') >= 20)
            )
        else:
            df_filtered = df.filter(pl.col('Score') >= 20)
        
        # Create sites using Polars - convert to string and concatenate
        sites = set(
            (df_filtered['Chromosome'].cast(pl.Utf8) + '_' + df_filtered['End'].cast(pl.Utf8)).to_list()
        )
        print(f"{key}: {len(sites)} sites after filtering")
        return sites
    else:
        print(f"Warning: {key} not found in dictionary")
        return set()

def process_orthogonal_data_smart(df, chr_col, pos_col, label="Orthogonal"):
    """
    Create genomic site identifiers - AUTOMATICALLY handles chromosome naming AND float positions
    Adds 'chr' prefix if needed, converts floats to ints
    """
    if isinstance(df, pl.DataFrame):  # Polars
        sites = set(
            (df[chr_col].cast(pl.Utf8) + '_' + df[pos_col].cast(pl.Int64).cast(pl.Utf8)).to_list()
        )
    else:  # Pandas
        # Check if chromosomes already have 'chr' prefix
        sample_chr = str(df[chr_col].iloc[0])
        
        # Convert position to int first (handles floats like 44909262.0 → 44909262)
        pos_int = df[pos_col].astype(int).astype(str)
        
        if sample_chr.startswith('chr'):
            # Already has prefix
            sites = set(df[chr_col].astype(str) + '_' + pos_int)
        else:
            # Add prefix
            sites = set('chr' + df[chr_col].astype(str) + '_' + pos_int)
    
    print(f"{label}: {len(sites)} sites")
    return sites

# ============ M5C PLOTTING FUNCTION ============

def plot_gm12878_m5c(dorado_mods_dict, m5c_orthogonal_df, 
                     output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_m5c_venn.pdf'):
    """
    Create Venn diagram for GM12878 m5C vs orthogonal validation
    """
    print("\n" + "="*60)
    print("=== Processing GM12878 m5C ===")
    print("="*60)
    
    gm12878_sites = get_gm12878_sites(dorado_mods_dict, 'm5c')
    orthogonal_sites = process_orthogonal_data_smart(m5c_orthogonal_df, 'chromosome', 'position', 'm5C Orthogonal')
    
    # Calculate overlap
    overlap = gm12878_sites & orthogonal_sites
    print(f"Overlap: {len(overlap)} sites")
    if len(gm12878_sites) > 0 and len(orthogonal_sites) > 0:
        print(f"Overlap %: {100*len(overlap)/len(gm12878_sites):.1f}% of GM12878, {100*len(overlap)/len(orthogonal_sites):.1f}% of Orthogonal")
    
    # Create figure
    fig, ax = plt.subplots(figsize=(4, 4))
    
    venn = venn2([gm12878_sites, orthogonal_sites],
                 set_labels=('GM12878 DRS', 'Orthogonal'),
                 set_colors=('lightgreen', 'salmon'), 
                 alpha=0.7,
                 ax=ax)
    
    ax.set_title('m5C Sites: GM12878 vs Orthogonal', fontweight='bold', fontsize=14)
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ============ PSEUDOURIDINE PLOTTING FUNCTION ============

def plot_gm12878_psi(dorado_mods_dict, bid_seq_df, praise_filtered,
                     output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_psi_venns.pdf'):
    """
    Create Venn diagrams for GM12878 Pseudouridine vs orthogonal validation
    """
    print("\n" + "="*60)
    print("=== Processing GM12878 Pseudouridine ===")
    print("="*60)
    
    gm12878_sites = get_gm12878_sites(dorado_mods_dict, 'psi')
    bid_sites = process_orthogonal_data_smart(bid_seq_df, 'chr', 'pos', 'BID-seq')
    praise_sites = process_orthogonal_data_smart(praise_filtered, 'chromosome', 'genomic_position', 'PRAISE')
    praise_bid_combined = praise_sites & bid_sites
    
    print(f"PRAISE & BID-seq overlap: {len(praise_bid_combined)} sites")
    
    # Calculate overlaps
    gm_vs_combined = gm12878_sites & praise_bid_combined
    gm_vs_bid = gm12878_sites & bid_sites
    gm_vs_praise = gm12878_sites & praise_sites
    
    print(f"GM12878 vs BID-seq: {len(gm_vs_bid)} sites")
    print(f"GM12878 vs PRAISE: {len(gm_vs_praise)} sites")
    print(f"GM12878 vs Combined Orthogonal: {len(gm_vs_combined)} sites")
    
    # Create figure with 3 subplots
    fig = plt.figure(figsize=(9.75, 3.5))
    
    # GM12878 vs Combined Orthogonal
    ax1 = plt.subplot(1, 3, 1)
    plt.sca(ax1)
    venn2([gm12878_sites, praise_bid_combined],
          set_labels=('GM12878 DRS', 'PRAISE & BID-seq'),
          set_colors=('lightgreen', 'salmon'), 
          alpha=0.7)
    ax1.set_title('Ψ: GM12878 vs Orthogonal', fontweight='bold', fontsize=14)
    
    # Three-way comparison
    ax2 = plt.subplot(1, 3, 2)
    plt.sca(ax2)
    venn = venn3([gm12878_sites, praise_sites, bid_sites],
                 set_labels=('GM12878 DRS', 'PRAISE', 'BID-seq'))
    for patch, color in zip(venn.patches, ['lightgreen', 'salmon', 'gold', 'purple', 'orange', 'lightblue', 'gray']):
        if patch:
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
    ax2.set_title('Ψ: Three-way Comparison', fontweight='bold', fontsize=14)
    
    # PRAISE vs BID-seq
    ax3 = plt.subplot(1, 3, 3)
    plt.sca(ax3)
    venn2([praise_sites, bid_sites],
          set_labels=('PRAISE', 'BID-seq'),
          set_colors=('salmon', 'gold'), 
          alpha=0.7)
    ax3.set_title('Ψ: PRAISE vs BID-seq', fontweight='bold', fontsize=14)
    
    fig.suptitle('Pseudouridine (Ψ) Sites - GM12878', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ============ M6A PLOTTING FUNCTION ============

def plot_gm12878_m6a(dorado_mods_dict, combined_glori_1, combined_glori_2_df,
                     output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_m6a_venns.pdf'):
    """
    Create Venn diagrams for GM12878 m6A vs GLORI validation
    """
    print("\n" + "="*60)
    print("=== Processing GM12878 m6A ===")
    print("="*60)
    
    gm12878_sites = get_gm12878_sites(dorado_mods_dict, 'm6a')
    glori1_sites = process_orthogonal_data_smart(combined_glori_1, 'Chr', 'Sites', 'GLORI-1')
    glori2_sites = process_orthogonal_data_smart(combined_glori_2_df, 'Chr', 'Site', 'GLORI-2')
    glori_combined = glori1_sites & glori2_sites
    
    print(f"GLORI combined: {len(glori_combined)} sites")
    
    # Calculate overlaps
    gm_vs_combined = gm12878_sites & glori_combined
    gm_vs_glori1 = gm12878_sites & glori1_sites
    gm_vs_glori2 = gm12878_sites & glori2_sites
    
    print(f"GM12878 vs GLORI-1: {len(gm_vs_glori1)} sites")
    print(f"GM12878 vs GLORI-2: {len(gm_vs_glori2)} sites")
    print(f"GM12878 vs Combined GLORI: {len(gm_vs_combined)} sites")
    
    # Create figure with 3 subplots
    fig = plt.figure(figsize=(9.75, 3.5))
    
    # GM12878 vs Combined GLORI
    ax1 = plt.subplot(1, 3, 1)
    plt.sca(ax1)
    venn2([gm12878_sites, glori_combined],
          set_labels=('GM12878 DRS', 'GLORI'),
          set_colors=('lightgreen', 'salmon'), 
          alpha=0.7)
    ax1.set_title('m6A: GM12878 vs GLORI', fontweight='bold', fontsize=14)
    
    # Three-way comparison
    ax2 = plt.subplot(1, 3, 2)
    plt.sca(ax2)
    venn = venn3([gm12878_sites, glori1_sites, glori2_sites],
                 set_labels=('GM12878 DRS', 'GLORI-1', 'GLORI-2'))
    for patch, color in zip(venn.patches, ['lightgreen', 'salmon', 'gold', 'purple', 'orange', 'lightblue', 'gray']):
        if patch:
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
    ax2.set_title('m6A: Three-way Comparison', fontweight='bold', fontsize=14)
    
    # GLORI-1 vs GLORI-2
    ax3 = plt.subplot(1, 3, 3)
    plt.sca(ax3)
    venn2([glori1_sites, glori2_sites],
          set_labels=('GLORI-1', 'GLORI-2'),
          set_colors=('salmon', 'gold'), 
          alpha=0.7)
    ax3.set_title('m6A: GLORI-1 vs GLORI-2', fontweight='bold', fontsize=14)
    
    fig.suptitle('m6A Sites - GM12878', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ============ INOSINE PLOTTING FUNCTION ============

def plot_gm12878_inosine(dorado_mods_dict, combined_ino_df,
                         output_path='/Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_inosine_venn.pdf'):
    """
    Create Venn diagram for GM12878 Inosine vs orthogonal validation
    """
    print("\n" + "="*60)
    print("=== Processing GM12878 Inosine ===")
    print("="*60)
    
    gm12878_sites = get_gm12878_sites(dorado_mods_dict, 'inosine')
    orthogonal_sites = process_orthogonal_data_smart(combined_ino_df, 'Chromosome', 'position', 'Inosine Orthogonal')
    
    # Calculate overlap
    overlap = gm12878_sites & orthogonal_sites
    print(f"Overlap: {len(overlap)} sites")
    if len(gm12878_sites) > 0 and len(orthogonal_sites) > 0:
        print(f"Overlap %: {100*len(overlap)/len(gm12878_sites):.1f}% of GM12878, {100*len(overlap)/len(orthogonal_sites):.1f}% of Orthogonal")
    
    # Create figure
    fig, ax = plt.subplots(figsize=(4, 4))
    
    venn = venn2([gm12878_sites, orthogonal_sites],
                 set_labels=('GM12878 DRS', 'Orthogonal'),
                 set_colors=('lightgreen', 'salmon'), 
                 alpha=0.7,
                 ax=ax)
    
    ax.set_title('Inosine (A-to-I) Sites:\nGM12878 vs Orthogonal', fontweight='bold', fontsize=14)
    
    plt.tight_layout()
    plt.savefig(output_path, format='pdf', dpi=600, transparent=True, bbox_inches='tight')
    plt.show()
    print(f"Saved to {output_path}\n")

# ============================================================================
# USAGE: Run all GM12878 comparison plots
# ============================================================================

print("\n" + "="*70)
print("GENERATING ALL GM12878 VENN DIAGRAMS")
print("="*70)

# 1. m5C
plot_gm12878_m5c(dorado_mods_dict, m5c_orthogonal_df)

# 2. Pseudouridine
plot_gm12878_psi(dorado_mods_dict, bid_seq_df, praise_filtered)

# 3. m6A
plot_gm12878_m6a(dorado_mods_dict, combined_glori_1, combined_glori_2)

# 4. Inosine
plot_gm12878_inosine(dorado_mods_dict, combined_ino)

print("\n" + "="*70)
print("✓ ALL PLOTS COMPLETE!")
print("="*70)
print("\nGenerated files:")
print("  1. GM12878_m5c_venn.pdf")
print("  2. GM12878_psi_venns.pdf")
print("  3. GM12878_m6a_venns.pdf")
print("  4. GM12878_inosine_venn.pdf")

1 extra bytes in post.stringData array
'created' timestamp seems very low; regarding as unix timestamp
  plt.show()
1 extra bytes in post.stringData array
'created' timestamp seems very low; regarding as unix timestamp



GENERATING ALL GM12878 VENN DIAGRAMS

=== Processing GM12878 m5C ===
GM12878_m5c: 31473 sites after filtering
m5C Orthogonal: 2191 sites
Overlap: 159 sites
Overlap %: 0.5% of GM12878, 7.3% of Orthogonal
Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_m5c_venn.pdf


=== Processing GM12878 Pseudouridine ===
GM12878_psi: 3877 sites after filtering
BID-seq: 543 sites
PRAISE: 1801 sites
PRAISE & BID-seq overlap: 1 sites
GM12878 vs BID-seq: 50 sites
GM12878 vs PRAISE: 5 sites
GM12878 vs Combined Orthogonal: 1 sites
Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_psi_venns.pdf


=== Processing GM12878 m6A ===
GM12878_m6a: 71468 sites after filtering
GLORI-1: 60462 sites
GLORI-2: 65687 sites
GLORI combined: 38039 sites
GM12878 vs GLORI-1: 24385 sites
GM12878 vs GLORI-2: 31802 sites
GM12878 vs Combined GLORI: 20371 sites


  plt.show()
1 extra bytes in post.stringData array
'created' timestamp seems very low; regarding as unix timestamp
  plt.show()
1 extra bytes in post.stringData array
'created' timestamp seems very low; regarding as unix timestamp


Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_m6a_venns.pdf


=== Processing GM12878 Inosine ===
GM12878_inosine: 12986 sites after filtering
Inosine Orthogonal: 29745 sites
Overlap: 2148 sites
Overlap %: 16.5% of GM12878, 7.2% of Orthogonal
Saved to /Volumes/AJS_SSD/HEK293/scripts/notebooks/Plots/GM12878_inosine_venn.pdf


✓ ALL PLOTS COMPLETE!

Generated files:
  1. GM12878_m5c_venn.pdf
  2. GM12878_psi_venns.pdf
  3. GM12878_m6a_venns.pdf
  4. GM12878_inosine_venn.pdf


  plt.show()


In [33]:
# Check m5C orthogonal chromosome naming
print("="*70)
print("M5C ORTHOGONAL DATA CHECK")
print("="*70)
print("Columns:", m5c_orthogonal_df.columns.tolist())
print("\nFirst 5 rows:")
print(m5c_orthogonal_df[['chromosome', 'position']].head())
print("\nUnique chromosomes:")
print(sorted(m5c_orthogonal_df['chromosome'].unique()[:20]))

# Check Psi orthogonal chromosome naming
print("\n" + "="*70)
print("PSI ORTHOGONAL DATA CHECK - BID-seq")
print("="*70)
print("Columns:", bid_seq_df.columns.tolist())
print("\nFirst 5 rows:")
print(bid_seq_df[['chr', 'pos']].head())
print("\nUnique chromosomes:")
print(sorted(bid_seq_df['chr'].unique()[:20]))

print("\n" + "="*70)
print("PSI ORTHOGONAL DATA CHECK - PRAISE")
print("="*70)
print("Columns:", praise_filtered.columns.tolist())
print("\nFirst 5 rows:")
print(praise_filtered[['chromosome', 'genomic_position']].head())
print("\nUnique chromosomes:")
print(sorted(praise_filtered['chromosome'].unique()[:20]))

# Compare to DRS
print("\n" + "="*70)
print("DRS CHROMOSOME NAMING (for reference)")
print("="*70)
drs_sample = list(dorado_mods_dict['GM12878_m5c'].values())[0]
print("DRS chromosomes sample:", drs_sample['Chromosome'].unique()[:10].to_list())

M5C ORTHOGONAL DATA CHECK
Columns: ['chromosome', 'position', 'strand', 'gene_type', 'gene_name', 'gene_pos', 'unconverted', 'converted', 'ratio']

First 5 rows:
    chromosome position
213          1   918816
214          1   941220
215          1   944475
216          1   999131
217          1   999142

Unique chromosomes:
['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '3', '4', '5', '6', '7']

PSI ORTHOGONAL DATA CHECK - BID-seq
Columns: ['chr', 'pos', 'name', 'refseq', 'seg', 'strand', 'Deletion_rep1', 'Deletion_rep2', 'Deletion_rep3', 'Deletion_Ave', 'Motif_1', 'Motif_2', 'Frac_rep1 %', 'Frac_rep2 %', 'Frac_rep3 %', 'Frac_Ave %', 'Deletion count_rep1', 'Deletion count_rep2', 'Deletion count_rep3']

First 5 rows:
2    chr        pos
0   chr1   43450939
1   chrX   78129969
2  chr14   69380270
3   chr2   27326828
4  chr12  124012693

Unique chromosomes:
['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr16', 'chr17', 'chr18', 'chr19', 

In [None]:
import polars as pl

# ============ INSPECT DRS DATA ============

def inspect_drs_data(dorado_mods_dict, cell_line, mod_type):
    """Inspect a specific DRS dataframe"""
    key = f"{cell_line}_{mod_type}"
    
    if key in dorado_mods_dict:
        df = list(dorado_mods_dict[key].values())[0]
        print(f"\n{'='*60}")
        print(f"DRS Data: {key}")
        print(f"{'='*60}")
        print(f"Shape: {df.shape}")
        print(f"Columns: {df.columns}")
        print(f"\nFirst 5 rows:")
        print(df.head())
        print(f"\nData types:")
        print(df.dtypes)
        return df
    else:
        print(f"Warning: {key} not found in dictionary")
        return None

# ============ EXTRACT GM12878 2'OMe DATA ============

print("\n" + "="*70)
print("EXTRACTING GM12878 2'OMe DATA")
print("="*70)

# Extract all 4 types of 2'OMe modifications
gm12878_2ome_A = inspect_drs_data(dorado_mods_dict, 'GM12878', '2OMeA')
gm12878_2ome_C = inspect_drs_data(dorado_mods_dict, 'GM12878', '2OMeC')
gm12878_2ome_G = inspect_drs_data(dorado_mods_dict, 'GM12878', '2OMeG')
gm12878_2ome_U = inspect_drs_data(dorado_mods_dict, 'GM12878', '2OMeU')

# ============ EXTRACT GM12878 INOSINE DATA ============

print("\n" + "="*70)
print("EXTRACTING GM12878 INOSINE DATA")
print("="*70)

gm12878_inosine = inspect_drs_data(dorado_mods_dict, 'GM12878', 'inosine')

# ============ INSPECT ORTHOGONAL DATA ============

print("\n" + "="*70)
print("ORTHOGONAL DATA INSPECTION")
print("="*70)

# Check what orthogonal datasets you have available
print("\nPlease provide information about your orthogonal datasets:")
print("\n1. For 2'OMe orthogonal data:")
print("   - Variable name: ?")
print("   - Run: print(your_2ome_orthogonal_df.head())")
print("   - Run: print(your_2ome_orthogonal_df.columns)")

print("\n2. For Inosine orthogonal data:")
print("   - Variable name: ?")
print("   - Run: print(your_inosine_orthogonal_df.head())")
print("   - Run: print(your_inosine_orthogonal_df.columns)")

# ============ HELPER FUNCTION TO INSPECT ANY DATAFRAME ============

def inspect_orthogonal_data(df, data_name):
    """Inspect orthogonal dataset structure"""
    print(f"\n{'='*60}")
    print(f"Orthogonal Data: {data_name}")
    print(f"{'='*60}")
    
    # Check if Polars or Pandas
    if hasattr(df, 'filter'):
        print("Type: Polars DataFrame")
        print(f"Shape: {df.shape}")
        print(f"Columns: {df.columns}")
        print(f"\nFirst 5 rows:")
        print(df.head())
    else:
        print("Type: Pandas DataFrame")
        print(f"Shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")
        print(f"\nFirst 5 rows:")
        print(df.head())
    
    return df

# ============ EXAMPLE USAGE FOR ORTHOGONAL DATA ============

# Once you identify your orthogonal datasets, run:
# inspect_orthogonal_data(your_2ome_df, "2'OMe Orthogonal")
# inspect_orthogonal_data(your_inosine_df, "Inosine Orthogonal")

print("\n" + "="*70)
print("NEXT STEPS:")
print("="*70)
print("1. Run the code above to see your DRS data structure")
print("2. Tell me the variable names for your orthogonal datasets:")
print("   - 2'OMe orthogonal data variable name?")
print("   - Inosine orthogonal data variable name?")
print("3. Use inspect_orthogonal_data() to show me their structure")
print("4. Let me know:")
print("   - Which chromosome column to use (chr, Chr, chromosome, etc.)?")
print("   - Which position column to use (pos, position, Site, etc.)?")
print("   - For 2'OMe: Does it have all bases in one df or separate dfs?")

# Old Shit

In [None]:
merged

In [None]:
# Merge the datasets
merged = GLORI1.merge(GLORI2, how="inner", on=['Chr', 'Sites', 'Strand'], suffixes=('_rep1', '_rep2'))

In [None]:
dorado_df = load_df(dorado_file)

In [None]:
# Filtering
dorado = dorado_df[(dorado_df['Call'] == 'a') & (dorado_df['N_valid_cov'] >= 10) & (dorado_df['Mod_Percent'] > 0)]
dorado_og_filtered = dorado[dorado['Mod_Percent'] >= 10]
dorado_filtered = dorado[dorado['Adjusted_Mod_Percent'] >= 10]

In [None]:
# dorado_filtered = dorado[dorado["Mod_Percent"] >= 20]
# dorado_with_glori = pd.merge(dorado, merged, left_on = ["Chromosome", "End", "Strand"], right_on = ["Chr", "Sites", "Strand"])
dorado_og_outer = dorado_og_filtered.merge(merged, how = "outer", left_on = ["Chromosome", "End", "Strand"], right_on = ["Chr", "Sites", "Strand"])
dorado_og_filtered_with_glori = dorado_og_filtered.merge(merged, how = "inner", left_on = ["Chromosome", "End", "Strand"], right_on = ["Chr", "Sites", "Strand"])

dorado_outer = dorado_filtered.merge(merged, how = "outer", left_on = ["Chromosome", "End", "Strand"], right_on = ["Chr", "Sites", "Strand"])
dorado_filtered_with_glori = dorado_filtered.merge(merged, how = "inner", left_on = ["Chromosome", "End", "Strand"], right_on = ["Chr", "Sites", "Strand"])

In [None]:
print(f'The length of the dataframes before merging are: GLORI1: {len(GLORI1)}; GLORI2: {len(GLORI2)}')
print(f'The length after merging on Chr, Site and Strand is: {len(merged)}')
print(f'The length of Dorado with valid coverage is: {len(dorado_og_filtered)}')
# print(f'The length of Dorado with valid coverage and at least 20% modification is: {len(dorado[dorado["Mod_Percent"] >= 20])}')
print(f'The length of sites where dorado and GLORI agree is: {len(dorado_og_filtered_with_glori)}')
print(f'The percentage of total overlap is: {(len(dorado_og_filtered_with_glori) / len(dorado_og_outer)) * 100:.2f}%')
# print(f'The percent of total aggregate GLORI sites that Dorado agrees with are: {((len(dorado_filtered_with_glori)/len(merged)) * 100):.2f}%')

In [None]:
print(f'The length of the dataframes before merging are: GLORI1: {len(GLORI1)}; GLORI2: {len(GLORI2)}')
print(f'The length after merging on Chr, Site and Strand is: {len(merged)}')
print(f'The length of Dorado with valid coverage is: {len(dorado_filtered)}')
# print(f'The length of Dorado with valid coverage and at least 20% modification is: {len(dorado[dorado["Mod_Percent"] >= 20])}')
print(f'The length of sites where dorado and GLORI agree is: {len(dorado_filtered_with_glori)}')
print(f'The percentage of total overlap is: {(len(dorado_filtered_with_glori) / len(dorado_outer)) * 100:.2f}%')
# print(f'The percent of total aggregate GLORI sites that Dorado agrees with are: {((len(dorado_filtered_with_glori)/len(merged)) * 100):.2f}%')

In [None]:
%matplotlib inline
from matplotlib_venn import venn2
import matplotlib.pyplot as plt

# Create sets of unique identifiers for each dataset
# Using tuple of (Chromosome/Chr, End/Sites, Strand) as unique identifier
dorado_set = set(zip(dorado_filtered['Chromosome'], 
                     dorado_filtered['End'], 
                     dorado_filtered['Strand']))

merged_set = set(zip(merged['Chr'], 
                     merged['Sites'], 
                     merged['Strand']))

# Create the Venn diagram
plt.figure(figsize=(10, 8))
venn = venn2([dorado_set, merged_set], 
             set_labels=('GM12878', 'GLORI_Seq'))

# Customize the diagram
venn.get_label_by_id('10').set_text(f'08_01_25 HEK293 Only\n{len(dorado_set - merged_set)}')
venn.get_label_by_id('01').set_text(f'GLORI Seq Only\n{len(merged_set - dorado_set)}')
venn.get_label_by_id('11').set_text(f'Overlap\n{len(dorado_set & merged_set)}')

plt.title('Overlap between 08_07_24 GM12878 and GLORI Seq Datasets 10/10')
plt.savefig("/projects/Genometechlab/RNA002_vs_RNA004/RNA004/HEK293_Comparison/Plots/08_07_24_GM12878_Overlap_10_Percent_and_10_reads.pdf")
plt.show()

# Print statistics
print(f"dorado_filtered unique sites: {len(dorado_set - merged_set)}")
print(f"merged unique sites: {len(merged_set - dorado_set)}")
print(f"Overlapping sites: {len(dorado_set & merged_set)}")
print(f"Total unique sites: {len(dorado_set | merged_set)}")