This Colab estimates inhibition based on both counts and appearance (i.e. Joint Inhibition) for:
* hypnozoites
* parasites

The input dataset is designed to be a set of screening plates, though the joint inhibition can be calculated for dose response or other types of plates also.

In [None]:
#@title Run this cell only the FIRST time you connect to the colab kernel
!pip install gcsfs
!git clone https://github.com/google/cell_img
!pip install --quiet -e cell_img
!pip install jax
!pip install lightgbm
!pip install optax

In [None]:
#@title For Cloud VM kernel, run this after restarting before granting access

!ls /content/.config/
!rm /content/.config/gce 
!rm /var/colab/mp

import os
os.environ['NO_GCE_CHECK']
del os.environ['NO_GCE_CHECK']
os.environ['GCE_METADATA_TIMEOUT']
del os.environ['GCE_METADATA_TIMEOUT']

In [None]:
#@title Run this cell after restarting your kernel. It will pop up window to grant access.
from google.colab import auth
auth.authenticate_user()

# Train the model

In [None]:
#@title Choose type of inhibition to estimate
INHIBITION_TYPE = 'hypnozoite'  #@param ['hypnozoite', 'parasite'] {allow-input: true}

In [None]:
print(INHIBITION_TYPE)

In [None]:
import copy
import datetime
import dataclasses
import fsspec
import gc
import os
import pathlib
import re
from google.cloud import storage

import jax
import jax.numpy as jnp
import optax

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.stats
import seaborn as sns
import sklearn.metrics
import cycler
import matplotlib

import lightgbm

import statsmodels.nonparametric.smoothers_lowess

import tensorflow as tf

import cell_img
from cell_img.analysis import jax_tree
from cell_img.common import image_lib
from cell_img.dose_response import constrain
from cell_img.predict_inhibition import joint_model
from cell_img.predict_inhibition import spline
from cell_img.malaria_liver import metadata_lib

In [None]:
# Set up matplotlib styles
matplotlib.rcParams['axes.prop_cycle'] = cycler.cycler(
    'color', matplotlib.cm.get_cmap('tab10').colors)

IMAGE_STYLE = {
    'axes.grid': False,
    'axes.linewidth': 0,

    'xtick.labelsize': 0,
    'xtick.color': 'none',
    'xtick.major.size': 0,
    'xtick.minor.size': 0,

    'ytick.labelsize': 0,
    'ytick.color': 'none',
    'ytick.major.size': 0,
    'ytick.minor.size': 0,

    'image.cmap': 'viridis',
    'image.interpolation': 'none',

    'figure.subplot.hspace': 0.02,
    'figure.subplot.wspace': 0.04,

    'figure.subplot.left': 0.01,
    'figure.subplot.bottom': 0.01,
    'figure.subplot.right': 0.98,
    'figure.subplot.top': 0.98,
    'font.size': 22,
}

COLORS = plt.cm.tab10(np.linspace(0, 1, 10))

In [None]:
assert INHIBITION_TYPE in {'parasite', 'hypnozoite'}

In [None]:
# filtered count data and embeddings as generated by prepare_data.ipynb

CLOUD_BUCKET = 'bucket'

PREDICT_INHIBITION_PATH = 'your/path/'

# parasite embeddings
batches = [ 'batch1', 'batch2' ]
files_to_load = [
    f'{PREDICT_INHIBITION_PATH}prefix_{batch}.parquet'
    for batch in batches]


# count data
DF_DATE = 'date-count-data-generated'
COUNT_DF_PATH = os.path.join(
    f'gs://{CLOUD_BUCKET}', PREDICT_INHIBITION_PATH,
    f'count_df-{DF_DATE}.parquet')

# output file
today_str = datetime.date.today().isoformat()
INHIBITION_FILE = os.path.join(
    f'gs://{CLOUD_BUCKET}', PREDICT_INHIBITION_PATH,
    f'joint_inhibition_{INHIBITION_TYPE}-{today_str}.parquet')

# dose-response estimates
DOSE_RESPONSE_PATH = f'gs://{CLOUD_BUCKET}/path/to/dose_response'
DOSE_RESPONSE_FILE = os.path.join(DOSE_RESPONSE_PATH, f'dr_df.parquet')

# TensorStore parameters for patch images
TENSORSTORE_SHORTNAME = 'short_name'
TENSORSTORE_PATH = f'gs://{CLOUD_BUCKET}/tensorstore/{TENSORSTORE_SHORTNAME}'
METADATA_ROOT_PATH = f'gs://{CLOUD_BUCKET}/tensorstore/{TENSORSTORE_SHORTNAME}/metadata/'
CHANNEL_TO_RGB = ['w3', 'w2', 'w1']

In [None]:
def pd_read_parquet(path, filter_list):
  with fsspec.open(path) as f:
    if filter_list:
      return pd.read_parquet(f, filters=filter_list)
    else:
      return pd.read_parquet(f)


def compress_df(df: pd.DataFrame) -> pd.DataFrame:
  """Reduce the byte size of columns."""
  for col in df.columns:
    if df[col].dtype == np.float64:
      df[col] = df[col].astype(np.float32)
    elif df[col].dtype == np.int64:
      df[col] = df[col].astype(np.int32)
  return df


def format_plate_strings(plate_names):
  """Format the plate strings as strings of five digit ints.

  Args:
    plate_names: A pd.Series of strings representing the plate names that we
      want to format.
  Raises:
    ValueError: If plate_names contains a name that is more than five digits
      long.
  Returns:
    formatted_plates: A pd.Series representing the formatted plate names.
  """
  # Format the plate strings as 5 character strings with no decimal
  formatted_plates = plate_names.astype(str).apply(
      lambda x: x.split('.')[0].zfill(5))
  # If any of the plates are more than 5 digits, scream loudly.
  len_plate_names = np.array([len(p) for p in formatted_plates.values])
  if np.any(len_plate_names > 5):
    raise ValueError('Plate name > 5 characters found')
  # If any of the plates have non-digit characters, scream loudly.
  if not np.all([re.match(r'^\d+$', p) for p in formatted_plates.values]):
    raise ValueError('Plate with non-digit characters found')
  return formatted_plates


def glob_cloud(bucket, path):
  client = storage.Client()
  return [blob.name for blob in client.list_blobs(bucket, prefix=path)]
  emb_df = compress_df(pd_read_parquet(parquet_filename, filter_list))


def load_subset_dataset_to_emb_df(files_to_load,
                                  expand_embedding=True, filter_list=None):
  # Read the new Parquet output
  emb_df_list = []
  num_files_loaded = 0
  for f in files_to_load:
    num_files_loaded += 1
    f_to_load = os.path.join(f'gs://{CLOUD_BUCKET}', f)
    one_df = pd_read_parquet(f_to_load, filter_list)
    # expand the embedding
    if expand_embedding:
      tmp_df = pd.DataFrame([pd.Series(x) for x in one_df.embedding])
      tmp_df.columns = [str(x) for x in range(192)]
      one_df = pd.concat([one_df.reset_index(), tmp_df], axis=1)
      one_df.drop(columns=['embedding'], inplace=True)
    emb_df_list.append(compress_df(one_df))
  emb_df = pd.concat(emb_df_list)

  print('Loaded %d rows across %d batches in %d files!' % (
      len(emb_df), emb_df.batch.nunique(), num_files_loaded
  ))

  return emb_df.rename(columns={str(i): i for i in range(192)})

In [None]:
# Note: takes 6-7 minutes to run
patch_df = load_subset_dataset_to_emb_df(files_to_load)
patch_df.set_index('index', inplace=True)

In [None]:
# constants for variable names, embedding dimensions, and whether to weight data
if INHIBITION_TYPE == 'hypnozoite':
  # embedding columns we use
  COUNT_COL = 'ml_hypnozoite'
  INHIBITION_COL = 'inhibition_ml_hyp'
  EMB_COLS = list(range(64, 192))  # keep non-DAPI stains
  WEIGHT_SAMPLES = True

elif INHIBITION_TYPE == 'parasite':
  # embedding columns we use
  COUNT_COL = 'ml_parasite'
  INHIBITION_COL = 'inhibition_ml_par'
  EMB_COLS = list(range(64, 192))  # keep non-DAPI stains
  WEIGHT_SAMPLES = True

else:
  raise ValueError('Invalid INHIBITION_TYPE %s' % INHIBITION_TYPE)

EMB_DIM = len(EMB_COLS)

In [None]:
gc.collect()

## Step 1: predict hypnozoite inhibition in a well given a single parasite image

We build a regression model that predicts well-level inhibition from individual parasite embeddings. This is going to be weak predictor, but we will later show that it can be improved by (1) combining predictions from multiple parasites in a well, and (2) by combining parasite appearance with parasite counts.

We use LightGBM to generate our predictions. Our primary input features are embeddings for parasite-specific stains (i.e. no DAPI). There is considerable variation in the embeddings from plate to plate and batch to batch, so we tried
a few different approaches to controlling for it.

The biggest source of variation appears to be the hepatocyte lot in which the
parasites were cultivated. One straightforward way to control for the effects of hep lot on appearance is to include hep lot as a categorical predictor. We found that while including hep lot improved the model's predictions, it led to difficulties later when we combine appearance and count-based measures of inhibition. An alternative approach that we found helpful was to combine an
embedding for a treated hypnozoite with an embedding for an untreated control hypnozoite from the same plate. We use the latter approach below.

In [None]:
# Split data into training, validation, and test subsets for model training
# and evaluation. We'll split by plate.

np.random.seed(123)

DATA_FRACTION = 1.

fraction_train = 0.8
fraction_validation = 0.1
fraction_test = 1. - fraction_train - fraction_validation

plates = sorted(set(patch_df.plate))
n_plates = len(plates)
perm = np.random.permutation(n_plates)

train_plates = {plates[i] for i in perm[:int(fraction_train * n_plates)]}
validation_plates = {plates[i] for i in perm[int(fraction_train * n_plates):int((1. - fraction_test) * n_plates)]}
test_plates = {plates[i] for i in perm[int((1. - fraction_test) * n_plates):]}

# sanity check - make sure there's no overlap between the sets of plates
print('n_train', len(train_plates),
      'overlap', len(train_plates & validation_plates), len(train_plates & test_plates))
print('n_validation', len(validation_plates),
      'overlap', len(validation_plates & test_plates))
print('n_test', len(test_plates))

In [None]:
# clipped logit function that bounds output to MIN_LOGIT, MAX_LOGIT
MIN_P = scipy.special.expit(joint_model.MIN_LOGIT_INHIBITION)
MAX_P = scipy.special.expit(joint_model.MAX_LOGIT_INHIBITION)

def clipped_logit(x: np.array) -> np.array:
  return scipy.special.logit(np.clip(x, MIN_P, MAX_P))

In [None]:
# Get features and objective for LightGBM
def get_xy(df: pd.DataFrame,
           use_infected_controls: bool,
           use_active_controls: bool,
           use_hep_lot: bool):
  """Get x, y, and weights for model training."""
  features_df = []
  for plate, plate_df in df[df.actives != 'uninfected_control'].groupby(['plate']):

    n = plate_df.shape[0]
    # the first set of features is the embedding vector for a parasite
    features = [plate_df[EMB_COLS].to_numpy().astype(np.float32)]

    if use_infected_controls:
      # Optional: we want to measure how stressed treated parasites are *relative to untreated parasites*
      # One way to get at this is to add a random untreated (infected_control) parasite from the same plate
      # for contrast
      infected_controls = plate_df[plate_df.actives == 'infected_control'][EMB_COLS].to_numpy().astype(np.float32)
      n_infected_control = infected_controls.shape[0]
      features.append(infected_controls[np.random.choice(n_infected_control, size=n, replace=True)])

    if use_active_controls:
      # Optional: another point of reference for how damaged parasites look is how parasites treated with
      # active controls on the same plate look. Here we add a random active_control parasite for contrast.
      active_controls = plate_df[plate_df.actives == 'active_control'][EMB_COLS].to_numpy().astype(np.float32)
      n_active_control = active_controls.shape[0]
      features.append(active_controls[np.random.choice(n_active_control, size=n, replace=True)])

    plate_features_df = pd.DataFrame(np.column_stack(features), index=plate_df.index)
    if use_hep_lot:
      # Optional: Hepatocyte lot can make a big difference to parasite embeddings. One way to correct for
      # hep lot effects is to add the hep lot as a covariate. Note that this feature needs to be treated
      # as a categorical variable (which we do below).
      plate_features_df['hep_lot_index'] = plate_df.hep_lot_index
    plate_features_df['y'] = clipped_logit(plate_df[INHIBITION_COL])
    if WEIGHT_SAMPLES:
      # Wells with low inhibition are overrepresented in the data because they
      # have more parasites. We weight by 1/(1-inhibition) to compensate.
      # For example, parasites from a well with 0% inhibition have weight 1,
      # parasites from a well with 50% inhibition have weight 2 (since there are
      # half as many of them), etc.
      plate_features_df['weight'] = 1. / (1. - plate_df[INHIBITION_COL])
    else:
      plate_features_df['weight'] = 1.

    features_df.append(plate_features_df)

  return pd.concat(features_df, axis=0)

In [None]:
# Get training, validation, and test data sets for LightGBM
np.random.seed(123)

USE_INFECTED_CONTROLS = True
USE_ACTIVE_CONTROLS = False
USE_HEP_LOT = False

# We're pairing each treated parasite with a random
# infected control and active control, then making
# a prediction. Here we'll take N_PREDICTIONS predictions,
# each with a different random infected / active control,
# and average them
N_PREDICTIONS = 4

FEATURE_COLS = list(range(EMB_DIM * (1 + int(USE_INFECTED_CONTROLS) + int(USE_ACTIVE_CONTROLS))))
if USE_HEP_LOT:
  FEATURE_COLS += ['hep_lot_index']

train_df = get_xy(
    patch_df[patch_df.plate.isin(train_plates)],
    use_infected_controls=USE_INFECTED_CONTROLS,
    use_active_controls=USE_ACTIVE_CONTROLS,
    use_hep_lot=USE_HEP_LOT,)

validation_df = get_xy(
    patch_df[patch_df.plate.isin(validation_plates)],
    use_infected_controls=USE_INFECTED_CONTROLS,
    use_active_controls=USE_ACTIVE_CONTROLS,
    use_hep_lot=USE_HEP_LOT,)

test_df = get_xy(
    patch_df[patch_df.plate.isin(test_plates)],
    use_infected_controls=USE_INFECTED_CONTROLS,
    use_active_controls=USE_ACTIVE_CONTROLS,
    use_hep_lot=USE_HEP_LOT,)

In [None]:
# make sure the train_df has all hep lots
if USE_HEP_LOT:
  print(set(train_df['hep_lot_index']), set(validation_df['hep_lot_index']), set(test_df['hep_lot_index']))

In [None]:
train_df.shape, validation_df.shape, test_df.shape

In [None]:
gc.collect()

In [None]:
# Predict inhibition from individual parasites

In [None]:
n_estimators = 1000
lgbm_inhibition = lightgbm.LGBMRegressor(
    n_estimators=n_estimators,  # controls the maximum size of the model
    boosting='goss',

    # Regularization
    max_depth=2,  # this is probably the strongest constraint
    min_data_in_leaf=1000,  # default is 20
    lambda_l1=10.,
    lambda_l2=10.,

    # Miscellaneous
    feature_fraction=1.,  # Setting smaller fractions reduces training time but worsens results
    learning_rate=0.1,
    )

In [None]:
# Fit the model
kwargs = {}
if USE_HEP_LOT:
  kwargs['categorical_feature'] = ['hep_lot_index']

lgbm_inhibition.fit(
    train_df[FEATURE_COLS],
    train_df['y'],
    sample_weight=train_df['weight'],
    eval_set=[(validation_df[FEATURE_COLS],
               validation_df['y'])],
    eval_sample_weight=[validation_df['weight']],
    early_stopping_rounds=n_estimators // 10,
    **kwargs,
    )

In [None]:
if INHIBITION_TYPE in {'hypnozoite', 'parasite'}:
  # in general stain #3 looks a little more important than stain #2 (we've dropped DAPI, stain #1)
  # the model puts considerably more weight on the treated parasite than the paired control parasite
  print('sample')
  print('stain 2', np.mean(lgbm_inhibition.feature_importances_[:64]))
  print('stain 3', np.mean(lgbm_inhibition.feature_importances_[64:128]))

  print('infected control')
  print('stain 2', np.mean(lgbm_inhibition.feature_importances_[128:192]))
  print('stain 3', np.mean(lgbm_inhibition.feature_importances_[192:256]))

In [None]:
if USE_INFECTED_CONTROLS:
  # See if the classifier uses the same features for the samples ([:EMB_DIM])
  # and the infected controls ([EMB_DIM:2*EMB_DIM])
  # (Not really!)
  print(
      np.corrcoef(
          lgbm_inhibition.feature_importances_[:EMB_DIM],
          lgbm_inhibition.feature_importances_[EMB_DIM:2*EMB_DIM])[0, 1]
  )
  plt.scatter(lgbm_inhibition.feature_importances_[:EMB_DIM],
              lgbm_inhibition.feature_importances_[EMB_DIM:2*EMB_DIM])

In [None]:
# Compare the train and test mean squared error - they look comparable
y_pred_train = lgbm_inhibition.predict(train_df[FEATURE_COLS])
mse_train = np.mean((train_df['y'] - y_pred_train)**2)
print('Train MSE', mse_train)

y_pred_test = lgbm_inhibition.predict(test_df[FEATURE_COLS])
mse_test = np.mean((test_df['y'] - y_pred_test)**2)
print('Test MSE', mse_test)


In [None]:
del train_df, validation_df, test_df, y_pred_train, y_pred_test

In [None]:
gc.collect()

### Generate predicted inhibition estimates for the whole dataset

We paired each treated parasite with a random untreated parasite to help the model control for plate-to-plate variations in appearance. The randomness in the paired control hypnozoite introduces some noise into our estimates. Now we'll average over several random pairings to try to reduce that noise. (We found that this step didn't make a huge difference.)

In [None]:
# The model uses a treated parasite plus a random infected control and/or active control parasite as features.
# Make N_PREDICTIONS, each of which uses different paired control parasites and then average the predictions.

def _make_predictions(patch_df):
  prediction_df = None
  cols = []
  for seed in range(N_PREDICTIONS):
    col = f'pred_inhibition_logit_{seed}'
    cols.append(col)
    np.random.seed(seed)
    print(seed, flush=True)
    feature_df = get_xy(
      patch_df,
      use_infected_controls=USE_INFECTED_CONTROLS,
      use_active_controls=USE_ACTIVE_CONTROLS,
      use_hep_lot=USE_HEP_LOT)
    df_pred = pd.DataFrame({col:
                            lgbm_inhibition.predict(
                                feature_df[FEATURE_COLS])},
                          index=feature_df.index)

    if prediction_df is None:
      prediction_df = df_pred
    else:
      prediction_df = prediction_df.join(df_pred)
      df_pred = None
    del feature_df
    gc.collect()
  # Average the predictions
  prediction_df['pred_inhibition_logit'] = np.mean(prediction_df[cols].to_numpy(), axis=-1)
  prediction_df = prediction_df.drop(columns=cols)
  return prediction_df.copy()

# make predictions one batch at a time to reduce memory spikes
prediction_df_list = []
for b in patch_df.batch.unique():
  print('Making predictions for %s' % b)
  prediction_df_list.append(_make_predictions(patch_df.query(
      'batch == "%s"' % b
  )))
prediction_df = pd.concat(prediction_df_list)

prediction_df_saved = prediction_df.copy()

In [None]:
gc.collect()

In [None]:
prediction_df.sample(10)

In [None]:
prediction_df.shape

## Appearance based predictions

Here we look at some plots of our predictions.

First, we plot inhibition as estimated by hypnozoite counts against our model's predictions of inhibition based on single hypnozoite appearance. Our model is quite weak - the Spearman correlation between count based inhibition estimates and appearance-based estimates is only 0.12 - but there is some signal.

Note that the variance increases as the count-based inhibition increases. A reasonable interpretation is that as parasites are damaged by treatment, their appearances start to change in a wide variety of ways.

In [None]:
# Look at how predictions correspond to ML inhibition for individual sample parasite embeddings.

joined_df = patch_df[['actives', INHIBITION_COL]].join(prediction_df[['pred_inhibition_logit']])
joined_df = joined_df[joined_df.actives == 'sample']
print('Spearman correlation between ML inhibition and predictions:',
      scipy.stats.spearmanr(joined_df[INHIBITION_COL],
                            joined_df.pred_inhibition_logit))

plt.figure(figsize=(8, 4))
alpha=0.1
subset = np.random.choice(joined_df.shape[0], 50000)
x = joined_df[INHIBITION_COL].to_numpy()[subset]
y = joined_df.pred_inhibition_logit.to_numpy()[subset]
sns.regplot(x=x,
            y=y,
            marker='.',
            scatter_kws={'alpha': alpha},
            line_kws={'color': 'orange'},
            fit_reg=False,
            )
sub = x > 0.005  # lowess has problems with big pile-up on the left

lowess_mean = statsmodels.nonparametric.smoothers_lowess.lowess(y[sub], x[sub], frac=0.1, return_sorted=False)
y_var = (y[sub] - lowess_mean) ** 2
lowess_var = statsmodels.nonparametric.smoothers_lowess.lowess(y_var, x[sub], frac=0.1, return_sorted=False)
idx = np.argsort(x[sub])

x = x[sub][idx]
lowess_mean = lowess_mean[idx]
lowess_var = lowess_var[idx]

x = np.concatenate([np.zeros(1), x])
lowess_mean = np.concatenate([np.array([np.mean(y[~sub])]), lowess_mean])
lowess_var = np.concatenate([np.array([np.var(y[~sub])]), lowess_var])

plt.plot(x, lowess_mean, color='red', lw=2)
plt.plot(x, lowess_mean + 2.*np.sqrt(lowess_var), color='red', ls='--')
plt.plot(x, lowess_mean - 2.*np.sqrt(lowess_var), color='red', ls='--')
plt.xlabel('Actual inhibition')
plt.ylabel('Predicted logit inhibition')
plt.title(f'Predicted inhibition from a single {INHIBITION_TYPE} embedding')
plt.show()

plt.plot(x, lowess_var)
plt.title('Variance of predictions as function of inhibition')
plt.xlabel('Actual inhibition')
plt.ylabel('Variance of predicted logit inhibition')
plt.show()

x_single = x
lowess_mean_single = lowess_mean
lowess_var_single = lowess_var

In [None]:
joined_df = None

In [None]:
gc.collect()

### Aggregating individual parasite predictions within a well

Now we'll average the predicted inhibition for all hypnozoites in a well to see if we are better able to predict the count-based inhibition.

In [None]:
# wait to read in the count df until later to save memory
count_df = pd.read_parquet(COUNT_DF_PATH)

In [None]:
patch_df.sample(4)

In [None]:
# add the prediction score, counts, and inhibition to prediction_df
prediction_df = prediction_df[['pred_inhibition_logit']].join(
    patch_df[['batch', 'plate', 'well', 'site', 'center_row', 'center_col', 'actives', 'hep_lot', 'hep_lot_index']])
prediction_df = prediction_df.merge(
    count_df[['plate', 'well', COUNT_COL, INHIBITION_COL]], on=['plate', 'well'], how='left')
prediction_df.sort_values(by=['plate', 'well', 'site', ], inplace=True)
prediction_df.head()

In [None]:
if 'sum_x' in count_df.columns:
  count_df = count_df.drop(columns=['n_x', 'sum_x', 'sum_x2', 'mean_x', 'var_x'])

if 'sum_x_x' in count_df.columns:
  count_df = count_df.drop(columns=['n_x_x', 'sum_x_x', 'sum_x2_x', 'mean_x_x', 'var_x_x',
                                    'n_x_y', 'sum_x_y', 'sum_x2_y', 'mean_x_y', 'var_x_y'])

In [None]:
# Generate summary statistics of scores in each well and add them to count_df
prediction_df['pred_inhibition_logit_sq'] = prediction_df.pred_inhibition_logit ** 2.
prediction_df['one'] = 1.
pred_sum_df = prediction_df.groupby(['plate', 'well'])[['one', 'pred_inhibition_logit', 'pred_inhibition_logit_sq']].sum().rename(
    columns={'one': 'n_x', 'pred_inhibition_logit': 'sum_x', 'pred_inhibition_logit_sq': 'sum_x2'})
pred_mean_df = prediction_df.groupby(['plate', 'well'])[['pred_inhibition_logit']].mean().rename(
    columns={'pred_inhibition_logit': 'mean_x'})
pred_var_df = prediction_df.groupby(['plate', 'well'])[['pred_inhibition_logit']].var().rename(
    columns={'pred_inhibition_logit': 'var_x'})
prediction_df = prediction_df.drop(columns=['one'])

count_df = count_df.merge(pred_sum_df, on=['plate', 'well'], how='left')
count_df = count_df.merge(pred_mean_df, on=['plate', 'well'], how='left')
count_df = count_df.merge(pred_var_df, on=['plate', 'well'], how='left')
count_df['n_x'] = count_df['n_x'].fillna(0.)
count_df['sum_x'] = count_df['sum_x'].fillna(0.)
count_df['sum_x2'] = count_df['sum_x2'].fillna(0.)

In [None]:
gc.collect()

Look at how well the well mean of the appearance-based inhibition estimates predicts the count-based estimate for the well.

The predictive power is still weak - the Spearman correlation is 0.27 - but it's an improvement over the 0.12 we got for a single hypnozoite.

In [None]:
# Look at how well mean predictions correspond to ML inhibition for samples.

subset = (count_df[INHIBITION_COL] < 1.) & (count_df.actives == 'sample')
x = count_df[INHIBITION_COL][subset].to_numpy()
y = count_df.mean_x[subset].to_numpy()
print(scipy.stats.spearmanr(x, y))

plt.figure(figsize=(8, 4))
alpha=0.1
subset = np.random.choice(x.shape[0], 50000)
x = x[subset]
y = y[subset]
sns.regplot(x=x,
            y=y,
            marker='.',
            scatter_kws={'alpha': alpha},
            line_kws={'color': 'orange'},
            fit_reg=False,
            )
sub = x > 0.005  # lowess has problems with big pile-up on the left

lowess_mean = statsmodels.nonparametric.smoothers_lowess.lowess(y[sub], x[sub], frac=0.1, return_sorted=False)
y_var = (y[sub] - lowess_mean) ** 2
lowess_var = statsmodels.nonparametric.smoothers_lowess.lowess(y_var, x[sub], frac=0.1, return_sorted=False)
idx = np.argsort(x[sub])

x = x[sub][idx]
lowess_mean = lowess_mean[idx]
lowess_var = lowess_var[idx]

x = np.concatenate([np.zeros(1), x])
lowess_mean = np.concatenate([np.array([np.mean(y[~sub])]), lowess_mean])
lowess_var = np.concatenate([np.array([np.var(y[~sub])]), lowess_var])

plt.plot(x, lowess_mean, color='red', lw=3)
plt.plot(x, lowess_mean + 2.*np.sqrt(lowess_var), color='red', ls='--')
plt.plot(x, lowess_mean - 2.*np.sqrt(lowess_var), color='red', ls='--')
plt.xlabel('Actual inhibition')
plt.ylabel('Mean(predicted logit inhibition)')
plt.title(f'Mean predicted {INHIBITION_TYPE} inhibition per well')
plt.show()

plt.plot(x, lowess_var)
plt.xlabel('Actual inhibition')
plt.ylabel('Var(predicted logit inhibition)')
plt.title(f'Variance of predicted {INHIBITION_TYPE} inhibition per well')
plt.ylim(0, None)
plt.show()


In [None]:
# plot prediction variance vs ML inhibition
subset = (count_df[COUNT_COL] > 1) & (count_df.actives == 'sample')
x = count_df[INHIBITION_COL][subset].to_numpy()
y_var = count_df.var_x[subset].to_numpy()

plt.figure(figsize=(8, 4))
alpha=0.1
subset = np.random.choice(x.shape[0], 50000)
sns.regplot(x=x[subset],
            y=y_var[subset],
            marker='.',
            scatter_kws={'alpha': alpha},
            line_kws={'color': 'orange'},
            fit_reg=False,
            )


sub = x > 0.001
lowess_var = statsmodels.nonparametric.smoothers_lowess.lowess(y_var[sub], x[sub], frac=0.1, return_sorted=False)
idx = np.argsort(x[sub])
x = x[sub][idx]
lowess_var = lowess_var[idx]
plt.plot(x, lowess_var, color='red', lw=3)
plt.xlabel('Actual inhibition')
plt.ylabel('Var(predicted logit inhibition)')
plt.title(f'Within well variance of predicted {INHIBITION_TYPE} inhibition')
plt.ylim(0, 2.5)
plt.show()

### Covariates

Here we break out appearance-based predictions for different hepatocyte lots for the control wells. We see that the predictions vary considerably from one hepatocyte lot to the next.

In [None]:
# Break out within-well prediction means by hepatocyte lot.
# There are big differences between lots!
plt.figure(figsize=(8, 4))
for i, (hep_lot, hep_df) in enumerate(count_df[count_df.actives!='sample'].groupby('hep_lot')):
  color = COLORS[i]
  x = hep_df[INHIBITION_COL]
  y = hep_df.mean_x
  sub = x > 0.001
  sns.regplot(x=x,
              y=y,
              label=hep_lot,
              marker='.', color=color,
              fit_reg=False,
              scatter_kws={'alpha': 0.25})
  lowess = statsmodels.nonparametric.smoothers_lowess.lowess(
      y[sub], x[sub])
  plt.plot(lowess[:, 0], lowess[:, 1], color=color, lw=3)
legend = plt.legend(title='Hep lot', loc='best')
for lh in legend.legendHandles:
  lh.set_alpha(1)
plt.xlabel('Actual inhibition')
plt.ylabel('Mean(predicted logit inhibition)')
plt.title(f'Mean predicted {INHIBITION_TYPE} inhibition by hepatocyte lot')

plt.show()

In [None]:
# Break out within-well prediction variances by hepatocyte lot.
# There are big differences between lots!
plt.figure(figsize=(8, 4))
for i, (hep_lot, hep_df) in enumerate(count_df[count_df.actives!='sample'].groupby('hep_lot')):
  color = COLORS[i]
  sns.regplot(x=hep_df[INHIBITION_COL],
              y=hep_df.var_x, label=hep_lot,
              marker='.', color=color,
              fit_reg=False,
              scatter_kws={'alpha': 0.25})
  lowess = statsmodels.nonparametric.smoothers_lowess.lowess(
      hep_df.var_x, hep_df[INHIBITION_COL])
  plt.plot(lowess[:, 0], lowess[:, 1], color=color, lw=3)
legend = plt.legend(title='Hep lot', loc='best')
for lh in legend.legendHandles:
  lh.set_alpha(1)
plt.ylim(0, 2.5)
plt.xlabel('Actual inhibition')
plt.ylabel('Var(predicted logit inhibition)')
plt.title(f'Within-well variance of predicted {INHIBITION_TYPE} inhibition by hepatocyte lot')

plt.show()

In [None]:
count_df.head()

## Combining appearance with counts: motivation

The graph below shows the distribution of predicted inhibitions for individual hypnozoites as a function of the count-based estimates of inhibition in the well.

The key thing to note is that in wells where the count inhibition is low (< 0.25), the logits of the appearance predictions are mostly below -4. As the count inhibition increases, the apearance predictions shift rightward, with more and more predictions above -4.

Below we will model the joint variation of appearance-based and count-based inhibition estimates to get a combined estimate of a drug candidate's impact on the parasites.

In [None]:
tmp_df = prediction_df[prediction_df['actives'] == 'sample'][[INHIBITION_COL, 'pred_inhibition_logit']].copy()
label = np.array(['']*tmp_df.shape[0], dtype=object)

plt.figure(figsize=(12, 8))
n = 4
for i in range(n):
  lower = i/n
  upper = (i+1)/n
  subset = (tmp_df[INHIBITION_COL] >= lower) & (tmp_df[INHIBITION_COL] < upper)
  label = ['count inhibition']
  if i > 0:
    label = [f'{lower} <='] + label
  if i < n-1:
    label = label + [f'< {upper}']
  label = ' '.join(label)
  sns.kdeplot(data=tmp_df[subset],
              x='pred_inhibition_logit',
              palette=sns.color_palette('tab10')[i],
              lw=4,
              label=label)
plt.title('Appearance inhibition by count inhibition', fontsize=18)
plt.xlabel('logit(predicted inhibition)', fontsize=16)
plt.ylabel('Density', fontsize=16)
plt.legend(loc='best', fontsize=14)
plt.show()

In [None]:
# Put together some JAX arrays that will be used by our inhibition model
treatment = []
for (plate, well, actives) in zip(count_df.plate, count_df.well, count_df.actives):
  if actives == 'infected_control':
    treatment.append('infected_control')
  elif actives == 'active_control':
    treatment.append(f'active_control_{plate}')
  elif actives == 'sample':
    treatment.append(f'sample_{plate}_{well}')
  else:
    raise ValueError(f'Bad actives {actives}')
count_df['treatment'] = treatment

plates = sorted(set(count_df.plate))
plate_to_plate_index = {plate: i for i, plate in enumerate(plates)}
plate_index_to_plate = {i: plate for i, plate in enumerate(plates)}
plate_index = [plate_to_plate_index[p] for p in count_df.plate]
count_df['plate_index'] = plate_index

treatments = sorted(set(count_df.treatment))
treatments.remove('infected_control')
treatments = ['infected_control'] + treatments
treatment_to_treatment_index = {treatment: i for i, treatment in enumerate(treatments)}
treatment_index_to_treatment = {i: plate for i, treatment in enumerate(treatments)}
treatment_index = [treatment_to_treatment_index[t] for t in count_df.treatment]
count_df['treatment_index'] = treatment_index

In [None]:
# Create a pytree containing the model parameters we'll optimize over
n_plates = len(plates)
n_treatments = len(treatments)
n_hep_lots = len(set(count_df.hep_lot))

key = jax.random.PRNGKey(0)
key, subkey0, subkey1, subkey2, subkey3, subkey4, subkey5, subkey6 = jax.random.split(key, num=8)

# spline parameters for the map between true inhibition and prediction score mean
spline_order_mean = 3
knots_mean = spline.set_knot_multiplicity(
    k=spline_order_mean,
    knots=tuple(np.linspace(joint_model.MIN_LOGIT_INHIBITION, joint_model.MAX_LOGIT_INHIBITION, num=5).tolist()))
n_splines_mean = spline.n_basis_fns(k=spline_order_mean, knots=knots_mean)

# spline parameters for the map between true inhibition and prediction score variance
spline_order_var = 3
knots_var = spline.set_knot_multiplicity(
    k=spline_order_var,
    knots=tuple(np.linspace(joint_model.MIN_LOGIT_INHIBITION, joint_model.MAX_LOGIT_INHIBITION, num=5).tolist()))
n_splines_var = spline.n_basis_fns(k=spline_order_var, knots=knots_var)
print(n_splines_mean, n_splines_var)

# rough initial guess for the log(count) in an untreated well
mean_items = np.log(np.mean(count_df[COUNT_COL][count_df.actives == 'infected_control']))

# initial parameters that we will optimize
initial_params = joint_model.InhibitionParams(
    mean_items_unconstrained=jax.random.normal(key=subkey0, shape=(n_plates,)) + mean_items,
    overdispersion_unconstrained=jax.random.normal(key=subkey1, shape=(n_plates,)),
    inhibition_unconstrained=jax.random.normal(key=subkey2, shape=(n_treatments-1,)),
    mean_coeffs_unconstrained=jax.random.normal(key=subkey3, shape=(n_hep_lots, n_splines_mean,)),
    var_coeffs_unconstrained=jax.random.normal(key=subkey4, shape=(n_hep_lots, n_splines_var,)),
    mean_offset_unconstrained=jax.random.normal(key=subkey5, shape=(n_plates,)),
    var_offset_unconstrained=jax.random.normal(key=subkey6, shape=(n_plates,)),
    spline_order_mean=spline_order_mean,
    knots_mean=knots_mean,
    spline_order_var=spline_order_var,
    knots_var=knots_var,
    min_logit_inhibition=joint_model.MIN_LOGIT_INHIBITION,
    max_logit_inhibition=joint_model.MAX_LOGIT_INHIBITION,
    max_mean_coeff=joint_model.MAX_MEAN_COEFF,
    max_var_coeff=joint_model.MAX_VAR_COEFF,
    constant_fields=set(),
)

In [None]:
# Step 1: Estimate inhibition using only the counts.

# get the objective function for optimizing for counts only
objective_count = joint_model.get_objective_count(count_df, COUNT_COL)

# make sure the objective function doesn't blow up for the initial values
print(objective_count(initial_params))

value_and_grad_fn = jax.jit(jax.value_and_grad(objective_count))

learning_rate = 0.1
opt_init, opt_update = optax.chain(
    optax.scale_by_adam(),
    optax.scale(learning_rate)
)

params = initial_params
state = opt_init(params)
for i in range(2000):
  value, grad = value_and_grad_fn(params)
  if i % 100 == 0:
    print(i, value, flush=True)
  updates, state = opt_update(grad, state, params)
  params = optax.apply_updates(params, updates)
count_params = copy.deepcopy(params)
count_df['inhibition_count'] = np.asarray(count_params.inhibition)[treatment_index]

In [None]:
# Step 2: now we have some initial "true" inhibition estimates from above.
# Use those to estimate the parameters for the splines that map true inhibition
# to means and variances of prediction scores.

objective_joint = joint_model.get_objective_joint(count_df, COUNT_COL)
# make sure the objective function doesn't blow up for the initial values
print(objective_joint(count_params))

value_and_grad_fn = jax.jit(jax.value_and_grad(objective_joint))

d = dataclasses.asdict(count_params)
d['constant_fields'] = {'inhibition_unconstrained'}
params = joint_model.InhibitionParams(**d)

for learning_rate, steps in [(0.05, 200), (0.01, 1000)]:
  opt_init, opt_update = optax.chain(
      optax.scale_by_adam(),
      optax.scale(learning_rate)
  )

  state = opt_init(params)
  for i in range(steps):
    value, grad = value_and_grad_fn(params)
    if i % 100 == 0:
      print(i, value, flush=True)
    updates, state = opt_update(grad, state, params)
    params = optax.apply_updates(params, updates)
spline_params = copy.deepcopy(params)

In [None]:
hep_lots = sorted(set(count_df.hep_lot))
plate_index_by_hep_lot_index = {}
for i in range(n_hep_lots):
  plate_index_by_hep_lot_index[i] = sorted(set(count_df[count_df.hep_lot_index == i].plate_index))

In [None]:
# Step 3: now that we have initial estimates for "true" inhibition and for
# the spline parameters, do joint optimization for all parameters.
value_and_grad_fn = jax.jit(jax.value_and_grad(objective_joint))

d = dataclasses.asdict(spline_params)
d['constant_fields'] = set()
params = joint_model.InhibitionParams(**d)

for learning_rate, steps in [(0.01, 2000)]:
  opt_init, opt_update = optax.chain(
      optax.scale_by_adam(),
      optax.scale(learning_rate)
  )

  state = opt_init(params)
  for i in range(steps):
    value, grad = value_and_grad_fn(params)
    if i % 100 == 0:
      print(i, value, flush=True)
    updates, state = opt_update(grad, state, params)
    params = optax.apply_updates(params, updates)
joint_params = copy.deepcopy(params)
count_df['inhibition_joint'] = np.asarray(joint_params.inhibition)[treatment_index]

In [None]:
# Visualize the splines mapping inhibition to mean/variance of predictions
x = jnp.arange(joint_model.MIN_LOGIT_INHIBITION, joint_model.MAX_LOGIT_INHIBITION, 0.01)
n_hep_lots = len(hep_lots)

fig, axes = plt.subplots(1, n_hep_lots, figsize=(2*n_hep_lots, 2), sharex=True, sharey=True)
for i, hep_lot in enumerate(hep_lots):
  color = COLORS[i]
  for j, plate_index in enumerate(plate_index_by_hep_lot_index[i]):
    pred_mean = joint_params.pred_mean(
        x,
        spline_index=i,
        offset_index=plate_index,
        k=spline_order_mean,
        knots=knots_mean)
    expit_x = jax.scipy.special.expit(x)
    axes[i].plot(expit_x, pred_mean, color=color, label=hep_lot)
  axes[i].set_xlabel('inhibition')
  if i == 0:
    axes[i].set_ylabel('mean_score(inhibition)')
  axes[i].set_ylim(-6, 0.)
  axes[i].set_title(f'{hep_lot}')
plt.show()

fig, axes = plt.subplots(1, n_hep_lots, figsize=(2*n_hep_lots, 2), sharex=True, sharey=True)
for i, hep_lot in enumerate(hep_lots):
  color = COLORS[i]
  for j, plate_index in enumerate(plate_index_by_hep_lot_index[i]):
    pred_var = joint_params.pred_var(
        x,
        spline_index=i,
        offset_index=plate_index,
        k=spline_order_var,
        knots=knots_var)
    expit_x = jax.scipy.special.expit(x)
    axes[i].plot(expit_x, pred_var, color=color, label=hep_lot)
  axes[i].set_xlabel('inhibition')
  if i == 0:
    axes[i].set_ylabel('var_score(inhibition)')
  axes[i].set_ylim(0, 2.5)
  axes[i].set_title(f'{hep_lot}')
plt.show()


In [None]:
# Step 4 (bonus): Get an appearance inhibition score by finding the true
# inhibition that maximizes just the appearance loss (we'll use the spline
# parameters fit above)
objective_appearance = joint_model.get_objective_appearance(count_df)
value_and_grad_fn = jax.jit(jax.value_and_grad(objective_appearance))

d = dataclasses.asdict(joint_params)
constant_fields = set(d.keys())
constant_fields.remove('inhibition_unconstrained')
d['constant_fields'] = constant_fields
params = joint_model.InhibitionParams(**d)

for learning_rate, steps in [(0.1, 1000)]:
  opt_init, opt_update = optax.chain(
      optax.scale_by_adam(),
      optax.scale(learning_rate)
  )

  state = opt_init(params)
  for i in range(steps):
    value, grad = value_and_grad_fn(params)
    if i % 100 == 0:
      print(i, value, flush=True)
    updates, state = opt_update(grad, state, params)
    params = optax.apply_updates(params, updates)
appearance_params = copy.deepcopy(params)

inhibition_appearance = np.asarray(appearance_params.inhibition)[treatment_index]
inhibition_appearance[count_df[COUNT_COL] == 0] = 1.
count_df['inhibition_appearance'] = inhibition_appearance

In [None]:
# Look at relationship between count inhibition and appearance inhibition
plt.figure(figsize=(8, 4))
subset = count_df.actives == 'sample'
plt.scatter(count_df.inhibition_count[subset],
            count_df.inhibition_appearance[subset],
            marker='.', alpha=0.25)
plt.plot((0, 1), (0, 1), color='gray', ls='--')
plt.xlabel('Count inhibition')
plt.ylabel('Appearance inhibition')
plt.title(f'Count based vs appearance based inhibition, {INHIBITION_TYPE}')
plt.show()
print(scipy.stats.spearmanr(count_df.inhibition_count[subset],
                            count_df.inhibition_appearance[subset]))

In [None]:
# Look at relationship between count inhibition and appearance inhibition
# broken out by hepatocyte lot

xcol = 'inhibition_count'
ycol = 'inhibition_appearance'

subset = (count_df.actives == 'sample') & (np.isfinite(count_df[ycol]))
plt.figure(figsize=(8, 4))
for i, (hep_lot, hep_df) in enumerate(count_df[subset].groupby('hep_lot')):
  color = COLORS[i]
  x = hep_df[xcol]
  y = hep_df[ycol]
  sns.regplot(x=x,
              y=y,
              marker='.',
              fit_reg=False,
              scatter_kws={'alpha':0.25},)
  lowess = statsmodels.nonparametric.smoothers_lowess.lowess(y, x, 0.25)
  plt.plot(lowess[:, 0], lowess[:, 1], color=color, lw=3, label=hep_lot)
plt.xlabel('Count inhibition')
plt.ylabel('Appearance inhibition')
plt.title(f'Count based vs appearance based inhibition, {INHIBITION_TYPE}')
plt.ylim(0, 1)
legend = plt.legend(loc='best', title='Hep lot')
for lh in legend.legendHandles:
  lh.set_alpha(1)
plt.show()
print(scipy.stats.spearmanr(count_df[xcol][subset], count_df[ycol][subset]))

In [None]:
# Look at relationship between count inhibition and joint inhibition

plt.figure(figsize=(8, 4))
subset = count_df.actives == 'sample'
plt.scatter(count_df.inhibition_count[subset],
            count_df.inhibition_joint[subset],
            marker='.', alpha=0.25)
plt.plot((0, 1), (0, 1), color='gray', ls='--', lw=2)
plt.xlabel('Count inhibition')
plt.ylabel('Joint inhibition')
plt.title(f'Count based vs joint inhibition, {INHIBITION_TYPE}')
plt.show()
print(scipy.stats.spearmanr(count_df.inhibition_count[subset], count_df.inhibition_joint[subset]))

In [None]:
# Look at relationship between count inhibition and joint inhibition
# broken out by hepatocyte lot

xcol = 'inhibition_count'
ycol = 'inhibition_joint'

subset = (count_df.actives == 'sample') & (np.isfinite(count_df[ycol]))
plt.figure(figsize=(8, 4))
for i, (hep_lot, hep_df) in enumerate(count_df[subset].groupby('hep_lot')):
  color = COLORS[i]
  x = hep_df[xcol]
  y = hep_df[ycol]
  sns.regplot(x=x,
              y=y,
              marker='.',
              fit_reg=False,
              scatter_kws={'alpha':0.25},)
  lowess = statsmodels.nonparametric.smoothers_lowess.lowess(y, x, frac=0.15)
  plt.plot(lowess[:, 0], lowess[:, 1], color=color, lw=3, label=hep_lot)
plt.xlabel('Count inhibition')
plt.ylabel('Joint inhibition')
plt.title(f'Count based vs joint inhibition, {INHIBITION_TYPE}')
legend = plt.legend(loc='best', title='Hep lot')
for lh in legend.legendHandles:
  lh.set_alpha(1)
plt.show()
print(scipy.stats.spearmanr(count_df[xcol][subset], count_df[ycol][subset]))

In [None]:
# load dose response estimates
dr_df = pd.read_parquet(DOSE_RESPONSE_FILE)

COLS_TO_KEEP = ['batch', 'plate', 'well', 'embedding', 'actives', 'hep_lot',
                'ml_hypnozoite', 'inhibition_ml_hyp', 'ml_parasite',
                'inhibition_ml_par', 'hep_lot_index']

In [None]:
dr_df.rename(columns={'compound': 'blinded_concept'}, inplace=True)
dr_df = dr_df.groupby('blinded_concept').median().reset_index()

In [None]:
dr_df.sample(10)

In [None]:
# Measures of inhibition to compare as predictors of IC50/CC50
MEASURES = ['inhibition_cp_hyp_act',  # cell-profiler based counts with correction based on active control inhibition
            'inhibition_cp_hyp',      # cell-profiler based counts
            'inhibition_ml_hyp',      # ml based counts
            'inhibition_count',       # ml based counts (maximum likelihood estimate)
            'inhibition_appearance',  # appearance based estimate
            'inhibition_joint',       # joint estimate based on ml counts and appearance
            ]

In [None]:
sample_df = count_df[(count_df.actives == 'sample') & (count_df.blinded_concept != 'none')].copy()
sample_df = sample_df.groupby('blinded_concept')[MEASURES].median().reset_index()

# Join screen measures and IC50 estimates
joined_df = sample_df.merge(dr_df, how='inner', on=['blinded_concept']).reset_index()
joined_df.shape

In [None]:
# Compute AUCs for predictors of IC50/CC50 being below different thresholds
TARGET_COL = 'log_ic50'
TARGET_NAME = 'IC50'

for t in [10, 3, 1, 0.3]:
  y = joined_df[TARGET_COL] < np.log(t)
  print(f'Threshold {t}: {np.sum(y)} / {y.shape[0]} compounds have {TARGET_NAME} < threshold')
  for measure in MEASURES:
    print(f'AUC, {measure:25s} {sklearn.metrics.roc_auc_score(y, joined_df[measure]):.3f}')
  print()

In [None]:
# Look at patch scores as a function of total ML inhibition (count-based)
BINS = 200
plt.figure(figsize=(8, 4))
subset = (prediction_df[INHIBITION_COL] < 0.25)
plt.hist(prediction_df.pred_inhibition_logit[(prediction_df.actives == 'sample') & subset], bins=BINS, alpha=0.7, density=True,
         label='count inhibition < 0.25')
subset = (prediction_df[INHIBITION_COL] >= 0.25) & (prediction_df[INHIBITION_COL] < 0.5)
plt.hist(prediction_df.pred_inhibition_logit[(prediction_df.actives == 'sample') & subset], bins=BINS, alpha=0.6, density=True,
         label='0.25 <= count inhibition < 0.5')
subset = (prediction_df[INHIBITION_COL] >= 0.5) & (prediction_df[INHIBITION_COL] < 0.75)
plt.hist(prediction_df.pred_inhibition_logit[(prediction_df.actives == 'sample') & subset], bins=BINS, alpha=0.5, density=True,
                  label='0.5 <= count inhibition < 0.75')
subset = (prediction_df[INHIBITION_COL] >= 0.75)
plt.hist(prediction_df.pred_inhibition_logit[(prediction_df.actives == 'sample') & subset], bins=BINS, alpha=0.4, density=True,
         label='0.75 <= count inhibition')
plt.legend(loc='best')
plt.xlabel('Model prediction (logit scale)') #, fontsize=16)
plt.ylabel(f'{INHIBITION_TYPE.capitalize()} density') #, fontsize=16)
plt.title('Predicted inhibition') #, fontsize=20)
plt.show()

# Visualize parasites with different scores

In [None]:
from cell_img.common import image_lib
from cell_img.malaria_liver import metadata_lib

In [None]:
# set up the object to create images based on metadata
meta_ts = metadata_lib.MetadataIndex(TENSORSTORE_PATH, CHANNEL_TO_RGB,
                                     METADATA_ROOT_PATH)

In [None]:
prediction_df.sample(3)

In [None]:
# Display a random sample of hypnozoites
np.random.seed(123)
query_str = 'actives == "sample"'
nrows = 8
ncols = 8

subset_df = prediction_df.query(query_str)
print('There are %d patches with %s, subsetting %d' % (
    len(subset_df), query_str, ncols*nrows))

sampled_subset = subset_df.sample(ncols*nrows)

_ = meta_ts.contact_sheet_for_df(
    example_df=sampled_subset,
    patch_size=50, ncols=ncols, nrows=nrows,
    name_for_x_col='center_col', name_for_y_col='center_row',
    norm_then_stack=True)

In [None]:
# Display very happy hypnozoites (pred_inhibition < -5)
np.random.seed(123)
query_str = 'actives == "sample" and pred_inhibition_logit < -5'
nrows = 8
ncols = 8

subset_df = prediction_df.query(query_str)
print('There are %d patches with %s, subsetting %d' % (
    len(subset_df), query_str, ncols*nrows))

sampled_subset = subset_df.sample(ncols*nrows)

_ = meta_ts.contact_sheet_for_df(
    example_df=sampled_subset,
    patch_size=50, ncols=ncols, nrows=nrows,
    name_for_x_col='center_col', name_for_y_col='center_row',
    norm_then_stack=True)

In [None]:
# Display happy hypnozoites (pred_inhibition < -4)
np.random.seed(123)
query_str = 'actives == "sample" and pred_inhibition_logit < -4'
nrows = 8
ncols = 8

subset_df = prediction_df.query(query_str)
print('There are %d patches with %s, subsetting %d' % (
    len(subset_df), query_str, ncols*nrows))

sampled_subset = subset_df.sample(ncols*nrows)

_ = meta_ts.contact_sheet_for_df(
    example_df=sampled_subset,
    patch_size=50, ncols=ncols, nrows=nrows,
    name_for_x_col='center_col', name_for_y_col='center_row',
    norm_then_stack=True)

In [None]:
# Display stressed hypnozoites (pred_inhibition > -4)
np.random.seed(123)
query_str = 'actives == "sample" and pred_inhibition_logit > -4'
nrows = 8
ncols = 8

subset_df = prediction_df.query(query_str)
print('There are %d patches with %s, subsetting %d' % (
    len(subset_df), query_str, ncols*nrows))

sampled_subset = subset_df.sample(ncols*nrows)

_ = meta_ts.contact_sheet_for_df(
    example_df=sampled_subset,
    patch_size=50, ncols=ncols, nrows=nrows,
    name_for_x_col='center_col', name_for_y_col='center_row',
    norm_then_stack=True)

In [None]:
# Display very stressed hypnozoites (pred_inhibition > -3)
np.random.seed(123)
query_str = 'actives == "sample" and pred_inhibition_logit > -3'
nrows = 8
ncols = 8

subset_df = prediction_df.query(query_str)
print('There are %d patches with %s, subsetting %d' % (
    len(subset_df), query_str, ncols*nrows))

sampled_subset = subset_df.sample(ncols*nrows)

_ = meta_ts.contact_sheet_for_df(
    example_df=sampled_subset,
    patch_size=50, ncols=ncols, nrows=nrows,
    name_for_x_col='center_col', name_for_y_col='center_row',
    norm_then_stack=True)

In [None]:
if INHIBITION_TYPE in {'hypnozoite', 'parasite'}:
  STRESSED_HEP_LOT = 'KOG'
  HAPPY_HEP_LOT = 'HTV'

# things look more stressed in some hep lots
np.random.seed(123)
nrows = 8
ncols = 8

subset = ((prediction_df.actives == 'sample') &
          (prediction_df.hep_lot == STRESSED_HEP_LOT) &
          (prediction_df[INHIBITION_COL] > 0.1) &
          (prediction_df[INHIBITION_COL] <= 0.2))
subset_df = prediction_df[subset]
print('There are %d patches with %s, subsetting %d' % (
    len(subset_df), query_str, ncols*nrows))

sampled_subset = subset_df.sample(ncols*nrows)

_ = meta_ts.contact_sheet_for_df(
    example_df=sampled_subset,
    patch_size=50, ncols=ncols, nrows=nrows,
    name_for_x_col='center_col', name_for_y_col='center_row',
    norm_then_stack=True)

In [None]:
# things look less stressed in some hep lots
np.random.seed(123)

subset = ((prediction_df.actives == 'sample') &
          (prediction_df.hep_lot == HAPPY_HEP_LOT) &
          (prediction_df[INHIBITION_COL] > 0.1) &
          (prediction_df[INHIBITION_COL] <= 0.2))
subset_df = prediction_df[subset]
print('There are %d patches with %s, subsetting %d' % (
    len(subset_df), query_str, ncols*nrows))

sampled_subset = subset_df.sample(ncols*nrows)

_ = meta_ts.contact_sheet_for_df(
    example_df=sampled_subset,
    patch_size=50, ncols=ncols, nrows=nrows,
    name_for_x_col='center_col', name_for_y_col='center_row',
    norm_then_stack=True)

# Save out predictions

In [None]:
print(f'Writing to:\n{INHIBITION_FILE}')

count_df[['batch', 'plate', 'well', 'actives', 'blinded_concept', 'hep_lot',
            'ml_hypnozoite', 'ml_schizont', 'ml_parasite',
            'inhibition_ml_hyp', 'inhibition_ml_par',
            'inhibition_count', 'inhibition_appearance', 'inhibition_joint']].to_parquet(INHIBITION_FILE)