In [None]:
import pandas as pd
import numpy as np

from tqdm.notebook import tqdm

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import colormaps

## mutation count annotations

In [None]:
# load annotated data
naive_path = "/home/jovyan/shared/AbLM_training_data/HDfluJOY_SAHD_naive_AbLM.csv"
mem_path = "/home/jovyan/shared/AbLM_training_data/HDfluJOY_SAHD_memory_AbLM.csv"
naive = pd.read_csv(naive_path, low_memory=False)
memory = pd.read_csv(mem_path, low_memory=False)
all_annot = pd.concat([naive, memory], ignore_index=True)

# only south african donors
all_annot = all_annot[all_annot["is_SouthAfrican"] == True] 

In [None]:
# filter for relevant columns
# [col for col in all_annot.columns]
col_list = ['name', 'sample', 'timepoint', 'experiment', 'donor', 'cell_type', 
            'v_gene:0', 'd_gene:0', 'j_gene:0', 'junction_aa:0', 'cdr3_length:0', 'fr1_aa:0', 'cdr1_aa:0', 'fr2_aa:0', 'cdr2_aa:0', 'fr3_aa:0', 'cdr3_aa:0', 'fr4_aa:0', 
            'v_identity:0', 'v_identity_aa:0', 'v_mutations:0', 'v_mutations_aa:0', 'v_insertions:0', 'v_deletions:0', 'isotype:0', 'locus:0', 'sequence:0', 'sequence_aa:0',
            'v_gene:1', 'd_gene:1', 'j_gene:1', 'junction_aa:1', 'cdr3_length:1', 'fr1_aa:1', 'cdr1_aa:1', 'fr2_aa:1', 'cdr2_aa:1', 'fr3_aa:1', 'cdr3_aa:1', 'fr4_aa:1', 
            'v_identity:1', 'v_identity_aa:1', 'v_mutations:1', 'v_mutations_aa:1', 'v_insertions:1', 'v_deletions:1', 'isotype:1', 'locus:1', 'sequence:1', 'sequence_aa:1']
all_annot = all_annot[col_list]
all_annot

In [None]:
# count mutations on each chain
h = all_annot.loc[:, ["name", "sequence_aa:0", "v_identity_aa:0", "v_mutations_aa:0", "locus:0"]]
h.columns = h.columns.str.replace(':0', '', regex=False)

l = all_annot.loc[:, ["name", "sequence_aa:1", "v_identity_aa:1", "v_mutations_aa:1", "locus:1"]]
l.columns = h.columns.str.replace(':0', '', regex=False)

all_annot = pd.concat([h, l], ignore_index=True)

counts = []
for row in all_annot["v_mutations_aa"]:
    if isinstance(row, str):
        counts.append(row.count(":"))
    else:
        counts.append(0)

all_annot["v_mutation_count_aa"] = pd.Series(counts, name="v_mutation_count_aa")
all_annot.rename(columns={"name":"sequence_id"}, inplace=True)
all_annot

## shuffled pairs dataset

In [None]:
path = "/home/jovyan/shared/mahdi/1_projects/model_optimization/08paired_classification/data/SA_donors_stuff/all_combined.csv"
df = pd.read_csv(path)

In [None]:
# separate ids for each chain in shuffled dataset 
df["h_id"] = df["name"].apply(lambda name: name.split("|")[0])
df["l_id"] = df["name"].apply(lambda name: name.split("|")[1])

## combine mutation counts and shuffled pairs dataset

In [None]:
# to add mutation counts
heavy_all = all_annot[all_annot["locus"] == "IGH"]
light_all = all_annot[all_annot["locus"] != "IGH"]

heavy = df.loc[:, ["name", "h_id", "h_sequence", "label"]].rename(columns={"h_id": "sequence_id", "h_sequence": "sequence_aa"})
heavy = heavy.merge(heavy_all, on=["sequence_id", "sequence_aa"])
heavy = heavy.drop(columns=["v_identity_aa", "v_mutations_aa", "locus"]).rename(columns={"sequence_id":"h_id", "sequence_aa":"h_sequence", "v_mutation_count_aa":"h_mutation_count"})

light = df.loc[:, ["name", "l_id", "l_sequence", "label"]].rename(columns={"l_id": "sequence_id", "l_sequence": "sequence_aa"})
light = light.merge(light_all, on=["sequence_id", "sequence_aa"])
light = light.drop(columns=["v_identity_aa", "v_mutations_aa", "locus"]).rename(columns={"sequence_id":"l_id", "sequence_aa":"l_sequence", "v_mutation_count_aa":"l_mutation_count"})

data_mut = heavy.merge(light, on=["name", "label"]).loc[:, ["name", "label", "h_sequence", "h_mutation_count", "l_sequence", "l_mutation_count"]]
data_mut

## process data by mutation counts

In [None]:
model_order = [
    "8M-Q", "35M-Q", "150M-Q", "350M-Q", "650M-Q", 
    "8M-H", "35M-H", "150M-H", "350M-H", "650M-H",
    "8M-F", "35M-F", "150M-F", "350M-F", "650M-F"
]

# list to load test data with predictions
pair_preds = [f"/home/jovyan/shared/mahdi/1_projects/model_optimization/08paired_classification/SA_donors_ANALYSIS/KN_analysis/results/all_predictions_itr{i}.csv" for i in range(5)]
pair_preds

In [None]:
# for each iteration, find the predictive accuracies separated by types of chain pairings
pairs_types_accs = []
for itr in range(len(pair_preds)):

    # load data and merge with mutation counts
    all_annot = pd.read_csv(pair_preds[itr])
    all_annot = all_annot.merge(data_mut, on=["name", "label", "h_sequence", "l_sequence"], how="left")

    # different splits of the test dataset based on the mutation count of the chain pairs (both classes)
    germline = all_annot[(all_annot["h_mutation_count"] == 0) & (all_annot["l_mutation_count"] == 0)]
    mutated = all_annot[(all_annot["h_mutation_count"] != 0) & (all_annot["l_mutation_count"] != 0)]
    h_germ = all_annot[(all_annot["h_mutation_count"] == 0) & (all_annot["l_mutation_count"] != 0)]
    l_germ = all_annot[(all_annot["h_mutation_count"] != 0) & (all_annot["l_mutation_count"] == 0)]

    sames = all_annot[((all_annot["h_mutation_count"] == 0) & (all_annot["l_mutation_count"] == 0)) |
                      ((all_annot["h_mutation_count"] != 0) & (all_annot["l_mutation_count"] != 0))]
    diffs = all_annot[((all_annot["h_mutation_count"] != 0) & (all_annot["l_mutation_count"] == 0)) |
                      ((all_annot["h_mutation_count"] == 0) & (all_annot["l_mutation_count"] != 0))]

    datasets = {
        "Unmutated": germline,
        "Mutated": mutated,
        "Different": diffs,
        "All": all_annot,
    }

    # calculate stats per sequence type
    for model in model_order:
        for name, pair_df in datasets.items():
            # prediction accuracy
            cm = pd.crosstab(pair_df["label"], pair_df[f"{model}_prediction"])
            pair_acc = np.diag(cm).sum() / cm.to_numpy().sum()
            
            # mean probability towards CORRECT class
            correct_probs = pair_df.apply(lambda p: p[f"{model}_probability"] if (p["label"] == 1) else (1 - p[f"{model}_probability"]), axis=1)
            
            pairs_types_accs.append({
                "itr": itr,
                "model": model,
                "pair_type": name,
                "acc": pair_acc,
                "correct_probability": correct_probs.mean()
            })

pair_accs_df = pd.DataFrame(pairs_types_accs)
pair_accs_df.groupby(["pair_type", "model"]).mean()

## plot overall accuracy

In [None]:
all_df = pair_accs_df[pair_accs_df['pair_type'] == 'All'].copy()

In [None]:
# extract model / dataset name
all_df[["model", "datasets"]] = all_df["model"].str.split("-", 1, expand=True)

In [None]:
pallete = sns.color_palette("colorblind", n_colors=3)
palette_dict = {
    'F': pallete[0],
    'H': pallete[1],
    'Q': pallete[2]
}

In [None]:
# accuracy barplot
fig, ax = plt.subplots(figsize=(12, 7))

sns.barplot(data=all_df,
            x="model", y="acc", 
            hue="datasets", 
            palette=palette_dict,
            errorbar="se",
            gap=0.025,
            width=0.75,
            saturation=0.85,
           )

# get legends
handles_ds, labels_ds = ax.get_legend_handles_labels()
ax.get_legend().set_visible(False)

# labels & params
ax.set_xlabel("Model Size (Parameters)", fontsize=15)
ax.set_ylabel("Average Accuracy", fontsize=15)
ax.xaxis.set_tick_params(labelsize = 12)
ax.yaxis.set_tick_params(labelsize = 12)

ax.set_ybound(0.48, 0.74)

# random guessing line
line = plt.axhline(y=0.5, color='black', linestyle='--', label='Random Guessing')
ax.legend(handles=[line], loc='upper left', fontsize=12)

# plot legend
fig.legend(handles_ds, labels_ds, 
           loc="center right", 
           title="Dataset",
           bbox_to_anchor=(0.99, 0.5),
           fontsize=11, title_fontsize=12,
           ncols=1)

# save
plt.savefig("./figures/avg_overall_accuracy.pdf", bbox_inches='tight', dpi=300)

## plot true positives / negatives

In [None]:
def itr_col(path, itr):
    df = pd.read_csv(path)
    df["itr"] = itr
    return df

all_pair_preds = [itr_col(pair_preds[itr], itr) for itr in range(len(pair_preds))]
all_pair_preds = pd.concat(all_pair_preds, ignore_index=True) 
len(all_pair_preds)

In [None]:
# split patterns by model (though i am mostly looking at the overall trend)
pred_categories = pd.concat([pd.DataFrame(zip(all_pair_preds[f"{model}_category"], 
                                              all_pair_preds[f"{model}_probability"], 
                                              all_pair_preds["itr"], 
                                              [model]*len(all_pair_preds)), 
                                          columns=["category", "probability", "itr", "model"]) for model in model_order])

df_truefalse = pred_categories.groupby(["itr", "model", "category"]).size().reset_index(name="count")
df_true = df_truefalse[df_truefalse['category'].isin(['true_positive', 'true_negative'])]
order = ['350M-Q', '350M-H', '350M-F']
palette = {
    '350M-F': pallete[0],
    '350M-H': pallete[1],
    '350M-Q': pallete[2]
}

fig, ax = plt.subplots(figsize=(4.5, 7))
sns.barplot(
    data=df_true[df_true['model'].isin(order)], 
    x="category", y="count", hue="model", errorbar="se",
    order=["true_positive", "true_negative"],
    hue_order=order, 
    palette=palette,
    gap=0.025,
    width=0.75,
    saturation=0.85,
)

# legend
ax.get_legend().set_visible(False)
ax.set_title("350M Models", fontsize=15)

# x axis
ax.set_xlabel("Class", fontsize=15)
ax.set_xticklabels(["Shuffled", "Native"])
ax.xaxis.set_tick_params(labelsize = 12)

# y axis
ax.set_ylabel("Number of Correct Predictions", fontsize=15)
ax.yaxis.set_tick_params(labelsize = 12)

# save
plt.savefig("./figures/true-positives-negatives.pdf", bbox_inches='tight', dpi=300)

## plot heatmaps
Accuracy for sequences with 15 or less mutations

In [None]:
# filter out annotations to just be in the zoomed in area that we are interested in plotting
data_mut = data_mut[(data_mut["h_mutation_count"] <= 15) & (data_mut["l_mutation_count"] <= 15)]
len(data_mut)

In [None]:
# merge with mutation counts (add mutation counts to inference data)
test_data = all_pair_preds.merge(data_mut, on=["name", "label", "h_sequence", "l_sequence"], how="inner")

# remove donor, logits, and category columns (for now so that the df is smaller, if i want to use these i can just not remove them)
test_data.drop([col_name for col_name in test_data.columns if any(keyword in col_name for keyword in ["donor", "category", "logits"])], axis=1, inplace=True)

# stats that cannot be calculated in a groupby object
for model in tqdm(model_order):
    # correct prediction counter column
    test_data[f"{model}_correct"] = test_data["label"] == test_data[f"{model}_prediction"]

    # prediction confidence towards the CORRECT label class
    test_data[f"{model}_confidence_towards_label"] = test_data.apply(lambda p: p[f"{model}_probability"] if (p["label"] == 1) else (1 - p[f"{model}_probability"]), axis=1)

test_data

In [None]:
# colormaps for mutation plots
cmap_name = "RdBu"
cmap = colormaps[cmap_name]
cmap_r = colormaps[f"{cmap_name}_r"]

In [None]:
# aggregate inference stats
def agg_stats(grouped_df):
    d = {}
    # total counts
    counts = grouped_df["label"].count()
    d["count"] = grouped_df["label"].count()
    d["log_count"] = np.log(counts)
    d["native_count"] = counts - grouped_df["label"].sum()
    d["shuffled_count"] = grouped_df["label"].sum()

    # model-specific metrics
    for model in model_order:
        # accuracy
        d[f"{model}_acc"] = grouped_df[f"{model}_correct"].sum()/counts

        # prediction confidence towards CORRECT label (prob if shuffled label [1] , 1-prob if native label [0])
        d[f"{model}_confidence"] = grouped_df[f"{model}_confidence_towards_label"].mean()

        # prediction confidence towards shuffled label (probability)
        d[f"{model}_confidence_towards_shuffled"] = grouped_df[f"{model}_probability"].mean()
        
    return pd.Series(d)

In [None]:
# aggregate by mutation count
agg_all = test_data.groupby(["h_mutation_count", "l_mutation_count"]).apply(agg_stats).reset_index()

In [None]:
cols_350 = ["h_mutation_count", "l_mutation_count"] + [col for col in agg_all.columns if '350M' in col]
models_350 = ["350M-Q", "350M-H", "350M-F"]

fig, axs = plt.subplots(1, 3, figsize=(16, 5))
mappable = None
for i, ax in enumerate(axs):
    model = models_350[i]

    heatmap = sns.heatmap(
        agg_all[cols_350].pivot(
            index="h_mutation_count", 
            columns="l_mutation_count", 
            values=f"{model}_acc",
        ),
        vmin=0, vmax=1,
        center=0.5, cmap=cmap_r,
        square=True,
        ax=ax,
        cbar=False
    )
    if mappable is None:
        mappable = heatmap.collections[0] 
    
    ax.set_xlabel("Light Chain Mutation Count", fontsize=14)
    ax.set_xticks(np.arange(0.5, 16, 2), 
                  list(range(0, 16, 2)),
                  rotation="horizontal", 
                  fontsize=10)
    
    ax.set_ylabel("") 
    ax.set_yticks(np.arange(0.5, 16, 2), 
                  list(range(0, 16, 2)),
                  rotation="horizontal", 
                  fontsize=10)
    
    ax.set_title(f"{model}", fontsize=15)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.25)
        spine.set_color('grey')

# shared y axis
fig.text(0.09, 0.5, "Heavy Chain Mutation Count",
         va='center', rotation='vertical', fontsize=14)

# shared colorbar
cbar = fig.colorbar(mappable, ax=axs, orientation='vertical', fraction=0.02, pad=0.03, label="Average Accuracy")
cbar.ax.tick_params(labelsize=10)
cbar.outline.set_visible(False)

# save
plt.savefig("./figures/350M_mutation_counts_acc.pdf", bbox_inches='tight', dpi=300)