# Overview

The object detection code identifies potential parasites and segments candidate patches into staining artifacts, hypnozoites, and schizonts. One concern: the object detector is trained on a limited set of human labeled examples. Drug candidates may affect parasite appearance in ways that lead to incorrect classification by the object detector. We are particularly concerned about damaged hypnozoites being mislabeled as staining artifacts.

In this Colab we do two things:

1) We build a simple model of parasite counts to get a rough estimate of the hypnozoite / artifact misclassification rate, then

2) We built a better hypnozoite / artifact classifier and use these estimates to calibrate it.


In [None]:
#@title Run this cell only the FIRST time you connect to the colab kernel
!pip install gcsfs
!git clone https://github.com/google/cell_img
!pip install --quiet -e cell_img
!pip install lightgbm
!pip install statsmodels --upgrade

In [None]:
#@title For Cloud VM kernel, run this after restarting before granting access

!ls /content/.config/
!rm /content/.config/gce
!rm /var/colab/mp

import os
try:
  os.environ['NO_GCE_CHECK']
  del os.environ['NO_GCE_CHECK']
except KeyError:
  pass
try:
  os.environ['GCE_METADATA_TIMEOUT']
  del os.environ['GCE_METADATA_TIMEOUT']
except KeyError:
  pass

In [None]:
#@title Run this cell after restarting your kernel. It will pop up window to grant access.
from google.colab import auth
auth.authenticate_user()

In [None]:
CLOUD_BUCKET = 'bucket_name'

# path to plate layout metadata
METADATA_PATH = 'path/to/metadata'
METADATA_PREFIX = f'gs://{CLOUD_BUCKET}/{METADATA_PATH}/metadata_prefix'

# path to plate quality control annotations
PLATE_ANNOTATION_FILE = f'gs://{CLOUD_BUCKET}/path/to/annotations.csv'

# template for csv files with counts
COUNT_CSV_TEMPLATE = f'gs://{CLOUD_BUCKET}/path/to/counts.csv'

# path on the bucket where output should be saved
OUTPUT_PATH = '/path/to/outputs'

# path to a parquet file containing patch embeddings
PATCH_PARQUET_TEMPLATE='path/to/patches.parquet'

In [None]:
from typing import NamedTuple

import datetime
import fsspec
import gc
import os
import re
import time

from cell_img.analysis import optim_lib
import gcsfs
import jax
import jax.numpy as jnp
import lightgbm
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import scipy
import seaborn as sns
import statsmodels.api as sm

from google.cloud import storage

In [None]:
# Utility methods for reading Cloud files

def glob_cloud(bucket, prefix, must_contain_str=None):
  """Returned files start with gs://[bucket]/[prefix] and contain the must_contain_str if provided."""
  client = storage.Client()
  files_in_dir = [blob.name for blob in client.list_blobs(bucket, prefix=prefix)]
  if not must_contain_str:
    return [os.path.join('gs://%s' % bucket, f) for f in files_in_dir]

  ret_val = []
  for f in files_in_dir:
    if must_contain_str in f:
      ret_val.append(os.path.join('gs://%s' % bucket, f))
  return ret_val


def format_plate_strings(plate_names):
  """Format the plate strings as strings of five digit ints.

  Args:
    plate_names: A pd.Series of strings representing the plate names that we
      want to format.
  Raises:
    ValueError: If plate_names contains a name that is more than five digits
      long.
  Returns:
    formatted_plates: A pd.Series representing the formatted plate names.
  """
  # Format the plate strings as 5 character strings with no decimal
  formatted_plates = plate_names.astype(str).apply(
      lambda x: x.split('.')[0].zfill(5))
  # If any of the plates are more than 5 digits, scream loudly.
  len_plate_names = np.array([len(p) for p in formatted_plates.values])
  if np.any(len_plate_names > 5):
    raise ValueError('Plate name > 5 characters found')
  # If any of the plates have non-digit characters, scream loudly.
  if not np.all([re.match(r'^\d+$', p) for p in formatted_plates.values]):
    raise ValueError('Plate with non-digit characters found')
  return formatted_plates


def compress_df(df: pd.DataFrame) -> pd.DataFrame:
  """Reduce the byte size of columns."""
  for col in df.columns:
    if df[col].dtype == np.float64:
      df[col] = df[col].astype(np.float32)
    elif df[col].dtype == np.int64:
      df[col] = df[col].astype(np.int32)
  return df


def create_parquet_filter(field, values_to_keep):
  # Create the filter for all plates
  return [(field, 'in', set(values_to_keep))]


def load_parquet_to_emb_df(bucket, prefix, batch_list, filter_list=None):
  # Read the new Parquet output
  emb_df_list = []
  filepaths = []
  for b in batch_list:
    this_b = glob_cloud(bucket, prefix % b)
    filepaths.extend(this_b)
  print('There are a total of %d shards' % len(filepaths))

  t0 = time.time()
  for i, shard_path in enumerate(filepaths):
    if (i + 1) % 100 == 0:
      print(f'shard {i + 1}, {time.time() - t0:.1f} sec')
    with fsspec.open(shard_path) as f:
      if filter_list:
        shard_df = pd.read_parquet(
            f,
            #columns=['col_name'],  # only read these columns
            filters=filter_list,
          )
      else:
        shard_df = pd.read_parquet(f)
    shard_df = compress_df(shard_df)
    emb_df_list.append(shard_df)
    gc.collect()

  emb_df = pd.concat(emb_df_list)
  if len(emb_df) == 0:
    raise ValueError('The embedding dataframe is empty, did you use the right filters?')

  for col in ['image', 'channel_order']:
    if col in emb_df.columns:
      emb_df.drop(columns=[col], inplace=True)

  # expand the embedding
  tmp_df = pd.DataFrame([pd.Series(x) for x in emb_df.embedding])
  tmp_df.columns = [str(x) for x in range(192)]
  emb_df = pd.concat([emb_df.reset_index(), tmp_df], axis=1)
  emb_df.drop(columns=['embedding'], inplace=True)

  # expand the parasite stage inference values
  if 'parasite_stage_infer' in emb_df.columns:
    tmp_df = pd.DataFrame([pd.Series(x) for x in emb_df.parasite_stage_infer])
    tmp_df.columns = ['stage_infer_artifact', 'stage_infer_hypnozoite', 'stage_infer_schizont']
    emb_df = pd.concat([emb_df, tmp_df], axis=1)
    emb_df.drop(columns=['parasite_stage_infer', 'parasite_stage_names'], inplace=True)

  # set the index
  emb_df.set_index([c for c in emb_df.columns if not c in [str(x) for x in range(192)]], inplace=True)

  return emb_df

## Load and preprocess the parasite counts

The code below is specific to our experiments.

In [None]:
# get screen plates that have passed QC as well as dose response plates
plate_ann_df = pd.read_csv(PLATE_ANNOTATION_FILE)
plate_ann_df['plate'] = format_plate_strings(plate_ann_df['plate'])

screen_plates = set(plate_ann_df.query('singlePointScreening == "yes"').plate)
qc_screen_plates = set(plate_ann_df.query('singlePointScreening == "yes" and RZprime > 0').plate)
qc_screen_batches = set(plate_ann_df.query('singlePointScreening == "yes" and RZprime > 0').experiment)

del plate_ann_df

In [None]:
# load a dataframe of parasite and artifact counts for all plates
count_df = []
for batch in qc_screen_batches:
  batch_count_df = pd.read_csv(COUNT_CSV_TEMPLATE % batch)
  batch_count_df['plate'] = format_plate_strings(batch_count_df['plate'])
  batch_count_df = batch_count_df[batch_count_df['plate'].isin(qc_screen_plates)].copy()
  count_df.append(batch_count_df)
count_df = pd.concat(count_df, axis=0)

In [None]:
count_df = count_df.rename(columns={'num_artifact': 'ml_artifact', 'num_hypnozoite': 'ml_hypnozoite', 'num_schizont': 'ml_schizont'})
count_df = count_df.drop(columns=['num_obj'])

In [None]:
count_df.head()

In [None]:
count_df.shape

In [None]:
# Find metadata paths
metadata_paths={batch:[] for batch in qc_screen_batches}
glob_list = glob_cloud(CLOUD_BUCKET, METADATA_PATH)
for blob_name in glob_list:
  for batch in qc_screen_batches:
    if blob_name.startswith(f'{METADATA_PREFIX}{batch}'):
      metadata_paths[batch].append(blob_name)

missing_batches = []
for batch, paths in metadata_paths.items():
  if len(paths) == 0:
    missing_batches.append(batch)

if missing_batches:
  raise Exception('Unable to find metadata for batches: %s' % missing_batches)

# Only keep the paths for the latest metadata
metadata_paths = {key:sorted(val)[-1] for key, val in metadata_paths.items()}

In [None]:
# Load metadata from cloud
metadata_list = []
for met in metadata_paths.values():
  metadata_list.append(pd.read_csv(met))
metadata = pd.concat(metadata_list)
metadata.head()

In [None]:
# These are the columns in our metadata to keep and how to rename them.
METADATA_COLS = ['plate', 'Metadata_Well', 'actives', 'concentration', 'hepLot',
                 'blindedConcept', 'flag-IQCH30', 'flag-Nuclei',
                 'Count_Nuclei', 'hypnozoite', 'schizont']
METADATA_RENAME = {
    'Metadata_Well': 'well',
    'hepLot': 'hep_lot',
    'blindedConcept': 'blinded_concept',
    'flag-IQCH30': 'flag_iqch30',
    'flag-Nuclei': 'flag_nuclei',
    'Count_Nuclei': 'cp_hepatocyte',
    'hypnozoite': 'cp_hypnozoite',
    'schizont': 'cp_schizont'}

metadata = metadata[METADATA_COLS].rename(columns=METADATA_RENAME)
metadata['actives'] = metadata.actives.apply(
    lambda a: a.replace(' ', '_')
)
metadata['plate'] = metadata.plate.apply(lambda x: str(int(x)).zfill(5))
metadata = metadata[metadata['plate'].isin(qc_screen_plates)].copy()

In [None]:
metadata.head()

In [None]:
# Join counts and metadata and clean up column names, fix NaNs, cast values to floats
tmp_df = count_df.merge(metadata, on=['plate', 'well'])

STAGES = ['artifact', 'hypnozoite', 'schizont', 'hepatocyte']
for s in STAGES:
  for colname in ['ml_%s' % s, 'cp_%s' % s]:
    if colname not in tmp_df.columns:
      continue
    values = tmp_df[colname].to_numpy().astype(np.float32)
    values[~np.isfinite(values)] = 0.
    tmp_df[colname] = values

for col in ['concentration']:
  tmp_df[col] = tmp_df[col].to_numpy().astype(np.float32)

for col in ['flag_iqch30', 'flag_nuclei']:
  tmp_df[col] = (tmp_df[col].to_numpy() == 'yes')

count_df = tmp_df.copy()

In [None]:
# force infected control concentration to be uniform
# (workaround for some possibly bad metadata in a few plates)
corrected_df = []
for _, plate_df in count_df.groupby(['plate']):
  plate_df = plate_df.copy()
  is_infected_control = plate_df.actives == 'infected_control'
  if np.sum(is_infected_control):
    conc = plate_df.concentration.to_numpy()
    conc[is_infected_control] = np.median(conc[is_infected_control])
    plate_df['concentration'] = conc
  corrected_df.append(plate_df)
count_df = pd.concat(corrected_df)
del corrected_df

In [None]:
count_df.head()

In [None]:
# replace ml_parasite, which currently counts all patches, with a count of all parasites
count_df['ml_parasite'] = count_df['ml_hypnozoite'] + count_df['ml_schizont']
count_df['cp_parasite'] = count_df['cp_hypnozoite'] + count_df['cp_schizont']

# get median parasite and hepatocyte counts for infected controls
COUNT_COLS = ['ml_parasite', 'ml_hypnozoite',
              'cp_parasite', 'cp_hypnozoite', 'cp_hepatocyte']
count_df_med = (count_df[count_df.actives == 'infected_control'].groupby(['plate'])
    [COUNT_COLS].median().rename(columns={col:col + '_med_neg' for col in COUNT_COLS}).reset_index())
# add in infected control medians so we can compute inhibition
count_df = count_df.reset_index().merge(count_df_med, on='plate')

count_df_med = (count_df[count_df.actives == 'active_control'].groupby(['plate'])
    [COUNT_COLS].median().rename(columns={col:col + '_med_pos' for col in COUNT_COLS}).reset_index())
# add in active control medians so we can adjust inhibition
count_df = count_df.reset_index().merge(count_df_med, on='plate')

for method in ['ml', 'cp']:
  count_df[f'inhibition_{method}_par'] = (1. - np.clip(count_df[f'{method}_parasite'] / count_df[f'{method}_parasite_med_neg'], 0, 1)).astype(np.float32)
  count_df[f'inhibition_{method}_hyp'] = (1. - np.clip(count_df[f'{method}_hypnozoite'] / count_df[f'{method}_hypnozoite_med_neg'], 0, 1)).astype(np.float32)
count_df[f'inhibition_cp_hyp_act'] = (1. - np.clip((count_df[f'cp_hypnozoite'] - count_df[f'cp_hypnozoite_med_pos'])/
                                                   (count_df[f'cp_hypnozoite_med_neg'] - count_df[f'cp_hypnozoite_med_pos']), 0, 1)).astype(np.float32)
count_df[f'inhibition_cp_par_act'] = (1. - np.clip((count_df[f'cp_parasite'] - count_df[f'cp_parasite_med_pos'])/
                                                     (count_df[f'cp_parasite_med_neg'] - count_df[f'cp_parasite_med_pos']), 0, 1)).astype(np.float32)

if 'level_0' in count_df.columns:
  count_df = count_df.drop(columns=['level_0'])

In [None]:
count_df.head()

In [None]:
count_df.columns

In [None]:
# Set of plates / wells to use:
# * Plate must pass QC
# * Median hypnozoites for infected controls must be >= 10
# * Drop wells with flag_IQCH30 set (the flag indicates a bad well)
# * Restrict to control wells

subset = ((count_df.plate.isin(qc_screen_plates)) &
          (count_df.ml_hypnozoite_med_neg >= 10) &
          (~count_df.flag_iqch30) &
          (count_df.actives.isin({'infected_control', 'uninfected_control', 'active_control'})))

control_df = count_df[subset].copy().sort_values(['plate', 'well']).reset_index()

In [None]:
control_df['ml_hyp_or_art'] = control_df['ml_hypnozoite'] + control_df['ml_artifact']

In [None]:
# Drop wells with unusual numbers of hypnozoites + artifacts
#
# Here we'll use Tukey's fence (Q3 + 3 * IQR) to flag outliers on each plate
# See https://en.wikipedia.org/wiki/Outlier#Tukey's_fences
control_df_filtered = []
for _, df_plate in control_df.groupby('plate'):
  for method in ['ml']:
    col = f'{method}_hyp_or_art'
    q1, q3 = np.quantile(control_df[col], (0.25, 0.75))
    iqr = q3 - q1
    threshold = q3 + 3. * iqr  # Tukey's fence for outliers
    df_plate = df_plate[df_plate[col] < threshold]
  bad_plate = False
  for actives in ['infected_control', 'active_control', 'uninfected_control']:
    if np.sum(df_plate.actives == actives) == 0:
      bad_plate = True
  if not bad_plate:
    control_df_filtered.append(df_plate)
control_df = pd.concat(control_df_filtered)
del control_df_filtered

## Diagnostics

A couple of plots that show something is amiss:

1) First, we see that the object detector consistently finds more staining artifacts in infected control wells than in uninfected control wells. That suggests that we may be mislabeling some hypnozoites as staining artifacts.

2) Second, we see that the more hypnozoites we have in the infected control wells, the greater the gap in artifacts. Again, this suggests the object detector may be mislabling a fraction of hypnozoites.

In [None]:
plt.figure(figsize=(8, 8))
art_uninf = control_df[control_df.actives == 'uninfected_control'].groupby('plate')['ml_artifact'].mean()
art_inf = control_df[control_df.actives == 'infected_control'].groupby('plate')['ml_artifact'].mean()

_, bins, _ = plt.hist(art_uninf, label='Uninfected controls', alpha=0.5)
plt.hist(art_inf, label='Infected controls', bins=bins, alpha=0.5)
plt.title('Mean artifact counts')
plt.legend(loc='best')
plt.show()


plt.figure(figsize=(8, 8))
hyp = control_df[control_df.actives == 'infected_control'].groupby('plate')['ml_hypnozoite'].mean()

x = hyp
y = (art_inf - art_uninf)
vmax = max(np.max(x), np.max(y))
sns.regplot(x=x, y=y)
plt.title('Excess infected control artifacts as a function of hypnozoite counts')
plt.xlabel('Infected control hypnozoites')
plt.ylabel('Infected - uninfected artifacts')
plt.hlines(0, 0, vmax, ls='--', color='gray')
plt.show()

## Estimating the true number of artifacts and hypnozoites

We're going to use a simple model to estimate the true numbers of hypnozoites and artifacts.

Our model makes a few core assumptions:

1) Within each plate, the number of artifacts in a well is roughly consistent. We model the distribution of artifacts using a negative binomial distribution.

2) Within each plate, the number of hypnozoites in the infected control wells are also negative binomially distributed as are the number of hypnozoites in the active control wells, and the two sets of wells share an overdispersion coefficient.

We can use these assumptions to compute plate-specific maximum likelihood estimates for the mean number of artifacts and hypnozoites in control wells.

We observe that the model estimates higher hypnozoite counts than does the object detector, which suggests that the object detector may be systematically misclassifying a subset of hypnozoites as staining artifacts.




### Model assumptions in detail

1) We assume the number of artifacts we observe in a well has a negative binomial distribution with mean given by

$$\text{mean}_{\text{a, plate, treatment}} = a_{\text{plate}}$$

and variance

$$\text{var}_{\text{a, plate, treatment}} = \theta_{a, \text{plate}} a_{\text{plate}}$$

where $\theta_{a, \text{plate}} \geq 1$ represents a plate-specific level of overdispersion.

2) We assume the number of hypnozoites is also negative binomially distributed with mean

$$\text{mean}_{\text{h, plate, treatment}} = h_{\text{plate}} (1 - \eta_{\text{treatment}})$$

and variance

$$\text{var}_{\text{h, plate, treatment}} = \theta_{h, \text{plate}} (1 - \eta_{\text{treatment}}) h_{\text{plate}}$$

where again $\theta_{h, \text{plate}} \geq 1$ represents a plate-specific level of overdispersion.

Here $\eta_{\text{treatment}}$ is the parasite inhibition for a treatment, i.e. the fraction of parasites that die after a treatment.

We assume the following:
* For the infected control wells, $\eta_{\text{infected}} = 0$.
* For the uninfected control wells, $\eta_{\text{uninfected}} = 1$.
* For the active control wells, we allow the treatment strength to
vary from plate to plate, i.e. $\eta_{\text{active}} = k_{\text{plate}}$.


3) The sum of two negative binomially-distributed random variables doesn't have a nice closed-form distribution except in the special case in which the variables have a common overdispersion. We'll approximte the sum with another negative binomial distribution whose mean and variance equal the mean and variance for the sum, i.e. a negative binomial distribution with mean

$$a_{\text{plate}} + (1 - \eta_{\text{treatment}}) h_{\text{plate}}$$

and variance

$$\theta_{a, \text{plate}} a_{\text{plate}} +
\theta_{h, \text{plate}} (1 - \eta_{\text{treatment}}) h_{\text{plate}}
$$


We will use gradient ascent to find parameter values that maximize the likelihood of the observed counts for control wells under the model assumptions.

In [None]:
sorted_plates = sorted(set(control_df.plate))
plate_to_index = {p:i for i, p in enumerate(sorted_plates)}
plate_index = jnp.array([plate_to_index[p] for p in control_df.plate])
hyp_or_art_ml = jnp.array(control_df['ml_hyp_or_art'])

In [None]:
# We're going to treat all sample wells as different treatments
# treatment_index 0 = infected control
# treatment_index 1 = uninfected control
# treatment_index 2 = active control
# treatment_index 3 and up = everything else

max_treatment = 2
treatment_index = []
for actives in control_df.actives:
  if actives == 'infected_control':
    treatment_index.append(0)
  elif actives == 'uninfected_control':
    treatment_index.append(1)
  elif actives == 'active_control':
    treatment_index.append(2)
  else:
    raise ValueError(actives)
treatment_index = jnp.array(treatment_index)

In [None]:
control_df['plate_index'] = plate_index
control_df['treatment_index'] = treatment_index

In [None]:
MAX_ETA = 0.999

class Model(NamedTuple):
  # our parameters are constrained;
  # we'll enforce the constraints via invertible transforms and
  # do gradient descent on unconstrained versions of the parameters
  a_unconstrained: jnp.ndarray  # (n_plates,)
  h_unconstrained: jnp.ndarray  # (n_plates,)
  theta_a_unconstrained: jnp.ndarray  # (n_plates,)
  theta_h_unconstrained: jnp.ndarray  # (n_plates,)
  eta_active_unconstrained: jnp.ndarray  # (n_plates)

  @property
  def a(self):
    return jnp.exp(self.a_unconstrained)

  @property
  def h(self):
    return jnp.exp(self.h_unconstrained)

  @property
  def theta_a(self):
    return 1 + jnp.exp(self.theta_a_unconstrained)

  @property
  def theta_h(self):
    return 1 + jnp.exp(self.theta_h_unconstrained)

  @property
  def eta_active(self):
    # active control effect - constrain to (0, MAX_ETA)
    return MAX_ETA * jax.scipy.special.expit(self.eta_active_unconstrained)

  def get_eta(self, treatment_index: jnp.ndarray, plate_index: jnp.ndarray) -> jnp.ndarray:
    return jnp.where(
        treatment_index == 0,
        0.,  # infected control
        jnp.where(
            treatment_index == 1,
            1.,  # uninfected control
            self.eta_active[plate_index],  # active control
        )
    )


In [None]:
n_plates = jnp.max(plate_index) + 1
n_treatments = jnp.max(treatment_index) + 1
model_init = Model(
    a_unconstrained=jnp.zeros((n_plates,)),
    h_unconstrained=jnp.zeros((n_plates,)),
    theta_a_unconstrained=jnp.zeros((n_plates,)),
    theta_h_unconstrained=jnp.zeros((n_plates,)),
    eta_active_unconstrained=jnp.zeros((n_plates,)),
)

In [None]:
def get_loss_fn(
    hyp_or_art: jnp.ndarray,
    plate_index: jnp.ndarray,
    treatment_index: jnp.ndarray):
  def loss(model: Model) -> jnp.ndarray:
    eta = model.get_eta(treatment_index=treatment_index, plate_index=plate_index)
    mean_a = model.a[plate_index]
    var_a = model.theta_a[plate_index] * mean_a
    mean_h = model.h[plate_index] * (1 - eta)
    var_h = model.theta_h[plate_index] * mean_h

    mean_ah = mean_a + mean_h
    var_ah = var_a + var_h

    # for scipy.stats.nbinom
    # n = number of successes
    # p = P(success)
    # k = number of failures
    # mean(k) = n(1-p)/p
    # var(k) = n(1-p)/p^2 = mean / p

    p = mean_ah / var_ah
    n = mean_ah * p / (1 - p)

    log_prob = jax.scipy.stats.nbinom.logpmf(k=hyp_or_art, n=n, p=p)
    return -jnp.sum(log_prob)
  return loss

In [None]:
loss_fn_ml = jax.jit(get_loss_fn(
    hyp_or_art=hyp_or_art_ml,
    plate_index=plate_index,
    treatment_index=treatment_index))

In [None]:
print('initial loss', loss_fn_ml(model_init))

In [None]:
model = optim_lib.adam_optimize(
    loss_fn_ml,
    model_init,
    learning_rate=0.01,
    train_steps=20000,
    verbose=1
)

A quick sanity check of the fitted parameters.

Make sure there are no crazy outliers.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].hist(np.array(model.a), bins=50)
axes[0].set_title('Mean artifacts')

axes[1].hist(np.array(model.theta_a), bins=50)
axes[1].set_title('Overdispersion scaling')
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].hist(np.array(model.h), bins=50)
axes[0].set_title('Mean hypnozoites')

axes[1].hist(np.array(model.theta_h), bins=50)
axes[1].set_title('Overdispersion hypnozoites')
plt.show()

### Modeling results

First, we see that the estimated mean number of artifacts for each plate is well approximated by the median of the uninfected control artifacts.

In [None]:
# compute rough estimates of the fraction of hypnozoites in each control well
estimated_hyp_fraction = []

for _, row in control_df.iterrows():
  plate_index = row['plate_index']
  treatment_index = row['treatment_index']
  eta = model.get_eta(treatment_index=treatment_index, plate_index=plate_index)
  mean_artifacts = model.a[plate_index]
  mean_hypnozoites = model.h[plate_index] * (1 - eta)
  estimated_hyp_fraction.append(mean_hypnozoites / (mean_hypnozoites + mean_artifacts))
control_df['estimated_hyp_fraction'] = estimated_hyp_fraction


In [None]:
uninfected_mean = control_df[control_df.actives == 'uninfected_control'].groupby('plate').mean()
infected_mean = control_df[control_df.actives == 'infected_control'].groupby('plate').mean()

In [None]:
# Compare mean uninfected control artifact counts to model estimated counts
x = np.array(uninfected_mean['ml_artifact'])
y = np.array(model.a)

plt.figure(figsize=(8, 8))
sns.regplot(
    x=x,
    y=y,
    marker='.')
vmax = max(np.max(x), np.max(y))
plt.plot((0, vmax), (0, vmax), linestyle='--', color='gray')
plt.xlim(0, vmax)
plt.ylim(0, vmax)

plt.title('Model')
plt.xlabel('Mean uninfected control artifacts')
plt.ylabel('Model uninfected control artifacts')
plt.show()

Second, we see that the model suggests that the object detector is underestimating the number of true hypnozoites by about ~10%.

In [None]:
# Compare mean infected control hypnozoite counts to model estimated counts

x = np.array(infected_mean['ml_hypnozoite'])
y = np.array(model.h)

plt.figure(figsize=(8, 8))
sns.regplot(
    x=x,
    y=y,
    marker='.')
plt.title('Model')
plt.xlabel('Median infected control hypnozoites')
plt.ylabel('Model infected control hypnozoites')
vmax = max(np.max(x), np.max(y))
plt.plot((0, vmax), (0, vmax), linestyle='--', color='gray')
plt.xlim(0, vmax)
plt.ylim(0, vmax)
plt.show()

In [None]:
m = sm.OLS(y, x)
results = m.fit()
results.summary()

## A better artifact classifier

The analysis above suggests that the object detector misclassifies some hypnozoites as staining artifacts. Below we build a better classifier.

Our challenge is that our supply of labeled data is limited, and in some cases our expert labelers may be hard-pressed to determine whether an object is a hypnozoite or a staining artifact. Adding to the difficulty is the possibility that some drug candidates may make change the appearance of parasites in ways we can't anticipate.

Rather than trying to create a better hand-labeled dataset to use for a better supervised classifier, we will take advantage of our experimental design. Each plate contains a set of uninfected control wells that contain no parasites. We can build a classifier that distinguishes between objects found in uninfected control wells and objects found in wells with parasites. For such a classifier, we have ground truth for all objects found by the object detector. If we are willing to assume that staining artifacts from infected and uninfected wells are indistinguishable, the only way such a classifier can differentiate between objects from infected and uninfected wells is by distinguishing parasites from staining artifacts. If we can identify an object as a parasite, we know that it must have come from an infected well. In contrast, under our assumptions, we will be unable to determine whether a staining artifact comes from an infected or uninfected well.

The resulting infected well/uninfected well classifier is effectively a hypnozoite/artifact classifier with miscalibrated probabilities. Objects that the classifier indicates have a high probability of coming from an infected well are likely parasites, while objects for which the probability of coming from an infected or uninfected well are roughly equal are likely staining artifacts. To turn our infected / uninfected well classifier into a parasite / staining artifact classifier, we need to recalibrate the output probabilities.




In [None]:
# loading takes ~20 minutes
subset = ((count_df.plate.isin(qc_screen_plates)) &
          (count_df.ml_hypnozoite_med_neg >= 10) &
          (~count_df.flag_iqch30))

batches = sorted(set(count_df[subset]['batch']))
plates = sorted(set(count_df[subset]['plate']))

filter = create_parquet_filter('plate', plates)
emb_df = load_parquet_to_emb_df(
    bucket=CLOUD_BUCKET,
    prefix=PATCH_PARQUET_TEMPLATE,
    batch_list=batches,
    filter_list=filter)

emb_df_bak = emb_df.copy()  # make a backup copy so we don't have to reload from disk!

In [None]:
emb_df = emb_df_bak.copy()  # recover from backup

In [None]:
emb_df.shape

In [None]:
emb_df.head()

In [None]:
emb_df = emb_df.reset_index().merge(
    count_df[subset].reset_index()[['plate', 'well', 'actives']],
    how='inner',
    on=['plate', 'well']).drop(columns=['index']).copy()  # copy to reduce fragmentation

In [None]:
emb_df.head()

### Classifying objects as coming from infected or uninfected wells

We use object embeddings as features for classification. We will use LightGBM as our classifier because it is fast and robust to overfitting. We regularize by limiting tree depth. We train on 80% of the data and use 10% each for validation and test.


In [None]:
# split plates into training / validation / test sets
np.random.seed(123)

DATA_FRACTION = 1.

fraction_train = 0.8
fraction_validation = 0.1
fraction_test = 1. - fraction_train - fraction_validation

plates = sorted(set(emb_df.plate))
n_plates = len(plates)
perm = np.random.permutation(n_plates)
random_subset = (emb_df['actives'] == 'uninfected_control') | (np.random.rand(emb_df.shape[0]) <= DATA_FRACTION)

train_plates = {plates[i] for i in perm[:int(fraction_train * n_plates)]}
validation_plates = {plates[i] for i in perm[int(fraction_train * n_plates):int((1. - fraction_test) * n_plates)]}
test_plates = {plates[i] for i in perm[int((1. - fraction_test) * n_plates):]}

print(len(train_plates), len(train_plates & validation_plates), len(train_plates & test_plates))
print(len(validation_plates), len(validation_plates & test_plates))
print(len(test_plates))

train_subset = emb_df.plate.isin(train_plates)
validation_subset = emb_df.plate.isin(validation_plates)
test_subset = emb_df.plate.isin(test_plates)

In [None]:
# classification task: predict whether embedding came from uninfected control well
y = (emb_df.actives == 'uninfected_control').to_numpy().astype(np.int32)
x = (emb_df[[str(x) for x in list(range(64, 192))]].to_numpy().astype(np.float32))

In [None]:
x.shape, y.shape

In [None]:
# We'll use LightGBM because it appears to work more reliably than the neural
# nets I tried, all of which overfit pretty badly. We regularize LightGBM by
# limiting the maximum tree depth to 2, adding L1 and L2 penalties, and
# requiring at least 40 observations per tree node.

# tuning lgbm parameters: https://neptune.ai/blog/lightgbm-parameters-guide

n_estimators = 200

lgbm_artifact = lightgbm.LGBMClassifier(
    objective='binary',
    n_estimators=n_estimators,
    class_weight='balanced',
    boosting='goss',
    learning_rate=0.25,
    # some regularization
    max_depth=2,
    min_data_in_leaf=40,
    lambda_l1=0.1,
    lambda_l2=0.5,
    feature_fraction=0.1,
    )

lgbm_artifact.fit(
    x[train_subset],
    y[train_subset],
    eval_set=[(x[validation_subset], y[validation_subset])],
    early_stopping_rounds=n_estimators // 10,
)

We're going to treat the LightGBM output as an uncalibrated probability that an object is an artifact. Our next step is to improve the calibration.

In [None]:
# Get the LightGBM's uncalibrated P(artifact) estimate
p_artifact = lgbm_artifact.predict_proba(x)[:, 1]

In [None]:
plt.hist(p_artifact, bins=50)
plt.xlabel('P(artifact)')
plt.ylabel('Number of patches')
plt.show()

In [None]:
# P(artifact) for patches from uninfected control wells
np.mean(p_artifact[emb_df.actives == 'uninfected_control'] > 0.5)

In [None]:
# P(artifact) for patches from infected wells
np.mean(p_artifact[emb_df.actives != 'uninfected_control'] > 0.5)

In [None]:
emb_df['p_artifact'] = p_artifact

In [None]:
emb_df.head()

In [None]:
stage = np.argmax(emb_df[['stage_infer_artifact', 'stage_infer_hypnozoite', 'stage_infer_schizont']].to_numpy(), axis=-1)
emb_df['stage'] = stage

At the beginning of this Colab we estimated the mean number of artifacts and hypnozoites in the control wells on each plate. We'll use these control well estimates to calibrate our probabilities. Specifically, we'll use the mean number of hypnozoites and artifacts in each control well on a plate to estimate the fraction of each well's objects that are artifacts. We'll then use Platt scaling to adjust the output of our uninfected well / infected well classifier to better estimate the probability that a given object is an artifact.



In [None]:
# Now we need to calibrate the model

In [None]:
# limit to the control wells
emb_df_control = emb_df[
    (emb_df['actives'].isin({'uninfected_control', 'infected_control', 'active_control'})) &
    (emb_df['stage'] != 2)
][['batch', 'plate', 'well', 'actives', 'p_artifact', 'stage']].copy()

In [None]:
emb_df_control.head()

In [None]:
control_df.head()

In [None]:
emb_df_control = emb_df_control.merge(control_df[['batch', 'plate', 'well', 'estimated_hyp_fraction']], on=['batch', 'plate', 'well'])

In [None]:
emb_df_control.head()

In [None]:
# For each well, get the estimated hypnozoite fraction and the uncalibrated P(artifact) for non-schizonts

# Implementation detail: we'll group together wells with the same number of
# non-schizont patches so we can estimate the calibration loss more efficiently
# below.

fraction = 1.  # fraction of wells to use for calibration
key = jax.random.PRNGKey(0)

hyp_frac_by_n = {}
proba_by_n = {}
for well_index, ((batch, plate, well, actives), well_df) in enumerate(emb_df_control.groupby(['batch', 'plate', 'well', 'actives'])):
  if actives not in {'uninfected_control', 'infected_control', 'active_control'}:
    continue
  well_df = well_df[well_df.stage != 2]  # drop schizonts
  n = well_df.shape[0]
  if n == 0:
    continue
  key, subkey = jax.random.split(key)
  if jax.random.uniform(subkey) > fraction:
    continue
  if not n in hyp_frac_by_n:
    hyp_frac_by_n[n] = []
    proba_by_n[n] = []
  hyp_frac_by_n[n].append(well_df.estimated_hyp_fraction.to_numpy()[0])
  proba_by_n[n].append(jnp.array(well_df.p_artifact.to_numpy()))

# for each value of n, turn lists into ndarrays
for n in hyp_frac_by_n.keys():
  hyp_frac_by_n[n] = jnp.array(hyp_frac_by_n[n])
  proba_by_n[n] = jnp.stack(proba_by_n[n], axis=0)

In [None]:
# We know for each well the approximate fraction of artifacts
# from the calibration Colab, but we don't know which specific
# patches are artifacts.
#
# We know from the classifier for each patch the uncalibrated P(artifact)
#
# Here we compute the log probability that we observe a particular fraction of
# artifacts given the uncalibrated P(artifacts). The exact probability is
# really expensive to compute, but we can approximate it: for each patch,
# we know P(artifact). In our idealized model, each patch's artifact state is
# an independent bernoulli random variable with mean P and variance P*(1-P).
# We'll approximate these Bernoulli r.v.'s with normal r.v.'s with the same
# mean and variance. The sum of these normal r.v.'s we can compute, and then
# we can use them to estimate the log likelihood of observing the fraction of
# artifacts we estimated from the calibration Colab.
def get_loss_fn(hyp_frac_by_n, proba_by_n):
  def loss_fn(params):
    # params = (a, log(b)) where a and b are the parameters for Platt scaling
    a, log_b = params
    b = jnp.exp(log_b)
    loss = 0.
    for n in hyp_frac_by_n.keys():
      frac_a = 1. - hyp_frac_by_n[n]  # (?,)
      probs = proba_by_n[n]  # (?, n)
      scaled_probs = jax.scipy.special.expit(a + b * probs)
      mean = jnp.mean(scaled_probs, axis=-1)
      var = jnp.mean(scaled_probs * (1. - scaled_probs), axis=-1)
      loss += jnp.sum(jax.scipy.stats.norm.logpdf(frac_a, loc=mean, scale=jnp.sqrt(var)), axis=0)
    return -loss
  return loss_fn

In [None]:
# The complexity of the loss function scales with the number of different
# well sizes. Using all sizes is quite expensive! Here we'll just grab a subset
# of all possible sizes to speed things up.
n_max = max(hyp_frac_by_n.keys())

hyp_frac_by_n_subset = {}
proba_by_n_subset = {}
for n in np.arange(10, n_max, 10):
  if n not in hyp_frac_by_n.keys():
    continue
  hyp_frac_by_n_subset[n] = hyp_frac_by_n[n]
  proba_by_n_subset[n] = proba_by_n[n]

loss_fn = get_loss_fn(hyp_frac_by_n_subset, proba_by_n_subset)

loss_fn_jit = jax.jit(loss_fn)
loss_fn_grad_jit = jax.jit(jax.grad(loss_fn))

In [None]:
platt_scaling_params = (0., 0.)  # corresponds to the original unscaled probabilities
optimizer = optax.adam(learning_rate=1.e-2)

opt_state = optimizer.init(platt_scaling_params)

for i in range(1000):
  grad = loss_fn_grad_jit(platt_scaling_params)
  updates, opt_state = optimizer.update(grad, opt_state, platt_scaling_params)
  platt_scaling_params = optax.apply_updates(platt_scaling_params, updates)
  if i % 100 == 0:
    print((float(platt_scaling_params[0]), float(platt_scaling_params[1])),
          loss_fn_jit(platt_scaling_params), flush=True)

We use Platt scaling to adjust the LightGBM classifier output probabilities so they are appropriate for use in an artifact / hypnozoite classifier.

In [None]:
# Get the Platt scaled artifact fraction
p_artifact_scaled = np.array(jax.scipy.special.expit(platt_scaling_params[0] + jnp.exp(platt_scaling_params[1]) * p_artifact))

Here we show the uncalibrated probabilities (in blue) and the calibrated probabilities (in orange). Scaling makes the model more confident. This increase in confidence is likely because the original model had a harder task: it had to determine whether an object came from an uninfected or an infected well, and for the case of an artifact, it couldn't tell.

In [None]:
# Scaling makes our model more confident!
plt.figure(figsize=(12, 6))
bins = np.arange(0, 1.01, 0.01)
plt.hist(np.array(p_artifact), bins=bins, alpha=0.5, label='Unscaled')
plt.hist(np.array(p_artifact_scaled), bins=bins, alpha=0.5, label='Scaled')
plt.xlabel('P(artifact)')
plt.ylabel('Number of patches')
plt.legend(loc='best')
plt.show()

In [None]:
# Here we visualize calibration for control wells relative to the fraction of artifacts
# predicted by our innitial model. The blue line shows that the uninfected well / infected
# well classifier is poorly calibrated as an artifact detector. In contrast, the model
# after Platt scaling, the orange line, is much better calibrated.
xval = []
yval_unscaled = []
yval_scaled = []
step = 1

emb_df_control_sample = emb_df_control.sample(frac=0.1)

est_art_frac = np.array(1. - emb_df_control_sample.estimated_hyp_fraction)
p_art = np.array(emb_df_control_sample.p_artifact)
p_art_scaled = np.array(scipy.special.expit(platt_scaling_params[0] + np.exp(platt_scaling_params[1]) * p_art))
for i in range(0, 100, step):
  subset = ((est_art_frac >= i / (100 / step)) &
            (est_art_frac < (i + 1) / (100 / step)))
  if np.sum(subset):
    xval.append(i * step / 100)
    yval_unscaled.append(np.mean(p_art[subset]))
    yval_scaled.append(np.mean(p_art_scaled[subset]))
plt.figure(figsize=(6, 6))
plt.plot(xval, yval_unscaled, label='Before Platt scaling')
plt.plot(xval, yval_scaled, label='After Platt scaling')
plt.plot((0, 1), (0, 1))
plt.xlabel('P(artifact) estimated from model')
plt.ylabel('Classifier probability')
plt.legend(loc='best')
plt.show()

In [None]:
print('Uninfected controls')
print('Mean, unscaled', np.mean(p_artifact[emb_df.actives == 'uninfected_control']))
print('Mean, scaled', np.mean(p_artifact_scaled[emb_df.actives == 'uninfected_control']))

print('Mean > 0.5, unscaled', np.mean(p_artifact[emb_df.actives == 'uninfected_control'] > 0.5))
print('Mean > 0.5, scaled', np.mean(p_artifact_scaled[emb_df.actives == 'uninfected_control'] > 0.5))

In [None]:
print('Infected wells')
print('Mean, unscaled', np.mean(p_artifact[emb_df.actives != 'uninfected_control']))
print('Mean, scaled', np.mean(p_artifact_scaled[emb_df.actives != 'uninfected_control']))

print('Mean > 0.5, unscaled', np.mean(p_artifact[emb_df.actives != 'uninfected_control'] > 0.5))
print('Mean > 0.5, scaled', np.mean(p_artifact_scaled[emb_df.actives != 'uninfected_control'] > 0.5))

In [None]:
# look at the extent to which different parasite stages get classified as artifacts
plt.hist(p_artifact_scaled[stage == 0], bins=50)
plt.title('Previous classifier artifacts')
plt.xlabel('P(artifact)')
plt.show()

plt.hist(p_artifact_scaled[stage == 1], bins=50)
plt.title('Hypnozoite')
plt.xlabel('P(artifact)')
plt.show()

plt.hist(p_artifact_scaled[stage == 2], bins=50)
plt.title('Schizont')
plt.xlabel('P(artifact)')
plt.show()

In [None]:
# save the LightGBM model and Platt scaling parameters locally, then upload to cloud
client = storage.Client()
bucket = client.get_bucket(CLOUD_BUCKET)
today = datetime.datetime.now().strftime('%y-%m-%d')

tmpfile_lgbm = '/tmp/lgbm.txt'
lgbm_artifact.booster_.save_model(tmpfile_lgbm)
filename = f'{OUTPUT_PATH[1:]}/artifact-classifier-{today}.lgbm'
blob = bucket.blob(filename)
blob.upload_from_filename(tmpfile_lgbm)
print(blob.public_url)

tmpfile_platt = '/tmp/platt.npy'
with open(tmpfile_platt, 'wb') as outfile:
  np.save(outfile, np.array(platt_scaling_params))
filename = f'{OUTPUT_PATH[1:]}/platt-scaling-{today}.npy'
blob = bucket.blob(filename)
blob.upload_from_filename(tmpfile_platt)
print(blob.public_url)
