In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler, RobustScaler, normalize
import matplotlib.pyplot as plt

In [None]:
METADATA_PATH = '../../raw_data/cristiano_cfdnas/meta_data.csv'
CRISTIANO_PAPER = 'Genome-wide cell-free DNA fragmentation in patients with cancer'

def parse_metadata(file_path, paper):
    metadata_df = pd.read_csv(file_path)
    metadata_df = metadata_df[metadata_df.publication == paper]
    return dict(zip(metadata_df.sample_file_id, metadata_df.sample_disease))

metadata = parse_metadata(METADATA_PATH, CRISTIANO_PAPER)

In [None]:
DATA_DIR = "../../data/test/"
PCA_PLOT_DIR = "../../data/pca_plots/"
DHS_FOLDER = '../../raw_data/dhs'

os.makedirs(PCA_PLOT_DIR, exist_ok=True)

DHS_FILES = [f.split('.')[0] for f in os.listdir(DHS_FOLDER)]


# hardcoded stats
STATS = {
    "ocf":   ("{sid}__{dhs}_sorted_ocf.npy", None),
    "lwps":  ("{sid}__{dhs}_sorted_lwps.npy", None),
    "ifs":   ("{sid}__{dhs}_sorted_ifs.npz", "ifs_scores"),
    "pfe":   ("{sid}__{dhs}_sorted_pfe.npz", "pfe_scores"),
    "fdi":   ("{sid}__{dhs}_sorted_fdi.npz", "overlapping_fdi_scores"),
}


def load_vectors(stat_name, metadata_cache):
    vectors, labels, samples = [], [], []
    
    for sid, group_name in metadata_cache.items():
        for dhs_name in DHS_FILES:
            pattern, key = STATS[stat_name]
            fname = pattern.format(sid=sid, dhs=dhs_name)
            path = os.path.join(DATA_DIR, fname)

            try:
                # load npy or npz
                if path.endswith(".npy"):
                    vec = np.load(path)
                elif path.endswith(".npz"):
                    data = np.load(path)
                    vec = data[key]
                else:
                    continue
            except FileNotFoundError:
                continue

            vectors.append(vec.flatten())
            labels.append(group_name)
            samples.append(sid)
    
    if not vectors:
        return None, None, None
    matrix = np.vstack(vectors)
    
    # this fixes outliers with high coverage
#     matrix = normalize(matrix, norm='l1', axis=1)

#     scaler = StandardScaler()
#     scaler.fit(matrix)
#     matrix = scaler.transform(matrix)
    return matrix, labels, samples
    
    
# matrix, labels, samples = load_vectors('ocf', metadata)
# plt.figure(figsize=(8, 6))
# pca = PCA(n_components=50, svd_solver="randomized", random_state=42)
# X_pca = pca.fit_transform(matrix)
# expl_var = pca.explained_variance_ratio_
# for group in set(labels):
#     mask = np.array(labels) == group
#     plt.scatter(
#         X_pca[mask, 0],
#         X_pca[mask, 1],
#         label=group,
#     )

# plt.xlabel(f"PC1 ({expl_var[0]*100:.1f}% var)")
# plt.ylabel(f"PC2 ({expl_var[1]*100:.1f}% var)")
# plt.title(f"PCA of Cancer_epithelial (OCF)")
# plt.legend()
# plt.grid(True, linestyle="--", alpha=0.4)
# plt.tight_layout()
# out_path = os.path.join(PCA_PLOT_DIR, f"Cancer_epithelial_OCF_pca.png")
# plt.savefig(out_path, dpi=200)
# plt.close()

# tsne = TSNE(n_components=2, verbose=1, perplexity=3, n_iter=300, random_state=42)
# X_tsne = tsne.fit_transform(matrix)
# plt.figure(figsize=(8, 6))
# for group in set(labels):
#     mask = np.array(labels) == group
#     plt.scatter(
#         X_tsne[mask, 0],
#         X_tsne[mask, 1],
#         label=group,
#     )

# plt.xlabel("t-SNE dimension 1")
# plt.ylabel("t-SNE dimension 2")
# plt.title(f"T-SNE of OCF")
# plt.legend()
# plt.grid(True, linestyle="--", alpha=0.4)
# plt.tight_layout()
# out_path = os.path.join(PCA_PLOT_DIR, f"Cancer_epithelial_OCF_tsne.png")
# plt.savefig(out_path, dpi=200)
# plt.close()

for stat in STATS.keys():
    print(f"\nProcessing: {stat}")
    matrix, labels, samples = load_vectors(stat, metadata)
    if matrix is not None:
        plot_pca(matrix, labels, samples, stat)
    else:
        print(f"Skipping {dhs} | {stat} — no data found.")

In [None]:
def plot_pca(matrix, labels, samples, stat_name):
    labels_set = set(labels)
    if stat_name == 'pfe':
        return
#         plt.figure(figsize=(8, 6))
#         data_per_group = [matrix.flatten()[np.array(labels) == g] for g in labels_set]
#         plt.violinplot(data_per_group, showmedians=True)
#         plt.xticks(range(1, len(labels)+1), labels)
#         plt.ylabel("PFE value")
#         plt.title(f"PFE Values per Sample Type ({dhs_name})")
#         out_path = os.path.join(PCA_PLOT_DIR, f"{dhs_name}_{stat_name}_violin.png")
#         plt.grid(True, linestyle="--", alpha=0.4)
#         plt.tight_layout()
#         plt.savefig(out_path, dpi=200)
#         plt.close()
    else:
        plt.figure(figsize=(8, 6))
        pca = PCA(n_components=10)
        X_pca = pca.fit_transform(matrix)
        expl_var = pca.explained_variance_ratio_

        for group in labels_set:
            mask = np.array(labels) == group
            plt.scatter(
                X_pca[mask, 0],
                X_pca[mask, 1],
                label=group,
            )

        plt.xlabel(f"PC1 ({expl_var[0]*100:.1f}% var)")
        plt.ylabel(f"PC2 ({expl_var[1]*100:.1f}% var)")
        plt.title(f"PCA of {stat_name.upper()}")
        plt.legend()
        plt.grid(True, linestyle="--", alpha=0.4)
        plt.tight_layout()
        out_path = os.path.join(PCA_PLOT_DIR, f"{stat_name}_pca.png")
        plt.savefig(out_path, dpi=200)
        plt.close()

#         plt.figure(figsize=(8, 6))
#         pca1 = PCA(n_components=1)
#         pc1 = pca1.fit_transform(matrix).flatten()
#         data_per_group = [pc1[np.array(labels) == g] for g in labels_set]
#         plt.violinplot(data_per_group, showmedians=True)
#         plt.xticks(range(1, len(labels)+1), labels)
#         plt.ylabel("PC1 score")
#         plt.title(f"PC1 Violin Plot ({dhs_name} | {stat_name.upper()})")
#         plt.grid(True, linestyle="--", alpha=0.4)
#         plt.tight_layout()
#         out_path = os.path.join(PCA_PLOT_DIR, f"{dhs_name}_{stat_name}_violin.png")
#         plt.savefig(out_path, dpi=200)
#         plt.close()