# YBS Presentation

In [None]:
import logging
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import umap
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [None]:
os.makedirs("results/figures/ybs", exist_ok=True)

## Class Distribution Bar Chart
Counts of proteins per compartment (train vs. test)

In [None]:
y_train = pd.read_csv("data/processed/y/train.csv")
y_val = pd.read_csv("data/processed/y/val.csv")
y_test = pd.read_csv("data/processed/y/test.csv")

# Merge rare classes based on threshold
threshold = 50
counts = y_train["localization"].value_counts()
print(f"Pre-merge class counts: {counts.to_dict()}")
rare = counts[counts < threshold].index.tolist()
if rare:
    y_train["localization"] = y_train["localization"].replace(
        {cls: "Other" for cls in rare}
    )
    y_val["localization"] = y_val["localization"].replace(
        {cls: "Other" for cls in rare}
    )
    y_test["localization"] = y_test["localization"].replace(
        {cls: "Other" for cls in rare}
    )
    print(f"Merged rare classes: {rare} -> 'Other'")
else:
    print("No rare classes to merge.")
print(f"Post-merge classes: {sorted(y_train['localization'].unique())}")

train_counts = y_train["localization"].value_counts().sort_index()
val_counts = y_val["localization"].value_counts().sort_index()
test_counts = y_test["localization"].value_counts().sort_index()

dist_data = {
    "Train": train_counts, 
    "Validation": val_counts, 
    "Test": test_counts
}
df_dist = pd.DataFrame(dist_data)

In [None]:
ax = df_dist.plot(kind="bar", figsize=(12, 6))
ax.set_title("Protein Counts per Compartment (Train vs Test)", fontsize=16)
ax.set_xlabel("Subcellular Compartment", fontsize=14)
ax.set_ylabel("Number of Proteins", fontsize=14)
plt.xticks(rotation=45, ha="right", fontsize=12)
for tick in ax.get_xticklabels():
    tick.set_fontsize(12)
for label in ax.get_yticklabels():
    label.set_fontsize(12)
ax.set_yscale("log")
ax.legend(fontsize=12)
plt.tight_layout()

plt.savefig("results/figures/ybs/class_distribution_train_test.png")
plt.show()

In [None]:
class_balance_df = pd.read_csv("results/csv/class_balance.csv")
class_balance_df["percent"] = class_balance_df["percent"].astype(float) / 100.0
class_balance_df["count"] = class_balance_df["count"].astype(int)
class_balance_df = class_balance_df.sort_values("count", ascending=False)

threshold = 50
if class_balance_df["count"].min() < threshold:
    rare_mask = class_balance_df["count"] < threshold
    rare_classes = class_balance_df[rare_mask]["localization"].tolist()
    is_rare = class_balance_df["localization"].isin(rare_classes)
    class_balance_df.loc[is_rare, "localization"] = "Other"
    grouped = class_balance_df.groupby("localization").sum()
    class_balance_df = grouped.reset_index()
else:
    rare_classes = []

if rare_classes:
    logging.info(f"Merged rare classes: {rare_classes} -> 'Other'")
else:
    logging.info("No rare classes to merge.")

# print data
print(f"Number of classes: {len(class_balance_df)}")
print(f"Classes: {', '.join(class_balance_df['localization'].tolist())}")

## Feature engineering

In [None]:
X_train = pd.read_csv("data/processed/X/train.csv")
y_train = pd.read_csv("data/processed/Y/train.csv").squeeze()
X_val = pd.read_csv("data/processed/X/val.csv")
y_val = pd.read_csv("data/processed/Y/val.csv").squeeze()
X_test = pd.read_csv("data/processed/X/test.csv")
y_test = pd.read_csv("data/processed/Y/test.csv").squeeze()

In [None]:
# 1. Sequence‐length distributions
plt.figure(figsize=(8, 4))
sns.histplot(X_train["original_length"], color="gray", label="Original", kde=True)
sns.histplot(X_train["sequence_length"], color="blue", label="Cleaned", kde=True)
plt.xlabel("Sequence Length (AA)", fontsize=14)
plt.ylabel("Frequency", fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.title("Original vs. Cleaned Sequence Length", fontsize=16)
plt.tight_layout()
plt.savefig("results/figures/ybs/seq_length_hist.png")
plt.close()

plt.figure(figsize=(10, 4))
sns.violinplot(x=y_train, y=X_train["sequence_length"])
plt.xticks(rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.title("Cleaned Sequence Length by Localization", fontsize=16)
plt.ylabel("Sequence Length (AA)", fontsize=14)
plt.xlabel("Localization", fontsize=14)
plt.tight_layout()
plt.savefig("results/figures/ybs/seq_length_violin.png")
plt.close()

# 2. Physico‐chemical summaries (GRAVY & pI)
for feat in ["gravy", "pI"]:
    plt.figure(figsize=(6, 4))
    sns.kdeplot(data=X_train, x=feat, fill=True)
    plt.xlabel(feat.upper(), fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.title(f"{feat.upper()} Distribution (Train)", fontsize=16)
    plt.tight_layout()
    plt.savefig(f"results/figures/ybs/{feat}_dist.png")
    plt.close()

    plt.figure(figsize=(10, 4))
    sns.boxplot(x=y_train, y=X_train[feat])
    plt.xticks(rotation=45, fontsize=12)
    plt.yticks(fontsize=12)
    plt.ylabel(feat.upper(), fontsize=14)
    plt.xlabel("Localization", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.title(f"{feat.upper()} by Localization", fontsize=16)
    plt.tight_layout()
    plt.savefig(f"results/figures/ybs/{feat}_by_loc.png")
    plt.close()

# 3. Amino‐acid composition
aas = list("ACDEFGHIKLMNPQRSTVWY")
aa_means = X_train[aas].groupby(y_train).mean()
plt.figure(figsize=(12, 6))
sns.heatmap(aa_means, cmap="viridis", cbar_kws={"label": "Fraction"})
plt.xlabel("Amino Acid", fontsize=14)
plt.ylabel("Localization", fontsize=14)
plt.xticks(rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.title("Mean Amino-Acid Composition per Localization", fontsize=16)
plt.tight_layout()
plt.savefig("results/figures/ybs/aa_composition_heatmap.png")
plt.close()

plt.figure(figsize=(14, 4))
X_train[aas].mean().plot.bar()
plt.ylabel("Mean Fraction", fontsize=14)
plt.xlabel("Amino Acid", fontsize=14)
plt.xticks(rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.title("Global Amino-Acid Composition", fontsize=16)
plt.tight_layout()
plt.savefig("results/figures/ybs/global_aa_composition.png")
plt.close()

# 4. Top‐10 Dipeptide frequencies
dp_cols = [c for c in X_train.columns if c.startswith("dp_")]
top10 = X_train[dp_cols].mean().sort_values(ascending=False).head(10)
plt.figure(figsize=(8, 4))
top10.plot.bar(color="tab:green")
plt.xticks(rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel("Dipeptide", fontsize=14)
plt.ylabel("Mean Frequency", fontsize=14)
plt.title("Top 10 Dipeptide Frequencies", fontsize=16)
plt.tight_layout()
plt.savefig("results/figures/ybs/top10_dipeptides.png")
plt.close()

# 5. Feature correlation matrix (limited to 50 dipeptides for readability)
num_feats = ["sequence_length", "original_length", "gravy", "pI"] + aas + dp_cols[:50]
corr = X_train[num_feats].corr()
plt.figure(figsize=(12, 10))
sns.heatmap(corr, cmap="coolwarm", center=0, xticklabels=False, yticklabels=False)
plt.xlabel("Features", fontsize=14)
plt.ylabel("Features", fontsize=14)
plt.title("Correlation Matrix of Numeric Features", fontsize=16)
plt.tight_layout()
plt.savefig("results/figures/ybs/feature_correlation.png")
plt.close()

In [None]:
emb = {}

pca_2 = PCA(n_components=2, random_state=42).fit_transform(X_train[aas])
emb["pca"] = (pca_2, "PCA")

tsne_2 = TSNE(n_components=2, random_state=42).fit_transform(X_train[aas])
emb["tsne"] = (tsne_2, "t-SNE")

umap_2 = umap.UMAP(n_components=2, random_state=42).fit_transform(X_train[aas])
emb["umap"] = (umap_2, "UMAP")

In [None]:
for key, (coords, title) in emb.items():
    plt.figure(figsize=(6, 5))
    sns.scatterplot(
        x=coords[:, 0], y=coords[:, 1], hue=y_train, s=20, alpha=0.7, legend=False
    )
    plt.title(f"{title} Projection of AA Composition")
    plt.xlabel(f"{title} 1")
    plt.ylabel(f"{title} 2")
    plt.tight_layout()
    plt.savefig(f"results/figures/ybs/{key}_aa_composition.png")
    plt.close()