# Check MWAS overlap for GWAS risk loci

## Explore

In [None]:
import os
import re
import pandas as pd
import numpy as np
from functools import reduce
import matplotlib.pyplot as plt
import seaborn as sns
from time import time
from concurrent.futures import ThreadPoolExecutor, as_completed

# List of summary statistics and stage 2 files
stats_path = [os.path.join("/expanse/lustre/projects/jhu152/naglemi/mwas/gwas", f) for f in os.listdir("/expanse/lustre/projects/jhu152/naglemi/mwas/gwas") if "stat" in f]
stage2_path = [os.path.join("/expanse/lustre/projects/jhu152/naglemi/mwas/CpGWAS/scripts", f) for f in os.listdir("/expanse/lustre/projects/jhu152/naglemi/mwas/CpGWAS/scripts") if "16a8" in f and "test" not in f]

def clean_and_standardize_colnames(summary_stats):
    start_time = time()
    summary_stats.columns = summary_stats.columns.str.replace(r'chr|#CHROM', 'CHR', regex=True)
    summary_stats.columns = summary_stats.columns.str.replace(r'pos|POS', 'BP', regex=True)
    summary_stats.columns = summary_stats.columns.str.replace(r'MarkerName|ID', 'SNP', regex=True)
    summary_stats.columns = summary_stats.columns.str.replace('LogOR', 'logOR')
    if 'logOR' not in summary_stats.columns and 'OR' in summary_stats.columns:
        summary_stats['logOR'] = np.log(summary_stats['OR'])
    summary_stats.columns = summary_stats.columns.str.replace('logOR', 'BETA')
    summary_stats.set_index('SNP', inplace=True)
    #print(f"clean_and_standardize_colnames executed in {time() - start_time:.2f} seconds.")
    return summary_stats

def load_and_sample(file_path, n=1000):
    start_time = time()
    total_rows = sum(1 for _ in open(file_path)) - 1
    skip_rows = sorted(np.random.choice(np.arange(1, total_rows+1), total_rows-n, replace=False))
    result = pd.read_csv(file_path, skiprows=skip_rows, sep='\s+', header=0)
    print(f"load_and_sample for {file_path} executed in {time() - start_time:.2f} seconds.")
    return result

def process_stage2_file(file_path, n=1000):
    start_time = time()
    total_rows = sum(1 for _ in open(file_path)) - 1
    skip_rows = sorted(np.random.choice(np.arange(1, total_rows+1), total_rows-n, replace=False))
    data = pd.read_csv(file_path, skiprows=skip_rows, nrows=n, header=0)
    data.columns = ['z', 'p', 'n', 'CHR', 'BP', 'population', 'region', 'stats', 'scaff']
    print(f"process_stage2_file for {file_path} executed in {time() - start_time:.2f} seconds.")
    return data

# Measure time for loading and cleaning summary stats data
start_time = time()
summary_stats_data = [clean_and_standardize_colnames(load_and_sample(path)) for path in stats_path]
print(f"Summary stats data loaded and cleaned in {time() - start_time:.2f} seconds.")

# Measure time for processing stage 2 files
start_time = time()
with ThreadPoolExecutor(max_workers=7) as executor:
    futures = {executor.submit(process_stage2_file, path): path for path in stage2_path}
    all_data = []
    for future in as_completed(futures):
        all_data.append(future.result())
        print(f"Stage 2 data processed for {futures[future]}")

print(f"All stage 2 data processed in {time() - start_time:.2f} seconds.")

In [28]:
len(all_data)

4

In [36]:
print(all_data[0]['stats'][[1]])

1    /expanse/lustre/projects/jhu152/naglemi/mwas/g...
Name: stats, dtype: object


In [44]:
import pandas as pd

# Set display option to ensure no truncation occurs for any string
pd.set_option('display.max_colwidth', None)

# Print the desired data
print(all_data[1]['stats'][[1]])


KeyError: 'stats'

In [40]:
stats_path

['/expanse/lustre/projects/jhu152/naglemi/mwas/gwas/gwas_stat_bp',
 '/expanse/lustre/projects/jhu152/naglemi/mwas/gwas/gwas_stat_mdd',
 '/expanse/lustre/projects/jhu152/naglemi/mwas/gwas/gwas_stat_scz']

In [38]:
all_data[1]

Unnamed: 0_level_0,CHR,BP,A1,A2,BETA,SE,PVAL,NGT,FCAS,FCON,IMPINFO,NEFFDIV2,NCAS,NCON,DIRE,CHR_BP
SNP,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
rs1219603,10,36254386,G,A,-0.001902,0.0128,0.885200,2,0.155,0.166,0.985,50981.48,41917,371549,---+--++-+++--+-++-++----+---+0-++--+-+---++-+-+-+++++-++,10_36254386
rs7923390,10,33161014,C,T,0.009396,0.0122,0.441100,2,0.819,0.836,0.983,50981.48,41917,371549,-+----+---++---+--+----+--++---+--++--+-++++-+++--++++++-,10_33161014
rs1742229,10,37748951,C,T,-0.011000,0.0169,0.516000,0,0.915,0.920,0.999,50186.29,41510,354340,++-++---+-+--+-+++---+++-+++-+--+---+--+++--+++-?++----++,10_37748951
rs11599167,10,34862783,A,G,-0.010505,0.0115,0.360500,1,0.745,0.749,0.947,47077.45,39945,178947,---+-+-+-+--+--+--+-+----+--------++-+++--++-++?++++--+-+,10_34862783
rs4251744,17,35822498,A,G,-0.010000,0.0214,0.639500,1,0.949,0.941,0.968,50619.46,41486,371081,-++--++-+-++---++-----+-++-+++----++-++-+-+++--++++-??+?+,17_35822498
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
rs17122346,14,22777601,T,C,-0.016404,0.0215,0.445300,26,0.939,0.942,0.852,50035.72,40997,370576,--+--++-+--+--------?+---++---+++-+++--++---+++++-+--++-+,14_22777601
rs2204931,14,21675620,G,T,-0.013602,0.0136,0.318300,1,0.862,0.884,0.993,50981.48,41917,371549,++--+--++++-+++--++-++-+---++--++-+-+---+++-++++-++--+--+,14_21675620
rs2884,14,28718198,C,T,-0.003496,0.0094,0.712600,1,0.537,0.529,0.989,50981.48,41917,371549,-+-+-++++-+--+-++-++++-+-++-+--+-+-+--+---+-++--+-++-+-+-,14_28718198
rs12435549,14,26057801,C,A,-0.021101,0.0122,0.084820,7,0.815,0.836,0.974,50981.48,41917,371549,++--+--++-+--++++--+--+++-+-+++---++-+-++++-+-++++--+-+--,14_26057801


In [None]:
# Measure time for combining data
start_time = time()
combined_data = pd.concat(all_data, ignore_index=True)
print(f"Combined data created in {time() - start_time:.2f} seconds.")

# Measure time for plotting
start_time = time()
sns.scatterplot(x='z.x', y='z.y', data=merged_data)
plt.title("z vs z Comparison")
plt.xlabel("z from GWAS")
plt.ylabel("z from Stage 2")
plt.show()
print(f"z vs z plot created in {time() - start_time:.2f} seconds.")

start_time = time()
sns.scatterplot(x='log10(p.x)', y='log10(p.y)', data=merged_data)
plt.title("Log(p) vs Log(p) Comparison")
plt.xlabel("Log(p) from GWAS")
plt.ylabel("Log(p) from Stage 2")
plt.show()
print(f"Log(p) vs Log(p) plot created in {time() - start_time:.2f} seconds.")