In [None]:
import os
import glob
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA


def parse_metadata(file_path: str, paper: str) -> dict:
    df = pd.read_csv(file_path)
    df = df[df.publication == paper]
    return dict(zip(df.sample_file_id, df.sample_disease))


def load_vectors(stat_name: str, metadata_map: dict, data_dir: str, dhs_files):
    pattern, key = STATS[stat_name]

    entries = []
    for sid, disease in metadata_map.items():
        binary_label = 'Healthy' if disease == 'Healthy' else 'Cancerous'
        for dhs in dhs_files:
            fname = pattern.format(sid=sid, dhs=dhs)
            path = os.path.join(data_dir, fname)
            
            # skip if preprocessing marked this pair as low coverage
            matrix_base = os.path.join(data_dir, f'{sid}__{dhs}_sorted.npy')
            if os.path.exists(matrix_base + '.skip'):
                continue
                
            try:
                if path.endswith('.npy'):
                    vec = np.load(path)
                elif path.endswith('.npz'):
                    data = np.load(path)
                    vec = data[key]
                else:
                    continue
            except FileNotFoundError:
                continue

            flat = vec.flatten()

            entries.append({
                "sample": sid,
                "dhs": dhs,
                "vector": flat,
                "disease": disease,
                "binary": binary_label,
            })

    if not entries:
        return None

#     all_vectors = StandardScaler().fit_transform(np.vstack([e['vector'] for e in entries]))
    all_vectors = np.vstack([e['vector'] for e in entries])
    
    loadings_df = None
    if stat != 'pfe':
        pca = PCA(n_components=2)
        pc_values = pca.fit_transform(all_vectors)
        expl_var = pca.explained_variance_ratio_
        loadings = pca.components_
        loadings_df = pd.DataFrame(
            loadings.T,
            columns=['PC1', 'PC2']
        )
        loadings_df.attrs['expl_var'] = expl_var
        pc1_var = expl_var[0]
        pc2_var = expl_var[1]

        for entry, (pc1, pc2) in zip(entries, pc_values):
            entry['pc1'] = pc1
            entry['pc2'] = pc2
            entry['pc1_var'] = pc1_var
            entry['pc2_var'] = pc2_var
    else:
        for entry in entries:
            val = entry['vector'][0]
            entry['pc1'] = val
            entry['pc2'] = val

    df = pd.DataFrame(entries)
    return df.pivot(index=['sample', 'binary', 'disease'], columns='dhs', values=['pc1', 'pc2']), loadings_df


STATS = {
    'ocf': ('{sid}__{dhs}_sorted_ocf.npy', None),
    'lwps': ('{sid}__{dhs}_sorted_lwps.npy', None),
    'ifs': ('{sid}__{dhs}_sorted_ifs.npz', 'ifs_scores'),
    'pfe': ('{sid}__{dhs}_sorted_pfe.npz', 'pfe_scores'),
    'fdi': ('{sid}__{dhs}_sorted_fdi.npz', 'overlapping_fdi_scores'),
}

DHS_FILES = [f.split('/')[-1].replace('.bed', '') for f in glob.glob(f'../../raw_data/dhs/*.bed', recursive=True)]

metadata_filepath = '../../raw_data/cristiano_cfdnas/meta_data.csv'
paper = 'Genome-wide cell-free DNA fragmentation in patients with cancer'

metadata_map = parse_metadata(metadata_filepath, paper)
result_dir = '../../data/cristiano_cfdnas_dhs/'
feature_dir = '../../data/cristiano_cfdnas_dhs_RESULT/'

for stat in STATS:
    print(f'Processing: {stat}')
    df, loadings_df = load_vectors(stat, metadata_map, result_dir, DHS_FILES)
    if df is not None:
        out_path = os.path.join(feature_dir, f'feature_matrix_{stat}.parquet')
        df.to_parquet(out_path)
        print(f"Saved: {out_path}")
    
    if loadings_df is not None:
        loadings_out_path = os.path.join(feature_dir, f'{stat}_pca_loadings.csv')
        loadings_df.to_csv(loadings_out_path)
        print(f"Saved: {loadings_out_path}")
        

In [None]:
def load_feature_df(stat: str, feature_dir: str) -> pd.DataFrame:
    parquet_path = os.path.join(feature_dir, f'feature_matrix_{stat}.parquet')
    if not os.path.exists(parquet_path):
        print(f'Missing parquet: {parquet_path}')
        return None

    return pd.read_parquet(parquet_path)


def plot_stat(df: pd.DataFrame, stat_name: str, out_dir: str):
    for dhs in df['pc1'].columns:
        if stat_name == 'pfe':
            data = df['pc1'][dhs].reset_index()
            plt.figure(figsize=(10, 6))
            sns.violinplot(data=data, x='disease', y=dhs, inner='quart')
            plt.title(f'PFE Distribution by Disease (stat={stat}, DHS={dhs})')
            plt.xlabel('Disease')
            plt.ylabel('PFE value')
            plt.xticks(rotation=30, ha='right')
            plt.tight_layout()
            out_path = os.path.join(out_dir, f'{stat_name}_{dhs}_violin_disease.png')
            plt.savefig(out_path, dpi=200)
            plt.close()
            continue
        
        modified_df = df.copy()
        if isinstance(df.columns, pd.MultiIndex):
            modified_df.columns = [f"{pc}_{dhs}" for pc, dhs in modified_df.columns]
        
        # pc1 vs pc2
        plt.figure(figsize=(10, 6))
        sns.scatterplot(modified_df, x=f'pc1_{dhs}', y=f'pc2_{dhs}', hue='disease', alpha=0.55)
        plt.xlabel('PC1')
        plt.ylabel('PC2')
        plt.title(f'PC1 vs PC2 (stat={stat}, DHS={dhs})')
        plt.legend(loc='best', fontsize='small')
        plt.grid(True, linestyle='--', alpha=0.4)
        plt.xticks(rotation=90, ha='right')
        plt.tight_layout()
        out_path_scatter = os.path.join(out_dir, f'{stat_name}_{dhs}_pca_by_disease.png')
        plt.savefig(out_path_scatter, dpi=200)
        plt.close()
        

def plot_loadings(stat_name: str, out_dir: str):
    if stat_name == 'pfe':
        return
    loadings_path = os.path.join(out_dir, f'{stat_name}_pca_loadings.csv')
    loadings_df = pd.read_csv(loadings_path)
              
    # pc1 weights
    plt.figure(figsize=(10, 6))
    plt.plot(loadings_df['PC1'])
    num_points = loadings_df.shape[0]
    plt.axvline(x=num_points//2, color='red', linestyle='--', linewidth=2)
    plt.xlabel('Feature Index')
    plt.ylabel('Weight in PC1')
    plt.title(f'PC1 Feature Weights (stat={stat_name})')
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.tight_layout()
    out_path_pc1 = os.path.join(out_dir, f'{stat_name}_pc1_loadings.png')
    plt.savefig(out_path_pc1, dpi=200)
    plt.close()
        
feature_dir = '../../data/cristiano_cfdnas_dhs_RESULT/'
        
for stat in STATS:
    df = load_feature_df(stat, feature_dir)
    if df is None or df.empty:
        print(f"{stat}: No parquet data found.")
        continue
    print(f"Processing: {stat}")
    plot_stat(df, stat, feature_dir)
    plot_loadings(stat, feature_dir)

In [None]:
df = load_feature_df('fdi', feature_dir)['pc1'].reset_index()
print(f'Shape before: {df.shape}')
df = df.drop(columns=['Stromal_A'])  # there are 100+ samples
n_unique_nan_samples = len(df.loc[df.isna().any(axis=1), 'sample'].unique())
print(f'There are {n_unique_samples} samples with at least 1 NaN')
df_clean = df.dropna(axis=0, how="any")
print(f'Shape after: {df_clean.shape}')

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict


def load_stat(stat: str, feature_dir: str):
    parquet_path = os.path.join(feature_dir, f'feature_matrix_{stat}.parquet')
    if not os.path.exists(parquet_path):
        print(f'Missing parquet: {parquet_path}')
        return None
    df = pd.read_parquet(parquet_path)['pc1'].reset_index()
    df = df.drop(columns=['Stromal_A'])
    df = df.dropna(axis=0, how="any")
    return df


def evaluate_stat_joint(stat: str, feature_dir: str, cv_splits: int = 10):
    df = load_stat(stat, feature_dir)
    if df is None or df.empty:
        return None

    # target labels
    labels = df['binary']
    le = LabelEncoder()
    y = le.fit_transform(labels)

    # feature matrix: all DHS columns
    meta_cols = ['sample', 'binary', 'disease']
    feature_cols = [c for c in df.columns if c not in meta_cols]
    X = df[feature_cols].to_numpy()

    clf = SVC(probability=True, random_state=42)
    cv = StratifiedKFold(n_splits=cv_splits, shuffle=True, random_state=42)
    scores = cross_val_score(clf, X, y, cv=cv, scoring='roc_auc', n_jobs=-1)
    
    y_pred = cross_val_predict(clf, X, y, cv=cv)
    auc = roc_auc_score(y, y_pred)

    return {
        'stat': stat,
        'n_features': X.shape[1],
        'auc_mean': auc,
        'y_true': y,
        'y_pred': y_pred,
    }

feature_dir = '../../data/cristiano_cfdnas_dhs_pca/'

STATS = {
    'ocf': ('{sid}__{dhs}_sorted_ocf.npy', None),
    'lwps': ('{sid}__{dhs}_sorted_lwps.npy', None),
    'ifs': ('{sid}__{dhs}_sorted_ifs.npz', 'ifs_scores'),
    'pfe': ('{sid}__{dhs}_sorted_pfe.npz', 'pfe_scores'),
    'fdi': ('{sid}__{dhs}_sorted_fdi.npz', 'overlapping_fdi_scores'),
}
        
results = []
for stat in STATS:
    result = evaluate_stat_joint(stat, feature_dir)
    if result is not None:
        results.append(result)

if results:
    stats_order = [r['stat'] for r in results]
    auc_means = [r['auc_mean'] for r in results]

    plt.figure(figsize=(8, 6))
    bars = plt.bar(stats_order, auc_means, capsize=5)
    plt.ylabel('Mean ROC AUC (all DHS sites)')
    plt.xlabel('Test statistic')
    for bar, val in zip(bars, auc_means):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{val:.2f}",
            ha='center',
            va='bottom',
            fontsize=9,
        )
    plt.title('Binary classification using joint DHS PC1 features + SVM')
    plt.tight_layout()
    plt.show()

In [None]:
plt.figure(figsize=(7, 6))

for r in results:
    fpr, tpr, _ = roc_curve(r["y_true"], r["y_pred_proba"])
    plt.plot(fpr, tpr, lw=1.8, label=f"{r['stat']} (AUC={r['auc_overall']:.3f})")

plt.plot([0, 1], [0, 1], "k--", lw=1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC curves for all statistics (cross-validated predictions)")
plt.legend()
plt.tight_layout()
plt.show()