In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

In [None]:
data_dir='data'


In [None]:
from sklearn.preprocessing import scale, minmax_scale
from sklearn.metrics import root_mean_squared_error, ndcg_score
def calc_test(true_scores, pred_scores, k=10):
    rho, _ = stats.spearmanr(true_scores, pred_scores)

    # RMSE
    rmse = root_mean_squared_error(true_scores, pred_scores)

    # NDCG@k
    std_tgts = minmax_scale([true_scores], (0, 5), axis=1)
    ndcg_val = ndcg_score(std_tgts,[pred_scores], k=k)

    result ={
        'spearman': rho,
        # 'rmse': rmse,
        'ndcg': ndcg_val
    }
    return result

In [None]:
import pandas as pd
from pygmo import hypervolume

def greedy_hypervolume_subset(points, n, ref_point):
    selected = []
    remaining = list(range(len(points)))

    for _ in range(n):
        max_hv = -float('inf')
        best_idx = None

        for idx in remaining:
            # 現在の選択 + 候補点のHypervolume計算
            current_points = points[selected + [idx]]
            hv = hypervolume(current_points)
            current_hv = hv.compute(ref_point)

            if current_hv > max_hv:
                max_hv = current_hv
                best_idx = idx

        if best_idx is not None:
            selected.append(best_idx)
            remaining.remove(best_idx)

    return selected, max_hv

def normalize_score(score):
    return (score-score.quantile(0.05))/(score.quantile(0.95)-score.quantile(0.05)+1e-10)


In [None]:
import numpy as np
from typing import Union, Tuple

def calculate_mean_similarity(latent_matrix: np.ndarray):
    
    # 入力チェック
    if not isinstance(latent_matrix, np.ndarray):
        raise TypeError("latent_matrix must be numpy.ndarray")
    
    if len(latent_matrix.shape) != 2:
        raise ValueError("latent_matrix must be 2-dimensional")
        
    N, H = latent_matrix.shape
    
    if N < 2:
        raise ValueError("Number of samples must be greater than 1")
    
    # 各ベクトルのノルムを計算
    norms = np.linalg.norm(latent_matrix, axis=1, keepdims=True)
    # ゼロ除算を防ぐ
    norms = np.where(norms == 0, 1e-8, norms)
    
    # 正規化された行列を計算
    normalized_matrix = latent_matrix / norms
    
    # コサイン類似度行列を計算
    similarity_matrix = np.dot(normalized_matrix, normalized_matrix.T)
    # 対角要素を0にする（自己との類似度は除外）
    np.fill_diagonal(similarity_matrix, 0)
    
    # 平均コサイン類似度を計算
    mean_similarity = similarity_matrix.sum() / (N * (N-1))
        
    return mean_similarity


In [None]:
import logomaker
def draw_logo(seqs, ax=None):
    if isinstance(seqs, str):
        seqs = [seqs]
    counts_matrix = logomaker.alignment_to_matrix(seqs)
    
    logo = logomaker.Logo(counts_matrix,
            shade_below=.5,
            fade_below=.5,
            color_scheme='NajafabadiEtAl2017',
            ax=ax
        )
    logo.ax.spines['right'].set_visible(False)
    logo.ax.spines['top'].set_visible(False)
    logo.ax.spines['bottom'].set_visible(False)
    logo.ax.spines['left'].set_visible(False)
    # logo.ax.set_xticks(np.arange(length))
    logo.ax.set_yticks([])


In [None]:
h3_dict = {
    "4D5_HER2_fitness_1N8Z":   "SRWGGDGFYAMDY",
    "5A12_Ang2_fitness_4ZFG":  "ARFVFFLPYAMDY",
    "5A12_VEGF_fitness_4ZFF":  "ARFVFFLPYAMDY",
}

In [None]:
def get_ddg(path):
    results_df = pd.read_csv(path)
    
    ddg_scores = (results_df[results_df["scored_state"]=="ddG"]
                     .groupby("case_name")["total_score"]
                     .min()
                     .sort_index())
    return ddg_scores

In [None]:
exps = ["greedy", "greedy_multi", "greedy_0.5", "greedy_multi_A", "greedy_multi_B"]

In [None]:
font_size=15
targets = ["4D5_HER2_fitness_1N8Z","5A12_Ang2_fitness_4ZFG", "5A12_VEGF_fitness_4ZFF"]

In [None]:
flex_ddg_dfs={}
sampled_seq_dfs = {}
flex_ddg_df_alls = {}
for target in targets:
    for mode in ["bias", "unbias"]:
        flex_ddg_df = pd.read_csv(f"flexddgs/{target}/{mode}/outputs-results.csv")
        flex_ddg_df = flex_ddg_df[flex_ddg_df["scored_state"]=="ddG"].groupby("case_name")["total_score"].min().sort_index()
        flex_ddg_dfs[target+"_"+mode]=flex_ddg_df
        sampled_seq_dfs[target+"_"+mode]=pd.read_csv(f"flexddgs/{target}/{mode}/sampled_mutations.csv", index_col=0)
        
test_dfs = {target: pd.read_csv(f"flexddgs/{target}/bias/sampled_mutations.csv") for target in targets}
for target in test_dfs:
    test_dfs[target]["DMS_score"] = - flex_ddg_dfs[target+"_bias"].values

In [None]:
import yaml

In [None]:
score_cols = ['flxddg_std', 'ablang2_perplexity_std']
ref_points = [2,2]

In [None]:
jobdf = pd.read_csv("jobs/job_multi.csv")
import os
cycles = {}
dfs = []
# dfs2 = []
configs = []
for confpath in jobdf["CONFIG"]:
    with open(confpath) as f:
        data = yaml.safe_load(f)
    target=data["data_dir"].split("/")[1]
    model_type = data["data_dir"].split("/")[2]
    exp=data["data_dir"].split("/")[3]
    df = pd.read_csv(os.path.join(data_dir, "..", data["data_dir"], "9", "train_data", "training_data.csv"))
    df["target"]=target
    df["model_type"]=model_type
    df["mutations"] = df["mutations"].fillna("")
    df["exp"]=exp
    df["flxddg"] = -df["DMS_score"]

    # hv
    df["flxddg_std"] = normalize_score(df["flxddg"])
    df["ablang2_perplexity_std"] = normalize_score(df["ablang2_perplexity"])
    df["IP_seq_std"] = normalize_score(-df["IP_seq"])
    for score_col, ref_point in zip(score_cols, ref_points):
        df[score_col]*=ref_point
    df["sum_score"] =df[score_cols].sum(axis=1)
    
    df["#Mutation"]=df["mutations"].apply(lambda x: len(x.split(",")) if x !="" else 0)
    dfs.append(df)
    configs.append({
        "target":target,
        "MAXCYCLE":10,
        "model_type": model_type,
        "exp":exp,
        "data_dir": data["data_dir"]
        
    })
    ddgs_list = []
    for cycle in range(10):
        ddgs = get_ddg(os.path.join(data_dir, "..", data["data_dir"], str(cycle), "flex_ddG", "outputs-results.csv"))
        ddgs = ddgs.reset_index()
        ddgs["cycle"]=cycle
        ddgs_list.append(ddgs)


In [None]:
from fast_pareto import is_pareto_front, nondominated_rank


In [None]:
import yaml
from tqdm import tqdm

N=40

def get_filtered_df(df, cycle, filter_type, N=40, score_cols=None, ref_points=None):
    cycle_df = df[df["cycle"]<=cycle]
    
    if filter_type == "all":
        return cycle_df
    elif filter_type == "cycle":
        return cycle_df[cycle_df["cycle"]==cycle]
    elif filter_type == "top":
        return cycle_df.sort_values("DMS_score", ascending=False).head(N)
    elif filter_type == "filter":
        cycle_df = cycle_df[cycle_df["ablang2_perplexity"]<10].copy()
        # cycle_df = cycle_df[cycle_df["IP_seq"]>6]
        return cycle_df.sort_values("DMS_score", ascending=False).head(N)
    elif filter_type == "hv":
        df_c = df.copy()
        for i in range(len(score_cols)):
            df_c = df_c[df_c[score_cols[i]] <= ref_points[i]]
        cycle_df = df_c[df_c["cycle"]<=cycle]
        selected_indices, _ = greedy_hypervolume_subset(cycle_df[score_cols].values, N, ref_points)
        return cycle_df.iloc[selected_indices]
    elif filter_type == "sum":
        return cycle_df.sort_values("sum_score", ascending=True).head(N)
    elif filter_type == "round":
        cycle_df = cycle_df.copy()
        for i in range(len(score_cols)):
            cycle_df[score_cols[i]+"_round"] = cycle_df[score_cols[i]].round(1)
        round_cols = [score_cols[i]+"_round" for i in range(len(score_cols))]
        return cycle_df.sort_values(round_cols, ascending=True).head(N)
    elif filter_type == "non_dominated":
        ranks = nondominated_rank(cycle_df[score_cols].values)
        return cycle_df.iloc[np.argsort(ranks)][:N]
    elif filter_type == "cycle":
        return cycle_df[cycle_df["cycle"]==cycle]
    else:
        raise ValueError(f"Unknown filter type: {filter_type}")

def process_df_by_type(df, CYCLE, filter_type, **kwargs):
    dfs = {cycle+1: get_filtered_df(df, cycle, filter_type, **kwargs) 
           for cycle in range(CYCLE)}
    df_merge = pd.concat(dfs)
    df_merge.index.names = ["CYCLE", "index"]
    return df_merge.reset_index()

# Initialize containers for different filtering methods
filter_types = ["all", "top", "hv", "filter", "sum", "round", "non_dominated", "cycle"]
df_merges = {filter_type: [] for filter_type in filter_types}
df_merge_cats = {}

# Process each dataset
for i in range(len(dfs)):
    target = configs[i]["target"]
    exp = configs[i]["exp"]
    df = dfs[i]
    CYCLE = configs[i]["MAXCYCLE"]
    
    # Process with each filter type
    for filter_type in filter_types:
        kwargs = {"N": N}
        if filter_type in ["hv", "round", "non_dominated"]:
            kwargs.update({"score_cols": score_cols, "ref_points": ref_points})
            
        df_merge = process_df_by_type(df, CYCLE, filter_type, **kwargs)
        df_merges[filter_type].append(df_merge)

# Concatenate results
for filter_type in filter_types:
    df_merge_cats[filter_type] = pd.concat(df_merges[filter_type])

# For backward compatibility
all_df_merges = df_merges["all"]
top_df_merges = df_merges["top"] 
hv_df_merges = df_merges["hv"]
filter_df_merges = df_merges["filter"]
sum_df_merges = df_merges["sum"]
round_df_merges = df_merges["round"]
non_dominated_df_merges = df_merges["non_dominated"]
cycle_df_merges = df_merges["cycle"]
all_df_merge_cat = df_merge_cats["all"]
top_df_merge_cat = df_merge_cats["top"]
hv_df_merge_cat = df_merge_cats["hv"]
filter_df_merge_cat = df_merge_cats["filter"]
sum_df_merge_cat = df_merge_cats["sum"]
round_df_merge_cat = df_merge_cats["round"]
non_dominated_df_merge_cat = df_merge_cats["non_dominated"]
cycle_df_merge_cat = df_merge_cats["cycle"]
X = round_df_merge_cat

def calculate_diversity_metrics(df, emb, conf):
    """Calculate diversity metrics, mutation means and medians for a dataframe"""
    # Calculate diversity for each cycle
    divs = {cycle: 1-calculate_mean_similarity(emb[df[(df["CYCLE"]==cycle)]["index"].values]) 
            for cycle in range(1,11)}
    
    # Calculate mutation statistics
    mean_muts = pd.Series({cycle: df[df["CYCLE"]==cycle]["#Mutation"].mean() 
                          for cycle in range(1,11)})
    med_muts = pd.Series({cycle: df[df["CYCLE"]==cycle]["#Mutation"].median() 
                         for cycle in range(1,11)})
    
    # Create output dataframe
    div_df = pd.Series(divs)
    div_df.index.name = "CYCLE"
    div_df.name = "Diversity"
    div_df = div_df.reset_index()
    
    # Add additional columns
    div_df["mean_mutation_num"] = mean_muts.values
    div_df["median_mutation_num"] = med_muts.values
    div_df["target"] = conf["target"]
    div_df["model_type"] = conf["model_type"]
    div_df["exp"] = conf["exp"]
    
    return div_df

# Initialize containers for diversity metrics
all_divs, top_divs, hv_divs, filter_divs, sum_divs, round_divs, non_dominated_divs, cycle_divs = [], [], [], [], [], [], [], []

# Process each dataset
for i in range(len(dfs)):
    conf = configs[i]
    
    # Load embedding data
    input_dir = os.path.join(data_dir, conf["target"], conf["model_type"], 
                            conf["exp"], "9", "train_data")
    emb = np.load(os.path.join(input_dir, "embedding.npy"))
    
    # Calculate diversity metrics for each filtering method
    all_divs.append(calculate_diversity_metrics(all_df_merges[i], emb, conf))
    top_divs.append(calculate_diversity_metrics(top_df_merges[i], emb, conf))
    hv_divs.append(calculate_diversity_metrics(hv_df_merges[i], emb, conf))
    filter_divs.append(calculate_diversity_metrics(filter_df_merges[i], emb, conf))
    sum_divs.append(calculate_diversity_metrics(sum_df_merges[i], emb, conf))
    round_divs.append(calculate_diversity_metrics(round_df_merges[i], emb, conf))
    non_dominated_divs.append(calculate_diversity_metrics(non_dominated_df_merges[i], emb, conf))
    cycle_divs.append(calculate_diversity_metrics(cycle_df_merges[i], emb, conf))

# Concatenate results
all_divs_cat = pd.concat(all_divs, ignore_index=True)
top_divs_cat = pd.concat(top_divs, ignore_index=True)
hv_divs_cat = pd.concat(hv_divs, ignore_index=True)
filter_divs_cat = pd.concat(filter_divs, ignore_index=True)
sum_divs_cat = pd.concat(sum_divs, ignore_index=True)
round_divs_cat = pd.concat(round_divs, ignore_index=True)
non_dominated_divs_cat = pd.concat(non_dominated_divs, ignore_index=True)
cycle_divs_cat = pd.concat(cycle_divs, ignore_index=True)


In [None]:
configdf = pd.DataFrame(configs)

In [None]:
all_test_scores=[]
for conf in configs:
    target = conf["target"]
    # if target not in targets:
    #     continue
    for cycle in range(10):
        input_dir = os.path.join(data_dir, conf["target"], conf["model_type"], conf["exp"], str(cycle), "train_data")
        test_pred = np.load(os.path.join(input_dir, "test_inference_bias.npy"))
        test_df = test_dfs[target]
        test_df_ = test_df.copy()
        test_df_["Pred"] = test_pred
        all_test_scores.append({
            **calc_test(test_df_["DMS_score"], test_df_["Pred"]),
            "CYCLE": cycle+1,
            "target": conf["target"],
            "model_type": conf["model_type"],
            "exp": conf["exp"],
        })
all_test_scores_cat = pd.DataFrame(all_test_scores)
all_test_scores_cat["spearman"] = all_test_scores_cat["spearman"].fillna(0)

In [None]:
sum_df_merge_cat.to_csv("results/flexddg_offline/multi/sum_results.csv",index=False)
all_test_scores_cat.to_csv("results/flexddg_offline/multi/all_results_test.csv",index=False)
all_df_merge_cat.to_csv("results/flexddg_offline/multi/all_results.csv",index=False)