# License

Licensed under the Apache License, Version 2.0 (the "License")
```
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# Setup

In [None]:
# Uncomment to install the covid_vhh_design package

# !pip install git+https://github.com/google-research/google-research.git#subdirectory=covid_vhh_design

In [None]:
#@title Imports

import collections
import os

import altair as alt
from IPython import display
import immutabledict
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns
import statsmodels.stats.multitest
import scipy
import numpy as np
import pandas as pd

In [None]:
from covid_vhh_design import helper
from covid_vhh_design import covid
from covid_vhh_design import plotting
from covid_vhh_design import utils

# Settings

In [None]:
plotting.update_rcparams()

%config InlineBackend.figure_format = 'retina'

pd.set_option('display.width', 200)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', 200)

# Constants

In [None]:
COV_WT = covid.COV_WT
COV1_WT = covid.COV1_WT
COV2_WT = covid.COV2_WT
SHUFFLED = utils.SHUFFLED
PARENT_NAME = covid.PARENT_NAME

COLOR_PALETTE = plotting.PALETTE

TARGET_NAME_MAPPING = covid.TARGET_SHORT_NAME_MAPPING
TARGET_NAME_ORDER = tuple(list(covid.TARGET_SHORT_NAME_MAPPING.values()))

# Utils

In [None]:
def concat_data_rounds(data_rounds):
  return pd.concat(
      [df.assign(round=i) for i, df in data_rounds.items()],
      ignore_index=True)

# Initialize global variables

In [None]:
Data = helper.Bunch()  # Data
D = Data
C = helper.Bunch()  # Computations
T = helper.Bunch()  # Temporary

# Load and clean data

## AlphaSeq data

In [None]:
# @title Load data

def load_alpha_seq_data():
  """Map 0-indexed round to corresponding library."""
  alpha_seq = {i: covid.load_alphaseq_data(i + 1) for i in range(3)}
  return alpha_seq


Data.raw_alpha_seq = load_alpha_seq_data()

In [None]:
#@title Fix annotation mistakes
def multiple_annotations_within_round(alphaseq_data):
  df = pd.concat([df.assign(round=i) for i, df in alphaseq_data.items()],
                 ignore_index=True)
  return (df.groupby(['source_seq', 'round',
                      'source_num_mutations'])['source_std_group'].agg([
                          'nunique', 'unique'
                      ]).query('nunique != 1')['unique'].reset_index())


# This sequence is annotated both as `best_round1` and `baseline_r0`.
_SPECIAL_SEQ = (
    'QVQLQESGGGLVQAGGSLRLSCAASGFTFSEYAMGWFRQAPGKEREFVATISWSGRSTYYTD'
    'SVKGRFTISRDNAKNTVYLQMNSLKPDDTAVYYCASAGLFTYVSEWDYDYDYWGQGTQVTVSS')

mistaken_annotations = multiple_annotations_within_round(Data.raw_alpha_seq)
mistaken_non_singles = mistaken_annotations[
    mistaken_annotations['source_num_mutations'] != 1]
assert len(mistaken_non_singles) == 1
assert helper.get_unique_value(
    mistaken_non_singles['source_seq']) == _SPECIAL_SEQ
assert helper.get_unique_value(
    mistaken_non_singles['round']) == 2

# Correct annotation (singles are fixed in the `annotate` call).
Data.raw_alpha_seq[2].loc[Data.raw_alpha_seq[2]['source_seq'] == _SPECIAL_SEQ,
                          'source_group'] = 'baseline_r0'

display.display(mistaken_annotations)

In [None]:
#@title Annotate & filter data
%time Data.alpha_seq = covid.filter_alphaseq_data(covid.annotate_alphaseq_data(Data.raw_alpha_seq))

In [None]:
#@title Write updated annotations to disk
def sequences_with_updated_annotations(alphaseq_data):
  df = pd.concat([df.assign(round=i + 1) for i, df in alphaseq_data.items()],
                 ignore_index=True)
  updated_seqs = df[df['source_std_group'] !=
                    df['new_source_std_group']]['source_seq'].unique()
  df = df[df['source_seq'].isin(updated_seqs)]
  df = df.groupby(['source_seq', 'round', 'new_source_std_group'
                  ])['source_std_group'].first().reset_index()
  return (df.pivot(
      index=['source_seq', 'new_source_std_group'],
      columns='round',
      values='source_std_group'))


updated_annos = sequences_with_updated_annotations(Data.alpha_seq).reset_index()
display.display(updated_annos.head())

## BLI data

In [None]:
Data.bli_v1 = covid.load_df('bli_v1.csv')
Data.bli_v2 = covid.load_df('bli_v2.csv')

## Aggregate all sequence replicas

To compute the "representative" binding value of a sequence to a target, we
1. Standardize by the experimental replicate (each experimental replicate has a mean of 0 and a standard deviation of 1 (`inf`s excluded).
2. Standardize by the parent, so that the parent binding for each experimental replicate has a median of 0 and an IQR of 1.
3. Median aggregate per sequence.

In [None]:
def get_metadata(alphaseq_data):
  return {i: utils.get_metadata(alphaseq_data[i]) for i in alphaseq_data}


def aggregate_affinities(alphaseq_data):
  return {
      i: utils.aggregate_affinities(alphaseq_data[i])
      for i in alphaseq_data
  }


def compute_pvalues(alphaseq_data,
                    alpha: float = 0.05,
                    correction_method: str = 'bonferroni'):
  return {
      i: utils.compute_pvalues(
          alphaseq_data[i],
          alpha=alpha,
          correction_method=correction_method,
          min_counts=18) for i in alphaseq_data
  }


Data.metadata = get_metadata(Data.alpha_seq)
Data.aggregated = aggregate_affinities(Data.alpha_seq)

Data.r2_pvalues = utils.compute_pvalues(
    Data.alpha_seq[2].query('target_name in @COV_WT'),
    alpha=0.05,
    correction_method='bonferroni',
    min_replicas=18)

In [None]:
def join_aggregate_data_with_values(metadata, aggregated, pvalues):
  return {
      i: utils.join_aggregate_data_with_values(
          metadata_df=metadata[i], agg_df=aggregated[i], pvalues_df=pvalues[i])
      for i in range(3)
  }


Data.main = join_aggregate_data_with_values(
    metadata=Data.metadata,
    aggregated=Data.aggregated,
    pvalues={
        0: None,
        1: None,
        2: Data.r2_pvalues
    })

In [None]:
Data.unified = utils.aggregate_over_rounds(Data.alpha_seq).merge(
    Data.r2_pvalues, how='left', on=['source_seq', 'target_name'])

# Standard deviation vs IQR

In [None]:
def _compare_std_and_iqr(raw_df, target_name):
  df = raw_df[raw_df['target_name'] == target_name].copy()
  df = utils.standardize_experimental_replicates(df)
  df = utils.extract_parent_df(df)

  std_max_impute = np.std(helper.max_impute_inf(df['value']))
  std_drop = np.std(helper.drop_inf_df(df, 'value')['value'])
  return {
      'std (max impute)': std_max_impute,
      'std (drop inf)': std_drop,
      'IQR': scipy.stats.iqr(df['value'])
  }


def compare_std_and_iqr(raw_data, target_name):

  rows = []
  for i in raw_data:
    row = dict(round=i + 1)
    row.update(_compare_std_and_iqr(raw_data[i], target_name))
    rows.append(row)

  return pd.DataFrame(rows).set_index('round')


compare_std_and_iqr(Data.alpha_seq, COV2_WT)

# AlphaSeq

**Round 0**
- 1 synthesis replica
- 3 experimental replicas

**Round 1**
- 1 synthesis replica
- 3 experimental replicas, but we drop the first one, leaving only `replica in [2, 3]`

**Round 2**
- 3 synthesis replicas
- 6 experimental replicas


In [None]:
#@title Experimental replicas {form-width: "30%"}

source_replica = 1  #@param
target_name = COV2_WT  #@param
round = 0  #@param


def _plot_experimental_replica_noise(ax, df, replica_x, replica_y):
  df = df.pivot(
      index=['source_seq', 'source_design'], values='value', columns='replica')
  df.columns = [f'Exp. replica {c}' for c in df.columns]

  plotting.plot_correlations(
      ax,
      df,
      x_col=f'Exp. replica {replica_x}',
      y_col=f'Exp. replica {replica_y}',
      hue_col='source_design')


_, ax = plt.subplots(figsize=(5, 5))
_plot_experimental_replica_noise(
    ax,
    df= (
        Data.alpha_seq[round]
         .query('target_name == @target_name')
         .query('source_replica == @source_replica')
    ),
    replica_x=2,
    replica_y=3)

In [None]:
#@title Synthesis replicas {form-width: "30%"}

experimental_replica = 1  #@param
target_name = COV2_WT  #@param
round = 2  #@param


def _plot_synthetic_replica_noise(ax, df, replica_x, replica_y):
  df = df.pivot(
      index=['source_seq', 'source_design'],
      values='value',
      columns='source_replica')
  df.columns = [f'Syn. replica {int(c)}' for c in df.columns]

  plotting.plot_correlations(
      ax,
      df,
      x_col=f'Syn. replica {replica_x}',
      y_col=f'Syn. replica {replica_y}',
      hue_col='source_design')


_, ax = plt.subplots(figsize=(5, 5))
_plot_synthetic_replica_noise(
    ax,
    df=(
        Data.alpha_seq[round]
        .query('target_name == @target_name')
        .query('replica == @experimental_replica')),
    replica_x=1,
    replica_y=2)

In [None]:
#@title By round {form-width: "30%"}
round_x = 0  #@param
round_y = 2  #@param
target_name = COV2_WT  #@param


def _plot_experimental_round_noise(ax, df, round_x, round_y):
  df = df.pivot(
      index=['source_seq', 'source_design'], values='value', columns='round')
  df.columns = [f'Round {c+1}' for c in df.columns]
  plotting.plot_correlations(
      ax,
      df,
      x_col=f'Round {round_x + 1}',
      y_col=f'Round {round_y + 1}',
      hue_col='source_design')


_, ax = plt.subplots(figsize=(5, 5))
_plot_experimental_round_noise(
    ax,
    df=pd.concat([Data.main[x].assign(round=x) for x in range(3)],
                 ignore_index=True).query('target_name==@target_name'),
    round_x=round_x,
    round_y=round_y)

In [None]:
#@title Source replicas {form-width: "30%"}

target_name = COV2_WT  #@param


def _plot_source_replica_noise(ax, df: pd.DataFrame, shrink=.8):
  df = df[df['source_replica'].isin(
      [1, 2, 3])].assign(source_replica=df['source_replica'].astype(int))
  df = df.pivot_table(
      index=['source_seq'],
      values='value',
      columns=['source_replica', 'replica'])

  df = df.corr(method='spearman')
  df.columns = df.columns.to_flat_index()
  df.index = df.index.to_flat_index()

  df.columns = df.columns.map(lambda x: f'{x[0]} - {x[1]}'
                              if x[1] == 1 else x[1])
  df.index = df.index.map(lambda x: f'{x[0]} - {x[1]}' if x[1] == 1 else x[1])

  sns.heatmap(
      df,
      linewidth=.5,
      vmin=0,
      vmax=1,
      cmap='rocket_r',
      cbar_kws={'shrink': shrink},
      ax=ax)
  ax.set_aspect('equal', adjustable='box')
  ax.set_xlabel('')
  ax.set_ylabel('')
  ax.xaxis.tick_top()  # x axis on top
  ax.tick_params(axis='x', labelrotation=90, labelsize=14)
  ax.tick_params(axis='y', labelrotation=0, labelsize=14)


_, ax = plt.subplots(figsize=(6, 6))
_plot_source_replica_noise(
    ax,
    df=Data.alpha_seq[2].query('target_name==@target_name'),
)

# $p$-value plots

In [None]:
def _plot_pvalue_vs_value(ax, df: pd.DataFrame, target_name=COV2_WT):
  df = df[df['target_name'] == target_name]
  sns.scatterplot(
      data=df,
      x='value',
      y='pvalue',
      ax=ax,
      palette=plotting.PALETTE,
      hue='source_design',
  )
  ax.set_xlabel('Normalized log KD')
  ax.set_ylabel('Mann-Whitney U-test pvalue')
  return ax


def _plot_corrected_pvalue_vs_value(
    ax,
    df: pd.DataFrame,
    correction_method: str,
    target_name=COV2_WT,
):
  # We use pvalue == pvalue to drop NaN pvalues.
  df = df.query('target_name == @target_name & pvalue == pvalue')
  df = df.assign(
      pvalue_corrected=statsmodels.stats.multitest.multipletests(
          df['pvalue'], alpha=.05, method=correction_method)[1])

  sns.scatterplot(
      data=df,
      x='value',
      y='pvalue_corrected',
      hue='source_design',
      ax=ax,
      palette=plotting.PALETTE)
  ax.set_xlabel('Normalized log KD')
  ax.set_ylabel(f'{correction_method.upper()}-corrected pvalue')
  return ax


f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
_plot_pvalue_vs_value(ax1, Data.main[2])
_plot_corrected_pvalue_vs_value(
    ax2, Data.main[2], correction_method='bonferroni')

for ax in (ax1, ax2):
  ax.legend().set_visible(False)

handles, labels = ax1.get_legend_handles_labels()
f.legend(handles, labels, ncol=4, bbox_to_anchor=(.7, 1.0), frameon=False)

# Percentage of Infs

In [None]:
def _compute_percentage_infs(values):
  is_infs = np.isinf(values)
  return (
      pd.DataFrame(dict(
        num_infs=is_infs.sum(),
        frac_infs=is_infs.mean(),
        num_samples=len(values)),
        index=[0])
      .assign(
          percentage_infs=lambda df: df['frac_infs'] * 100))


def compute_percentage_infs(df, group_vars):
  group_vars = helper.to_list(group_vars)
  return (
      df.groupby(group_vars)['value']
      .apply(_compute_percentage_infs)
      .reset_index(group_vars))

## All rounds

In [None]:
def _plot_percentage_infs_rounds(ax, data):
  df = (
      helper.map_columns(
          data,
          target_name=covid.TARGET_SHORT_NAME_MAPPING,
          round=covid.ROUND_MAPPING)
  )

  sns.barplot(
      ax=ax,
      data=df,
      x='round',
      order=covid.ROUND_MAPPING.values(),
      y='percentage_infs',
      hue='target_name',
      palette=plotting.PALETTE,
  )
  plotting.rotate_xlabels(ax, 30)
  ax.set_ylabel('% of non-binders')
  ax.legend(title='Target')
  return ax


C.per_infs_rounds = compute_percentage_infs(
    concat_data_rounds(Data.alpha_seq)
    .query('target_name in @COV_WT'),
    group_vars=['target_name', 'round'])


_, ax = plt.subplots(figsize=(6, 4))
_plot_percentage_infs_rounds(ax, C.per_infs_rounds)

In [None]:
def _plot_percentage_infs_by_model(ax, data, model_mapping):
  df = (
      data.query('source_group in @model_mapping').pipe(
          helper.map_columns,
          target_name=covid.TARGET_SHORT_NAME_MAPPING,
          source_group=model_mapping))
  sns.barplot(
      ax=ax,
      data=df,
      x='source_group',
      order=model_mapping.values(),
      y='percentage_infs',
      hue='target_name',
      palette=plotting.PALETTE)
  plotting.rotate_xlabels(ax, 30)
  ax.set_ylabel('% of non-binders')
  ax.legend(
      bbox_to_anchor=(0.5, 1.1),
      ncol=2,
      loc='center',
      frameon=False,
  )

  return ax

In [None]:
#@title Round 1

C.per_infs_r0 = compute_percentage_infs(
    Data.alpha_seq[0].query('target_name in @COV_WT'),
    group_vars=['target_name', 'source_group'])

_, ax = plt.subplots(figsize=(10, 4))
_plot_percentage_infs_by_model(
    ax,
    C.per_infs_r0,
    model_mapping=covid.ROUND0_SOURCE_GROUP_MAPPING)

In [None]:
x = 1.058224
y = 2.128489

print(f'Pearson score={x:.2f}\np-value={y:.2e}')

In [None]:
#@title Round 2

C.per_infs_r1 = compute_percentage_infs(
    Data.alpha_seq[1].query('target_name in @COV_WT'),
    group_vars=['target_name', 'source_group'])

_, ax = plt.subplots(figsize=(10, 4))
_plot_percentage_infs_by_model(
    ax, C.per_infs_r1, model_mapping=covid.ROUND1_SOURCE_GROUP_MAPPING)

In [None]:
#@title Round 3

C.per_infs_r2 = compute_percentage_infs(
    Data.alpha_seq[2].query('target_name in @COV_WT'),
    group_vars=['target_name', 'source_group'])

_, ax = plt.subplots(figsize=(10, 4))
_plot_percentage_infs_by_model(
    ax,
    C.per_infs_r2,
    model_mapping=covid.ROUND2_SOURCE_GROUP_MAPPING)

## By distance

In [None]:
def _plot_percentage_infs_by_distance(ax, data):
  df = (
      data.query('source_num_mutations > 0')
      .pipe(
          helper.map_columns,
          target_name=covid.TARGET_SHORT_NAME_MAPPING)
      .assign(
          source_num_mutations=lambda df: df['source_num_mutations'].astype('int'))
  )
  ax = sns.barplot(
      ax=ax,
      data=df,
      x='source_num_mutations',
      y='percentage_infs',
      hue='target_name',
      palette=plotting.PALETTE)
  ax.set_xlabel('Number of mutations from VHH-72')
  ax.set_ylabel('% of non-binders')
  ax.legend(
      bbox_to_anchor=(0.5, 1.1),
      ncol=2,
      loc='center',
      frameon=False,
  )
  ax.set_xticklabels(ax.get_xticklabels(), ha='center')

  return ax


In [None]:
#@title Round 1
C.per_infs_r0_dist = compute_percentage_infs(
    Data.alpha_seq[0].query('target_name in @COV_WT'),
    group_vars=['target_name', 'source_num_mutations'])

_, ax = plt.subplots(figsize=(5.5, 4))
_plot_percentage_infs_by_distance(ax, C.per_infs_r0_dist)
ax.legend().set_visible(False)

In [None]:
#@title Round 2
C.per_infs_r1_dist = compute_percentage_infs(
    Data.alpha_seq[1].query(
        'target_name in @COV_WT and source_design != @SHUFFLED'),
    group_vars=['target_name', 'source_num_mutations'])

_, ax = plt.subplots(figsize=(5.5, 4))
_plot_percentage_infs_by_distance(ax, C.per_infs_r1_dist)

In [None]:
#@title Round 3
C.per_infs_r2_dist = compute_percentage_infs(
    Data.alpha_seq[2].query(
        'target_name in @COV_WT and source_design != @SHUFFLED'),
    group_vars=['target_name', 'source_num_mutations'])

_, ax = plt.subplots(figsize=(5.5, 4))
_plot_percentage_infs_by_distance(ax, C.per_infs_r2_dist)
ax.legend().set_visible(False)

# log KD plots

## All rounds

In [None]:
def _plot_log_kd_by_round(ax, df, n=100, model_as_design: bool = False):
  df = df.query('target_name == @COV2_WT and source_design != @SHUFFLED')
  round2_df = df.query('round == 2 | source_num_mutations == 0')
  if model_as_design:
    round2_df = plotting.replace_ml_design_by_model(round2_df)

  # We only want new hues for the last round (assuming `model_as_design=True`),
  # so we're going to have to be sneaky.
  # We are going to plot to boxplots. The first one plots the first two rounds,
  # with hues 'Baseline' and 'ML', in the simple call below.
  ax = plotting.plot_log_kd(
      ax,
      agg_df=df.query('round < 2'),
      max_impute_inf=True,
      num_top_seqs=n,
      x_col='round',
      hue_col='source_design',
      show_iqr=False,
  )

  # The second plotting call does some manipulations to get the alignment to
  # work out: we only plot round 3, with its own hues ('Baseline', 'LGB', 'CNN')
  # but to get theses boxes correctly aligned on the x-axis, we "mock" Baseline
  # values (with some absurdly large numbers that won't show up after yaxis
  # rescaling) for rounds 1 and 2.
  mocked_prev_round_values = pd.DataFrame(
      dict(
          round=[0, 1] * n,
          source_design='Baseline',
          value=100,
          source_num_mutations=1,
      )
  )
  plotting.plot_log_kd(
      ax,
      agg_df=pd.concat([round2_df, mocked_prev_round_values]),
      max_impute_inf=True,
      num_top_seqs=n,
      x_col='round',
      hue_col='source_design',
      show_iqr=True,
  )
  ax.set_ylim(-4.3, 1.3)

  ax.set_title(
      f'Log KDs of each round\'s top {n} sequences against CoV-2', y=-0.2
  )
  ax.set_ylabel('Normalized log KD against CoV-2')

  # Finally, some of the labels are duplicated in the legend (because of two
  # plotting calls), so we need to deduplicate.
  handles, labels = ax.get_legend_handles_labels()
  labels_to_handles = dict(zip(labels, handles))
  handles = list(labels_to_handles.values())
  labels = list(labels_to_handles.keys())
  ax.legend(
      handles=handles,
      labels=labels,
      bbox_to_anchor=(0.5, 1.1),
      ncol=5,
      loc='center',
      frameon=False,
  )
  return ax


_, ax = plt.subplots(figsize=(10, 5))
ax = _plot_log_kd_by_round(ax, Data.unified, n=100, model_as_design=True)

## All rounds by distance

In [None]:
def _plot_log_kd_by_mutations(ax, df, n=100):
  df = df.query(
      'target_name == @COV2_WT and source_design != @SHUFFLED and (round == 2 or source_num_mutations == 0)'
  )

  plotting.plot_log_kd(
      ax,
      agg_df=df,
      max_impute_inf=True,
      num_top_seqs=n,
      x_col='source_num_mutations',
      hue_col='source_design',
      show_iqr=True,
  )

  ax.set_title(f'Log KDs of top {n} round 3 sequences against CoV-2', y=-.3)
  ax.set_ylabel('Normalized log KD against CoV-2')


_, ax = plt.subplots(figsize=(12, 5))
_plot_log_kd_by_mutations(ax, Data.unified)

## All rounds separately by distance

In [None]:
def _plot_all_rounds_by_mutations(ax, data, *, include_shuffled, ml_only):
  df = data.query('target_name == @COV2_WT').assign(
      round=data['round'].apply(lambda x: f'Round {x+1}')
  )
  if ml_only:
    df = df.query(
        'source_design == @utils.ML or source_num_mutations == 0'
    )

  elif not include_shuffled:
    df = df.query('source_design != @SHUFFLED')

  ax = plotting.plot_log_kd(
      ax=ax,
      agg_df=df,
      x_col='source_num_mutations',
      hue_col='round',
      palette='Dark2',
      max_impute_inf=True,
      show_iqr=True,
  )
  ax.legend(loc='center', bbox_to_anchor=(0.5, 1.1), ncol=4, frameon=False)
  ax.set_ylabel('Normalized log KD against CoV-2')


_, ax = plt.subplots(figsize=(12, 5))
_plot_all_rounds_by_mutations(
    ax, Data.unified, include_shuffled=False, ml_only=True
)

## By models

In [None]:
def _plot_log_kd_by_model(ax, df: pd.DataFrame, model_mapping: dict[str, str]):
  df = df[df['target_short_name'].isin(['SARS-CoV-1', 'SARS-CoV-2'])].copy()

  # TODO: The text should note what mutants are and that we drop them here.
  df = df[df['source_group'] != 'mutant']
  df = df.assign(source_model=df['source_group'].map(model_mapping))

  # We take the median because there are slightly different values for the
  # parent (all close to zero) for both targets.
  assert utils.extract_parent_df(df)['value'].values.max() < 1e-10
  df.loc[df['source_num_mutations'] == 0, 'value'] = 0.

  plotting.plot_log_kd(
      ax,
      agg_df=df,
      max_impute_inf=True,
      x_col='source_model',
      order=model_mapping.values(),
      hue_col='target_short_name',
      show_iqr=True,
  )
  ax.set_xlabel('')
  ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=30, ha='right')
  ax.legend(bbox_to_anchor=(.5, 1.1), loc='center', ncol=3, frameon=False)
  return ax


In [None]:
#@title Round 1
_, ax = plt.subplots(figsize=(10, 4))
_plot_log_kd_by_model(ax, Data.main[0], model_mapping=covid.ROUND0_SOURCE_GROUP_MAPPING)

In [None]:
#@title Round 2
_, ax = plt.subplots(figsize=(10, 4))
_plot_log_kd_by_model(ax, Data.main[1], model_mapping=covid.ROUND1_SOURCE_GROUP_MAPPING)

In [None]:
#@title Round 3
_, ax = plt.subplots(figsize=(6, 4))
_plot_log_kd_by_model(
    ax, Data.main[2], model_mapping=covid.ROUND2_SOURCE_GROUP_MAPPING)

## By distance

In [None]:
def _plot_round_by_mutations(ax, df: pd.DataFrame):
  df = df[df['target_short_name'].isin(['SARS-CoV-1', 'SARS-CoV-2'])].copy()

  # Since we plot multiple targets, there may be slightly different values of
  # the parent binding.
  assert utils.extract_parent_df(df)['value'].values.max() < 1e-10
  df.loc[df['source_num_mutations'] == 0, 'value'] = 0.

  plotting.plot_log_kd(
      ax=ax,
      agg_df=df,
      x_col='source_num_mutations',
      hue_col='target_short_name',
      max_impute_inf=True,
      show_iqr=True,
  )

In [None]:
#@title Round 1
_, ax = plt.subplots(figsize=(7, 5))
_plot_round_by_mutations(ax, Data.main[0])

In [None]:
#@title Round 2
_, ax = plt.subplots(figsize=(6, 4))
_plot_round_by_mutations(ax, Data.main[1].query('source_design != @SHUFFLED'))

In [None]:
#@title Round 3
_, ax = plt.subplots(figsize=(10, 5))
_plot_round_by_mutations(ax, Data.main[2].query('source_design != @SHUFFLED'))

# Supplemental figures for each round

Using gridspec to avoid manually dealing with alignment.

In [None]:
def supplemental_round_analysis_figure(alphaseq_data, agg_data, round_idx: int):
  width_ratios = [3, 2] if round_idx in [0, 1] else [2, 3]

  model_mappings = {
      0:
          covid.ROUND0_SOURCE_GROUP_MAPPING,
      1:
          covid.ROUND1_SOURCE_GROUP_MAPPING,
      2:
          covid.ROUND2_SOURCE_GROUP_MAPPING | {
              'baseline_r0': 'Recomb. & \nsingles (R1)',
              'baseline_r1': 'Recomb. & \nsingles (R2)',
          },
  }

  fig = plt.figure(figsize=(20, 16))
  legend_spec = gridspec.GridSpec(
      3,
      2,
      top=.8,
      height_ratios=[.15, 1, 1],
  )
  figure_spec = gridspec.GridSpec(
      3,
      2,
      width_ratios=width_ratios,
      hspace=0.5,
      height_ratios=[.15, 1, 1],
  )

  legend_ax = fig.add_subplot(legend_spec[0, :])
  plotting.make_target_legend(ax=legend_ax, with_parent=True)

  agg_df = agg_data[round_idx]
  alphaseq_df = alphaseq_data[round_idx]

  ax1 = fig.add_subplot(figure_spec[1, 0])
  _plot_log_kd_by_model(ax1, agg_df, model_mapping=model_mappings[round_idx])
  ax1.legend().set_visible(False)

  ax2 = fig.add_subplot(figure_spec[1, 1])
  _plot_round_by_mutations(
      ax2, agg_df.query('source_design != @SHUFFLED'))
  ax2.legend().set_visible(False)

  ax3 = fig.add_subplot(figure_spec[2, 0])
  per_infs = compute_percentage_infs(
      alphaseq_df.query('target_name in @COV_WT'),
      group_vars=['target_name', 'source_group'])
  _plot_percentage_infs_by_model(
      ax3, per_infs, model_mapping=model_mappings[round_idx])
  ax3.legend().set_visible(False)

  ax4 = fig.add_subplot(figure_spec[2, 1])
  per_infs_dist = compute_percentage_infs(
      alphaseq_df.query(
          'target_name in @COV_WT and source_design != @SHUFFLED'),
      group_vars=['target_name', 'source_num_mutations'])
  _plot_percentage_infs_by_distance(ax4, per_infs_dist)
  ax4.legend().set_visible(False)


supplemental_round_analysis_figure(Data.alpha_seq, Data.main, round_idx=0)

In [None]:
supplemental_round_analysis_figure(Data.alpha_seq, Data.main, round_idx=1)

In [None]:
supplemental_round_analysis_figure(Data.alpha_seq, Data.main, round_idx=2)

# Hit rate

In [None]:
#@title By distance

def _plot_hits_by_mutations(ax, df, model_as_hue: False):
  df = df.query(
      'target_name == @COV2_WT and pvalue_corrected == pvalue_corrected '
      'and (round == 2 or source_num_mutations == 0)'
      )

  if model_as_hue:
    df = plotting.replace_ml_design_by_model(df)

  plotting.plot_hit_rate(
      ax=ax,
      agg_df=df,
      how='pvalue',
      x_col='source_num_mutations',
      hue_col='source_design',
      show_counts=True,
  )
  ax.set_title(
      r'Hits against SARS-CoV-2 (p $\leq$ 0.05) in round 3', y=-.3)


_, ax = plt.subplots(figsize=(12, 5))
_plot_hits_by_mutations(ax, Data.unified, model_as_hue=True)

In [None]:
#@title Comparison by pvalue correction
f, axes = plt.subplots(1, 3, figsize=(30, 5))

df = Data.unified.query(
    'target_name == @COV2_WT and pvalue_corrected == pvalue_corrected '
    'and (round == 2 or source_num_mutations == 0)')
df = df.assign(
    fdr_by=statsmodels.stats.multitest.multipletests(
        df['pvalue'], alpha=.05, method='fdr_by')[1],
    fdr_bh=statsmodels.stats.multitest.multipletests(
        df['pvalue'], alpha=.05, method='fdr_bh')[1],
    bonferroni=statsmodels.stats.multitest.multipletests(
        df['pvalue'], alpha=.05, method='bonferroni')[1])

title = 'Round 3'

for ax, method in zip(axes, ['fdr_by', 'fdr_bh', 'bonferroni']):
  plotting.plot_hit_rate(
      ax=ax,
      agg_df=df.assign(pvalue_corrected=df[method]),
      how='pvalue',
      x_col='source_num_mutations',
      hue_col='source_design',
      show_counts=True)
  ax.set_ylabel(method.upper())

f.suptitle(title, y=-.03)

## Hits in last round, aggregated

In [None]:
def analyze_hits(df):
  df = df.assign(significant_hit=Data.unified['pvalue_corrected'] <= .05)
  df = df[(df['round'] == 2) | (df['source_num_mutations'] == 0)]
  df = df[df.target_name == COV2_WT]

  display.display(
      df.groupby('source_design')['significant_hit'].agg(
          ['sum', 'count', 'mean']))


analyze_hits(Data.unified)

In [None]:
#@title By targets

def _cov1_hits_by_iqr(row):
  # Mock a "significant hit" by setting a pvalue < .05.
  if row['target_name'] == COV1_WT:
    return 0 if row['value'] <= 0 else 1

  return row['pvalue_corrected']


def add_joint_targets(df, how: str):
  """Adds rows corresponding to hits to both CoV1 and CoV2."""
  df = df[df['target_name'].isin([COV1_WT, COV2_WT])].copy()
  df['target_name'] = df['target_name'].map(
      covid.TARGET_SHORT_NAME_MAPPING)


  if how == 'pvalue':
    measurement_col = 'pvalue_corrected'
  elif how == 'iqr':
    measurement_col='value'
  else:
    raise ValueError('`how` argument must be either "pvalue" or "iqr".')

  worst_binding = (
      df.groupby([
          'source_seq', 'source_design', 'source_num_mutations'
      ])[measurement_col].max().reset_index().assign(target_name='CoV1 & CoV2'))
  return pd.concat([df, worst_binding], ignore_index=True)


def _plot_hits_by_targets(ax, df, how, cov1_as_iqr: bool):
  df = df[(df['round'] == 2) | (df['source_num_mutations'] == 0)]
  if cov1_as_iqr:
    df = df.assign(pvalue_corrected=df.apply(_cov1_hits_by_iqr, axis=1))

  df = add_joint_targets(plotting.replace_ml_design_by_model(df), how=how)

  plotting.plot_hit_rate(
      ax,
      agg_df=df,
      how=how,
      x_col='target_name',
      hue_col='source_design',
      order=('SARS-CoV-1', 'SARS-CoV-2', 'CoV1 & CoV2'),
      show_counts=True)
  ax.set_title('Hits by target in round 3.', y=-.2)
  return df


_, ax = plt.subplots(figsize=(7, 5.5))
df = _plot_hits_by_targets(ax, Data.unified, how='pvalue', cov1_as_iqr=True)

In [None]:
# @title Hits analysis


def describe_mutations(seq, ipos_to_pos, ipos_to_region):
  mutated_pos = []
  mutated_regions = set()
  for i, (mut, wt) in enumerate(zip(seq, covid.PARENT_SEQ)):
    if mut != wt:
      mutated_pos.append(ipos_to_pos[i])
      mutated_regions.add(ipos_to_region[i])
  return mutated_pos, sorted(mutated_regions)


def describe_hits(df, target_name):
  df = df[df['target_name'] == target_name].copy()
  parent_seq = covid.load_aligned_parent_seq(offset_ipos=0)
  ipos_to_pos = dict(zip(parent_seq['ipos'], parent_seq['pos']))
  ipos_to_region = dict(zip(parent_seq['ipos'], parent_seq['region']))

  parent_binding = helper.get_unique_value(
      utils.extract_parent_df(df)['value']
  )
  hits = df[df['value'] <= parent_binding - 1].copy()
  hits['pos'], hits['regions'] = zip(
      *hits['source_seq'].apply(
          describe_mutations,
          ipos_to_pos=ipos_to_pos,
          ipos_to_region=ipos_to_region,
      )
  )

  return hits[[
      'source_seq',
      'source_model',
      'source_group',
      'pos',
      'regions',
      'value',
  ]].sort_values(by='value')


design_round = 0
print(
    f'Sequences with >= 1 IQR improvement over parent in round {design_round}'
)
describe_hits(Data.main[design_round], target_name=COV2_WT)

# Diversity

### t-SNE

In [None]:
Data.tsne = utils.compute_tsne_embedding(
    Data.unified[Data.unified['target_name'] == COV2_WT])

In [None]:
Data.tsne['bli_v1'] = Data.tsne['source_seq'].apply(
    lambda x: x in Data.bli_v1.source_seq.unique())
Data.tsne['bli_v2'] = Data.tsne['source_seq'].apply(
    lambda x: x in Data.bli_v2.source_seq.unique())

In [None]:
_, ax = plt.subplots(figsize=(7, 5))
plotting.plot_tsne_by_round(ax=ax, agg_df=Data.tsne, round_idx=0)

In [None]:
def make_tsne_figure():

  spec = gridspec.GridSpec(
      ncols=3,
      nrows=2,
      width_ratios=[1, 1, 1],
      wspace=0.25,
      hspace=0.2,
      height_ratios=[.05, 1])

  fig = plt.figure(figsize=(20, 5.5))
  bli_legend_ax = fig.add_subplot(spec[0, :])
  plotting.make_tsne_legend(ax=bli_legend_ax, with_initial_bli=True)

  # Fig A
  ax1 = fig.add_subplot(spec[1, 0])
  plotting.plot_tsne_by_round(ax1, Data.tsne, round_idx=0)
  ax1.set_title('Round 1', y=-.3)

  ax2 = fig.add_subplot(spec[1, 1])
  plotting.plot_tsne_by_round(ax2, Data.tsne, round_idx=1)
  ax2.set_title('Round 2', y=-.3)

  ax3 = fig.add_subplot(spec[1, 2])
  plotting.plot_tsne_by_round(ax3, Data.tsne, round_idx=2)
  ax3.set_title('Round 3', y=-.3)


make_tsne_figure()

## Distance by log KD

In [None]:
def _plot_log_kd_by_distance(ax, df, n):
  df = df[(df['round'] == 2) | (df['source_num_mutations'] == 0)]
  plotting.plot_log_kd_by_distance(
      ax=ax, df=df, target_name=COV2_WT, n=n)
  ax.set_title(f'Top {n} sequences in round 3', y=-.3)

_, ax = plt.subplots(figsize=(7, 5))
_plot_log_kd_by_distance(ax, Data.unified, n=100)

In [None]:
def _plot_log_kd_by_mpd(ax, df, n: int):
  df = df[df['source_design'] != SHUFFLED]
  plotting.plot_log_kd_by_mean_pairwise_distance(
      ax, df, n=n, target_name=COV2_WT)

_, ax = plt.subplots(figsize=(7, 5))
_plot_log_kd_by_mpd(ax, Data.unified.query('round == 2'), n=50)

## Dendogram

In [None]:
_, ax = plt.subplots(figsize=(5, 12))
plotting.plot_dendogram(ax, Data.bli_v2)

# BLI and neutralization experiments

In [None]:
def join_curves_with_bli_data(
    curves: pd.DataFrame, bli: pd.DataFrame
) -> pd.DataFrame:
  """Joins BLI curves with BLI summary statistics."""
  return helper.safe_merge(curves, bli, on=['target_name', 'source_pid'])


def aggregate_curves_over_replicas(curves: pd.DataFrame) -> pd.DataFrame:
  """Aggregates BLI or neutralization curves over replicas."""
  if 'meas_replica' in curves.columns:
    # Neutralization columns
    by = ['target_name', 'source_pid', 'meas_replica', 'conc']
  else:
    # BLI columns
    by = ['target_name', 'source_pid', 'conc', 'time']
  return (
      curves.drop(columns=['replica', 'source_key'], errors='ignore')
      .groupby(by)
      .mean()
      .reset_index()
  )


def preprocess_curves(curves: pd.DataFrame, bli: pd.DataFrame) -> pd.DataFrame:
  """Preprocesses BLI or neutralization curves."""
  return curves.pipe(aggregate_curves_over_replicas).pipe(
      join_curves_with_bli_data, bli=bli
  )


D.bli_curves = preprocess_curves(
    covid.load_df('bli_v2_curves.csv.gz', compression='gzip'), D.bli_v2
)
D.neutralization_curves = preprocess_curves(
    covid.load_df('bli_v2_neutralization_curves.csv.gz', compression='gzip'),
    D.bli_v2,
)

## BLI charts

In [None]:
BLI_METRIC_MAPPING = immutabledict.immutabledict({
    'kd': 'KD (nM)',
    'ic50': 'IC50 (nM)',
    'prod': 'Expression (mg/L))',
})


BLI_COLOR_PALETTES = immutabledict.immutabledict({
    'kd': 'spectral_a',
    'ic50': 'plasma_d',
    'prod': 'viridis',
})


BLI_LABELS_MAPPING = collections.OrderedDict([
    (63207, 'Seq1'),
    (63411, 'Seq2'),
    (64284, 'Seq3'),
    (63052, 'Seq4'),
    (61433, 'Seq5'),
    (62541, 'Seq6'),
    (61300, 'Seq7'),
    (63385, 'Seq8'),
    (60461, 'Seq9'),
    (60441, 'Seq10'),
    (62851, 'Seq11'),
    (9643, 'Seq12'),
    (60008, 'VHH-72'),
])


IC50_LOWER_BOUNDS = (500, 1000)

PARENT_COLOR = '#636363'


def bli_value_is_lower_bound(metric: str, record: pd.Series) -> bool:
  return metric == 'ic50' and record[metric] in IC50_LOWER_BOUNDS


def _plot_bli_bar(
    data,
    metric,
    max_value,
    labels,
    show_labels,
    width=70,
    height=210):
  """Plots the BLI measurement of a single target (one column)."""
  df = data.assign(
      rank_metric=lambda df: df[f'{metric}'].rank(method='first'),
      is_best=lambda df: metric != 'prod' and df['rank_metric'] == 1)

  parent_record = df.query('source_is_parent')
  parent_record = parent_record.iloc[0]
  parent_is_lower_bound = bli_value_is_lower_bound(metric, parent_record)
  metric_fold = f'{metric}_fold'

  def _get_text_label(row):
    value = row[metric]
    if value > max_value:
        label = f'>{int(max_value):d}'
    elif bli_value_is_lower_bound(metric, row):
      label = f'>{int(value):d}'
    else:
      label = f'{value:.1f}'
    if row['is_best'] and metric_fold in row:
      fold_value = row[metric_fold]
      fold_label = f'{fold_value:.1f}x'
      if parent_is_lower_bound:
        fold_label = f'>{fold_label}'
      label = f'{label} ({fold_label})'
    return label

  df = df.assign(text=lambda df: df.apply(_get_text_label, axis=1))

  tooltip = helper.ordered_intersection(
      ['key', 'source_std_group', 'source_num_mutations',
       'kd', 'kd_fold', 'ic50', 'ic50_fold', 'ka', 'kdis'],
      df.columns)

  def _get_y(axis=None):
    return alt.Y(
        'key:N',
        title='',
        sort=list(labels.values()),
        axis=axis)

  base = alt.Chart(df).encode(
      y=_get_y(axis=alt.Axis(labels=False)),
      x=alt.X(metric, title=BLI_METRIC_MAPPING.get(metric, metric)),
      tooltip=tooltip,
  )

  bars = base.mark_bar().encode(
      x=alt.X(metric, scale=alt.Scale(domain=(0, max_value), clamp=True)),
      color=alt.condition(
          alt.datum.source_is_parent,
          alt.value(PARENT_COLOR),
          plotting.get_alt_color(
            metric,
            palette=BLI_COLOR_PALETTES.get(metric),
            legend=alt.Legend(
                orient='none',
                direction='horizontal',
                legendX=0,
                legendY=-50),
            title=helper.get_unique_value(df['target_name'])))
  )

  def _get_labels(condition, color='black', **kwargs):
    return alt.Chart(df).mark_text(
        align='right',
        **kwargs
    ).encode(
        x=alt.value(-5),
        y=_get_y(),
        text='key:N',
        color=alt.condition(condition, alt.value(color), alt.value('black')),
        opacity=alt.condition(
            condition, alt.value(1), alt.value(0)
        ) if show_labels else alt.value(0)
    )

  def _get_text(condition, **kwargs):
    return base.mark_text(
        align='left',
        baseline='middle',
        dx=3,
        **kwargs
    ).encode(
        text='text',
        opacity=alt.condition(condition, alt.value(1), alt.value(0))
    )

  chart = (
      bars
      + _get_text(~alt.datum.is_best)
      + _get_text(alt.datum.is_best, color='red', fontWeight='bold')
  )
  if show_labels:
    chart += (
      _get_labels(~alt.datum.source_is_parent, color='black')
      + _get_labels(
          alt.datum.source_is_parent, color=PARENT_COLOR, fontWeight='bold'))
  return chart.properties(width=width, height=height)


def plot_bli_bars(bli, metric, max_value, width=70, **kwargs):
  """Plots the BLI measurement of multiple target."""
  label_by_key = BLI_LABELS_MAPPING
  df = (
      bli
      .assign(key=lambda df: df['source_key'].map(label_by_key)))

  charts = []
  target_names = helper.ordered_intersection(
      TARGET_NAME_ORDER, df['target_name'])
  num_cols = 0
  for target_name in target_names:
    df_target = df.query('target_name == @target_name')
    if df_target[metric].isna().all():
      continue
    num_cols += 1
    charts.append(
        _plot_bli_bar(
            df_target,
            labels=label_by_key,
            show_labels=num_cols == 1,
            metric=metric,
            max_value=max_value,
            width=width,
            **kwargs))
  return (
      alt.hconcat(*charts)
      .resolve_scale(color='independent')
      .configure_legend(gradientLength=width))

In [None]:
plot_bli_bars(
    helper.map_columns(D.bli_v2, target_name=TARGET_NAME_MAPPING),
    metric='kd',
    max_value=50,
)

In [None]:
plot_bli_bars(
    helper.map_columns(D.bli_v2, target_name=TARGET_NAME_MAPPING),
    metric='ic50',
    max_value=500,
)

## BLI curves

In [None]:
def plot_bli_curves(data, labels, with_legend=True, figsize=(6, 4)):
  _, ax = plt.subplots(figsize=figsize)
  sns.lineplot(
      data=data,
      x='time',
      y='data',
      hue='label',
      hue_order=labels,
      palette='tab10',
      linewidth=3,
      alpha=0.8,
      ax=ax,
  )
  ax.set(xlabel='Time (s)', ylabel='Response (nm)')
  if with_legend:
    ax.legend(title='', ncol=1, bbox_to_anchor=(1.0, 1.0))
  else:
    ax.legend_.remove()
  return ax


def plot_bli_curves_facet(data, labels, concentation=30):
  df = data.query('conc == @concentation and label in @labels').pipe(
      helper.map_columns, target_name=TARGET_NAME_MAPPING
  )
  target_names = helper.ordered_intersection(
      TARGET_NAME_ORDER, df['target_name']
  )
  for target_name in target_names:
    plot_bli_curves(
        data=df.query('target_name == @target_name'),
        labels=labels,
        with_legend=target_name == target_names[-1],
    ).set_title(target_name)

In [None]:
plot_bli_curves_facet(D.bli_curves, labels=['Seq1', 'Seq2', 'Seq3', 'VHH-72'])

## Neutralization curves

In [None]:
def plot_neutralization_curves(data, labels, with_legend=True, figsize=(6, 4)):
  _, ax = plt.subplots(figsize=figsize)
  sns.lineplot(
      data=data,
      x='conc',
      y='value',
      hue='label',
      hue_order=labels,
      palette='tab10',
      linewidth=3,
      marker='o',
      dashes=True,
      err_style='bars',
      errorbar='sd',
      ax=ax,
  )
  ax.set(xlabel='Concentation (nM)', xscale='log', ylabel='% Infectivity')
  if with_legend:
    ax.legend(title='', ncol=1, bbox_to_anchor=(1.0, 1.0))
  else:
    ax.legend_.remove()
  return ax


def plot_neutralization_curves_facet(data, labels):
  df = data.query('label in @labels').pipe(
      helper.map_columns, target_name=TARGET_NAME_MAPPING
  )
  target_names = helper.ordered_intersection(
      TARGET_NAME_ORDER, df['target_name']
  )
  for target_name in target_names:
    plot_neutralization_curves(
        data=df.query('target_name == @target_name'),
        labels=labels,
        with_legend=target_name == target_names[-1],
    ).set_title(target_name)

In [None]:
plot_neutralization_curves_facet(
    D.neutralization_curves, labels=['Seq1', 'Seq2', 'Seq3', 'VHH-72']
)

# Model evaluation

## Constants

In [None]:
# Datasets that were used for training.
CONFIG_DSET_MAPPING = immutabledict.immutabledict({
    'r1': 'Round 1',
    'r2': 'Round 1+2',
    'r3': 'Round 1+2+3',
})

CONFIG_DSET_ORDER = tuple(CONFIG_DSET_MAPPING.values())

CONFIG_MODEL_MAPPING = immutabledict.immutabledict({
    'cnn': 'CNN',
    'lgb': 'LGB',
    'linear': 'Baseline',
})

MODEL_ORDER =  ('Baseline', 'CNN', 'LGB')

METRIC_MAPPING = immutabledict.immutabledict({
    'spearmanr': 'Spearman R',
    'roc_auc_score_parent': 'ROC AUC VHH-72',
})

MODEL_COLUMN_MAPPING = immutabledict.immutabledict({
    'config_dset': 'Training set',
    'config_model': 'Model',
})


# Targets that were used to train both r1 and r2 models.
JOINT_MODEL_TARGETS = (
    'SARS-CoV1_RBD',
    'SARS-CoV2_RBD',
    'SARS-CoV2_RBD_G502D',
    'SARS-CoV2_RBD_N439K',
    'SARS-CoV2_RBD_N501D',
    'SARS-CoV2_RBD_N501F',
    'SARS-CoV2_RBD_S477N',
    'SARS-CoV2_RBD_V367F',
)

def map_model_columns(df, ignore_missing=True, **kwargs):
  return helper.map_columns(
      df,
      config_model=CONFIG_MODEL_MAPPING,
      metric=METRIC_MAPPING,
      config_dset=CONFIG_DSET_MAPPING,
      **kwargs)

## Read data

In [None]:
def _read_model_eval_df(name):
  filename = os.path.join(covid.DATA_DIR, 'model_scores', f'{name}.csv.gz')
  helper.tprint(filename)
  return helper.read_csv(filename)


M = helper.Bunch()
# Hold-out model evaluation scores.
M.scores = _read_model_eval_df('scores')
# Hold-out model evaluation scores per distance.
M.scores_dist = _read_model_eval_df('scores_dist')

In [None]:
def _read_cv_scores():
  name_by_config_dset = dict(
      r1='cv_scores_r0',
      r2='cv_scores_r1',
      r3='cv_scores_r2',
  )
  return (
      pd.concat([
          _read_model_eval_df(name).assign(config_dset=config_dset)
          for config_dset, name in name_by_config_dset.items()
          ], ignore_index=True)
      .rename(columns={'config': 'config_model'})
  )

# Cross-validation scores.
M.cv_scores = _read_cv_scores()

## Per round

In [None]:
def plot_model_performance_per_round(
    scores,
    metric,
    metric_label,
    config_dset=None,
    ylim=None,
    figsize=(5, 5),
    **kwargs):
  df = scores.query(
      'metric == @metric and target_name in @JOINT_MODEL_TARGETS and '
      'config_model in @CONFIG_MODEL_MAPPING')
  if 'source_dset' in df.columns:
    # The dataset that was used for evaluation. "r2" corresponds to the 3rd
    # library, which was not used for model-training.
    df = df.query('source_dset == "r2"')
  if config_dset:
    df = df.query('config_dset == @config_dset')
  df = map_model_columns(df)

  _, ax = plt.subplots(figsize=figsize)
  ax = sns.barplot(
      data=df,
      x='config_dset',
      order=helper.ordered_intersection(CONFIG_DSET_ORDER, df['config_dset']),
      y='value',
      hue='config_model',
      palette=COLOR_PALETTE,
      hue_order=MODEL_ORDER,
      ax=ax,
      **kwargs)
  ax.set_xlabel('Training set')
  ax.set_ylabel(metric_label)
  if ylim is not None:
    ax.set_ylim(ylim)
  ax.legend(bbox_to_anchor=(.5, 1.1), loc='center', ncol=3, frameon=False)

### Hold-out

In [None]:
plot_model_performance_per_round(
    # Plot the performance of the joint regressor-classifier model ("joint").
    # config_type == "reg"/"cla" selects the results of the
    # regressor/classififer.
    M.scores.query('config_type == "joint"'),
    metric='spearmanr',
    metric_label='Hold-out Spearman R')

In [None]:
plot_model_performance_per_round(
    M.scores.query('config_type == "joint"'),
    metric='roc_auc_score_parent',
    metric_label='Hold-out ROC AUC VHH-72',
    ylim=(0.5, None))

### Cross-validation

In [None]:
plot_model_performance_per_round(
    M.cv_scores,
    metric='spearmanr',
    metric_label='CV Spearman R',
    config_dset=['r1', 'r2'],
    ylim=(0.5, None))

In [None]:
plot_model_performance_per_round(
    M.cv_scores,
    metric='roc_auc_score_parent',
    metric_label='CV ROC AUC VHH-72',
    config_dset=['r1', 'r2'],
    ylim=(0.75, None))

## Per distance

In [None]:
def plot_model_performance_per_distance(
    all_scores,
    metric='spearmanr',
    with_num_samples=True,
    ylim=None,
    ylabel=None,
    ylabel_prefix=None,
    figsize=(12, 5.5),
    **kwargs):
  scores = all_scores.query(
      'target_name in @JOINT_MODEL_TARGETS and config_type == "joint" and '
      'config_model in @CONFIG_MODEL_MAPPING and source_dset == "r2"')
  df = (
      scores
      .query('metric == @metric')
      .pipe(map_model_columns)
      .rename(columns=MODEL_COLUMN_MAPPING))

  if with_num_samples:
    # Plot the number of samples that were used for the evaluation.
    evaluator = helper.get_unique_value(df['evaluator'])
    df_samples = (
        scores.query('metric == "samples" and evaluator == @evaluator'))
    display.display(
        df_samples.pivot_table(
            index=['config_dset', 'config_model'],
            columns='source_num_mutations',
            values='value'))
    samples_per_dist = (
        df_samples.groupby('source_num_mutations')['value'].mean()
        .astype('int').to_dict())

    def _get_xlabel(dist):
      n = samples_per_dist[dist]
      if n > 1000:
        n = f'{int(n // 1000)}k'
      return f'{dist:d}\nn={n}'

    df['x'] = df['source_num_mutations'].apply(_get_xlabel)
  else:
    df['x'] = df['source_num_mutations']
  order = sorted(set(df['x']))

  _, ax = plt.subplots(figsize=figsize)
  sns.lineplot(
      data=df,
      x='x',
      y='value',
      hue='Model',
      hue_order=MODEL_ORDER,
      palette=dict(COLOR_PALETTE),
      style='Training set',
      style_order=['Round 1+2', 'Round 1'],
      markers=True,
      dashes=True,
      markersize=10,
      err_style='bars',
      ax=ax)
  ax.tick_params(axis='x', which='major', labelsize=15)
  ax.set_xlabel(f'Number of mutations from {PARENT_NAME}')

  if ylabel is None:
    ylabel = helper.get_unique_value(df['metric'])
  if ylabel_prefix is not None:
    ylabel = f'{ylabel_prefix} {ylabel}'
  ax.set_ylabel(ylabel)
  ax.legend(
      loc='lower left', ncol=2, frameon=True)

In [None]:
plot_model_performance_per_distance(
    M.scores_dist,
    metric='spearmanr',
    ylabel_prefix='Hold-out',
    ylim=(0.2, None))

In [None]:
plot_model_performance_per_distance(
    M.scores_dist,
    metric='roc_auc_score_parent',
    ylabel_prefix='Hold-out',
    ylim=(0.2, None))