<a href="https://colab.research.google.com/github/mattybae/spatial-transcriptomics-project/blob/main/In_Silico_Staining_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scanpy squidpy

In [None]:
import scanpy as sc
import squidpy as sq
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
print("Work bench ready!")


In [None]:
adata = sq.datasets.visium_fluo_adata_crop()
img = sq.datasets.visium_fluo_image_crop()

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
print("Data loaded and cleaned")

In [None]:
sq.im.calculate_image_features(
    adata,
    img,
    features="summary",
    key_added="image_features",
    n_jobs=1
)

X = pd.DataFrame(adata.obsm["image_features"])

print(X.head())

In [None]:
target_gene = 'Rpl37a'

y = adata[:, target_gene].X.toarray().flatten()

print(f"Ready to predict {target_gene}. First 5 values: {y[:5]}")

In [None]:
model = Ridge(alpha=1.0)

model.fit(X, y)

y_pred = model.predict(X)

adata.obs["predicted_gene"] = y_pred

print(adata.obs["predicted_gene"].head())

print("AI Training Complete.")

In [None]:
sq.pl.spatial_scatter(
    adata,
    color=[target_gene, "predicted_gene"], # Compare Real vs Predicted
    title=["TRUE Biology (Rpl37a)", "AI PREDICTION (In Silico)"],
    cmap="viridis" # Color map (Purple to Yellow)
)

In [None]:
score = np.corrcoef(y, y_pred)[0, 1]
print(f"Model Accuracy (Pearson R): {score:.4f}")

In [None]:
import matplotlib.pyplot as plt

# 1. Extract the weights (coefficients) from the trained model
weights = model.coef_

# 2. Get the feature names (the column headers from X)
feature_names = X.columns

# 3. Organize them into a nice table
importance = pd.DataFrame({
    "Feature": feature_names,
    "Weight": weights
})

# 4. Sort by absolute impact (ignoring positive/negative sign for sorting)
importance["Abs_Weight"] = importance["Weight"].abs()
top_features = importance.sort_values("Abs_Weight", ascending=False).head(10)

# 5. Visualize the Drivers
plt.figure(figsize=(10, 6))
# Color code: Red = Positive Correlation, Blue = Negative Correlation
colors = ["red" if x > 0 else "blue" for x in top_features["Weight"]]

plt.barh(top_features["Feature"], top_features["Weight"], color=colors)
plt.title(f"What features drive '{target_gene}'?")
plt.xlabel("Influence on Prediction")
plt.show()