In [1]:
# Standard imports
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from dataclasses import dataclass
import time
import pickle

# Import x to one hot encoding
import utils 

# Logomaker
import logomaker

# jax import
import jax.random as random
import jax.numpy as jnp
from jax.numpy import DeviceArray

# numpyro imports
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, HMC, SVI
from numpyro.infer import Predictive
from numpyro.diagnostics import hpdi

# arviz
import arviz as az

# For visualization. Note that this notebook won't be making figure, so don't need to be picky. 
#%matplotlib notebook
%matplotlib inline

  from jax.numpy import DeviceArray


In [2]:
@dataclass
class args:
    num_chains = 4
    num_samples = 30000
    num_warmup = 30000
    device = 'cpu'
    
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

# generate random keys for training and predictions
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

In [7]:
def load_data(sheet_name=None):
    # Name of excel file
    file_name = '../data/qPCR/control_dose_response_curves.xlsx'
    
    # Return sheet names if none specified
    if sheet_name is None:
        return pd.ExcelFile(file_name).sheet_names
    
    # Load data from Excel sheet
    data_df = pd.read_excel(file_name,
                            sheet_name=sheet_name, header=[0, 1], index_col=[0, 1])
    data_df.index.rename(['conc', 'bio_rep'], inplace=True)
    data_df.columns.rename(['primers', 'tech_rep'], inplace=True)

    # Pivot to make tidy data
    tidy_df = data_df.melt(value_name='cycles',
                           ignore_index=False).reset_index()
    
    # For each (conc, bio_rep) compute mean of cycles across tech_rep
    df = tidy_df.groupby(['conc','bio_rep','primers']).mean()['cycles'].to_frame().reset_index()
    df = df.pivot(index=['conc','bio_rep'],columns='primers', values='cycles')

    # Compute dCt values (NOTE: exclusion - inclusion !)
    df['dCt'] = df['exclusion'] - df['inclusion']
    df = df.reset_index()[['conc','bio_rep','dCt']]
    df.columns.name = ''
    
    return df

In [8]:
def single_drug_model(x, y=None):

    # Context strength
    log10_S = numpyro.sample("log10_S", dist.Uniform(low=-3, high=3))
    S=numpyro.deterministic('S', 10.0**log10_S)
    #S=numpyro.deterministic('S', 0.55)

    # Hill coefficient
    log2_H = numpyro.sample("log2_H", dist.Uniform(low=-2, high=2))
    H=numpyro.deterministic('H', 2.0**log2_H)
    #H = numpyro.deterministic('H', 1.13)

    # Normalizing concentration
    log10_NC = numpyro.sample("log10_NC", dist.Uniform(low=-3, high=3))
    NC=numpyro.deterministic('NC',10.0**log10_NC)

    # Measurement noise
    log10_sigma = numpyro.sample("log10_sigma", dist.Uniform(low=-2, high=2))
    sigma = numpyro.deterministic('sigma',10.0**log10_sigma)

    # Amplitude
    log10_alpha = numpyro.sample("log10_alpha", dist.Uniform(low=0, high=6))
    alpha = numpyro.deterministic('alpha',10.0**log10_alpha)
    #alpha = numpyro.deterministic('alpha',97.9)
    
    # R = inclusion / exclusion
    R = numpyro.deterministic('R', S*(1 + (x/NC)**H)/(1 + (1/alpha)*(x/NC)**H))
    log2_R = numpyro.deterministic('log2_R', jnp.log2(R))
    numpyro.sample('yhat', dist.Normal(log2_R, sigma), obs=y)

In [9]:
def plot_regression(model_prediction_dict, data_dict, ax):
    x = data_dict['x'].ravel()
    y = data_dict['y'].ravel()
    # Sort values for plotting by x axis
    idx = jnp.argsort(x)
    x_plot = model_prediction_dict['concs'] #x[idx]
    #y_plot = model_prediction_dict['concs'] #y[idx]
    mean   = model_prediction_dict['mean'] #y_hat_mean[idx]
    hpdi   = model_prediction_dict['hdpi'] #y_hat_hpdi[:, idx]
    # Plot
    ax.scatter(x, y, label='data', s=20, c='k')
    ix = (x==0)
    x_left = min(x[~ix])/3
    ax.scatter(x_left*np.ones(sum(ix)),
               y[ix],
               s=20,
               marker='<', 
               c='k')
    ax.fill_between(x_plot, hpdi[0,:].ravel(), 
                    hpdi[1,:].ravel(), alpha=0.3, 
                    interpolate=True, color='darkorange',
                    label=r'$95\%HDI$')
    ax.plot(x_plot, mean, label=r'$\hat{C_t}$', c='royalblue')

    return ax

In [10]:
# Get list of all sheet names
# sheet_names = load_data()
# sheet_names = [
#     't25a_ris_v2',          #5C  
#     'smn2_pt1_ris',         #5D
#     'smn2_pt2_ris',         #5E
#     'smn2_pt3_ris',         #5F
#     't25a_bran_v5',         #5G
#     'smn2_pt1_bran',        #5H 
#     'smn2_pt2_bran',        #5I
#     'smn2_pt3_bran',        #5J
#     't25a_asoi7_v2',        #6A
#     't25a_asoi6',           #6B
#     'elp1_rectas',          #6C
#     'elp1_aso',             #6D
#     't25a_ris_bran_v2',     #6E
#     't25a_ris_asoi6_v4',    #6F
#     't25a_ris_asoi7_v2',    #6G
#     't25a_bran_asoi6_v2',   #6H
#     't25a_bran_asoi7',      #6I
#     't25a_asoi6_asoi7_v3',  #6J
#     'elp1_rectas_aso'       #6K
# ]
sheet_names = ['control_smn2_asoi20']

# Do inference for every sheet name
for sheet_name in sheet_names:
    print(f'Processing {sheet_name}...')
    
    # Load data
    data_df = load_data(sheet_name)

    # Remove NaNs
    data_df = data_df.dropna()

    print('\nTraining using MCMC\n')
    start = time.time()
    data_dict = {
        "y": jnp.array(data_df['dCt'].values),
        "x": jnp.array(data_df['conc'].values)}

    kernel = NUTS(model=single_drug_model)
    sample_kwargs = dict(
        sampler=kernel, 
        num_warmup=args.num_warmup, 
        num_samples=args.num_samples, 
        num_chains=args.num_chains, 
        chain_method="parallel",
    )
    mcmc = MCMC(**sample_kwargs)
    mcmc.run(rng_key, **data_dict)
    print("\nMCMC elapsed time:", time.time() - start)

    #Save posterior samples
    mcmc_samples = mcmc.get_samples()
    mcmc_file_name = f'../mcmc_samples/mcmc_{sheet_name}.pkl'
    with open(mcmc_file_name, 'wb') as f:
        pickle.dump(mcmc_samples, f)
        
    # Use Arviz for summary
    numpyro_inference_data = az.from_numpyro(mcmc);
    summary_df = az.summary(numpyro_inference_data)

    # Plot posterior samples
    ppc = Predictive(single_drug_model, mcmc.get_samples())
    x_grid = np.logspace(-2,3,100)
    ppc_val = ppc(rng_key_predict, x=x_grid)
    
    model_prediction_dict = {
        "concs":x_grid,
        "mean": jnp.mean(ppc_val['yhat'], axis=0),
        "hdpi": numpyro.diagnostics.hpdi(ppc_val['yhat'], prob=0.95),
    }
    
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 4))
    ax = plot_regression(model_prediction_dict, data_dict, ax)
    ax.legend()
    ax.set_xscale('log')
    ax.set_xlabel('concentraion')
    ax.set_ylabel(r'$\Delta C_t$')
    H_mean = summary_df.loc['H']['mean']
    H_lo = summary_df.loc['H']['hdi_3%']
    H_hi = summary_df.loc['H']['hdi_97%']
    S_mean = summary_df.loc['S']['mean']
    ax.set_title(f'{sheet_name}: $H =${H_mean} [{H_lo}, {H_hi}]')
    plt.tight_layout()
    plt.savefig(f'all_dose_response_curves/{sheet_name}.png')
    plt.close()

Processing control_smn2_asoi20...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 3.2050938606262207
