In [1]:
from pathlib import Path

import pandas as pd
from scipy.stats import spearmanr

from sklearn.model_selection import KFold


from predictability.models import PottsRegressor
from predictability.constants import PROJECT_ROOT, DATA_ROOT

In [2]:
potts_model = PottsRegressor(msa_path=str(DATA_ROOT / "amylase/msa.a3m"))

2023-06-13 15:18:39.506 | INFO     | unpredictability.models:__init__:24 - Downloading Potts model from: s3://sagemaker-us-east-1-118749263921/ecnet-25-09-05-01/output/model.tar.gz


In [4]:
seed = 42

In [12]:
results_dir = Path(PROJECT_ROOT / "results/amylase/feature_utilization/potts")
results_dir.mkdir(exist_ok=True, parents=True)

In [5]:
data = pd.read_csv(DATA_ROOT / "amylase/combinatorials.csv")

In [7]:
property = "stain_activity"

Removed 0 samples
Removed 0 samples


# Eval on held out bins

In [8]:
experiment_results_extrapolate = {
    "n_val_samples": [],
    "spearman_train": [],
    "spearman_val": [],
    "eval": []
}
df = data

for bin in df["bin_label"].unique():
    df["split"] = df["bin_label"].map(lambda x: "valid" if x==bin else "train")
    potts_model.fit(df[df["split"] == "train"], property)
    predictions_train = potts_model.predict(df[df["split"] == "train"])
    predictions_val = potts_model.predict(df[df["split"] == "valid"])
    spearman_train = spearmanr(df[df["split"] == "train"][property].values, predictions_train)[0]
    spearman_val = spearmanr(df[df["split"] == "valid"][property].values, predictions_val)[0]
    experiment_results_extrapolate["n_val_samples"].append(len(df[df["split"] == "valid"]))
    experiment_results_extrapolate["spearman_train"].append(spearman_train)
    experiment_results_extrapolate["spearman_val"].append(spearman_val)
    experiment_results_extrapolate["eval"].append("extrapolate")

# Eval on randomly held out data

In [10]:
experiment_results_random = {
    "n_val_samples": [],
    "spearman_train": [],
    "spearman_val": [],
    "eval": []
}
df = data

kfold = KFold(n_splits=16, shuffle=True, random_state=42)
kfold.get_n_splits(df)
print(kfold)

for i, (train_index, val_index) in enumerate(kfold.split(df)):
    df.loc[train_index, "split"] = "train"
    df.loc[val_index, "split"] = "valid"
    potts_model.fit(df[df["split"] == "train"], property)
    predictions_train = potts_model.predict(df[df["split"] == "train"])
    predictions_val = potts_model.predict(df[df["split"] == "valid"])
    spearman_train = spearmanr(df[df["split"] == "train"][property].values, predictions_train)[0]
    spearman_val = spearmanr(df[df["split"] == "valid"][property].values, predictions_val)[0]
    print(f"Spearman val: {spearman_val}")
    experiment_results_random["n_val_samples"].append(len(df[df["split"] == "valid"]))
    experiment_results_random["spearman_train"].append(spearman_train)
    experiment_results_random["spearman_val"].append(spearman_val)
    experiment_results_random["eval"].append("random")

KFold(n_splits=16, random_state=42, shuffle=True)
Spearman val: 0.6461914879920567
Spearman val: 0.6606661453218562
Spearman val: 0.641316289893166
Spearman val: 0.6382357852208708
Spearman val: 0.7542822985993629
Spearman val: 0.6398999401057925
Spearman val: 0.7134431185778108
Spearman val: 0.684934455697177
Spearman val: 0.6900011844420536
Spearman val: 0.731132528782788
Spearman val: 0.6644320669606
Spearman val: 0.6970879967912958
Spearman val: 0.7799279239033368
Spearman val: 0.6367418126457061
Spearman val: 0.7190344032314469
Spearman val: 0.702015448634204


In [11]:
results_df = pd.concat([pd.DataFrame(experiment_results_extrapolate), pd.DataFrame(experiment_results_random)])

In [15]:
results_df.to_csv(results_dir / "scores.csv")

In [14]:
results_df

Unnamed: 0,n_val_samples,spearman_train,spearman_val,eval
0,278,0.748775,0.092019,extrapolate
1,233,0.746026,0.148118,extrapolate
2,196,0.737634,0.39648,extrapolate
3,247,0.72873,0.4967,extrapolate
4,209,0.741309,0.121375,extrapolate
5,255,0.72169,-0.053535,extrapolate
6,250,0.730354,0.196031,extrapolate
7,148,0.73428,0.338443,extrapolate
8,153,0.733104,0.181231,extrapolate
9,290,0.705252,0.082375,extrapolate
