## Setup and Imports

In [1]:
import os
import pickle
import shutil
import gzip
import tempfile
from collections import defaultdict, Counter
from intervaltree import IntervalTree
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import numpy as np
import pandas as pd
import pysam
import time
import subprocess
from datetime import datetime

## GTF Processing

In [2]:
def parse_gtf_attributes(attr_string):
    attrs = {}
    for attr in attr_string.strip(';').split(';'):
        attr = attr.strip()
        if not attr:
            continue
        parts = attr.split(' ', 1)
        if len(parts) == 2:
            attrs[parts[0]] = parts[1].strip('"')
    return attrs

def determine_utr_type(exon_number):
    exon_str = str(exon_number).strip()
    if exon_str == '1':
        return "5' UTR"
    return "3' UTR"

def build_interval_trees_fast(gtf_file):
    trees = defaultdict(IntervalTree)
    feature_counts = Counter()
    
    valid_features = {'exon', 'CDS', 'UTR', 'start_codon', 'stop_codon'}
    
    print(f"Building interval trees from {gtf_file}...")
    
    with open(gtf_file, 'r') as f:
        for line in tqdm(f, desc="Loading GTF", unit="lines"):
            if line.startswith('#'):
                continue
            
            fields = line.split('\t', 9)
            if len(fields) < 9:
                continue
            
            feature_type = fields[2]
            if feature_type not in valid_features:
                continue
            
            start = int(fields[3]) - 1
            end = int(fields[4])
            
            attrs = {}
            attr_string = fields[8]
            
            if 'gene_id "' in attr_string:
                gene_start = attr_string.find('gene_id "') + 9
                gene_end = attr_string.find('"', gene_start)
                attrs['gene_id'] = attr_string[gene_start:gene_end]
            
            if 'gene_name "' in attr_string:
                name_start = attr_string.find('gene_name "') + 11
                name_end = attr_string.find('"', name_start)
                attrs['gene_name'] = attr_string[name_start:name_end]
            
            if feature_type == 'UTR' and 'exon_number ' in attr_string:
                exon_start = attr_string.find('exon_number ') + 12
                exon_end = attr_string.find(';', exon_start)
                if exon_end == -1:
                    exon_end = len(attr_string)
                exon_number = attr_string[exon_start:exon_end].strip().strip('"')
                display_feature = determine_utr_type(exon_number)
            else:
                display_feature = feature_type
            
            data = {
                'feature': display_feature,
                'gene_id': attrs.get('gene_id', ''),
                'gene_name': attrs.get('gene_name', ''),
                'strand': fields[6]
            }
            
            trees[fields[0]][start:end] = data
            feature_counts[display_feature] += 1
    
    print(f"Loaded {sum(feature_counts.values())} features")
    
    return dict(trees), dict(feature_counts)

## Annotation

In [3]:
def process_bed_chunk_to_file(args):
    chunk_data, trees, chunk_id, temp_dir = args
    
    temp_file = os.path.join(temp_dir, f"chunk_{chunk_id:05d}.tmp")
    stats = Counter()
    seen_positions = set()
    
    with open(temp_file, 'w') as outf:
        for line in chunk_data:
            if not line.strip():
                continue
            
            fields = line.strip().split('\t')
            
            if len(fields) < 6:
                stats['skipped'] += 1
                continue
            
            try:
                chrom = fields[0]
                start = int(fields[1])
                end = int(fields[2])
                modification = fields[3]
                score = fields[4]
                strand = fields[5]
                
                pos_key = (chrom, start, strand, modification)
                if pos_key in seen_positions:
                    stats['duplicates'] += 1
                    continue
                
                seen_positions.add(pos_key)
                
                if chrom in trees:
                    overlaps = trees[chrom][start]
                    
                    if overlaps:
                        for interval in overlaps:
                            data = interval.data
                            if data['strand'] == strand:
                                stats['annotated'] += 1
                                
                                result = [
                                    chrom, start, end, modification,
                                    interval.begin, interval.end,
                                    score, strand,
                                    data['gene_id'],
                                    data['gene_name'],
                                    data['feature'],
                                    fields[10] if len(fields) > 10 else '.',
                                    fields[11] if len(fields) > 11 else '.',
                                    fields[12] if len(fields) > 12 else '.',
                                    fields[18] if len(fields) > 18 else '.'
                                ]
                                
                                outf.write('\t'.join(str(x) for x in result) + '\n')
                                stats[f"feature_{data['feature']}"] += 1
                    else:
                        stats['no_overlap'] += 1
                else:
                    stats['no_overlap'] += 1
                    
            except (ValueError, IndexError) as e:
                stats['errors'] += 1
                continue
    
    return temp_file, stats

def annotate_bed_file(bed_file, trees, output_file, n_processes=None):
    if n_processes is None:
        n_processes = min(cpu_count(), 20)
    
    print(f"\nAnnotating {bed_file} using {n_processes} processes...")
    
    temp_dir = tempfile.mkdtemp(prefix="gencode_annotator_")
    
    try:
        with open(bed_file, 'r') as f:
            lines = f.readlines()
        
        print(f"Total lines: {len(lines):,}")
        
        chunk_size = max(100000, len(lines) // (n_processes * 4))
        chunk_size = min(chunk_size, 5000000)
        
        chunks = []
        for i in range(0, len(lines), chunk_size):
            chunk = lines[i:i + chunk_size]
            chunks.append((chunk, trees, len(chunks), temp_dir))
        
        print(f"Split into {len(chunks)} chunks")
        
        total_stats = Counter()
        temp_files = []
        
        with Pool(n_processes) as pool:
            for temp_file, stats in tqdm(
                pool.imap_unordered(process_bed_chunk_to_file, chunks),
                total=len(chunks),
                desc="Processing chunks"
            ):
                temp_files.append(temp_file)
                total_stats.update(stats)
        
        temp_files.sort()
        
        print(f"Merging temporary files...")
        total_written = 0
        
        with open(output_file, 'w') as outf:
            for temp_file in temp_files:
                with open(temp_file, 'r') as inf:
                    for line in inf:
                        outf.write(line)
                        total_written += 1
                
                os.remove(temp_file)
        
        print(f"  Annotated positions: {total_stats['annotated']:,}")
        
    finally:
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
    
    return total_stats

## Kmer Processing

In [4]:
def find_5mer_surround_pysam(fasta, reference, position):
    if position < 3:
        return "N/A"
    
    center = position
    five_mer_start = center - 2
    five_mer_end = center + 3
    
    try:
        five_mer = fasta.fetch(reference, five_mer_start, five_mer_end)
        return five_mer
    except (KeyError, ValueError):
        return "N/A"

def reverse_complement(seq):
    complement = str.maketrans("ACGT", "TGCA")
    return seq.translate(complement)[::-1]

def process_kmer_chunk(chunk_data):
    chunk_df, pos_column, fasta_path = chunk_data
    chunk_df = chunk_df.copy()
    
    with pysam.FastaFile(fasta_path) as fasta:
        kmers = []
        for _, row in chunk_df.iterrows():
            kmer = find_5mer_surround_pysam(fasta, row["chrom"], row[pos_column])
            kmers.append(kmer)
        chunk_df['kmer'] = kmers
    
    chunk_df = chunk_df[chunk_df['kmer'] != "N/A"]
    
    adjusted_kmers = []
    for _, row in chunk_df.iterrows():
        if row['strand'] == '+':
            adjusted_kmers.append(row['kmer'])
        elif row['strand'] == '-':
            adjusted_kmers.append(reverse_complement(row['kmer']))
        else:
            adjusted_kmers.append("N/A")
    
    chunk_df['adjusted_kmer'] = adjusted_kmers
    
    return chunk_df

def add_5mer_to_df_parallel(df, fasta_path, pos_column="drs_start", n_cores=8):
    n_cores = min(n_cores, 32)
    
    print(f"Processing k-mers for {len(df):,} rows using {n_cores} cores...")
    
    if len(df) > 10_000_000:
        chunk_size = 50000
    elif len(df) > 1_000_000:
        chunk_size = 20000
    else:
        chunk_size = max(1000, len(df) // (n_cores * 4))
    
    chunks = [(df.iloc[i:i+chunk_size], pos_column, fasta_path) 
              for i in range(0, len(df), chunk_size)]
    
    try:
        with Pool(n_cores) as pool:
            results = list(tqdm(
                pool.imap(process_kmer_chunk, chunks),
                total=len(chunks),
                desc="Processing k-mers",
                unit="chunk"
            ))
    except Exception as e:
        print(f"Parallel processing error: {e}")
        print("Falling back to sequential processing...")
        results = []
        for chunk in tqdm(chunks, desc="Processing k-mers", unit="chunk"):
            results.append(process_kmer_chunk(chunk))
    
    result_df = pd.concat(results, ignore_index=True)
    
    return result_df

## Data Processing

In [5]:
def deduplicate_annotated(df):
    df = df.copy()
    
    priority_map = {
        'three_prime_UTR': 1, '3\' UTR': 1,
        'five_prime_UTR': 2, '5\' UTR': 2,
        'CDS': 3, 'stop_codon': 4,
        'start_codon': 5, 'exon': 999
    }
    
    df['priority'] = df['feature_type'].map(lambda x: priority_map.get(x, 50))
    df_sorted = df.sort_values(
        by=['chrom', 'drs_start', 'strand', 'mod', 'gene_id', 'priority']
    )
    df_dedup = df_sorted.drop_duplicates(
        subset=['chrom', 'drs_start', 'strand', 'mod', 'gene_id'],
        keep='first'
    )
    df_dedup = df_dedup.drop('priority', axis=1)
    
    return df_dedup

def process_annotated_file(annotated_file, modkit_file, fasta_path, output_name, 
                          output_dir, n_cores=8):
    print(f"\nProcessing {output_name}...")
    
    # Load modkit data
    modkit_df = pd.read_csv(modkit_file, sep='\t', header=None)
    modkit_df.rename(columns={
        0: 'chrom', 1: 'start', 2: 'end', 3: 'mod', 4: 'score',
        5: 'strand', 6: 'start_1', 7: 'end_1', 8: 'color',
        9: 'n_valid_cov', 10: 'mod_percent', 11: 'n_mod',
        12: 'n_canon', 13: 'n_other_mod', 14: 'n_delete',
        15: 'n_fail', 16: 'n_diff', 17: 'n_no_call',
        18: 'fp_adjusted_mod_percent'
    }, inplace=True)
    modkit_df['fp_adjusted_mod_percent'] = modkit_df['fp_adjusted_mod_percent'].clip(lower=0)
    
    # Load annotated data
    annotated_df = pd.read_csv(annotated_file, sep='\t', header=None)
    annotated_columns = [
        'chrom', 'drs_start', 'drs_end', 'mod', 'feature_start', 'feature_end',
        'score', 'strand', 'gene_id', 'gene_name', 'feature_type',
        'mod_percent', 'n_mod', 'n_canon', 'fp_adjusted_mod_percent'
    ]
    annotated_df.columns = annotated_columns
    
    for col in ['score', 'mod_percent', 'n_mod', 'n_canon', 'fp_adjusted_mod_percent']:
        annotated_df[col] = pd.to_numeric(annotated_df[col], errors='coerce')
    
    # Apply quality filtering
    print(f"  Before filtering: {len(annotated_df):,} rows")
    annotated_valid = annotated_df[
        (annotated_df['score'] >= 20) & 
        (annotated_df['fp_adjusted_mod_percent'] >= 20)
    ].copy()
    print(f"  After filtering: {len(annotated_valid):,} rows")
    
    # Deduplicate
    annotated_valid = deduplicate_annotated(annotated_valid)
    print(f"  After deduplication: {len(annotated_valid):,} rows")
    
    # Add k-mers
    annotated_valid_kmer = add_5mer_to_df_parallel(
        annotated_valid, fasta_path, pos_column="drs_start", n_cores=n_cores
    )
    print(f"  Final dataset: {len(annotated_valid_kmer):,} rows")
    
    # Save
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{output_name}_annotated_valid_kmer.pkl")
    annotated_valid_kmer.to_pickle(output_file)
    print(f"  Saved to: {output_file}")
    
    return annotated_valid_kmer

## Paths

In [6]:
# Download reference files quietly if they don't exist
GTF_FILE_GZ = './gencode.v47.annotation.gtf.gz'
FASTA_FILE_GZ = './GCA_000001405.15_GRCh38_full_analysis_set.fna.gz'

if not os.path.exists(GTF_FILE_GZ):
    print("Downloading GTF file...")
    subprocess.run(['wget', '-q', '-O', GTF_FILE_GZ, 
                    'https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_47/gencode.v47.annotation.gtf.gz'])
    print(f"  Downloaded to {GTF_FILE_GZ}")

if not os.path.exists(FASTA_FILE_GZ):
    print("Downloading FASTA file...")
    subprocess.run(['wget', '-q', '-O', FASTA_FILE_GZ,
                    'https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/001/405/GCA_000001405.15_GRCh38/seqs_for_alignment_pipelines.ucsc_ids/GCA_000001405.15_GRCh38_full_analysis_set.fna.gz'])
    print(f"  Downloaded to {FASTA_FILE_GZ}")

# Decompress files
GTF_FILE = GTF_FILE_GZ[:-3]  # Remove .gz extension
FASTA_FILE = FASTA_FILE_GZ[:-3]

if not os.path.exists(GTF_FILE):
    print("Decompressing GTF file...")
    subprocess.run(['gunzip', '-k', GTF_FILE_GZ])  # -k keeps the original
    print(f"  Decompressed to {GTF_FILE}")

if not os.path.exists(FASTA_FILE):
    print("Decompressing FASTA file...")
    subprocess.run(['gunzip', '-k', FASTA_FILE_GZ])
    print(f"  Decompressed to {FASTA_FILE}")

# Output directories
OUTPUT_DIR = '../../Exemplar_Data/annotated_output/'
PICKLE_DIR = '../../Exemplar_Data/annotated_output/pickle_output/'

# Input files - These are the modkit output BED files
MODKIT_FILES = ['/scratch/stein.an/GM12878_Official/08_07_24_R9RNA_GM12878_mRNA_RT_sup_8mods_polyA_sorted_filtered.chr12-112000000-114000000_pileup_fp_adjusted.tsv']  # Modkit output BED files
SAMPLE_NAMES = ['08_07_24_GM12878_chr12-112000000-114000000']  # Names for output files

# Processing parameters
N_PROCESSES = 8  # For annotation
N_CORES = 8       # For k-mer extraction

# Cache file for GTF trees (speeds up repeated runs)
CACHE_FILE = './gencode.v47.annotation.tree.pkl'

## Build / Load Trees

In [7]:
if os.path.exists(CACHE_FILE):
    print(f"Loading cached trees from {CACHE_FILE}")
    with open(CACHE_FILE, 'rb') as f:
        trees, feature_counts = pickle.load(f)
else:
    trees, feature_counts = build_interval_trees_fast(GTF_FILE)
    
    print(f"Caching trees to {CACHE_FILE}")
    with open(CACHE_FILE, 'wb') as f:
        pickle.dump((trees, feature_counts), f, protocol=pickle.HIGHEST_PROTOCOL)

Loading cached trees from ./gencode.v47.annotation.tree.pkl


## Process All Data

In [8]:
print(f"\n{'='*60}")
print("RNA Modification Annotation and Processing Pipeline")
print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"{'='*60}\n")

os.makedirs(OUTPUT_DIR, exist_ok=True)

for modkit_bed_file, sample_name in zip(MODKIT_FILES, SAMPLE_NAMES):
    print(f"\n{'='*50}")
    print(f"Processing sample: {sample_name}")
    print(f"{'='*50}")
    
    # Step 1: Annotate the modkit BED file with gene information
    annotated_file = os.path.join(OUTPUT_DIR, f"{sample_name}_annotated.bed")
    
    if not os.path.exists(annotated_file):
        print("\nStep 1: Annotating modkit BED file with gene information...")
        annotate_bed_file(modkit_bed_file, trees, annotated_file, N_PROCESSES)
    else:
        print(f"\nStep 1: Using existing annotated file: {annotated_file}")
    
    # Step 2: Process annotated file - add k-mers, filter, deduplicate
    print("\nStep 2: Processing annotated file (filtering, k-mers)...")
    result = process_annotated_file(
        annotated_file,     # The annotated version
        modkit_bed_file,    # The original modkit file (for additional columns)
        FASTA_FILE, 
        sample_name, 
        PICKLE_DIR, 
        N_CORES
    )
    
    # Print summary
    print(f"\nSummary for {sample_name}:")
    print(f"  Total rows: {len(result):,}")
    print(f"  Unique genes: {result['gene_id'].nunique():,}")
    print(f"  Modification types:")
    for mod, count in result['mod'].value_counts().items():
        print(f"    {mod}: {count:,}")

print(f"\n{'='*60}")
print("Pipeline Complete!")
print(f"Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"{'='*60}")


RNA Modification Annotation and Processing Pipeline
Started at: 2025-11-04 15:36:00


Processing sample: 08_07_24_GM12878_chr12-112000000-114000000

Step 1: Annotating modkit BED file with gene information...

Annotating /scratch/stein.an/GM12878_Official/08_07_24_R9RNA_GM12878_mRNA_RT_sup_8mods_polyA_sorted_filtered.chr12-112000000-114000000_pileup_fp_adjusted.tsv using 8 processes...
Total lines: 670,373
Split into 7 chunks


Processing chunks: 100%|██████████| 7/7 [02:23<00:00, 20.48s/it]


Merging temporary files...
  Annotated positions: 1,088,323

Step 2: Processing annotated file (filtering, k-mers)...

Processing 08_07_24_GM12878_chr12-112000000-114000000...
  Before filtering: 1,088,323 rows
  After filtering: 1,024 rows
  After deduplication: 172 rows
Processing k-mers for 172 rows using 8 cores...


Processing k-mers: 100%|██████████| 1/1 [00:00<00:00, 34.91chunk/s]

  Final dataset: 172 rows
  Saved to: ../../Exemplar_Data/annotated_output/pickle_output/08_07_24_GM12878_chr12-112000000-114000000_annotated_valid_kmer.pkl

Summary for 08_07_24_GM12878_chr12-112000000-114000000:
  Total rows: 172
  Unique genes: 17
  Modification types:
    a: 105
    m: 44
    17596: 14
    19227: 4
    17802: 2
    69426: 1
    19229: 1
    19228: 1

Pipeline Complete!
Completed at: 2025-11-04 15:38:27



