In [1]:
import pysam
import pandas as pd
import numpy as np
import multiprocessing
pd.options.display.max_columns = None
pd.options.display.max_rows = None

In [2]:
binsize=25 #fix number of CpG
maxgap=500 #max length of gap between bins

In [None]:
# # all CpGs in the human genome
# df=pd.read_csv("/t2/ny/ref/hg38_autoxy_cg.bed",sep="\t",header=None,dtype={0:str},usecols=[0,1,2,5])
# df.columns=["chro","start","end","strand"]

# # binning
# nbin = 1
# bin_column = [1]
# chromStart = df['start'].values
# chrom = df['chro'].values

# present_size=1
# for i in range(1, df.shape[0]):
#     present_size+=1
#     if chromStart[i] - chromStart[i-1] > maxgap or chrom[i] != chrom[i-1] or present_size>binsize*2:
#         nbin += 1
#         present_size=1
#     bin_column.append(nbin)
# df['bin'] = bin_column

# # remove bins that do not reach binsize
# df_only_in_binsize = df.groupby('bin').filter(lambda x: len(x) == binsize*2)
# total_bin_num=len(df_only_in_binsize.bin.unique())

In [3]:
# Directly use the ready-made binned bed
df_only_in_binsize=pd.read_csv("/t4/like/aml/bin25CpG_500gap.bed",sep="\t",header=None,dtype={0:str})
df_only_in_binsize.columns=["chro","start","end","bin",'strand']
total_bin_num=len(df_only_in_binsize.bin.unique())

In [4]:
def vectorized_similarity(data):
    nan_mask = ~np.isnan(data)
    num_rows = data.shape[0]
    
    data_expanded_1 = np.expand_dims(data, axis=1)
    data_expanded_2 = np.expand_dims(data, axis=0)
    mask_expanded_1 = np.expand_dims(nan_mask, axis=1)
    mask_expanded_2 = np.expand_dims(nan_mask, axis=0)

    valid_mask = mask_expanded_1 & mask_expanded_2
    equality = (data_expanded_1 == data_expanded_2) & valid_mask

    valid_counts = np.sum(valid_mask, axis=2)
    equal_counts = np.sum(equality, axis=2)

    valid_counts[valid_counts == 0] = 1
    similarity = equal_counts / valid_counts

    return similarity,valid_counts
    
def get_read_m_array(read,l_cg_pos):
    d_m_bases=read.modified_bases
    d_methy=dict(zip(l_cg_pos,[np.nan]*len(l_cg_pos)))
    ref_positions = read.get_reference_positions(full_length=True)
    if ('C', 0, 'm') in d_m_bases:
        strand_flag=0
    elif ('C', 1, 'm') in d_m_bases:
        strand_flag=1
    elif d_m_bases=={}:
        return [False,False]
    else:
        raise Exception(f"cant handle {d_m_bases} in {chrom}:{start_pos}-{end_pos}")

    # Only take the pos of the corresponding strand
    l_cg_pos=l_cg_pos[strand_flag::2]
    
    d_methy=dict(zip(l_cg_pos,[np.nan]*len(l_cg_pos)))
    for i in d_m_bases[('C', strand_flag, 'm')]:
        pos=ref_positions[i[0]]
        if pos in l_cg_pos:
            if i[1]>204:
                d_methy[pos]=1
            elif i[1]<51:
                d_methy[pos]=0
    l_methy=list(d_methy.values())
    return l_methy,strand_flag

def get_read_matrix(dfbin,bam,l_cg_pos,strand='both'):
    
    start_pos=l_cg_pos[0]
    end_pos=l_cg_pos[-1]
    chrom=dfbin['chro'].iloc[0]
    reads = bam.fetch(chrom, start_pos, end_pos)

    read_matrix=[]

    for read in reads: 
        if read.is_unmapped or read.is_secondary or read.is_supplementary:
            continue
        if read.has_tag("MM"):
            l_methy,strand_flag=get_read_m_array(read,l_cg_pos)            
            if not l_methy:
                continue
            if strand!='both' and strand_flag!=strand:
                continue
            read_matrix.append(l_methy)
                
    read_matrix=np.array(read_matrix)
    return read_matrix

def get_similarity_summary(dfbin,bam,l_cg_pos,strand='both'):
    read_matrix=get_read_matrix(dfbin,bam,l_cg_pos,strand=strand)
    if np.size(read_matrix):
        similarity_matrix,valid_counts=vectorized_similarity(read_matrix)
        half_diagonal_similarity_matrix = similarity_matrix[np.triu_indices(len(similarity_matrix), k=1)]
        if half_diagonal_similarity_matrix.size > 0:
            mean_value = np.nanmean(half_diagonal_similarity_matrix)
        else:
            mean_value = np.nan
            
        half_diagonal_valid_counts = valid_counts[np.triu_indices(len(valid_counts), k=1)]
        total_overlap_point=np.sum(half_diagonal_valid_counts)
        
        reads=len(similarity_matrix)
        overlap_read_count = np.count_nonzero(~np.isnan(half_diagonal_similarity_matrix))
    else:
        reads=overlap_read_count=total_overlap_point=0
        mean_value=np.nan
    return mean_value,overlap_read_count,reads,total_overlap_point

def get_mpercentage(dfbin,bam,l_cg_pos):
    read_matrix=get_read_matrix(dfbin,bam,l_cg_pos)
    if np.size(read_matrix):
        read_matrix=get_read_matrix(dfbin,bam,l_cg_pos)
        m=np.sum(read_matrix == 1)
        unm=np.sum(read_matrix == 0)
        if m+unm==0:
            return np.nan
        return m/(m+unm)
    return np.nan
    
    
def get_summary(dfbin,bamfile):
    bam = pysam.AlignmentFile(bamfile, 'rb')
    l_cg_pos=list(dfbin['start'])
    dfbin_0,dfbin_last=dfbin.index[0],dfbin.index[-1]
    start=dfbin['start'][dfbin_0]
    end=dfbin['end'][dfbin_last]
    mean_value,overlap_read_count,reads,total_overlap_point=get_similarity_summary(dfbin,bam,l_cg_pos,strand='both')
#     f_mean_value,f_overlap_read_count,f_reads,f_total_overlap_point=get_similarity_summary(dfbin,bam,l_cg_pos,strand=0)
#     r_mean_value,r_overlap_read_count,r_reads,r_total_overlap_point=get_similarity_summary(dfbin,bam,l_cg_pos,strand=1)
#     mpercentage=get_mpercentage(dfbin,bam,l_cg_pos)
    summary={
    'chr':dfbin['chro'].iloc[0],
    'start':start,
    'end':end,
    'length':end-start,
#     'mpercentage':mpercentage,
        
    'similarity_mean':mean_value,
    'n_overlapped_reads_pairs':overlap_read_count,
    'n_reads':reads,
    'total_overlapped_points':total_overlap_point,
        
#     'f_similarity_mean':f_mean_value,
#     'f_n_overlapped_reads_pairs':f_overlap_read_count,
#     'f_n_reads':f_reads,
#     'f_total_overlapped_points':f_total_overlap_point,
        
#     'r_similarity_mean':r_mean_value,
#     'r_n_overlapped_reads_pairs':r_overlap_read_count,
#     'r_n_reads':r_reads,
#     'r_total_overlapped_points':r_total_overlap_point,
    }
    
    
    return summary            

In [7]:
import os
# path='/t2/ny/dac/dec_cel_bammod'
path='/t2/ny/dnmt/modbam_drd'
name_list=[i[:-8] for i in os.listdir(path) if i.endswith('_M_s.bam')]

d_sample={}
for i in name_list:
#     d_sample[i]='/t2/ny/dac/dec_cel_bammod'
    d_sample[i]='/t2/ny/dnmt/modbam_drd'

In [8]:
ls /t2/ny/dnmt/modbam_drd

MT1340_M_s.bam          MT2658_M_s.bam      WT561_M_s.bam.bai
MT1340_M_s.bam.bai      MT2658_M_s.bam.bai  WT621_M_s.bam
MT2136_5hmcM.bam        MT614_M_s.bam       WT621_M_s.bam.bai
MT2136_5hmcM_s.bam      MT614_M_s.bam.bai   WT730_M_s.bam
MT2136_5hmcM_s.bam.bai  MT615_M_s.bam       WT730_M_s.bam.bai
MT2136_M_s.bam          MT615_M_s.bam.bai   WT881_M_s.bam
MT2136_M_s.bam.bai      WT1027_M_s.bam      WT881_M_s.bam.bai
MT2137_M_s.bam          WT1027_M_s.bam.bai  WT898_M_s.bam
MT2137_M_s.bam.bai      WT2795_M_s.bam      WT898_M_s.bam.bai
MT2652_M_s.bam          WT2795_M_s.bam.bai
MT2652_M_s.bam.bai      WT561_M_s.bam


In [10]:
# import os
# path='/t2/ny/dac/dec_cel_modbam_drdsup'
# name_list=[i[:-8] for i in os.listdir(path) if i.endswith('K052_DMSO_M_s.bam')]

# d_sample={}
# for i in name_list:
#     d_sample[i]=path

In [None]:
# d_sample

In [5]:
# nonres=['LK_2671_D10','CWH_2713_D10','WCMA_2756_D0','CWH_2713_D0','HMS_2075_D0','LK_2671_D0','LCK_1341_D0','HMS_2075_D10','LCK_1341_D10','WCMA_2756_D10']
# res=['NKW_2685_D10', 'TTK_2749_D0', 'LRL_2827_D0', 'LCK_2643_D10', 'MYK_3054_D0', 'CBM_3053_D10', 'NKW_2685_D0', 'LCK_2643_D0', 'CBM_3053_D0', 'LWK_2966_D0', 'WLM_3033_D10', 'TTK_2749_D10', 'CWH_3113_D10', 'LLC_2861_D0', 'CWFS_2816_D0', 'LLC_2861_D10', 'FYC_3026_D10', 'LWK_2966_D10', 'CWH_3113_D0', 'FYC_3026_D0', 'LRL_2827_D10', 'WLM_3033_D0', 'CWFS_2816_D10', 'SKK_2719_D0', 'SKK_2719_D10', 'MYK_3054_D10']
# pbsc=['PBSC1', 'PBSC3', 'PBSC2']

# d_sample={}
# for i in nonres:
#     d_sample[i]='/t2/ny/dac/modbamall_drd/nonres'
# for i in res:
#     d_sample[i]='/t2/ny/dac/modbamall_drd/res'
# for i in pbsc:
#     d_sample[i]='/t2/ny/dac/modbamall_drd/pbsc'

In [None]:
for sample_name in name_list:
# for sample_name in nonres+res+pbsc:

    bamfile=f'{d_sample[sample_name]}/{sample_name}_M_s.bam'
    pool = multiprocessing.Pool(processes=128)
    l_summary = []
    for i in range(total_bin_num):
        result = pool.apply_async(get_summary, (df_only_in_binsize[i*binsize*2:(i+1)*binsize*2], bamfile))
        l_summary.append(result)
    pool.close()
    pool.join()
    
    l_summary = [res.get() for res in l_summary]
    df_summary = pd.DataFrame(l_summary)
    
    df_summary.to_csv(f'bin{binsize}CpG_{maxgap}gap_{sample_name}.csv',index=False,encoding='utf_8_sig')

In [5]:
# example


sample_name='LK_2671_D10'
bamfile=f'/t2/ny/dac/modbamall_drd/nonres/{sample_name}_M_s.bam'
# sample_name=name_list[0]
# bamfile=f'{d_sample[sample_name]}/{sample_name}_s.bam'

i=45
dfbin=df_only_in_binsize[i*binsize*2:(i+1)*binsize*2]

bam = pysam.AlignmentFile(bamfile, 'rb')

l_cg_pos=list(dfbin['start'])

start_pos=l_cg_pos[0]
end_pos=l_cg_pos[-1]
chrom=dfbin['chro'].iloc[0]
reads = bam.fetch(chrom, start_pos, end_pos)

read_matrix=[]

for read in reads: 
    if read.is_unmapped or read.is_secondary or read.is_supplementary:
        continue
    if read.has_tag("MM"):
        l_methy,strand_flag=get_read_m_array(read,l_cg_pos)            
        if not l_methy:
            continue
        read_matrix.append(l_methy)

read_matrix=np.array(read_matrix)
        

In [6]:
dfbin.head()

Unnamed: 0,chro,start,end,bin,strand
2250,1,134236,134237,80,+
2251,1,134237,134238,80,-
2252,1,134264,134265,80,+
2253,1,134265,134266,80,-
2254,1,134295,134296,80,+
