In [1]:
import pandas as pd
import glob
import numpy as np
import itertools
import functools
import os
import regex as re
import random

from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext
from pyspark.sql.types import IntegerType, LongType, ArrayType, StringType, DoubleType
from pyspark.sql.functions import udf, explode, broadcast, count, lit, length, col
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

In [2]:
# UPDATE HOME!
os.environ["SPARK_HOME"] = "/home/ec2-user/mambaforge/envs/2023_06_26_SRT_deconvolution_MS/lib/python3.7/site-packages/pyspark"
# THIS needs to be set-up before running the notebook
os.environ["SPARK_LOCAL_DIRS"] = "/temp"
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"

spark_conf = SparkConf()
spark_conf.set("spark.ui.showConsoleProgress", "True")
spark_conf.set("spark.executor.instances", "2")
spark_conf.set("spark.executor.cores", "2")
spark_conf.set("spark.executor.memory", "16g")
spark_conf.set("spark.driver.memory", "64g")
spark_conf.set("spark.driver.maxResultSize", "32g")
spark_conf.set("spark.parquet.filterPushdown", "true")
spark_conf.set("spark.local.dir", "/temp")
spark_conf.getAll()

sc = SparkContext(conf=spark_conf)
sc.setLogLevel("ERROR")
spark = SparkSession(sc)



In [3]:
REGIONS = 'deconvolution_v2.v23_conv.with_cpg_index'
REGION_BED_COLS = [
    'region_chr', 'region_start', 'region_end', 
    'region_cpg_index_min', 'region_cpg_index_max', 'region_id'
]
FILTER_CG_COUNT = 3
FILTER_CG_COUNT_REGION = 1

#--- Local paths
ROOT_DIR = '/analysis/gh-msun/projects'
PROJECT_SLUG = '2023_06_26_SRT_deconvolution_MS'
PROJECT_DIR = ROOT_DIR + '/{}'.format(PROJECT_SLUG)

# Regions
REGION_PATH = (
    PROJECT_DIR + '/stage/panel_data/{regions}.bed'
).format(regions=REGIONS)

# CpG map; genomic coordinate to CpG index;
CPG_MAP_PATH = PROJECT_DIR + '/stage/cpg_loci/cpg_loci_hg19.combined_annot.tsv.gz'


# BLUEPRINT HG38: s3://gh-bi-lunar/public_data/blueprint/hg38_20160816.pat.db_version.parquet/
PARQUET_PATH_LIST_HG38 = [
    '/analysis/hg38_20160816.pat.db_version.parquet'
]

#--- Where to store results
RESULTS_PATH = (
    PROJECT_DIR + '/output/meth_summaries/blueprint_meth_summaries_cg_count_geq_{k}_{regions}.tsv.gz'
).format(regions=REGIONS, k=FILTER_CG_COUNT)

### CpG Map

In [4]:
cpg_map = pd.read_csv(CPG_MAP_PATH, usecols=['chr', 'start', 'end', 'cpg_index', 'cpg_index_hg38'], sep='\t')

In [5]:
%%time
ridxs = ~cpg_map['cpg_index_hg38'].isna()
hg19_hg38_map = dict(itertools.zip_longest(cpg_map[ridxs]['cpg_index'], cpg_map[ridxs]['cpg_index_hg38'].astype(int)))
hg38_hg19_map = dict(itertools.zip_longest(cpg_map[ridxs]['cpg_index_hg38'].astype(int), cpg_map[ridxs]['cpg_index']))


CPU times: user 15.5 s, sys: 4.95 s, total: 20.5 s
Wall time: 20.4 s


### Regions

In [6]:
region_df = pd.read_csv(REGION_PATH, sep='\t', usecols=range(0, 6), names=REGION_BED_COLS)

region_df['region_cpg_index_max'] -= 1
region_df.sort_values('region_cpg_index_min', inplace=True)
region_df['region_cpg_index_min_hg38'] = region_df['region_cpg_index_min'].map(hg19_hg38_map)
region_df['region_cpg_index_max_hg38'] = region_df['region_cpg_index_max'].map(hg19_hg38_map)

region_df.shape[0], region_df['region_id'].nunique()

(1658, 1658)

In [7]:
ridxs = ~region_df['region_cpg_index_min_hg38'].isna()
ridxs &= ~region_df['region_cpg_index_max_hg38'].isna()
region_df = region_df[ridxs].copy()
region_df.shape[0], region_df['region_id'].nunique()

(1658, 1658)

In [8]:
cg_count_hg19 = region_df['region_cpg_index_max']-region_df['region_cpg_index_min'] + 1
cg_count_hg38 = region_df['region_cpg_index_max_hg38']-region_df['region_cpg_index_min_hg38'] + 1
ridxs = (cg_count_hg19==cg_count_hg38)
ridxs &= (cg_count_hg19>=FILTER_CG_COUNT_REGION)
region_df = region_df[ridxs].copy()
region_df.shape[0], region_df['region_id'].nunique()

(1657, 1657)

In [9]:
region_df['region_cpg_index_min_hg38'] = region_df['region_cpg_index_min_hg38'].astype(int)
region_df['region_cpg_index_max_hg38'] = region_df['region_cpg_index_max_hg38'].astype(int)

In [10]:
### >>> restrict to immune regions!!!

In [11]:
#-------------- CHANGE HERE FOR DIFFERENT REGION SUBSET ----------------------
# BLUEPRINT immune regions
ATLAS_PATH = PROJECT_DIR + f'/output/deconv_inhouse_v2.atlas.tsv.gz'
atlas = pd.read_csv(ATLAS_PATH, sep='\t')
subset_region_set = set(atlas.region_id)
#-----------------------------------------------------------------------------

# filter regions down to regions of interest
region_df = region_df[region_df['region_id'].isin(subset_region_set)]
region_df.head()


Unnamed: 0,region_chr,region_start,region_end,region_cpg_index_min,region_cpg_index_max,region_id,region_cpg_index_min_hg38,region_cpg_index_max_hg38
0,chr1,1114771,1114971,20117,20129,Immune_Broad_B-chr1:1114772-1114971,21119,21131
1,chr1,1157450,1157720,21684,21703,Immune_Broad_NK-chr1:1157451-1157720,22686,22705
2,chr1,1157879,1158277,21710,21726,Immune_Broad_NK-chr1:1157880-1158277,22712,22728
14,chr1,6341182,6341377,140667,140681,Immune_Broad_Eosi-chr1:6341183-6341377,142368,142382
19,chr1,9147788,9147871,188605,188608,Immune_Broad_Neutro-chr1:9147789-9147871,190307,190310


In [12]:
region_df.shape

(280, 8)

### PAT PARQUET Files

In [13]:
PARQUET_PATH_LIST_HG38 = [
    '/analysis/gh-msun/projects/2023_06_26_SRT_deconvolution_MS/output/mixture/mix_50B_50CD4_00CD8_00NK_00Mono_00Neutro_seed_888/mix0_seed_83723.parquet'
]

In [14]:
# PAT_COLS = [
#     'sample_id', 'molecule_id', 'chr', 'number_molecules',
#     'cpg_index_min', 'cpg_index_max', 'pat_string'
# ]

PAT_COLS = [
    'molecule_id', 'chr', 'number_molecules',
    'cpg_index_min', 'cpg_index_max', 'pat_string'
]

In [15]:
pat_parquet_files = [spark.read.parquet(ifile).select(*PAT_COLS) for ifile in PARQUET_PATH_LIST_HG38]
pat_hg38_ddf = functools.reduce(DataFrame.unionByName, pat_parquet_files)
pat_hg38_ddf.printSchema()

root
 |-- molecule_id: string (nullable = true)
 |-- chr: string (nullable = true)
 |-- number_molecules: long (nullable = true)
 |-- cpg_index_min: long (nullable = true)
 |-- cpg_index_max: long (nullable = true)
 |-- pat_string: string (nullable = true)



In [16]:
pat_hg38_ddf.show(6)

+-----------+-----+----------------+-------------+-------------+--------------------+
|molecule_id|  chr|number_molecules|cpg_index_min|cpg_index_max|          pat_string|
+-----------+-----+----------------+-------------+-------------+--------------------+
|  389696038|chr16|               1|     21772138|     21772144|             TTCTTTT|
|  410430588|chr16|               1|     21772135|     21772141|             TTCCTCC|
|  410430713|chr16|               1|     21772142|     21772148|             CCC.CCC|
|  330413047|chr16|               1|     21773403|     21773421| CCCTCCTCCC.CCCCCCCC|
|  227772570|chr16|               1|     21773397|     21773418|CCCCCCCCCC..CCCCC...|
|  246240980|chr16|               1|     21773396|     21773407|        CCCCCCCCCCCC|
+-----------+-----+----------------+-------------+-------------+--------------------+
only showing top 6 rows



## Fragment Level Scoring

In [17]:
QUANTILES = [0.1, 0.25, 0.75, 0.9]
KMERS = [1, 3, 4]
RATES_LEQ = [0.25]
RATES_GEQ = [0.75]

RETURN_SCHEMA = StructType()\
    .add('region_id', 'string')\
    .add('number_molecules', 'integer')\
    .add('meth_k1', 'integer')\
    .add('unmeth_k1', 'integer')\
    .add('total_k1', 'integer')\
    .add('meth_k3', 'integer')\
    .add('unmeth_k3', 'integer')\
    .add('total_k3', 'integer')\
    .add('meth_k4', 'integer')\
    .add('unmeth_k4', 'integer')\
    .add('total_k4', 'integer')\
    .add('frac_alpha_leq_25pct', 'float')\
    .add('frac_alpha_geq_75pct', 'float')
#     .add('sample_id', 'string')\

def compute_frag_scores(cpg_number_cutoff: int) -> pd.DataFrame:
    
    """
    Function that returns a function, used for reduce
    """
    
    def compute_frag_scores_inner(pat_df: pd.DataFrame) -> pd.DataFrame:
        
        data = pat_df.copy()
        data['offset_min'] = (data['region_cpg_index_min'] - data['cpg_index_min']).clip(lower=0)
        data['offset_max'] = np.minimum(
            data['region_cpg_index_max'] - data['cpg_index_min'], 
            data['cpg_index_max'] - data['cpg_index_min'])
        data['trimmed_pat'] = data.apply(lambda x: x['pat_string'][x['offset_min']:(x['offset_max']+1)], axis=1)
        #--- Filter molecules based on observed CpG loci
        observed_cpg_number = (data['trimmed_pat'].str.count('C')+data['trimmed_pat'].str.count('T'))
        ridxs = (observed_cpg_number>=cpg_number_cutoff)
        data = data[ridxs].copy()
        if (data.shape[0]>0):
            # Compute k-mer methylation states
            for k in KMERS:
                data['meth_k%i'%k] = data['trimmed_pat']\
                    .apply(lambda x: len(re.findall('[C]{%i}'%k, x, overlapped=True)))
                data['unmeth_k%i'%k] = data['trimmed_pat']\
                    .apply(lambda x: len(re.findall('[T]{%i}'%k, x, overlapped=True)))
                data['total_k%i'%k] = data['trimmed_pat']\
                    .apply(lambda x: len(re.findall('[TC]{%i}'%k, x, overlapped=True)))
            # Compute alpha distribution metrics
            data['alpha'] = data['meth_k1']/data['total_k1']
            for rate in RATES_LEQ:
                data['frac_alpha_leq_%ipct'%(100*rate)] = np.where(data['alpha']<=rate, 1, 0)
            for rate in RATES_GEQ:
                data['frac_alpha_geq_%ipct'%(100*rate)] = np.where(data['alpha']>=rate, 1, 0)
            # Expand entries that correspond to multiple molecules
            data['number_molecules'] = data['number_molecules'].apply(lambda x: list(range(x)))
            data = data.explode('number_molecules')
            data['number_molecules'] = 1
            # Aggregate metrics
            #rv = data.groupby(['region_id', 'sample_id'])\
            rv = data.groupby(['region_id'])\
                [['meth_k1', 'unmeth_k1', 'total_k1',
                  'meth_k3', 'unmeth_k3', 'total_k3',
                  'meth_k4', 'unmeth_k4', 'total_k4',
                  'frac_alpha_leq_25pct', 'frac_alpha_geq_75pct', 'number_molecules']].sum()\
                .reset_index()
            rv['frac_alpha_leq_25pct'] = rv['frac_alpha_leq_25pct']/rv['number_molecules']
            rv['frac_alpha_geq_75pct'] = rv['frac_alpha_geq_75pct']/rv['number_molecules']
        else:
            rv = pd.DataFrame(columns=RETURN_SCHEMA.names)
                      
        
        return rv[RETURN_SCHEMA.names]

    return compute_frag_scores_inner


compute_frag_scores_udf = compute_frag_scores(cpg_number_cutoff=FILTER_CG_COUNT)


### Compute for HG38 Data

In [18]:
%%time
BATCH_SIZE = 20
region_df['batch'] = (np.arange(region_df.shape[0])/BATCH_SIZE).astype(int)
rv_scores = list()
for batch, batch_region_df in region_df.groupby('batch'):
    rv_ov = list()
    print('---> Processing batch %i...' % batch)
    for _, row in batch_region_df.iterrows():
        ov_ddf = pat_hg38_ddf.filter(col('cpg_index_min')<=row['region_cpg_index_max_hg38'])\
            .filter(col('cpg_index_max') >= row['region_cpg_index_min_hg38'])\
            .withColumn('region_id', lit(row['region_id']))\
            .withColumn('region_cpg_index_min', lit(row['region_cpg_index_min_hg38']))\
            .withColumn('region_cpg_index_max', lit(row['region_cpg_index_max_hg38']))
        rv_ov.append(ov_ddf)
    scores_df = functools.reduce(DataFrame.union, rv_ov)\
        .groupby('region_id')\
        .applyInPandas(compute_frag_scores_udf, schema=RETURN_SCHEMA)\
        .toPandas()
    rv_scores.append(scores_df)

---> Processing batch 0...
---> Processing batch 1...
---> Processing batch 2...
---> Processing batch 3...
---> Processing batch 4...
---> Processing batch 5...
---> Processing batch 6...
---> Processing batch 7...
---> Processing batch 8...
---> Processing batch 9...
---> Processing batch 10...
---> Processing batch 11...
---> Processing batch 12...
---> Processing batch 13...
CPU times: user 757 ms, sys: 135 ms, total: 891 ms
Wall time: 29.2 s


In [19]:
scores_df = pd.concat(rv_scores)


In [21]:
scores_df.shape

(275, 13)

In [22]:
scores_df.head()

Unnamed: 0,region_id,number_molecules,meth_k1,unmeth_k1,total_k1,meth_k3,unmeth_k3,total_k3,meth_k4,unmeth_k4,total_k4,frac_alpha_leq_25pct,frac_alpha_geq_75pct
0,Immune_Broad_B-chr1:1114772-1114971,14,65,44,109,46,25,78,38,20,64,0.357143,0.428571
1,Immune_Broad_B-chr2:3258156-3258479,14,31,44,75,12,23,43,6,16,28,0.5,0.428571
2,Immune_Broad_B-chr2:9908120-9908241,6,12,19,31,7,11,19,5,7,13,0.666667,0.333333
3,Immune_Broad_CD4_plus_CD8-chr2:113753082-11375...,7,26,18,44,17,11,30,13,8,23,0.428571,0.571429
4,Immune_Broad_Dend_plus_Macro_plus_Mono-chr1:11...,1,3,0,3,1,0,1,0,0,0,0.0,1.0


In [23]:
%%time
RESULTS_PATH='/analysis/gh-msun/projects/2023_06_26_SRT_deconvolution_MS/output/score_matrix/test_mixture.tsv.gz'

scores_df.to_csv(RESULTS_PATH,
                 sep='\t', 
                 index=False)

CPU times: user 5.25 ms, sys: 0 ns, total: 5.25 ms
Wall time: 4.65 ms


In [None]:
# PCA sanity check