In [None]:
__author__ = "Matteo Pariset"

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from functools import reduce

In [1]:
import sys
sys.path.append("../../")

# Useful for Pylance
from utils import *

import utils
utils.refresh(sys.modules[__name__])

# Convert 'developability-X' raw files into dataset
version 0.9.1

In [None]:
# TODO: PARAMs: Put the names (not the paths) of the files containing the DPs (chain vs structural respectively)
chain_params_filename = "AbChain_whole_mAbs_developability_final.csv"
struct_params_filename = "AbStruc_whole_mAbs_developability_final.csv"
########################

# The name of the processed dataset (suggested: developability-thera)
dataset_name = "developability-thera"
########################

In [None]:
chain_dset = pd.read_csv(os.path.join(get_dataset_dir(dataset_name), chain_params_filename))
struct_dset = pd.read_csv(os.path.join(get_dataset_dir(dataset_name), struct_params_filename))

In [None]:
# Harmonize columns
cols_subst = {'identity_species': 'species', 'corrected_isotype': 'isotype'}
for k,v in cols_subst.items():
    chain_dset.rename({k: v}, axis=1, inplace=True);
    struct_dset.rename({k: v}, axis=1, inplace=True);

In [None]:
# Take care of missing columns

if 'isotype' not in chain_dset.columns:
    chain_dset['isotype'] = ""

if 'rowid' not in chain_dset.columns:
    chain_dset = chain_dset.reset_index()
    chain_dset.rename({'index': 'rowid'}, axis=1, inplace=True)

In [None]:
chain_dset = chain_dset.infer_objects()
chain_dset

In [None]:
# Sanity check: Sequences in Chain & Struct files coincide
assert (chain_dset['aaSeqAbChain'] == struct_dset['aaSeqAbChain']).all(), "Sequence mismatch"

In [None]:
seq_chunks_names = ['aaSeqFR1', 'aaSeqCDR1', 'aaSeqFR2', 'aaSeqCDR2', 'aaSeqFR3', 'aaSeqCDR3', 'aaSeqFR4']
cdr_names = seq_chunks_names[1::2]
fr_names = seq_chunks_names[::2]

In [None]:
# Sanity check: CDRs+FRs decompose each sequence
assert (chain_dset['aaSeqAbChain'] == chain_dset[seq_chunks_names].fillna("").apply("".join, axis=1)).all(), "Cannot decompose sequence"

### Inspect sequences

In [None]:
unq_sequences = chain_dset['aaSeqAbChain'].value_counts()
unq_sequences

In [None]:
repeated_seqs = unq_sequences.index[unq_sequences > 1].values
repeated_seqs.shape[0]

In [None]:
repeated_dev = chain_dset.query('aaSeqAbChain in @repeated_seqs')
repeated_dev

In [None]:
repeated_metrics = repeated_dev[['aaSeqAbChain'] + list(repeated_dev.columns.values[repeated_dev.columns.str.match(r"Ab(Chain|Struc)")])].groupby('aaSeqAbChain').agg(['mean', 'var'])
repeated_metrics

In [None]:
# Sanity check: Verify that max metric variance among metrics calculated for the same sequence is 0
assert np.isclose(repeated_metrics[repeated_metrics.columns[repeated_metrics.columns.get_level_values(1) == "var"]].max().max(), 0), "Metrics computed for the same sequence vary"

### Compute position and length of CDR & FR

In [None]:
cdr_idxs = (chain_dset.apply(lambda x: [x['aaSeqAbChain'].find(x[chunk_name]) for chunk_name in cdr_names], axis=1, result_type='expand')
                      .rename({i: cdr + "_idx" for i, cdr in enumerate(cdr_names)}, axis=1))
cdr_idxs

In [None]:
cdr_lengths = chain_dset[cdr_names].applymap(len).rename({cdr: cdr + "_length" for cdr in cdr_names}, axis=1)
cdr_lengths

### Inspect info on Ab origin

#### Species

In [None]:
chain_dset[['chain', 'species']].astype('category').value_counts()

#### Isotype

In [None]:
chain_dset['isotype'].hist(bins=chain_dset['isotype'].unique().shape[0])

#### VJ genes

In [None]:
print("Fraction of nan values:")
chain_dset['v_gene'].isna().mean(), chain_dset['j_gene'].isna().mean()

In [None]:
# Sanity check
assert (chain_dset['v_gene'].dropna() == struct_dset['v_gene'].dropna()).all(), f"{(chain_dset['v_gene'].dropna() != struct_dset['v_gene'].dropna()).mean()}% of entries differ!"

In [None]:
# Sanity check
assert (chain_dset['j_gene'].dropna() == struct_dset['j_gene'].dropna()).all(), f"{(chain_dset['j_gene'].dropna() != struct_dset['j_gene'].dropna()).mean()}% of entries differ!"

In [None]:
orig_dset = chain_dset[['chain', 'species', 'isotype', 'v_gene', 'j_gene']].rename("AbOrig_".__add__, axis=1)

In [None]:
ontology_regexp = re.compile(r".*IG..[0-9]+")
orig_dset['AbOrig_v_gene_prefix'] = orig_dset['AbOrig_v_gene'].dropna().astype(str).apply(lambda x: ontology_regexp.match(x)[0])

### Compose curated dataset

In [None]:
# Parse clinical trial phase
if "highest_clin_trial" in chain_dset.columns:
    print(chain_dset['highest_clin_trial'].unique())

    clin_trial_conversion = {
        'Approved': 4,
        'Phase-III': 3,
        'Phase-II': 2,
        'Phase-I': 1,
        'Preregistration': 0,
        'Phase-I/II': 1,
        'Phase-II/III': 2,
        'Preregistration (w)': 0,
        'Approved (w)': 4,
        'Preclinical': 0,
        'Unknown': None
    }

In [None]:
curated_dset = chain_dset[['rowid', 'aaSeqAbChain', 'chain']].join(
    [
        cdr_idxs,
        cdr_lengths,
        # Orig info
        orig_dset,
        # Chain-based metrics
        chain_dset.loc[:,chain_dset.columns.str.startswith("AbChain")],
        # Struct-based metrics
        struct_dset.loc[:,struct_dset.columns.str.startswith("AbStruc")]
    ]
)

if "highest_clin_trial" in chain_dset.columns:
    curated_dset = curated_dset.join(chain_dset['highest_clin_trial'].apply(clin_trial_conversion.get))


In [None]:
curated_dset['source_datasets'] = dataset_name

# Column renaming & type enforcement
curated_dset = curated_dset.rename({'aaSeqAbChain': 'sequence', 'chain': 'chain_type'}, axis=1)
curated_dset['rowid'] = curated_dset['rowid'].astype(int)
# Harmonize chain types
curated_dset['chain_type'] = curated_dset['chain_type'].apply(lambda x: x.split("_")[0] if x.find("_") > -1 else x)
curated_dset['AbOrig_chain'] = curated_dset['AbOrig_chain'].apply(lambda x: x.split("_")[0] if x.find("_") > -1 else x)
curated_dset

In [None]:
curated_dset.to_csv(os.path.join(get_dataset_dir(dataset_name), f"{dataset_name}.csv"))

### Transform metrics into embeddings

In [None]:
use_mdws = True

In [None]:
if use_mdws:
    # use MWDS
    selected_metrics = pd.read_csv("./reproducibility/extended_mwds_metrics.csv").loc[:,'0'].to_numpy()
    selected_metrics = np.intersect1d(selected_metrics, curated_dset.columns)
    metrics_embs_name = "mwds"
else:
    # Select all metrics
    selected_metrics = curated_dset.columns.str.contains(r"(AbChain|AbStruc)")
    metrics_embs_name = "metrics"

In [None]:
metrics_df = curated_dset.loc[:,selected_metrics]
metrics_df

In [None]:
# To perform comparisons of DP embedding across dataset, the list of metrics used should be the one used for native Abs
native_dset = pd.read_csv("./reproducibility/developability_processed_{metrics_embs_name}.csv")

metrics_without_nans = native_dset.loc[:,native_dset.columns.str.contains(r"(AbChain|AbStruc)")].columns.values

In [None]:
# Load std & mean from native Abs (since the same PCA will be used on this dataset too)
native_labels = pd.read_csv(os.path.join(get_dataset_dir("developability"), "developability.csv")).loc[:,metrics_without_nans]

native_means = native_labels.mean()
native_stds = native_labels.std()

In [None]:
native_labels['AbStruc_weak_hbonds'].hist()

In [None]:
metrics_df = metrics_df - native_means
assert np.isclose(metrics_df.std() , 0).sum() == 0, "Some metrics have zero variance, remove them"

In [None]:
metrics_df.shape

In [None]:
# Remove metrics not in original DP
metrics_df = metrics_df.loc[:,metrics_without_nans]

In [None]:
metrics_df.shape

In [None]:
print(f"{metrics_df.isna().sum().sum()} nan entries remaining")

In [None]:
# Fill residual NaNs with the mean computed on the native Abs dataset
for m_name in metrics_df.columns:
    metrics_df[m_name] = metrics_df[m_name].fillna(native_dset[m_name].mean())

In [None]:
# Sanity check
assert metrics_df.isna().sum().sum() == 0, "NaNs entries left"

In [None]:
metrics_df.std()

In [None]:
metrics_df /= native_stds

In [None]:
distr_metric = metrics_df.columns[0]
native_labels[distr_metric].hist(label="native")
chain_dset[distr_metric].hist(label="thera")
plt.yscale("log")
plt.title(f"{distr_metric} comparison")
plt.legend();

In [None]:
metrics_df.std().plot.line()
plt.yscale("log")
plt.xticks(rotation=90);

In [None]:
print("Metrics with high (>= 10) residual stds:")
metrics_df.std()[metrics_df.std() >= 10]

In [None]:
metrics_df.std()

In [None]:
metrics_support = metrics_df.apply(lambda x: (x.min(), x.max()))
metrics_support

In [None]:
def draw_metrics_support(metrics_df):
    metrics_support = metrics_df.apply(lambda x: (x.min(), x.max()))
    supp_fig, supp_ax = plt.subplots()
    supp_ax.barh(range(metrics_df.columns.shape[0]), metrics_support.loc[0,:].values)
    supp_ax.barh(range(metrics_df.columns.shape[0]), metrics_support.loc[1,:].values);

In [None]:
draw_metrics_support(metrics_df)

In [None]:
aberration_threshold = 15
aberrant_metrics = (metrics_support.abs() > aberration_threshold).sum(axis=0) > 0
aberrant_metrics.sum()

In [None]:
# Clip aberrant values
metrics_df = np.clip(metrics_df, -aberration_threshold, aberration_threshold)

In [None]:
plt.plot(metrics_df);

In [None]:
draw_metrics_support(metrics_df)

In [None]:
plt.plot(metrics_df.std(), 'x');
plt.xticks(rotation=90);
plt.title(f"Metrics stds for {dataset_name}");

In [None]:
# Save processed metrics, for later inspection
curated_metrics_dset = curated_dset[['sequence']].join(metrics_df)
curated_metrics_dset.to_csv(f"./reproducibility/{dataset_name}_processed_{metrics_embs_name}.csv")

In [None]:
# Save embeddings
np.save(os.path.join(get_dataset_dir(dataset_name), f"{dataset_name}_{metrics_embs_name}_seq.npy"), metrics_df.values)

# ... also by chain type
metrics_embeddings_filename_template = dataset_name + "_%s_%s_seq"
np.save(os.path.join(get_dataset_dir(dataset_name), metrics_embeddings_filename_template % ("heavy", metrics_embs_name)), metrics_df.loc[curated_dset['chain_type'] == "heavy"].values)
np.save(os.path.join(get_dataset_dir(dataset_name), metrics_embeddings_filename_template % ("light", metrics_embs_name)), metrics_df.loc[curated_dset['chain_type'] == "light"].values)

### Transform AA counts into embeddings

In [None]:
def _to_one_hot(i):
    ohe = np.zeros(20)
    ohe[i] = 1
    return ohe

aa_list = sorted(['A','R','F','N','D','C','E','Q','G','H','I','L','K','M','P','S','T','W','Y','V'])
aa_dict = dict([(x,_to_one_hot(i)) for i,x in enumerate(aa_list)])

def aa_freq_embedding(dataframe, seq_column):
    aa_count_array = np.vstack(dataframe[seq_column].apply(lambda x: np.array(reduce(lambda a, b: a+b, map(aa_dict.get, list(x))))))
    aa_freq_array = aa_count_array/aa_count_array.sum(axis=1,keepdims=True)
    return aa_freq_array

In [None]:
curated_aa_freqs = aa_freq_embedding(curated_dset, "sequence")

In [None]:
plt.bar(aa_list, curated_aa_freqs.mean(axis=0));
plt.title("Fraction of AAs");

In [None]:
# Save embeddings
np.save(os.path.join(get_dataset_dir(dataset_name), f"{dataset_name}_aas_seq.npy"), curated_aa_freqs)