In [43]:
import pysam
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

def parse_attr(attr, key):
    for item in attr.split(';'):
        if item.strip().startswith(key):
            return item.strip().split(' ')[1].replace('"', '')
    return None

def base_counts_to_df(base_counts, total_counts, region_start, region_end):
    positions = range(region_start, region_end + 1)
    records = []
    for pos in positions:
        row = {'pos': pos}
        for base in ['A', 'T', 'C', 'G']:
            row[f'{base}_count'] = base_counts[pos].get(base, 0)
        row['total_count'] = total_counts.get(pos, 0)
        records.append(row)
    return pd.DataFrame(records)

def plot_modifications(
    bamfile, gtf_file, bed_file,modcode,
    region_chr, region_start, region_end, region_strand,
    output_figure_path
):

    base_counts = defaultdict(lambda: defaultdict(int))
    total_counts = defaultdict(int)
    
    print('Counting Modified and Unmodified Counts')

    bam = pysam.AlignmentFile(bamfile, "rb")
    
    for pileupcolumn in bam.pileup(region_chr, region_start, region_end, truncate=True):
        pos = pileupcolumn.reference_pos
        for pileupread in pileupcolumn.pileups:
            read = pileupread.alignment
            if read.is_supplementary or read.is_secondary:
                continue
            strand = '-' if read.is_reverse else '+'
            if strand != region_strand:
                continue
            if not pileupread.is_del and not pileupread.is_refskip:
                base = read.query_sequence[pileupread.query_position]
                base_counts[pos][base] += 1

    for read in bam.fetch(region_chr, region_start, region_end):
        if read.is_secondary: #read.is_supplementary or 
            continue
        strand = '-' if read.is_reverse else '+'
        if strand != region_strand:
            continue
        ref_positions = read.get_reference_positions(full_length=False)
        for ref_pos in ref_positions:
            if ref_pos is None:
                continue
            if region_start <= ref_pos <= region_end:
                total_counts[ref_pos] += 1

    bam.close()

    coverage_df = base_counts_to_df(base_counts, total_counts, region_start, region_end)

    bed_cols = [
        'chrom', 'start', 'end', 'mod', 'score', 'strand',
        'thickStart', 'thickEnd', 'itemRgb',
        'total_call', 'mod_ratio', 'modified_call'
    ]
    df_bed = pd.read_csv(
        bed_file,
        sep='\t',
        names=bed_cols,
        header=None,
        usecols=range(12)
    )
    df_bed['pos'] = df_bed['end']

    region_bed = df_bed[
        (df_bed['chrom'] == region_chr) &
        (df_bed['pos'] >= region_start) &
        (df_bed['pos'] <= region_end) &
        (df_bed['strand'] == region_strand)
    ]

    modification_counts = (
        region_bed.groupby(['pos', 'mod'])['modified_call']
        .sum()
        .unstack(fill_value=0)
        .sort_index()
    )

    modification_counts.reset_index(inplace=True)

    merged_df = pd.merge(coverage_df, modification_counts, on='pos', how='left')
    merged_df = merged_df.fillna(0)

    print('Extracting gene annotations')

    gtf_cols = ["chrom", "source", "feature", "start", "end", "score", "strand", "frame", "attribute"]
    gtf = pd.read_csv(gtf_file, sep='\t', comment='#', header=None, names=gtf_cols)

    gtf['gene_name'] = gtf['attribute'].apply(lambda x: parse_attr(x, 'gene_name'))
    gtf['gene_id'] = gtf['attribute'].apply(lambda x: parse_attr(x, 'gene_id'))
    gtf['transcript_id'] = gtf['attribute'].apply(lambda x: parse_attr(x, 'transcript_id'))

    exons = gtf[
        (gtf['feature'] == 'exon') &
        (gtf['chrom'] == region_chr) &
        (gtf['start'] <= region_end) &
        (gtf['end'] >= region_start) &
        (gtf['strand'] == region_strand)
    ]

    transcripts = exons.groupby('transcript_id')

    positions = merged_df['pos']
    base_total = merged_df['total_count']
    A_total = merged_df['A_count']

    a_total = merged_df.get( modcode, pd.Series(0, index=positions))
    
    print('Start drawing')

    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(20, 8), sharex=True,
        gridspec_kw={'height_ratios': [3, 1]}
    )

    bar_width = 1
    
    ax1.plot(positions, base_total, color='lightgrey', label='Total base count')
    ax1.fill_between(positions, 0, base_total, color='lightgrey', alpha=0.3)

    
#    ax1.bar(positions, A_total, width=bar_width, color='lightblue', label='A count')
#    ax1.bar(positions, a_total, width=bar_width, color='red', label='m6A count')
    
    nonzero_mask = a_total > 0
    markerline, stemlines, baseline = ax1.stem(
        positions[nonzero_mask],
        a_total[nonzero_mask],
        basefmt=' ')
    markerline.set_color('#ffb178')  # 改 marker 颜色
    stemlines.set_color('#ffb178')   # 改 stem 线颜色
    plt.setp(markerline, markersize=2)
    plt.setp(stemlines, linewidth=1)


    ax1.set_ylabel('Counts')
    ax1.legend()

    ax1.tick_params(axis='x', labelbottom=True)

    strand_color = {'+': 'skyblue', '-': 'salmon'}
    intron_line_color = 'black'
    y_step = 1
    y_base = 0

    for idx, (transcript_id, group) in enumerate(transcripts):
        strand = group['strand'].iloc[0]
        gene_name = group['gene_name'].iloc[0] or transcript_id
        y = y_base + idx * y_step

        exon_starts = []
        exon_ends = []
        for _, row in group.iterrows():
            exon_start = row['start']
            exon_end = row['end']
            exon_starts.append(exon_start)
            exon_ends.append(exon_end)

            ax2.add_patch(mpatches.Rectangle(
                (exon_start, y),
                exon_end - exon_start,
                0.4,
                facecolor=strand_color.get(strand, 'grey'),
                edgecolor='black'
            ))

        exon_centers = [(start + end) / 2 for start, end in zip(exon_starts, exon_ends)]
        for i in range(len(exon_centers) - 1):
            ax2.add_line(mlines.Line2D(
                [exon_ends[i], exon_starts[i+1]],
                [y + 0.2, y + 0.2],
                color=intron_line_color,
                linewidth=0.5
            ))

        region_center = (region_start + region_end) / 2
        label_offset = 0.1
        ax2.text(region_center, y + label_offset, gene_name, ha='center', fontsize=8)

    ax2.set_ylim(-0.5, len(transcripts) * y_step)

    ax2.axis('off')

    ax2.set_xlabel('Genomic Position')

    region_xlim = (region_start, region_end)
    ax1.set_xlim(region_xlim)
    ax2.set_xlim(region_xlim)

    fig.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.15)

    plt.savefig(output_figure_path, dpi=300, bbox_inches='tight')
    plt.close()

In [39]:
plot_modifications(
    bamfile="/scratch/lb4489/project/dRNA/HIP_mum_mapped.sorted.bam",
    gtf_file="/scratch/lb4489/bioindex/gencode.vM33.annotation.gtf",
    bed_file="/scratch/lb4489/project/dRNA/modkit_all_filted.bed",
    region_chr="chr18",
    region_start=61058690,
    region_end=61121224,
    region_strand='+',
    output_figure_path="/scratch/lb4489/project/dRNA/Camk2a_figure.png"
)

Counting Modified and Unmodified Counts
Extracting gene annotations
Start drawing


In [40]:
plot_modifications(
    bamfile="/scratch/lb4489/project/dRNA/HIP_mum_mapped.sorted.bam",
    gtf_file="/scratch/lb4489/bioindex/gencode.vM33.annotation.gtf",
    bed_file="/scratch/lb4489/project/dRNA/modkit_all_filted.bed",
    region_chr="chr18",
    region_start=61115649,
    region_end=61122596,
    region_strand='+',
    output_figure_path="/scratch/lb4489/project/dRNA/Camk2a_lastexon_figure.png"
)

Counting Modified and Unmodified Counts
Extracting gene annotations
Start drawing


In [41]:
plot_modifications(
    bamfile="/scratch/lb4489/project/dRNA/HIP_mum_mapped.sorted.bam",
    gtf_file="/scratch/lb4489/bioindex/gencode.vM33.annotation.gtf",
    bed_file="/scratch/lb4489/project/dRNA/modkit_all_filted.bed",
    region_chr="chr11",
    region_start=5919644,
    region_end=6015748,
    region_strand='-',
    output_figure_path="/scratch/lb4489/project/dRNA/Camk2b_figure.png"
)

Counting Modified and Unmodified Counts
Extracting gene annotations
Start drawing


In [42]:
plot_modifications(
    bamfile="/scratch/lb4489/project/dRNA/HIP_mum_mapped.sorted.bam",
    gtf_file="/scratch/lb4489/bioindex/gencode.vM33.annotation.gtf",
    bed_file="/scratch/lb4489/project/dRNA/modkit_all_filted.bed",
    region_chr="chr11",
    modcode='a',
    region_start=5919080,
    region_end=5922638,
    region_strand='-',
    output_figure_path="/scratch/lb4489/project/dRNA/Camk2b_lastexon_figure.png"
)


Counting Modified and Unmodified Counts
Extracting gene annotations
Start drawing


In [44]:
plot_modifications(
    bamfile="/scratch/lb4489/project/dRNA/HIP_mum_mapped.sorted.bam",
    gtf_file="/scratch/lb4489/bioindex/gencode.vM33.annotation.gtf",
    bed_file="/scratch/lb4489/project/dRNA/modkit_all_filted.bed",
    region_chr="chr11",
    modcode='m',
    region_start=5919080,
    region_end=5922638,
    region_strand='-',
    output_figure_path="/scratch/lb4489/project/dRNA/Camk2b_lastexon_figure_m5c.png"
)

Counting Modified and Unmodified Counts
Extracting gene annotations
Start drawing
