## CondensateML
A deep-learning pipeline for condensate phenotyping and image-based clustering from high-content screens

### Features
- ResNet18-based feature extraction (512D embeddings)
- UMAP projection + anchor-based similarity scoring
- ClusterProfiler-based GO enrichment integration

In [None]:
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models
from skimage.measure import label, regionprops
from skimage.filters import threshold_otsu
from skimage.transform import resize
from skimage.io import imread

from sklearn.metrics import accuracy_score, roc_auc_score
import kornia.augmentation as K
from torch.optim.lr_scheduler import ReduceLROnPlateau
import cv2

# ==== Inference: Embedding similarity vs RNF26/ZNF335 ====
import os, re, numpy as np, pandas as pd, cv2
from tqdm import tqdm
import tifffile as tiff
from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops_table
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity

import torch
import torch.nn as nn
from torchvision.models import resnet18

#torch.backends.cudnn.benchmark = True

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"   # forces sync; gives accurate stack traces
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # optional, helps determinism

In [None]:
# ==== Preprocessing & Labeling (Regression: RNF=-1, Neutral=0, ZNF=+1) ====
import os, re
import numpy as np
import pandas as pd
from tqdm import tqdm
from skimage.io import imread
from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops
from skimage.transform import resize

# -----------------------------
# CONFIG
# -----------------------------
input_dir = "path_to_training_data"
crop_size, resize_size, half_crop = 100, 224, 50

use_channels = ["GFP"]   # which channels to include

#Update channel mapping as needed:
channel_map = {
    "DAPI": "_w1.TIF",
    "GFP": "_w2.TIF",
    "CellMask": "_w3.TIF"
}

# -----------------------------
# Load images
# -----------------------------
stack, fov_filenames = [], []
fov_filenames = sorted([fn.replace(channel_map["DAPI"], "")
                        for fn in os.listdir(input_dir) if fn.endswith(channel_map["DAPI"])])

for prefix in tqdm(fov_filenames, desc="Loading FOVs", unit="FOV"):
    img_w1 = imread(os.path.join(input_dir, prefix + channel_map["DAPI"]))
    imgs = {"DAPI": img_w1}

    for ch in use_channels:
        fn = prefix + channel_map[ch]
        if os.path.exists(os.path.join(input_dir, fn)):
            imgs[ch] = imread(os.path.join(input_dir, fn))

    if any(img.shape != img_w1.shape for img in imgs.values()):
        continue
    stack.append(imgs)

print(f"Loaded {len(stack)} FOVs")

# -----------------------------
# Segment nuclei + crop patches
# -----------------------------
all_patches, output_rows = [], []
for fov_idx, (imgs, fov_name) in tqdm(enumerate(zip(stack, fov_filenames)),
                                      total=len(stack), desc="Cropping patches", unit="FOV"):
    dapi = imgs["DAPI"]
    thresh = threshold_otsu(dapi)
    labeled = label(dapi > thresh)
    regions = regionprops(labeled, intensity_image=dapi)

    for region in regions:
        cy, cx = map(int, region.centroid)
        y1, y2, x1, x2 = cy-half_crop, cy+half_crop, cx-half_crop, cx+half_crop
        if y1 < 0 or y2 > dapi.shape[0] or x1 < 0 or x2 > dapi.shape[1]:
            continue

        patch_channels = []
        for ch in use_channels:
            patch = resize(imgs[ch][y1:y2, x1:x2],
                           (resize_size, resize_size),
                           preserve_range=True)
            patch_channels.append(patch)

        patch = np.stack(patch_channels, axis=0).astype(np.float32)
        all_patches.append(patch)
        output_rows.append({"name": fov_name})

patches = np.stack(all_patches, axis=0)
print(f"Patch tensor shape: {patches.shape}")

# -----------------------------
# Build metadata + assign labels
# -----------------------------
meta_df = pd.DataFrame(output_rows)
np.save("resnet18_patches.npy", patches)
meta_df.to_csv("resnet18_patch_metadata.csv", index=False)

# Reload (safe practice)
patches = np.load("resnet18_patches.npy")
meta_df = pd.read_csv("resnet18_patch_metadata.csv")

# Extract merge key
meta_df["merge_key"] = meta_df["name"].apply(lambda x: re.search(r"(AC\d+_\w\d+)", x).group(1))

# Plate-to-gene mapping
raw = pd.read_csv("names.csv", skiprows=2)
raw = raw.iloc[:, [0, 1, 2]]
raw.columns = ["Gene", "Plate", "Well"]
raw.dropna(subset=["Gene", "Plate", "Well"], inplace=True)
raw["merge_key"] = raw["Plate"].astype(str).str.strip() + "_" + raw["Well"].astype(str).str.strip()

meta_df = meta_df.merge(raw[["merge_key", "Gene"]], on="merge_key", how="left")

# Labels: RNF=-1, Neutral=0, ZNF=+1
meta_df["label"] = np.nan
meta_df.loc[meta_df["merge_key"].str.endswith("B02"), "label"] = 0
meta_df.loc[meta_df["Gene"] == "ZNF335", "label"] = 1
meta_df.loc[meta_df["Gene"] == "RNF26", "label"] = -1

# Keep labeled rows only
valid_idx = meta_df.index[~meta_df["label"].isna()]
patches = patches[valid_idx]
meta_df = meta_df.loc[valid_idx].reset_index(drop=True)

# Clean
meta_df["Gene"] = meta_df["Gene"].replace(["", " ", "nan", "NaN"], np.nan).fillna("NT")

print("Positive genes in training:", meta_df[meta_df.label == 1].Gene.unique())
print("Negative genes in training:", meta_df[meta_df.label == -1].Gene.unique())
print("Neutral wells in training:", meta_df[meta_df.label == 0].merge_key.unique())


In [None]:
# ==== Training (Regression: RNF=-1, Neutral=0, ZNF=+1) ====
from sklearn.model_selection import train_test_split
import torch, torch.nn as nn
from torchvision.models import resnet18
import numpy as np, random
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_absolute_error, roc_auc_score
from scipy.stats import pearsonr
from sklearn.utils import resample

# -----------------------------
# Reproducibility
# -----------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Model (regression, 1 output, same as inference)
# -----------------------------
class ResNetRegression(nn.Module):
    def __init__(self, hidden_dim=64, freeze_until_layer=6):
        super().__init__()
        # Use pretrained ImageNet weights
        base = resnet18(weights="IMAGENET1K_V1")

        # Freeze early layers
        child_counter = 0
        for child in base.children():
            child_counter += 1
            if child_counter < freeze_until_layer:  # freeze conv1..layer2
                for param in child.parameters():
                    param.requires_grad = False

        # Replace final FC
        num_ftrs = base.fc.in_features
        base.fc = nn.Identity()
        self.base = base

        # Head
        self.fc1 = nn.Linear(num_ftrs, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(hidden_dim, 1)

        # Initialize head
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity="relu")
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.kaiming_normal_(self.fc2.weight, nonlinearity="linear")
        nn.init.constant_(self.fc2.bias, 0)

    def forward(self, x, return_embed=False):
        if x.shape[1] < 3:  # ensure 3-channel input
            pad = torch.zeros((x.shape[0], 3 - x.shape[1], x.shape[2], x.shape[3]),
                              device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=1)
        feats = self.base(x)
        if return_embed:
            return feats
        x = self.fc1(feats)
        x = self.relu(x)
        x = self.dropout(x)
        return self.fc2(x)

model = ResNetRegression(hidden_dim=64, freeze_until_layer=6).to(device)


# -----------------------------
# Dataset prep
# -----------------------------
labels = meta_df["label"].values   # already -1,0,1
X = torch.tensor((patches - patches.mean()) / patches.std(), dtype=torch.float32)   # z-score norm
X = X.repeat(1, 3, 1, 1)  # (N,3,H,W)

# Separate classes
idx_rnf = np.where(labels == -1)[0]
idx_neu = np.where(labels == 0)[0]
idx_znf = np.where(labels == 1)[0]
print("Before undersampling:", len(idx_rnf), len(idx_neu), len(idx_znf))

# Undersample Neutral (0)
min_size = min(len(idx_rnf), len(idx_znf))
idx_neu_down = resample(idx_neu, replace=False, n_samples=min_size, random_state=SEED)

# Combine balanced
balanced_idx = np.hstack([idx_rnf, idx_neu_down, idx_znf])
np.random.shuffle(balanced_idx)

X_bal = X[balanced_idx]
y_bal = torch.tensor(labels[balanced_idx], dtype=torch.float32)
print("After undersampling:", (y_bal == -1).sum().item(),
      (y_bal == 0).sum().item(), (y_bal == 1).sum().item())

# Train/val split
train_idx, val_idx = train_test_split(
    np.arange(len(y_bal)),
    test_size=0.2,
    stratify=y_bal.numpy(),
    random_state=SEED
)
train_ds = TensorDataset(X_bal[train_idx], y_bal[train_idx])
val_ds   = TensorDataset(X_bal[val_idx],   y_bal[val_idx])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=0, pin_memory=True)

# -----------------------------
# Optimizer, Loss, Scheduler
# -----------------------------
criterion = nn.SmoothL1Loss()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

# -----------------------------
# Training loop (with history logging)
# -----------------------------
train_history = []  # <-- store metrics here

def train_with_early_stopping(max_epochs=100, patience=25):
    best_val_loss, patience_counter = float("inf"), 0
    for epoch in range(max_epochs):
        # --- Train ---
        model.train(); train_losses = []
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device).unsqueeze(1)
            optimizer.zero_grad()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward(); optimizer.step()
            train_losses.append(loss.item())
        train_loss = np.mean(train_losses)

        # --- Validate ---
        model.eval(); val_preds, val_truth, val_losses = [], [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device).unsqueeze(1)
                preds = model(xb)
                loss = criterion(preds, yb)
                val_losses.append(loss.item())
                val_preds.extend(preds.cpu().numpy().ravel())
                val_truth.extend(yb.cpu().numpy().ravel())
        val_loss = np.mean(val_losses)

        # --- Metrics ---
        val_truth = np.array(val_truth); val_preds = np.array(val_preds)
        mae = mean_absolute_error(val_truth, val_preds)
        corr = pearsonr(val_truth, val_preds)[0]
        val_preds_class = np.rint(val_preds).astype(int)
        acc = (val_preds_class == val_truth).mean()

        auc = np.nan
        mask = np.isin(val_truth, [-1,1])
        if mask.sum() > 0:
            truth_bin = (val_truth[mask] == 1).astype(int)
            scores = (np.clip(val_preds[mask], -1, 1) + 1) / 2
            auc = roc_auc_score(truth_bin, scores)

        print(f"Epoch {epoch+1:03d}/{max_epochs} | TrainLoss={train_loss:.4f} | "
              f"ValLoss={val_loss:.4f} | MAE={mae:.3f} | Corr={corr:.3f} | "
              f"Acc={acc:.3f} | AUC={auc:.3f}")

        # --- Save to history ---
        train_history.append({
            "epoch": epoch+1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "val_acc": acc,
            "mae": mae,
            "corr": corr,
            "auc": auc
        })

        # --- Early stopping ---
        if val_loss < best_val_loss:
            best_val_loss, patience_counter = val_loss, 0
            torch.save(model.state_dict(), "similarity_model_best.pt")
            print(f"Saved best model (ValLoss={val_loss:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break
        scheduler.step()

train_with_early_stopping(max_epochs=100, patience=25)


In [None]:
# ==== Full Evaluation & Training Plots (6-panel + individual saves) ====
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt, torch
from sklearn.metrics import (
    mean_absolute_error, r2_score, confusion_matrix, ConfusionMatrixDisplay,
    accuracy_score, roc_curve, roc_auc_score
)
from scipy.stats import pearsonr

# -----------------------------
# Config
# -----------------------------
CKPT = "similarity_model_best.pt"
custom_colors = ["#f2704e", "#e13c3d", "#921d5c", "#651f56"]

# -----------------------------
# Collect predictions
# -----------------------------
y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device).unsqueeze(1)
        preds = model(xb)
        y_true.extend(yb.cpu().numpy().ravel())
        y_pred.extend(preds.cpu().numpy().ravel())
y_true, y_pred = np.array(y_true), np.array(y_pred)

# -----------------------------
# Metrics
# -----------------------------
mae  = mean_absolute_error(y_true, y_pred)
r2   = r2_score(y_true, y_pred)
corr = pearsonr(y_true, y_pred)[0] if len(y_true) > 1 else np.nan
y_pred_class = np.rint(y_pred).astype(int)
acc = accuracy_score(y_true, y_pred_class)
cm  = confusion_matrix(y_true, y_pred_class, labels=[-1,0,1])

mask_bin = np.isin(y_true, [-1,1])
roc_auc, fpr, tpr = np.nan, None, None
if mask_bin.sum() > 0:
    y_true_bin = (y_true[mask_bin] == 1).astype(int)
    y_score_bin = y_pred[mask_bin]
    roc_auc = roc_auc_score(y_true_bin, y_score_bin)
    fpr, tpr, _ = roc_curve(y_true_bin, y_score_bin)

print("="*60)
print(f"MAE={mae:.4f}, R²={r2:.4f}, Corr={corr:.4f}, Acc={acc:.4f}, AUC={roc_auc:.4f}")

# -----------------------------
# Training history
# -----------------------------
hist = None
if "train_history" in globals() and len(train_history) > 0:
    hist = pd.DataFrame(train_history)

# -----------------------------
# 6-panel summary
# -----------------------------


fig, axes = plt.subplots(3, 2, figsize=(12, 15))

# 1. Scatter
color_map = {-1: custom_colors[0], 0: custom_colors[1], 1: custom_colors[2]}
axes[0,0].scatter(y_true, y_pred, c=[color_map[int(lbl)] for lbl in y_true],
                  alpha=0.6, edgecolor="k")
for h in [-1,0,1]:
    axes[0,0].axhline(h, color='grey', linestyle='--')
    axes[0,0].axvline(h, color='grey', linestyle='--')
axes[0,0].set_title(f"True vs Predicted (Corr={corr:.2f})")

# 2. Histogram
for i,(cls,label) in enumerate([(-1,"RNF26 (-1)"), (0,"Neutral (0)"), (1,"ZNF335 (+1)")]):
    axes[0,1].hist(y_pred[y_true==cls], bins=20, alpha=0.8,
                   label=label, color=custom_colors[i])
axes[0,1].set_title("Prediction Distribution"); axes[0,1].legend()

# 3. ROC
if fpr is not None:
    axes[1,0].plot(fpr, tpr, lw=2, color=custom_colors[3])
    axes[1,0].plot([0,1],[0,1],'--',color='grey')
    axes[1,0].text(0.05, 0.95, f"AUC = {roc_auc:.3f}",
                   transform=axes[1,0].transAxes, va="top", ha="left",
                   bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.3"))
    axes[1,0].set_title("ROC Curve (RNF26 vs ZNF335)")

# 4. Confusion Matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
    display_labels=['RNF26 (-1)','Neutral (0)','ZNF335 (+1)'])
disp.plot(ax=axes[1,1], cmap="BuPu", colorbar=False)
axes[1,1].set_title("Confusion Matrix")

# 5. Loss Curve
if hist is not None:
    axes[2,0].plot(hist["epoch"], hist["train_loss"], label="Train Loss", color=custom_colors[3])
    axes[2,0].plot(hist["epoch"], hist["val_loss"], label="Val Loss", color=custom_colors[0])
    min_idx = hist["val_loss"].idxmin()
    axes[2,0].scatter(hist["epoch"][min_idx], hist["val_loss"][min_idx], color="black")
    axes[2,0].text(hist["epoch"][min_idx], hist["val_loss"][min_idx],
                   f"Min Val Loss={hist['val_loss'][min_idx]:.3f}\n(Epoch {hist['epoch'][min_idx]})",
                   fontsize=9, va="bottom", ha="left",
                   bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.3"))
    axes[2,0].set_title("Loss Curve"); axes[2,0].legend()

# 6. Accuracy Curve
if hist is not None and "val_acc" in hist:
    axes[2,1].plot(hist["epoch"], hist["val_acc"], label="Val Accuracy", color=custom_colors[2])
    max_idx = hist["val_acc"].idxmax()
    axes[2,1].scatter(hist["epoch"][max_idx], hist["val_acc"][max_idx], color="black")
    axes[2,1].text(hist["epoch"][max_idx], hist["val_acc"][max_idx],
                   f"Max Val Acc={hist['val_acc'][max_idx]:.3f}\n(Epoch {hist['epoch'][max_idx]})",
                   fontsize=9, va="bottom", ha="left",
                   bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.3"))
    axes[2,1].set_title("Validation Accuracy"); axes[2,1].legend()

plt.tight_layout()
fig.savefig("evaluation_training_summary_full.pdf", dpi=300)
plt.close(fig)

# -----------------------------
# Individual plots (regenerate fresh)
# -----------------------------
def save_single_plot(fname, plot_func):
    fig, ax = plt.subplots(figsize=(6,5))
    plot_func(ax)
    fig.savefig(fname, dpi=300, bbox_inches="tight")
    plt.close(fig)

save_single_plot("scatter_only.pdf", lambda ax: ax.scatter(
    y_true, y_pred, c=[color_map[int(lbl)] for lbl in y_true],
    alpha=0.6, edgecolor="k"))

save_single_plot("hist_only.pdf", lambda ax: [
    ax.hist(y_pred[y_true==cls], bins=20, alpha=0.8, label=label, color=custom_colors[i])
    for i,(cls,label) in enumerate([(-1,"RNF26"),(0,"Neutral"),(1,"ZNF335")])
])

if fpr is not None:
    save_single_plot("roc_only.pdf", lambda ax: [
        ax.plot(fpr, tpr, lw=2, color=custom_colors[3]),
        ax.plot([0,1],[0,1],'--',color='grey'),
        ax.text(0.05, 0.95, f"AUC={roc_auc:.3f}", transform=ax.transAxes,
                va="top", ha="left",
                bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.3"))
    ])

# Corrected confusion matrix save
save_single_plot("cm_only.pdf", lambda ax: ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=['RNF26 (-1)','Neutral (0)','ZNF335 (+1)']
).plot(ax=ax, cmap="BuPu", colorbar=False))

if hist is not None:
    save_single_plot("loss_only.pdf", lambda ax: [
        ax.plot(hist["epoch"], hist["train_loss"], label="Train", color=custom_colors[3]),
        ax.plot(hist["epoch"], hist["val_loss"], label="Val", color=custom_colors[0]),
        ax.legend()
    ])
    if "val_acc" in hist:
        save_single_plot("acc_only.pdf", lambda ax: [
            ax.plot(hist["epoch"], hist["val_acc"], label="Val Acc", color=custom_colors[2]),
            ax.legend()
        ])


In [None]:
# ==== Inference (Regression: RNF=-1, Neutral=0, ZNF=+1) ====
import os, re, cv2
import numpy as np, pandas as pd
from tqdm import tqdm
import tifffile as tiff
from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops_table
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import mean_absolute_error, roc_auc_score, accuracy_score
from scipy.stats import pearsonr

import torch, torch.nn as nn
from torchvision.models import resnet18

# -----------------------------
# Config
# -----------------------------
test_dir    = "New_image_directory"
output_xlsx = "gene_similarity_predictions_inferenceFullList.xlsx"
names_csv   = "names.csv"
model_path  = "similarity_model_best.pt"
resize_size, half_crop, batch_size = 224, 50, 64
model_channels = ["GFP"]

#Update channel mapping as needed:
channel_map = {"DAPI": "_w1.TIF","GFP": "_w2.TIF","CellMask": "_w3.TIF"}

# -----------------------------
# Gene mapping -- MAY NEED TO BE CUSTOMIZED
# -----------------------------
raw = pd.read_csv(names_csv, skiprows=2).iloc[:,[0,1,2]]
raw.columns = ["Gene","Plate","Well"]
raw["merge_key"] = raw["Plate"].astype(str).str.strip()+"_"+raw["Well"].astype(str).str.strip()
plate2gene = dict(zip(raw["merge_key"], raw["Gene"]))

MERGE_KEY_REGEX = re.compile(r"(AC\d+_[A-Za-z]\d+)")
def extract_merge_key(p): return MERGE_KEY_REGEX.search(p).group(1)

# -----------------------------
# Model (same as training)
# -----------------------------
class ResNetRegression(nn.Module):
    def __init__(self, hidden_dim=64, freeze_until_layer=6):
        super().__init__()
        base = resnet18(weights="IMAGENET1K_V1")
        child_counter = 0
        for child in base.children():
            child_counter += 1
            if child_counter < freeze_until_layer:
                for param in child.parameters():
                    param.requires_grad = False
        num_ftrs = base.fc.in_features
        base.fc = nn.Identity()
        self.base = base
        self.fc1 = nn.Linear(num_ftrs, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(hidden_dim, 1)
    def forward(self, x, return_embed=False):
        if x.shape[1] < 3:
            pad = torch.zeros((x.shape[0], 3-x.shape[1], x.shape[2], x.shape[3]),
                              device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=1)
        feats = self.base(x)
        if return_embed: return feats
        x = self.fc1(feats); x = self.relu(x); x = self.dropout(x)
        return self.fc2(x)

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=ResNetRegression(hidden_dim=64, freeze_until_layer=6).to(device)
model.load_state_dict(torch.load(model_path,map_location=device))
model.eval()
print("Model loaded")

# -----------------------------
# Storage
# -----------------------------
gene_preds=defaultdict(list)
gene_embeds=defaultdict(list)

# -----------------------------
# Loop over images
# -----------------------------
file_list=sorted([f for f in os.listdir(test_dir) if f.endswith(channel_map["DAPI"])])
#file_list = file_list[::4]  # downsample for speed (same as old)
print(f"Processing {len(file_list)} images")

for fname in tqdm(file_list, desc="Scanning images"):
    prefix=fname.replace(channel_map["DAPI"],"")
    dapi=tiff.imread(os.path.join(test_dir,prefix+channel_map["DAPI"]))
    imgs={ch:tiff.imread(os.path.join(test_dir,prefix+channel_map[ch])) for ch in model_channels}
    if any(img.shape!=dapi.shape for img in imgs.values()): continue
    
    thresh=threshold_otsu(dapi)
    labeled=label(dapi>thresh)
    props=regionprops_table(labeled,properties=("centroid",))
    
    patches=[]
    for cy,cx in zip(props["centroid-0"],props["centroid-1"]):
        y1,y2,x1,x2=int(cy-half_crop),int(cy+half_crop),int(cx-half_crop),int(cx+half_crop)
        if y1<0 or y2>dapi.shape[0] or x1<0 or x2>dapi.shape[1]: continue
        chans=[cv2.resize(imgs[ch][y1:y2,x1:x2],(resize_size,resize_size)) for ch in model_channels]
        patches.append(np.stack(chans,axis=0).astype(np.float32)/65535.0)
    if not patches: continue
    
    patches=torch.tensor(np.stack(patches),dtype=torch.float32).to(device)
    with torch.no_grad():
        feats=model(patches,return_embed=True).cpu().numpy()
        preds=model(patches).cpu().numpy().ravel()
    
    merge_key=extract_merge_key(prefix)
    gene=plate2gene.get(merge_key,merge_key)
    gene_embeds[gene].append(feats)
    gene_preds[gene].extend(preds)


# -----------------------------
# Aggregate by gene (with ≤500 cells cap)
# -----------------------------
max_cells = 500
rng = np.random.default_rng(42)  # reproducible

gene_centroids = {}
gene_outputs = {}

for gene in gene_preds:
    preds = np.array(gene_preds[gene])
    embeds = np.vstack(gene_embeds[gene])

    # If more than max_cells, randomly sample indices
    if len(preds) > max_cells:
        idx = rng.choice(len(preds), size=max_cells, replace=False)
        preds = preds[idx]
        embeds = embeds[idx]

    # Compute centroid and store
    gene_centroids[gene] = embeds.mean(axis=0)
    gene_outputs[gene] = preds


# Anchor centroids
znf = gene_centroids.get("ZNF335")
rnf = gene_centroids.get("RNF26")
b02 = gene_centroids.get("AC0559_B02")  # <-- replace with correct merge_key/Gene name for B02 neutral anchor

# -----------------------------
# Results
# -----------------------------
results = []
for gene, preds in gene_outputs.items():
    cent = gene_centroids[gene]

    # Similarities
    sim_zn = cosine_similarity([cent], [znf])[0][0] if znf is not None else np.nan
    sim_rn = cosine_similarity([cent], [rnf])[0][0] if rnf is not None else np.nan
    sim_b0 = cosine_similarity([cent], [b02])[0][0] if b02 is not None else np.nan

    # Relative scores (unchanged!)
    rel = sim_zn - sim_rn if not (np.isnan(sim_zn) or np.isnan(sim_rn)) else np.nan
    diff_zn_b0 = sim_zn - sim_b0 if not (np.isnan(sim_zn) or np.isnan(sim_b0)) else np.nan
    diff_b0_rn = sim_b0 - sim_rn if not (np.isnan(sim_b0) or np.isnan(sim_rn)) else np.nan

    results.append({
        "Gene": gene,
        "sim_ZNF335": sim_zn,
        "sim_B02": sim_b0,
        "sim_RNF26": sim_rn,
        "relative_score": rel,
        "diff_ZNF335_B02": diff_zn_b0,
        "diff_B02_RNF26": diff_b0_rn,
        "n_cells": len(preds)
    })

gene_df = pd.DataFrame(results).sort_values("relative_score", ascending=False)
gene_df.to_excel(output_xlsx, index=False)
print(f"Saved {output_xlsx}")

print("\n Top 10 genes closest to ZNF335:")
print(gene_df.head(10).to_string(index=False))

print("\n Top 10 genes closest to RNF26:")
print(gene_df.tail(10).to_string(index=False))
 


In [None]:
# -----------------------------
# Results
# -----------------------------
results = []
for gene, preds in gene_outputs.items():
    cent = gene_centroids[gene]

    # Similarities
    sim_zn = cosine_similarity([cent], [znf])[0][0] if znf is not None else np.nan
    sim_rn = cosine_similarity([cent], [rnf])[0][0] if rnf is not None else np.nan
    sim_b0 = cosine_similarity([cent], [b02])[0][0] if b02 is not None else np.nan

    # Relative scores (unchanged!)
    rel = sim_zn - sim_rn if not (np.isnan(sim_zn) or np.isnan(sim_rn)) else np.nan
    diff_zn_b0 = sim_zn - sim_b0 if not (np.isnan(sim_zn) or np.isnan(sim_b0)) else np.nan
    diff_b0_rn = sim_b0 - sim_rn if not (np.isnan(sim_b0) or np.isnan(sim_rn)) else np.nan

    results.append({
        "Gene": gene,
        "sim_ZNF335": sim_zn,
        "sim_B02": sim_b0,
        "sim_RNF26": sim_rn,
        "relative_score": rel,
        "diff_ZNF335_B02": diff_zn_b0,
        "diff_B02_RNF26": diff_b0_rn,
        "n_cells": len(preds)
    })

gene_df = pd.DataFrame(results).sort_values("relative_score", ascending=False)
gene_df.to_excel(output_xlsx, index=False)
print(f"Saved {output_xlsx}")

print("\n Top 10 genes closest to ZNF335:")
print(gene_df.head(10).to_string(index=False))

print("\n Top 10 genes closest to RNF26:")
print(gene_df.tail(10).to_string(index=False))

In [None]:
import pandas as pd
import numpy as np

# Load both files
sim_df = pd.read_excel("gene_similarity_predictions_inferenceFullList.xlsx")
zscore_df = pd.read_excel("Zscore.xlsx")

# Standardize gene column
sim_df.rename(columns={sim_df.columns[0]: "Gene"}, inplace=True)
zscore_df.rename(columns={zscore_df.columns[0]: "Gene"}, inplace=True)

# Merge in Z-Score
merged = pd.merge(sim_df, zscore_df[["Gene", "Z-Score"]], on="Gene", how="left")

# Keep only desired columns
keep_cols = [
    "Gene", "sim_ZNF335", "sim_B02", "sim_RNF26",
    "relative_score", "diff_ZNF335_B02", "diff_B02_RNF26",
    "n_cells", "Z-Score"
]
merged = merged[keep_cols]

# Keep only one row per Gene: highest Z-Score if present,
# otherwise just the first available row
merged = (
    merged.sort_values("Z-Score", ascending=False, na_position="last")
    .groupby("Gene")
    .head(1)
    .reset_index(drop=True)
)

# Add normalization (0–100) and rank columns
for col in ["sim_ZNF335", "sim_B02", "sim_RNF26"]:
    # Linear normalization 0–100
    merged[f"{col}_norm100"] = 100 * (merged[col] - merged[col].min()) / (merged[col].max() - merged[col].min())
    
    # Integer rank (1 = most similar, highest value first)
    merged[f"{col}_rank"] = merged[col].rank(method="min", ascending=False).astype(int)

# Save
out_file = "final_similarity_with_ranks.xlsx"
merged.to_excel(out_file, index=False)

print(f"Saved {out_file}")


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial import ConvexHull

cluster_colors = ["#651f56", "#e13c3d", "#f2704e"]

# ---- Load gene_df from your processed file ----
gene_df = pd.read_excel("final_similarity_with_ranks.xlsx")

# ---- Build embedding ----
genes, X = [], []
for g, v in gene_centroids.items():
    genes.append(g)
    X.append(v)
X = np.vstack(X)

# ---- Dimensionality reduction ----
try:
    import umap
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1,
                        metric="cosine", random_state=42)
    Z = reducer.fit_transform(X)
except Exception:
    print("UMAP not available; falling back to PCA.")
    Z = PCA(n_components=3, random_state=42).fit_transform(X)

# ---- Build df ----
df_umap = pd.DataFrame({"Gene": genes, "UMAP1": Z[:,0], "UMAP2": Z[:,1]})
df_umap = df_umap.merge(gene_df[["Gene", "Z-Score"]], on="Gene", how="left")

# ---- Mask Z-scores for B02 wells ----
df_umap.loc[df_umap["Gene"].str.endswith("B02", na=False), "Z-Score"] = np.nan

# ---- KMeans clustering ----
kmeans = KMeans(n_clusters=3, random_state=42).fit(df_umap[["UMAP1","UMAP2"]].values)
df_umap["Cluster"] = kmeans.labels_

# ---- Anchors ----
anchors = ["RNF26", "ZNF335", "AC0559_B02"]
anchor_colors = {"RNF26": "orange", "ZNF335": "lightblue", "AC0559_B02": "green"}

# ---- Z-Score colormap ----
zscore_colors = ["#651f56", "#921d5c", "#e13c3d", "#f2704e"]
zscore_cmap = LinearSegmentedColormap.from_list("zscore_cmap", zscore_colors, N=256)

# ---- Scatter ----
plt.figure(figsize=(10,7))
sc = plt.scatter(
    df_umap["UMAP1"], df_umap["UMAP2"],
    c=df_umap["Z-Score"], cmap=zscore_cmap,
    s=30, alpha=1, vmin=2
)

# ---- Cluster boundaries ----
for cluster_id in np.unique(df_umap["Cluster"]):
    points = df_umap.loc[df_umap["Cluster"]==cluster_id, ["UMAP1","UMAP2"]].values
    if len(points) >= 3:
        hull = ConvexHull(points)
        plt.fill(points[hull.vertices,0], points[hull.vertices,1],
                 facecolor=cluster_colors[cluster_id % len(cluster_colors)],
                 edgecolor=cluster_colors[cluster_id % len(cluster_colors)],
                 alpha=0.3, linewidth=2, label=f"Cluster {cluster_id}")

# ---- Highlight anchors (stars only, no text) ----
for _, row in df_umap[df_umap["Gene"].isin(anchors)].iterrows():
    g = row["Gene"]
    plt.scatter(row["UMAP1"], row["UMAP2"], s=200, marker="*",
                color=anchor_colors[g], edgecolor="black", zorder=5)

# ---- Final touches ----
plt.colorbar(sc, label="Z-Score (NaN for B02 genes)")
plt.legend()
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.tight_layout()

# Save as PDF
plt.savefig("umap_gene_embeddings_Zscore_clusters.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
#Save relationships of each gene relative to their cluster and distance from anchor point (RNF26/ZNF335/NT Ctrl)
from sklearn.metrics.pairwise import euclidean_distances

# ---- Map each anchor to its cluster ----
anchor_clusters = {}
for anchor in anchors:
    if anchor in df_umap["Gene"].values:
        anchor_cluster = df_umap.loc[df_umap["Gene"]==anchor, "Cluster"].iloc[0]
        anchor_clusters[anchor] = anchor_cluster

# ---- Compute distance to anchor of the same cluster ----
distances = []
for idx, row in df_umap.iterrows():
    gene = row["Gene"]
    cluster = row["Cluster"]
    
    # Find anchor for this cluster
    anchor_for_cluster = None
    for a, c in anchor_clusters.items():
        if c == cluster:
            anchor_for_cluster = a
            break
    
    if anchor_for_cluster is not None:
        # Compute Euclidean distance in UMAP space
        anchor_coords = df_umap.loc[df_umap["Gene"]==anchor_for_cluster, ["UMAP1","UMAP2"]].values[0]
        gene_coords = row[["UMAP1","UMAP2"]].values
        dist = np.linalg.norm(gene_coords - anchor_coords)
    else:
        dist = np.nan
    
    distances.append(dist)

df_umap["Dist_to_Anchor"] = distances

# ---- Save results ----
df_umap.to_excel("gene_cluster_anchor_distances.xlsx", index=False)

In [None]:
#Heatmap plotting/clustering

import matplotlib as mpl

# Set global font to Arial
mpl.rcParams['font.family'] = 'Arial'

g = sns.clustermap(
    rank_df,
    cmap=custom_cmap,
    metric="euclidean",
    method="ward",
    col_cluster=False,
    figsize=(7, 18),
    cbar_kws={"label": "Rank"},
    dendrogram_ratio=(0.35, 0.05),
    colors_ratio=0.01,
    tree_kws={"linewidths": 0.8}
)

# Row labels
g.ax_heatmap.set_yticks(range(len(rank_df)))
g.ax_heatmap.set_yticklabels(
    g.data2d.index,
    fontsize=12,
    rotation=0
)

# Column labels
g.ax_heatmap.set_xticklabels(
    g.data2d.columns,
    fontsize=20,
    ha="right"
)

# Colorbar font
g.cax.yaxis.label.set_size(12)
g.cax.tick_params(labelsize=10)
g.savefig("clustermap2.pdf", format="pdf", bbox_inches="tight")

# Or if you want to also see it interactively:
plt.show()

In [None]:
#For elbow method plot:
#OPTIONAL

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples, silhouette_score
from matplotlib.colors import LinearSegmentedColormap
from scipy.spatial import ConvexHull

# ----------------------------------
# 1. Elbow Plot
# ----------------------------------
with PdfPages("combined_panels.pdf") as pdf:
    plt.figure(figsize=(8,6))
    plt.plot(K, inertias, "o-", lw=2, color="#651f50")
    plt.xlabel("Number of clusters (k)")
    plt.ylabel("Within-cluster sum of squares (Inertia)")
    plt.title("Elbow Method for Optimal k")
    plt.xticks(K)
    plt.grid(True, linestyle="--", alpha=0.6)
    pdf.savefig()
    plt.close()

    # ----------------------------------
    # 2. Silhouette plot (k=3 only)
    # ----------------------------------
    fig, ax = plt.subplots(figsize=(7,5))
    kmeans = KMeans(n_clusters=3, random_state=42).fit(X_umap)
    labels = kmeans.labels_
    sil_vals = silhouette_samples(X_umap, labels)
    sil_avg = silhouette_score(X_umap, labels)

    y_lower = 10
    cluster_colors = ["#ff6c20", "#e13c3d", "#651f50"]  # 3 cluster colors
    for i in range(3):
        ith_vals = sil_vals[labels == i]
        ith_vals.sort()
        size = len(ith_vals)
        y_upper = y_lower + size
        color = cluster_colors[i % len(cluster_colors)]
        ax.fill_betweenx(np.arange(y_lower, y_upper), 0, ith_vals,
                         facecolor=color, alpha=0.8)
        ax.text(-0.05, y_lower + 0.5 * size, str(i))
        y_lower = y_upper + 10
    ax.axvline(x=sil_avg, color="#651f50", linestyle="--", lw=2)
    ax.set_title(f"Silhouette plot for k=3 (avg={sil_avg:.3f})")
    ax.set_xlabel("Silhouette coefficient")
    ax.set_ylabel("Cluster")
    pdf.savefig()
    plt.close()