In [None]:
import pysam
import pandas as pd
import numpy as np
import multiprocessing
pd.options.display.max_columns = None
pd.options.display.max_rows = None

In [None]:
binsize=25 #fix number of CpG
maxgap=500 #max length of gap between bins

In [None]:
sample_name='FYC_3026_D0_M_s'
bamfile=f'/t2/ny/dac/modbamall_drd/res/{sample_name}.bam'

In [None]:
# all CpGs in the human genome
df=pd.read_csv("/t2/ny/ref/hg38_autoxy_cg.bed",sep="\t",header=None,dtype={0:str},usecols=[0,1,2,5])
df.columns=["chro","start","end","strand"]

# binning
nbin = 1
bin_column = [1]
chromStart = df['start'].values
chrom = df['chro'].values

present_size=1
for i in range(1, df.shape[0]):
    present_size+=1
    if chromStart[i] - chromStart[i-1] > maxgap or chrom[i] != chrom[i-1] or present_size>binsize*2:
        nbin += 1
        present_size=1
    bin_column.append(nbin)
df['bin'] = bin_column

In [None]:
# remove bins that do not reach binsize
df_only_in_binsize = df.groupby('bin').filter(lambda x: len(x) == binsize*2)
total_bin_num=len(df_only_in_binsize.bin.unique())

In [None]:
def pairwise_similarity(d_A,d_B,l_cg_pos):
    score=0
    overlap=0
    for pos in l_cg_pos[::2]:
        if (pos in d_A or pos+1 in d_A) and (pos in d_B or pos+1 in d_B):
            overlap+=1
            status_A=status_B=0
            if pos in d_A:
                status_A=max(status_A,d_A[pos])
            if pos+1 in d_A:
                status_A=max(status_A,d_A[pos+1])
            if pos in d_B:
                status_B=max(status_B,d_B[pos])
            if pos+1 in d_B:
                status_B=max(status_B,d_B[pos+1])                
            if status_A==status_B:
                score+=1
    return overlap,score
    
    

def cal_matrix(dfbin,bam):
    l_cg_pos=list(dfbin['start'])
    start_pos=l_cg_pos[0]
    end_pos=l_cg_pos[-1]
    chrom=dfbin['chro'].iloc[0]
    reads = bam.fetch(chrom, start_pos, end_pos)
    
    d_reads={}
    
    for read in reads: 
        if read.is_unmapped or read.is_secondary or read.is_supplementary:
            continue
        if read.has_tag("MM"):
            mm_tag = read.get_tag("MM")            
            mm_tag_l=[int(i) for i in mm_tag.strip(';').split(',')[1:]]
            l_cg_pos_within_read=[i for i in l_cg_pos if i>=read.reference_start-1 and i<= read.reference_end-1]
            d_methy=dict(zip(l_cg_pos_within_read,[0]*len(l_cg_pos_within_read)))
            pos=read.reference_start-1
            for i in mm_tag_l:
                pos+=i+1
                if pos in l_cg_pos:
                    d_methy[pos]=1
            
        d_reads[read.query_name]=d_methy
    
#     identical_matrix = [[0] * len(d_reads) for _ in range(len(d_reads))]
#     overlap_matrix = [[0] * len(d_reads) for _ in range(len(d_reads))]

    similarity_matrix = np.zeros((len(d_reads), len(d_reads)))
    total_overlap_point=0
    keys = list(d_reads.keys())
    for i in range(len(keys)):
        for j in range(i, len(keys)):
            key1 = keys[i]
            key2 = keys[j]
            data1 = d_reads[key1]
            data2 = d_reads[key2]

            n_overlap,identical_score = pairwise_similarity(data1, data2,l_cg_pos)
            if n_overlap:
                similarity_matrix[i][j]=similarity_matrix[j][i]=identical_score/n_overlap
                total_overlap_point+=n_overlap
            else:
                similarity_matrix[i][j]=similarity_matrix[j][i]=np.nan
#             overlap_matrix[i][j]=n_overlap
#             overlap_matrix[j][i]=n_overlap
#             identical_matrix[i][j] = identical_score
#             identical_matrix[j][i] = identical_score
            
    return similarity_matrix,total_overlap_point
    
def get_summary(dfbin,bamfile):
    bam = pysam.AlignmentFile(bamfile, 'rb')
    similarity_matrix,total_overlap_point=cal_matrix(dfbin,bam)

    half_diagonal = similarity_matrix[np.triu_indices(len(similarity_matrix), k=1)]
    if half_diagonal.size > 0:
        mean_value = np.nanmean(half_diagonal)
    else:
        mean_value = np.nan
        
    reads=len(similarity_matrix)
    overlap_read_count = np.count_nonzero(~np.isnan(half_diagonal))
    
    dfbin_0,dfbin_last=dfbin.index[0],dfbin.index[-1]
    start=dfbin['start'][dfbin_0]
    end=dfbin['end'][dfbin_last]
    
    summary={
    'chr':dfbin['chro'].iloc[0],
    'start':start,
    'end':end,
    'length':end-start,
    'similarity_mean':mean_value,
    'n_overlapped_reads_pairs':overlap_read_count,
    'n_reads':reads,
    'total_overlapped_points':total_overlap_point}
    return summary            

In [None]:
pool = multiprocessing.Pool(processes=128)
l_summary = []
for i in range(total_bin_num):
    result = pool.apply_async(get_summary, (df_only_in_binsize[i*binsize*2:(i+1)*binsize*2], bamfile))
    l_summary.append(result)
pool.close()
pool.join()

In [None]:
l_summary = [res.get() for res in l_summary]
df_summary = pd.DataFrame(l_summary)

In [None]:
df_summary.to_csv(f'bin{binsize}CpG_{maxgap}gap_{sample_name}.csv',index=False,encoding='utf_8_sig')