In [None]:
import pandas as pd
import yaml
import os
from datetime import datetime
import seaborn as sns

In [None]:
TASKNUM=3

In [None]:
dms_indices = list(range(TASKNUM))

In [None]:
train_df = pd.read_csv("../bindinggym_offline/input/BindingGYM_AL.csv")
idx2name = {i: train_df.loc[i, 'DMS_id'] for i in range(TASKNUM)}

In [None]:
esm_dfs = [pd.read_csv(f"../bindinggym_offline/modelzoo/esm2/output/{idx2name[i]}.csv") for i in range(TASKNUM)]
proteinmpnn_dfs = [pd.read_csv(f"../bindinggym_offline/modelzoo/proteinmpnn//output/{idx2name[i]}.csv") for i in range(TASKNUM)]


In [None]:
ablang_dfs = [pd.read_csv(f"outputs/greedy_0.0/dms_ablang2_{i}_N-50_ini-1/predictions_cycle_0.csv" ) for i in range(TASKNUM)]

In [None]:
import numpy as np
from typing import Union, Tuple

def calculate_mean_similarity(latent_matrix: np.ndarray):
    
    # 入力チェック
    if not isinstance(latent_matrix, np.ndarray):
        raise TypeError("latent_matrix must be numpy.ndarray")
    
    if len(latent_matrix.shape) != 2:
        raise ValueError("latent_matrix must be 2-dimensional")
        
    N, H = latent_matrix.shape
    
    if N < 2:
        raise ValueError("Number of samples must be greater than 1")
    
    # 各ベクトルのノルムを計算
    norms = np.linalg.norm(latent_matrix, axis=1, keepdims=True)
    # ゼロ除算を防ぐ
    norms = np.where(norms == 0, 1e-8, norms)
    
    # 正規化された行列を計算
    normalized_matrix = latent_matrix / norms
    
    # コサイン類似度行列を計算
    similarity_matrix = np.dot(normalized_matrix, normalized_matrix.T)
    # 対角要素を0にする（自己との類似度は除外）
    np.fill_diagonal(similarity_matrix, 0)
    
    # 平均コサイン類似度を計算
    mean_similarity = similarity_matrix.sum() / (N * (N-1))
        
    return mean_similarity


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import spearmanr
from sklearn.metrics import root_mean_squared_error, ndcg_score

from sklearn.preprocessing import scale, minmax_scale

def calc_test(true_scores, pred_scores, k=10):
    rho, _ = spearmanr(true_scores, pred_scores)

    # RMSE
    rmse = root_mean_squared_error(true_scores, pred_scores)

    # NDCG@k
    std_tgts = minmax_scale([true_scores], (0, 5), axis=1)
    ndcg_val = ndcg_score(std_tgts,[pred_scores], k=k)

    result ={
        'spearman': rho,
        'rmse': rmse,
        'ndcg': ndcg_val
    }
    return result

def compute_test_performance(df, cycle_list, k=10):
    """
    cycle_list: [1, 2, 3, ...] のように評価対象となるサイクル番号をリスト化
    """
    results = []

    df_test = df[df['is_test'] == True].copy()

    for c in cycle_list:
        pred_col = f'cycle_{c}_preds'
        if pred_col not in df_test.columns:
            continue  # そのサイクルの予測列がなければスキップ
        result = calc_test(df_test['DMS_score'].values, df_test[pred_col].values, k=k)
        result["cycle"] = c
        results.append(result)
    
    return pd.DataFrame(results)

def calc_recall_precision(d, df, p=None, k=None, topk=10):
    # p is the percentage of the top 1%
    # k is the kth top
    df_pool = df[~df["is_test"]].copy()
    if p is not None:
        df_pool["top"] = df_pool["DMS_score"]>df_pool["DMS_score"].quantile(1-p)
    elif k is not None:
        df_pool["top"] = df_pool["DMS_score"]>df_pool["DMS_score"].sort_values(ascending=False).iloc[k-1]
    else:
        raise ValueError("Either p or k must be provided")
    df_pool["top_10"] = df_pool["DMS_score"]>df_pool["DMS_score"].quantile(0.9)
    results = []
    df_selected = df_pool[df_pool["selected_cycle"]>=0].copy()
    for cycle in range(d["active_learning"]["M_cycles"]+1):
        df_cycle = df_selected[df_selected["selected_cycle"]<=cycle].copy()
        top1_recall = df_cycle["top"].sum()/df_pool["top"].sum()
        top1_precision = df_cycle["top"].sum()/len(df_cycle["top"])
        top10_recall = df_cycle["top_10"].sum()/df_pool["top_10"].sum()
        top10_precision = df_cycle["top_10"].sum()/len(df_cycle["top_10"])
        top_mean_1 = df_cycle.sort_values(f"DMS_score", ascending=False).head(topk)["DMS_score"].mean()
        top_mean_2 = df_cycle.sort_values(f"DMS_score", ascending=False).head(topk*2)["DMS_score"].mean()
        top_mean_3 = df_cycle.sort_values(f"DMS_score", ascending=False).head(topk*3)["DMS_score"].mean()
        results.append({"cycle": cycle+1, "recall": top1_recall, "precision": top1_precision, "recall_10": top10_recall, "precision_10": top10_precision, "top_mean_1": top_mean_1,"top_mean_2": top_mean_2,"top_mean_3": top_mean_3})
    return pd.DataFrame(results)


def calc_recall_precision_greedy(d, df, p=None, k=None, N=600, topk=10):
    df_pool = df[~df["is_test"]].copy()
    if p is not None:
        df_pool["top"] = df_pool["DMS_score"]>df_pool["DMS_score"].quantile(1-p)
    elif k is not None:
        df_pool["top"] = df_pool["DMS_score"]>df_pool["DMS_score"].sort_values(ascending=False).iloc[k-1]
    else:
        raise ValueError("Either p or k must be provided")
    df_pool["top_10"] = df_pool["DMS_score"]>df_pool["DMS_score"].quantile(0.9)
    results = []
    df_selected_all = df_pool[df_pool["selected_cycle"]>=0].copy()
    for cycle in range(d["active_learning"]["M_cycles"]):
        df_cycle = df_selected_all[df_selected_all["selected_cycle"]<=cycle].copy()
        df_not_selected = df_pool[(df_pool["selected_cycle"]<0) | (df_pool["selected_cycle"]>cycle)].copy()
        M = N - len(df_cycle)
        df_cycle_greedy = df_not_selected.sort_values(f"cycle_{cycle+1}_preds", ascending=False).head(M).copy()
        df_cycle_greedy = pd.concat([df_cycle, df_cycle_greedy])
        top1_recall = df_cycle_greedy["top"].sum()/df_pool["top"].sum()
        top1_precision = df_cycle_greedy["top"].sum()/len(df_cycle_greedy["top"])
        top10_recall = df_cycle["top_10"].sum()/df_pool["top_10"].sum()
        top10_precision = df_cycle["top_10"].sum()/len(df_cycle["top_10"])
        
        top_mean_1 = df_cycle_greedy.sort_values("DMS_score", ascending=False).head(topk)["DMS_score"].mean()
        top_mean_2 = df_cycle_greedy.sort_values("DMS_score", ascending=False).head(topk*2)["DMS_score"].mean()
        top_mean_3 = df_cycle_greedy.sort_values("DMS_score", ascending=False).head(topk*3)["DMS_score"].mean()
        results.append({"cycle": cycle+1, "recall": top1_recall, "precision": top1_precision, "recall_10": top10_recall, "precision_10": top10_precision, "top_mean_1": top_mean_1,"top_mean_2": top_mean_2,"top_mean_3": top_mean_3})
    return pd.DataFrame(results)


In [None]:
al_job_df = pd.read_csv("../bindinggym_offline/jobs.csv")
dfs={}
config_ds=[]
for config_path in al_job_df["config"]:
    wdir=os.path.dirname(config_path)
    with open(config_path) as f:
        config_d = yaml.safe_load(f)
    try:
        if int(config_d["dms_index"]) not in dms_indices:
            continue
        config_ds.append(config_d)
        M=config_d["active_learning"]["M_cycles"]
        config_d["ablang"]="bo4_ablang" in config_path
        dfs[wdir] = pd.read_csv(os.path.join(wdir,f"predictions_cycle_{M}.csv"))
    except:
        print(wdir,len(os.listdir(wdir))-3)

In [None]:
config_df = pd.DataFrame(config_ds)
alconf = pd.DataFrame(config_df["active_learning"].values.tolist())
config_df[alconf.columns]=alconf


In [None]:
tmps = [config_df[config_df["dms_index"]==i]["tmp_path"].values[0] for i in range(TASKNUM)]

In [None]:
esm_results = []
proteinmpnn_results = []
ablang2_results = []
for i in range(TASKNUM):
    tmp_path=tmps[i]
    test_indices = dfs[tmp_path][dfs[tmp_path]["is_test"]].index
    esm_test = esm_dfs[i].sort_values("esm2_t33_650M_UR50D")["DMS_score"].loc[test_indices]
    proteinmpnn_test = proteinmpnn_dfs[i].sort_values("design_score")["DMS_score"].loc[test_indices]

    proteinmpnn_result = calc_test(proteinmpnn_test, proteinmpnn_dfs[i]["design_score"].loc[test_indices])
    esm_result = calc_test(esm_test, esm_dfs[i]["esm2_t33_650M_UR50D"].loc[test_indices])
    ablang_result = calc_test(esm_test, ablang_dfs[i]["cycle_0_preds"].loc[test_indices])
    esm_result["dms_index"] = i
    proteinmpnn_result["dms_index"] = i
    ablang_result["dms_index"] = i
    esm_results.append(esm_result)
    proteinmpnn_results.append(proteinmpnn_result)
    ablang2_results.append(ablang_result)

esm_results = pd.DataFrame(esm_results)
esm_results_mean = esm_results.mean(axis=0)
proteinmpnn_results = pd.DataFrame(proteinmpnn_results)
proteinmpnn_results_mean = proteinmpnn_results.mean(axis=0)

ablang2_results = pd.DataFrame(ablang2_results)
ablang2_results_mean = ablang2_results.mean(axis=0)

In [None]:
p=0.01
pool_esm_results = []
pool_proteinmpnn_results = []
pool_ablang2_results = []
for i in range(TASKNUM):
    df = dfs[tmps[i]]
    df_pool = df[~df["is_test"]].copy()
    df_pool["top"] = df_pool["DMS_score"]>df_pool["DMS_score"].quantile(1-p)
    pool_indices = df[~df["is_test"]].index
    esm_pool = esm_dfs[i].loc[pool_indices].copy()
    
    esm_pool["top"] = esm_pool["DMS_score"]>esm_pool["DMS_score"].quantile(1-p)
    esm_pool = esm_pool.sort_values("esm2_t33_650M_UR50D", ascending=False)
    
    proteinmpnn_pool = proteinmpnn_dfs[i].loc[pool_indices].copy()
    proteinmpnn_pool["top"] = proteinmpnn_pool["DMS_score"]>proteinmpnn_pool["DMS_score"].quantile(1-p)
    proteinmpnn_pool = proteinmpnn_pool.sort_values("design_score", ascending=False).reset_index()

    ablang2_pool = ablang_dfs[i].loc[pool_indices].copy()
    ablang2_pool["top"] = ablang2_pool["DMS_score"]>ablang2_pool["DMS_score"].quantile(1-p)
    ablang2_pool = ablang2_pool.sort_values("cycle_0_preds", ascending=False).reset_index()
    for n in range(1,601):
        esm_top = esm_pool.head(n)
        esm_top_recall = esm_top["top"].sum()/df_pool["top"].sum()
        esm_top_precision = esm_top["top"].sum()/len(esm_top["top"])
        esm_top_mean = esm_top.sort_values("DMS_score", ascending=False).head(10)["DMS_score"].mean()
        pool_esm_results.append({"dms_index": i, "cycle": n, "recall": esm_top_recall, "precision": esm_top_precision, "top_mean": esm_top_mean})

        proteinmpnn_top = proteinmpnn_pool.head(n)
        proteinmpnn_top_recall = proteinmpnn_top["top"].sum()/df_pool["top"].sum()
        proteinmpnn_top_precision = proteinmpnn_top["top"].sum()/len(proteinmpnn_top["top"])
        proteinmpnn_top_mean = proteinmpnn_top.sort_values("DMS_score", ascending=False).head(10)["DMS_score"].mean()
        pool_proteinmpnn_results.append({"dms_index": i, "cycle": n, "recall": proteinmpnn_top_recall, "precision": proteinmpnn_top_precision, "top_mean": proteinmpnn_top_mean})

        ablang2_top = ablang2_pool.head(n)
        ablang2_top_recall = ablang2_top["top"].sum()/df_pool["top"].sum()
        ablang2_top_precision = ablang2_top["top"].sum()/len(ablang2_top["top"])
        ablang2_top_mean = ablang2_top.sort_values("DMS_score", ascending=False).head(10)["DMS_score"].mean()
        pool_ablang2_results.append({"dms_index": i, "cycle": n, "recall": ablang2_top_recall, "precision": ablang2_top_precision, "top_mean": ablang2_top_mean})

pool_esm_results = pd.DataFrame(pool_esm_results)
pool_proteinmpnn_results = pd.DataFrame(pool_proteinmpnn_results)
pool_ablang2_results = pd.DataFrame(pool_ablang2_results)

pool_esm_results_mean = pool_esm_results.groupby("cycle").mean()
pool_proteinmpnn_results_mean = pool_proteinmpnn_results.groupby("cycle").mean()
pool_ablang2_results_mean = pool_ablang2_results.groupby("cycle").mean()

In [None]:
test_metrics_dfs=[]
for i, d in config_df.iterrows():
    try:
        df = dfs[d["tmp_path"]]
        if d["dms_index"] not in dms_indices:
            continue
        cycle_list=list(range(d["active_learning"]["M_cycles"]+1))
        test_metrics_df = compute_test_performance(df, cycle_list)
        test_metrics_df["N_init"]=d["active_learning"]["N_init"]
        test_metrics_df["N_per_cycle"]=d["active_learning"]["N_per_cycle"]
        test_metrics_df["strategy"]=d["active_learning"]["strategy"]
        test_metrics_df["dms_index"]=d["dms_index"]
        test_metrics_df["model_type"]=d["model_type"]
        test_metrics_df["tmp_path"]=d["tmp_path"]
        test_metrics_df["noise_level"]=d["noise_level"]
        test_metrics_df["ablang"]=d["ablang"]
        test_metrics_df["use_dropout"]=d["use_dropout"]
        test_metrics_df["logit_mode"] = "bo4_logits" in d["tmp_path"]
        test_metrics_dfs.append(test_metrics_df)
    except Exception as e:
        print(i,e)
        continue

test_metrics_df_merge = pd.concat(test_metrics_dfs)
test_metrics_df_merge["Training size"] =  test_metrics_df_merge.apply(lambda x: (x["cycle"]-1)*x["N_per_cycle"]+x["N_init"],axis=1)
pool_metrics_dfs=[]
for i, d in config_df.iterrows():
    try:
        df = dfs[d["tmp_path"]]
        if d["dms_index"] not in dms_indices:
            continue
        cycle_list=list(range(d["active_learning"]["M_cycles"]+1))
        pool_metrics_df = calc_recall_precision(d, df, p=0.01)
        pool_metrics_df_greedy = calc_recall_precision_greedy(d, df, p=0.01)
        pool_metrics_dfx = pd.merge(pool_metrics_df, pool_metrics_df_greedy, on=["cycle"], how="left",suffixes=["","_g"])
        pool_metrics_dfx["N_init"]=d["active_learning"]["N_init"]
        pool_metrics_dfx["N_per_cycle"]=d["active_learning"]["N_per_cycle"]
        pool_metrics_dfx["strategy"]=d["active_learning"]["strategy"]
        pool_metrics_dfx["dms_index"]=d["dms_index"]
        pool_metrics_dfx["model_type"]=d["model_type"]
        pool_metrics_dfx["tmp_path"]=d["tmp_path"]
        pool_metrics_dfx["noise_level"]=d["noise_level"]
        pool_metrics_dfx["ablang"]=d["ablang"]
        pool_metrics_dfx["use_dropout"]=d["use_dropout"]
        pool_metrics_dfx["logit_mode"] = "bo4_logits" in d["tmp_path"]
    
        pool_metrics_dfs.append(pool_metrics_dfx)
    except:
        print(d["tmp_path"])
    

pool_metrics_df_merge = pd.concat(pool_metrics_dfs)
pool_metrics_df_merge["Training size"] =  pool_metrics_df_merge.apply(lambda x: (x["cycle"]-1)*x["N_per_cycle"]+x["N_init"],axis=1)



In [None]:
baseline_scores_metric_list = {}
for metric in ["ndcg", "spearman", "top_mean", "recall", "precision"]:
    baseline_scores_list = []
    for dms_index in range(TASKNUM):
        baseline_scores = {}
        if metric in ["ndcg", "spearman"]:
            baseline_scores["sequence"] = esm_results[esm_results["dms_index"]==dms_index][metric].values[0]
            baseline_scores["proteinmpnn"] = proteinmpnn_results[proteinmpnn_results["dms_index"]==dms_index][metric].values[0]
            baseline_scores["ablang2"] = ablang2_results[ablang2_results["dms_index"]==dms_index][metric].values[0]
        elif metric in ["top_mean", "recall", "precision"]:
            baseline_scores["sequence"] = pool_esm_results[pool_esm_results["dms_index"]==dms_index].set_index("cycle")[metric].tolist()
            baseline_scores["proteinmpnn"] = pool_proteinmpnn_results[pool_proteinmpnn_results["dms_index"]==dms_index].set_index("cycle")[metric].tolist()
            baseline_scores["ablang2"] = pool_ablang2_results[pool_ablang2_results["dms_index"]==dms_index].set_index("cycle")[metric].tolist()
        else:
            raise ValueError(f"Invalid metric: {metric}")
        baseline_scores_list.append(baseline_scores)
    baseline_scores_metric_list[metric] = baseline_scores_list
baseline_scores_metric_list["top_mean_1"]=baseline_scores_metric_list["top_mean"]
with open("../results/bindinggym_offline/baseline_scores_metric_list.json", 'w') as f:
    json.dump(baseline_scores_metric_list, f, indent=2)

In [None]:
pool_metrics_df_merge.to_csv("../results/bindinggym_offline/pool_metrics_df_merge.csv", index=False)
test_metrics_df_merge.to_csv("../results/bindinggym_offline/test_metrics_df_merge.csv", index=False)

# 