# Cell Aggregation

In [3]:
import scanpy as sc
from sklearn.neighbors import kneighbors_graph
import numpy as np
import pandas as pd
import os
import shutil

### Step 1: Find Cell Neighbors Using K-Nearest Neighbors (KNN)

In this step, we will use the K-Nearest Neighbors (KNN) algorithm to find the nearest neighboring cells for each cell in the dataset. 


In [None]:
# Load DOLPHIN model results from the h5ad file
adata = anndata.read_h5ad("DOLPHIN_Z.h5ad")

# Define the number of neighbors (default is 10, including the main cell itself)
n_neighbor = 10 

In [None]:
cell_conn_new = kneighbors_graph(adata.obsm['X_z'], 10, mode='connectivity',include_self=True, n_jobs=20).toarray()
cell_dist_new = kneighbors_graph(adata.obsm['X_z'], 10, mode='distance', include_self=True,n_jobs=20).toarray()

In [None]:
#save the neighborhood information for 0.701
main_name = []
combine_name = []
for _cell_idx in range(0, adata.obs.shape[0]):
    print("main_sample", adata.obs.index[_cell_idx])
    for i, _idx  in enumerate(np.nonzero(cell_conn_new[_cell_idx])[0]):
        print(adata.obs.index[_idx])
        main_name.append(adata.obs.index[_cell_idx])
        combine_name.append(adata.obs.index[_idx])

In [None]:
pd.DataFrame({"main_name": main_name, "neighbor":combine_name}).to_csv("DOLPHIN_aggregation_KNN10.csv", index=None)

### Step 2: Cell Aggregation - Adding Junction Reads from Neighboring Cells

In this step, we will perform cell aggregation by incorporating confident junction reads from neighboring cells. This process enhances the signal for alternative splicing analysis and helps to resolve potential noise by taking into account the junction read patterns of nearby cells.

### Get the Number of Reads per BAM File for Library Size Normalization

```bash
find ./02_exon_std -type f -name "*.bam" | while read file
do  
    echo "$file"
    echo "$file" >> get_single_bam_num.txt
    samtools flagstat $file >> get_single_bam_num.txt
done
```


In [None]:
## Process the number of reads per bam files
with open("./get_single_bam_num.txt") as f:
    lines = f.readlines()
samples_line = [x.replace("\n", '').split("/")[-1].split(".")[0] for x in lines[0::14]]
count_line = lines[1::14]
cnt = [int(x.split(" ")[0]) for x in count_line]
pd_cnt = pd.DataFrame({"sample":samples_line, "num_seqs":cnt})
pd_cnt.to_csv("single_bam_count.txt", index=False)

In [None]:
metadata = "your_metaData.csv"
pd_gt = pd.read_csv(metadata, sep="\t")
sample_list = list(pd_gt[pd_gt.columns[0]]) 
pd_aggr = pd.read_csv("./DOLPHIN_aggregation_KNN10.csv")
pd_single_size = pd.read_csv("./single_bam_count.txt")

In [None]:
## src_path: path to the original bam and sj files
src_path = "./02_exon_std/"
## dist_path: path to the final bam files
dist_path = "./DOLPHIN_aggregation/"


In [None]:
for target in sample_list:
    print(target)
    target_size = pd_single_size[pd_single_size["sample"] == target].iloc[0]["num_seqs"]
    _neighbor = list(pd_aggr[pd_aggr["main_name"] == target]["neighbor"])
    os.makedirs(os.path.join(dist_path, target))
    '''
    Majority voting: find the frequent junction reads
    '''
    for _i, _temp_n in enumerate(_neighbor):
        _df_junc = pd.read_csv(os.path.join(src_path, _temp_n, _temp_n+".std.SJ.out.tab"), sep="\t",usecols=[0, 1, 2, 7], names=["chr", "first_base", "last_base","multi_map"+_temp_n])
        if _i == 0:
            df_merge = _df_junc
        else:
            df_merge = pd.merge(df_merge, _df_junc, how="outer", left_on=["chr", "first_base", "last_base"], right_on=["chr", "first_base", "last_base"])
    ## count the occurence of the neighborhood junctions reads, only keep junction reads which is exist in half of the neighbor cells
    df_merge["nont_na"] = n_neighbor - df_merge.drop(columns=["chr", "first_base", "last_base"]).isna().sum(axis=1)
    df_keep_junct = df_merge[df_merge["nont_na"] >=5]
    ## save to bed file
    df_keep_junct[["chr", "first_base", "last_base"]].to_csv(os.path.join(dist_path, target, "keep_junction.bed"), sep="\t", index=False, header=False)
    '''
    Bam file batch size normalization
    '''
    for _n in _neighbor:
        _n_seq = pd_single_size[pd_single_size["sample"] == _n].iloc[0]["num_seqs"]
        shutil.copyfile(os.path.join(src_path, _n, _n+".std.Aligned.sortedByCoord.out.bam"), os.path.join(dist_path, target, _n+".bam"))
        if _n_seq == target_size:
            os.rename(os.path.join(dist_path, target, _n+".bam"), os.path.join(dist_path, target, _n+".norm.bam"))
        ##===== Upsampling:
        elif _n_seq < target_size: 
            ## random sample some of the sequcen and then add together with original one
            # concate itself n times, where n is the integer part of target_size/ _n_seq
            _cat_self_n = int(target_size/ _n_seq)
            if _cat_self_n == 1:
                _add_seq_perct = (target_size - _n_seq)/_n_seq
            else:
                _add_seq_perct = (target_size - _n_seq*_cat_self_n)/_n_seq
            ## sample the reset seq reads
            os.system(f"samtools view -b -s {_add_seq_perct} {os.path.join(dist_path, target, _n+'.bam')} > {os.path.join(dist_path, target, _n+'.sample.bam')}")
            ## concatenate all 
            combine_name = ""
            current_name = os.path.join(dist_path, target, _n+'.bam')
            for i in range(_cat_self_n):
                if i == 0:
                    combine_name = current_name
                else:
                    combine_name = combine_name + " " + current_name
            combine_name = combine_name + " " + os.path.join(dist_path, target, _n+'.sample.bam')
            result_name = os.path.join(dist_path, target, _n+".norm.bam")
            os.system(f"samtools merge {result_name} {combine_name}")
            os.remove(os.path.join(dist_path, target, _n+".sample.bam"))
            os.remove(os.path.join(dist_path, target, _n+".bam"))
        ##===== Downsampling:
        if _n_seq > target_size: 
            _keep_seq_perct = target_size/_n_seq
            os.system(f"samtools view -b -s {_keep_seq_perct} {os.path.join(dist_path, target, _n+'.bam')} > {os.path.join(dist_path, target, _n+'.norm.bam')}")
            os.remove(os.path.join(dist_path, target, _n+".bam"))
        '''
        Bam file split to junction readsfile
        '''
        if _n != target:
            os.system(f"samtools view -h {os.path.join(dist_path, target, _n+'.norm.bam')} | awk '$0 ~ /^@/ || $6 ~ /N/' | samtools view -b > {os.path.join(dist_path, target, _n+'.junction.norm.bam')}")
            os.system(f"samtools index {os.path.join(dist_path, target, _n+'.junction.norm.bam')}")
            '''
            Filter to only keep frequent junctions
            '''
            os.system(f"samtools view -h -L {os.path.join(dist_path, target, 'keep_junction.bed')} {os.path.join(dist_path, target, _n+'.junction.norm.bam')} > {os.path.join(dist_path, target, _n+'.mj.junction.norm.bam')}")
    '''
    Final concate all normalized bam files    
    '''     
    # Final concate all normalized fastq files
    os.system(f"samtools merge {os.path.join(dist_path, 'final_bam', target+'.final.bam')} {os.path.join(dist_path, target, '*.mj.junction.norm.bam')}  {os.path.join(dist_path, target, target+'.norm.bam')}")
    shutil.rmtree((os.path.join(dist_path, target)))