In [1]:
# Standard imports
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from dataclasses import dataclass
import time
import pickle
import pingouin as pg

# Import x to one hot encoding
import utils 

# Logomaker
import logomaker

# jax import
import jax.random as random
import jax.numpy as jnp
from jax.numpy import DeviceArray

# numpyro imports
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, HMC, SVI
from numpyro.infer import Predictive
from numpyro.diagnostics import hpdi

# arviz
import arviz as az

# For visualization. Note that this notebook won't be making figure, so don't need to be picky. 
#%matplotlib notebook
%matplotlib inline

  return warn(
  return warn(


In [2]:
@dataclass
class args:
    num_chains = 4
    num_samples = 30000
    num_warmup = 30000
    device = 'cpu'
    
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

# generate random keys for training and predictions
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

In [3]:
def load_data(sheet_name=None):
    # Name of excel file
    file_name = '../data/qPCR/linear_mixture_curves.xlsx'
    
    # Return sheet names if none specified
    if sheet_name is None:
        return pd.ExcelFile(file_name).sheet_names
    
    data_df = pd.read_excel(file_name,
                            sheet_name=sheet_name)
    data_df.columns = ['drug1','drug2','rep','deltaCt']
    
    return data_df

In [4]:
load_data(sheet_name='R_Ai6')

Unnamed: 0,drug1,drug2,rep,deltaCt
0,10,0,2,5.106479
1,10,0,1,5.243567
2,8,10,2,6.186193
3,8,10,1,6.173736
4,6,20,2,6.15159
5,6,20,1,6.167484
6,5,25,2,6.114139
7,5,25,1,5.962021
8,4,30,2,5.777797
9,4,30,1,5.703455


In [5]:
def ramp_model(x, y=None):
    a = numpyro.sample("a", dist.Normal(loc=0, scale=10))
    b = numpyro.sample("b", dist.Normal(loc=0, scale=10))
    c = numpyro.sample("c", dist.Normal(loc=0, scale=10))
    log10_sigma = numpyro.sample("log10_sigma", dist.Uniform(low=-3, high=-1))
    sigma = numpyro.deterministic('sigma', 10.0**log10_sigma)
    yhat = numpyro.deterministic('yhat', a + b*(x-.5) + c*(x-.5)**2)
    numpyro.sample('y', dist.Normal(yhat, sigma), obs=y)
    

In [6]:
def plot_regression(model_prediction_dict, data_dict, ax):
    x = data_dict['x'].ravel()
    y = data_dict['y'].ravel()
    
    # Sort values for plotting by x axis
    idx = jnp.argsort(x)
    
    x_plot = model_prediction_dict['concs'] #x[idx]
    mean   = model_prediction_dict['mean'] #y_hat_mean[idx]
    hpdi   = model_prediction_dict['hdpi'] #y_hat_hpdi[:, idx]
    
    # Plot
    ax.scatter(x, y, label='data', s=20, c='k')
    ax.fill_between(x_plot, hpdi[0,:].ravel(), 
                    hpdi[1,:].ravel(), alpha=0.3, 
                    interpolate=True, color='darkorange',
                    label=r'$95\%HDI$')
    
    ax.plot(x_plot, mean, label=r'$\hat{C_t}$', c='royalblue')

    return ax

In [7]:
# def draw(ax, sheet_name, color):
    
#     drug_names_dict = {
#         'R':'risdiplam',
#         'B':'branaplam',
#         'Ai6':'ASO i6',
#         'Ai7':'ASO i7'
#     }
    
#     ticks_dict = {
#         'R':[0, 2, 4, 6, 8, 10],
#         'B':[0, 2, 4, 6, 8, 10],
#         'Ai6':[0, 10, 20, 30, 40, 50],
#         'Ai7':[0, 5, 10, 15, 20, 25]
#     }
    
#     ax1 = ax
#     df = pd.read_excel('22.12.09 Linear combinations R B Ai7 Ai6.xlsx',
#                        sheet_name=sheet_name)
#     drug1, drug2 = sheet_name.split('_')
#     x1 = df[drug1].values.reshape(-1,1)
#     x2 = df[drug2].values.reshape(-1,1)
#     y = df['deltaCt']

#     ticks1=ticks_dict[drug1]
#     ticks2=ticks_dict[drug2]
    
#     xmin1 = min(x1)
#     xmax1 = max(x1)
#     xspan1 = xmax1-xmin1
#     xlabel1 = drug_names_dict[drug1]+' [nM]'
    
#     xmin2 = min(x2)
#     xmax2 = max(x2)
#     xspan2 = xmax2-xmin2
#     xlabel2 = drug_names_dict[drug2]+' [nM]'
    
#     xgrid1 = np.linspace(xmin1, xmax1, 100)

#     x = np.concatenate([x1, x1**2], axis=1)
#     pgdf = pg.linear_regression(x, y)
#     a = pgdf.loc[0,'coef']
#     b = pgdf.loc[1,'coef']
#     c = pgdf.loc[2,'coef']
#     ygrid1 = a + b*xgrid1 + c*xgrid1**2
    
#     ax1.semilogy(x1, 2**y, 'o', alpha=.5, linewidth=0, color=color) 
#     ax1.semilogy(xgrid1, 2**ygrid1, '-', color=color)
    
#     ax1.set_xlim([xmin1-.1*xspan1,xmax1+.1*xspan1])
#     ax1.set_ylim(ylim)
#     ax1.set_xlabel(xlabel1, labelpad=5)
#     ax1.set_xticks(ticks1)
#     ax1.set_xticklabels(ticks1[::-1])
    
#     ax2 = ax1.twiny()
#     ax2.set_xlim([xmin2-.1*xspan2,xmax2+.1*xspan2])
#     ax2.set_xlabel(xlabel2, labelpad=5)
#     ax2.set_xticks(ticks2)
#     ax2.set_xticklabels(ticks2[::-1])
    
#     s = f'$P=${pgdf.loc[2,"pval"]:.1e}'
#     ax1.text(s=s,x=xmin1, y=2**6.8, ha='left', va='top')
    
#     return pgdf

In [8]:
# Get list of all sheet names
sheet_names = [
    'R_B',      #6E
    'R_Ai6',    #6F
    'B_Ai6',    #6G
    'R_Ai7_v2', #6H
    'B_Ai7_v2', #6I
    'Ai6_Ai7',  #6J
    'REC_i20'   #6K
]

# Do inference for every sheet name
for sheet_name in sheet_names:
    print(f'Processing {sheet_name}...')
    
    # Load data
    data_df = load_data(sheet_name)

    drug1, drug2 = sheet_name.split('_')[:2]
    
    # Remove NaNs
    data_df = data_df.dropna()

    print('\nTraining using MCMC\n')
    start = time.time()
    x = data_df['drug1'].values
    x = x/max(x)
    data_dict = {
        "y": jnp.array(data_df['deltaCt'].values),
        "x": jnp.array(x)}

    kernel = NUTS(model=ramp_model)
    sample_kwargs = dict(
        sampler=kernel, 
        num_warmup=args.num_warmup, 
        num_samples=args.num_samples, 
        num_chains=args.num_chains, 
        chain_method="parallel",
    )
    mcmc = MCMC(**sample_kwargs)
    mcmc.run(rng_key, **data_dict)
    print("\nMCMC elapsed time:", time.time() - start)

    #Save posterior samples
    mcmc_samples = mcmc.get_samples()
    mcmc_file_name = f'../mcmc_samples/mcmc_{sheet_name}.pkl'
    with open(mcmc_file_name, 'wb') as f:
        pickle.dump(mcmc_samples, f)
        
    # Use Arviz for summary
    numpyro_inference_data = az.from_numpyro(mcmc);
    summary_df = az.summary(numpyro_inference_data)

    # Plot posterior samples
    ppc = Predictive(ramp_model, mcmc.get_samples())
    x_grid = np.linspace(0,1,100)
    ppc_val = ppc(rng_key_predict, x=x_grid)
    
    model_prediction_dict = {
        "concs":x_grid,
        "mean": jnp.mean(ppc_val['y'], axis=0),
        "hdpi": numpyro.diagnostics.hpdi(ppc_val['y'], prob=0.95),
    }
    
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 4))
    #ax = draw(ax, sheet_name, color='k')
    ax = plot_regression(model_prediction_dict, data_dict, ax)
    ax.legend()
    ax.set_xlabel('concentraion')
    ax.set_ylabel(r'$\Delta C_t$')
    plt.tight_layout()
    plt.savefig(f'all_linear_mixture_curves/{sheet_name}.png')
    plt.close()

Processing R_B...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 2.3576390743255615
Processing R_Ai6...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 2.298980236053467
Processing B_Ai6...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 2.366780996322632
Processing R_Ai7_v2...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 2.257889986038208
Processing B_Ai7_v2...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 2.287207841873169
Processing Ai6_Ai7...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 2.2408971786499023
Processing REC_i20...

Training using MCMC



  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]


MCMC elapsed time: 2.3007030487060547
