#### Here, we are going to investigate the cached_escape function in cached_semantics.py to reproduce the AUC result

In [1]:
# 0. look up auc function in sklearn
from sklearn.metrics import auc
import numpy as np
# Example data (FPR, TPR)
fpr = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 1.0])
tpr = np.array([0.0, 0.4, 0.6, 0.7, 0.8, 1.0])

# Calculate AUC
roc_auc = auc(fpr, tpr)
print(f"ROC AUC: {roc_auc}")
# I finded out this function simply gets the area under the curve of the ROC curve described by the fpr and tpr arrays

ROC AUC: 0.75


In [2]:
import pandas as pd
import scipy.stats as ss
cov_cscs_path = "../results/cov/semantics/analyze_semantics_cov_bilstm_512.txt"
# ignore U B J X Z
rbd_cscs_path = "../results/cov/semantics/analyze_semantics_cov2rbd_bilstm_512.txt"
cov_df = pd.read_csv(cov_cscs_path, sep="\t")
# cov_df = cov_df[(cov_df["wt"] != "U") & (cov_df["wt"] != "B") & 
#               (cov_df["wt"] != "J") & (cov_df["wt"] != "X") & (cov_df["wt"] != "Z")]

df = cov_df[cov_df['wt'] == 'Z']
df
# print(cov_df)
# len(cov_df)

Unnamed: 0,pos,wt,mut,prob,change,is_viable,is_escape


In [18]:
# 1. before start test the main function in cached_semantics.py first
def get_aucs(df: pd.DataFrame):
    """What we need is sars-cov-2 related results, so we ignore codes not related 
    to sars-cov-2.
    wt_seq & seqs_escape are used to calculate AUC.
    """
    """get prob & orig_prob from table
    escape_rank_dist, escape_idx
    acquisition = CSCS score (rank base)
     = apply rank to the array(grammar and semantic change) and sum

    target values: norm_auc, norm_auc_prob, norm_auc_change

    for RBD, ignore pos < 330 or pos > 530
    """
    grammaticalities = df['prob']
    semantic_changes = df['change']

    gram_ranks = grammaticalities.rank(ascending=False)
    sem_ranks = semantic_changes.rank(ascending=False)

    escape_indices = df[df['is_escape'] == True].index
    # acquisition = cscs_scores
    cscs_scores = gram_ranks + sem_ranks
    len_probs = len(grammaticalities)
    num_to_consider = list(range(1, len(grammaticalities) + 1))
    # escape mutants rank of cscs score
    escape_ranks_cscs = cscs_scores.rank()[escape_indices]
    escape_ranks_gram = gram_ranks[escape_indices]
    escape_ranks_sem = sem_ranks[escape_indices]
    n_escape_cscs = [sum(escape_ranks_cscs <= i + 1) for i in range(len_probs)]
    n_escape_gram = [sum(escape_ranks_gram <= i + 1) for i in range(len_probs)]
    n_escape_sem = [sum(escape_ranks_sem <= i + 1) for i in range(len_probs)]

    norm = (len_probs + 1) * len(escape_ranks_cscs)

    # auc of cscs
    norm_auc_cscs = auc(num_to_consider, n_escape_cscs) / norm
    # auc of grammaticalities
    norm_auc_gram = auc(num_to_consider, n_escape_gram) / norm
    # auc of semantic change 
    norm_auc_sem = auc(num_to_consider, n_escape_sem) / norm
    
    return (norm_auc_cscs, norm_auc_gram, norm_auc_sem)
result = get_aucs(cov_df)
result

(0.8659576887100414, 0.8390053866706172, 0.6598051358553816)

In [4]:
def cached_escape_original(cache_fname):
    from escape import load_baum2020
    wt_seq, seqs_escape = load_baum2020('../data/cov/cov2_spike_wt.fasta')
    prob, change, escape_idx, viable_idx = [], [], [], []
    with open(cache_fname) as f:
        f.readline()
        for line in f:
            fields = line.rstrip().split("\t")
            pos = int(fields[0])
            if "rbd" in cache_fname:
                if pos < 330 or pos > 530:
                    continue
            if fields[2] in {"U", "B", "J", "X", "Z"}:
                continue
            aa_wt = fields[1]
            aa_mut = fields[2]
            assert wt_seq[pos] == aa_wt
            mut_seq = wt_seq[:pos] + aa_mut + wt_seq[pos + 1 :]
            if mut_seq not in seqs_escape:
                continue
            prob.append(float(fields[3]))
            change.append(float(fields[4]))
            viable_idx.append(fields[5] == "True")
            escape_idx.append(
                (mut_seq in seqs_escape)
                and (sum([m["significant"] for m in seqs_escape[mut_seq]]) > 0)
            )

    prob, orig_prob = np.array(prob), np.array(prob)
    change, orig_change = np.array(change), np.array(change)
    escape_idx = np.array(escape_idx)
    viable_idx = np.array(viable_idx)

    beta = 1.0
    acquisition = ss.rankdata(change) + (beta * ss.rankdata(prob))

    pos_change_idx = change > 0

    pos_change_escape_idx = np.logical_and(pos_change_idx, escape_idx)
    escape_prob = prob[pos_change_escape_idx]
    escape_change = change[pos_change_escape_idx]
    prob = prob[pos_change_idx]
    change = change[pos_change_idx]

    log_prob, log_change = np.log10(prob), np.log10(change)
    log_escape_prob, log_escape_change = (
        np.log10(escape_prob),
        np.log10(escape_change),
    )

    acq_argsort = ss.rankdata(-acquisition)
    escape_rank_dist = acq_argsort[escape_idx]

    size = len(prob)
    print(
        "Number of escape seqs: {} / {}".format(len(escape_rank_dist), sum(escape_idx))
    )
    print("Mean rank: {} / {}".format(np.mean(escape_rank_dist), size))
    print("Median rank: {} / {}".format(np.median(escape_rank_dist), size))
    print("Min rank: {} / {}".format(np.min(escape_rank_dist), size))
    print("Max rank: {} / {}".format(np.max(escape_rank_dist), size))
    print("Rank stdev: {} / {}".format(np.std(escape_rank_dist), size))

    max_consider = len(prob)
    n_consider = np.array([i + 1 for i in range(max_consider)])

    n_escape = np.array([sum(escape_rank_dist <= i + 1) for i in range(max_consider)])
    norm = max(n_consider) * max(n_escape)
    norm_auc = auc(n_consider, n_escape) / norm

    escape_rank_prob = ss.rankdata(-orig_prob)[escape_idx]
    n_escape_prob = np.array(
        [sum(escape_rank_prob <= i + 1) for i in range(max_consider)]
    )
    norm_auc_prob = auc(n_consider, n_escape_prob) / norm

    escape_rank_change = ss.rankdata(-orig_change)[escape_idx]
    n_escape_change = np.array(
        [sum(escape_rank_change <= i + 1) for i in range(max_consider)]
    )
    norm_auc_change = auc(n_consider, n_escape_change) / norm

    def compute_p(true_val, n_interest, n_total, n_permutations=10000):
        null_distribution = []
        norm = n_interest * n_total
        for _ in range(n_permutations):
            interest = set(np.random.choice(n_total, size=n_interest, replace=False))
            n_acquired = 0
            acquired, total = [], []
            for i in range(n_total):
                if i in interest:
                    n_acquired += 1
                acquired.append(n_acquired)
                total.append(i + 1)
            null_distribution.append(auc(total, acquired) / norm)
        null_distribution = np.array(null_distribution)
        return sum(null_distribution >= true_val) / n_permutations

    norm_auc_p = compute_p(norm_auc, sum(escape_idx), len(escape_idx))

    print("AUC (CSCS): {}, P = {}".format(norm_auc, norm_auc_p))
    print("AUC (semantic change only): {}".format(norm_auc_change))
    print("AUC (grammaticality only): {}".format(norm_auc_prob))

    print(
        "{:.4g} (mean log prob), {:.4g} (mean log prob escape), "
        "{:.4g} (p-value)".format(
            log_prob.mean(),
            log_escape_prob.mean(),
            ss.mannwhitneyu(log_prob, log_escape_prob, alternative="two-sided")[1],
        )
    )
    print(
        "{:.4g} (mean log change), {:.4g} (mean log change escape), "
        "{:.4g} (p-value)".format(
            change.mean(),
            escape_change.mean(),
            ss.mannwhitneyu(change, escape_change, alternative="two-sided")[1],
        )
    )
result = cached_escape_original(cov_cscs_path)


Number of escape seqs: 19 / 19
Mean rank: 3757.657894736842 / 24187
Median rank: 3700.0 / 24187
Min rank: 122.0 / 24187
Max rank: 7815.0 / 24187
Rank stdev: 2574.3142267849016 / 24187
AUC (CSCS): 0.8446610075442876, P = 0.0
AUC (semantic change only): 0.6364478090666364
AUC (grammaticality only): 0.7966774234963105
-6.262 (mean log prob), -4.557 (mean log prob escape), 7.539e-06 (p-value)
3717 (mean log change), 3958 (mean log change escape), 0.03944 (p-value)
