In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import os


# ===========================
# 1. User-defined file paths
# ===========================
# Change these paths or pass as command-line args in a pipeline
expression_file = "path/to/normalized_counts.csv"
metadata_file   = "path/to/metadata.csv"
output_dir      = "path/to/output"
os.makedirs(output_dir, exist_ok=True)


# ===========================
# 2. Load data
# ===========================
# Read expression data: rows = genes, columns = samples → transpose to samples x genes
expr = pd.read_csv(expression_file, index_col=0).T

# Read metadata (must contain 'SampleID' column that matches sample names in expr)
meta = pd.read_csv(metadata_file)

# Merge metadata with expression matrix
df = expr.merge(meta, left_index=True, right_on="SampleID")

# Keep only numeric gene-expression columns
X = df.drop(columns=["SampleID", "Group"], errors="ignore").values


# ===========================
# 3. PCA
# ===========================
pca = PCA(n_components=2, random_state=42)
pcs = pca.fit_transform(X)

pca_df = pd.DataFrame(pcs, columns=["PC1", "PC2"])
pca_df["SampleID"] = df["SampleID"]
if "Group" in df.columns:
    pca_df["Group"] = df["Group"]

# ===========================
# 4. k-means clustering
# ===========================
kmeans = KMeans(n_clusters=2, random_state=42)
pca_df["Cluster"] = kmeans.fit_predict(pcs)

# ===========================
# 5. Plot PCA with clusters
# ===========================
plt.figure(figsize=(10, 6))
sns.set(style="whitegrid")

# Scatter plot
sns.scatterplot(data=pca_df, x="PC1", y="PC2",
                hue="Cluster", palette="Set2",
                s=120, edgecolor="k", alpha=0.7)

# Annotate each sample
for _, row in pca_df.iterrows():
    plt.text(row["PC1"] + 0.5, row["PC2"] + 0.5,
             s=row["SampleID"], fontsize=8, color="black")

# Plot centroids
centroids = pca_df.groupby("Cluster")[["PC1", "PC2"]].mean()
plt.scatter(centroids["PC1"], centroids["PC2"],
            s=300, c="black", marker="X", edgecolor="k", label="Centroid")

# Axis labels with explained variance %
plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)", fontsize=14)
plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)", fontsize=14)

plt.legend(title="K-means cluster", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()

out_file = os.path.join(output_dir, "KMeans_Clustered_PCA_Plot.png")
plt.savefig(out_file, dpi=300, bbox_inches="tight")
plt.show()

print(f"PCA plot saved → {out_file}")