In [1]:
# Standard imports
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time

# Insert path to mavenn beginning of path
import os
import sys
abs_path_to_mavenn = os.path.abspath('../../')
sys.path.insert(0, abs_path_to_mavenn)

# Load mavenn
import mavenn
print(mavenn.__path__)

['/Users/jkinney/github/mavenn/mavenn']


In [2]:
# Load example data
data_df = mavenn.load_example_dataset('mpsa')
data_df.head()

Unnamed: 0,training_set,y,dy,x
0,False,-3.751854,0.4442,AAAGCAAAA
1,True,-2.697741,0.369972,AAAGCAAAC
2,True,-2.242947,0.575121,AAAGCAAAG
3,False,-3.067251,0.357014,AAAGCAAAT
4,False,-2.987074,0.472637,AAAGCAACA


In [3]:
# Split into trianing and test data
ix = data_df['training_set']
L = len(data_df['x'][0])
train_df = data_df[ix]
print(f'training N: {len(train_df):,}')
test_df = data_df[~ix]
print(f'testing N: {len(test_df):,}')

training N: 17,498
testing N: 4,431


In [4]:
model_kwargs = {
    'regression_type':'GE',
    'L':L,
    'alphabet':'dna',
    'ge_nonlinearity_type':'nonlinear',
    'gpmap_type':'mlp',
    'ge_noise_model_type':'SkewedT',
    'ge_heteroskedasticity_order':2
}

fit_kwargs={'learning_rate':.005,
            'epochs':100,
            'batch_size': 200,
            'early_stopping': True,
            'early_stopping_patience': 100,
            'linear_initialization': False}

file_name = 'mpsa_ge_blackbox'

In [5]:
if True:

    # Set seeed
    mavenn.set_seed(0)

    # Define model
    model = mavenn.Model(**model_kwargs)

    # Set training data
    model.set_data(x=train_df['x'],
                   y=train_df['y'])

    # Fit model to data
    model.fit(**fit_kwargs)
    
    # Save model
    model.save(file_name)

N = 17,498 observations set as training data.
Data shuffled.
Time to set data: 0.366 sec.
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/

In [6]:
model.layer_gpmap.input_shape

(None, 36)

In [7]:
model.get_theta().keys()

dict_keys(['L', 'C', 'alphabet', 'theta_0', 'theta_lc', 'theta_lclc', 'theta_mlp', 'logomaker_df'])

In [8]:
model.get_theta()['theta_mlp']

[[array([[ 1.58156380e-01, -9.93106589e-02, -8.18898901e-03,
           1.91798314e-01,  8.13778713e-02,  2.73752189e-03,
           9.29782242e-02,  1.95636362e-01,  1.92756012e-01,
          -2.63394924e-14],
         [ 8.24619159e-02, -4.97937910e-02,  5.57352193e-02,
          -1.61574967e-02,  1.82693928e-01, -9.38864127e-02,
          -7.76444376e-02,  1.11436360e-01,  5.56413457e-02,
          -1.05564558e-35],
         [-3.57295901e-01, -1.31085768e-01, -4.23249006e-02,
           1.28747839e-02,  4.84974273e-02,  2.35908836e-01,
           1.07815169e-01, -3.38197909e-02, -3.58532995e-01,
           3.24944029e-37],
         [-2.79000495e-02, -8.41639191e-02, -3.37245725e-02,
           1.60978828e-02, -8.68158400e-01, -8.56827758e-03,
           1.18786044e-01,  1.19688578e-01,  2.88222462e-01,
          -8.92286093e-36],
         [ 1.39934361e-01, -8.98744911e-03, -7.43143782e-02,
           5.84351793e-02,  5.93201339e-01, -1.52412683e-01,
          -2.39134610e-01,  1.0261

In [9]:
assert False

AssertionError: 

In [None]:
# Load model
model = mavenn.load(file_name)

In [None]:
# Get x and y
x_test = test_df['x'].values
y_test = test_df['y'].values
dy_test = test_df['dy'].values

In [None]:
# Show training history
print('On test data:')

# Compute likelihood information
I_var, dI_var =  model.I_varlihood(x=x_test, y=y_test)
print(f'I_var_test: {I_var:.3f} +- {dI_var:.3f} bits') 

# Compute predictive information
I_pred, dI_pred = model.I_predictive(x=x_test, y=y_test)
print(f'I_pred_test: {I_pred:.3f} +- {dI_pred:.3f} bits')

# Compute intrinsic information
I_intr, dI_intr = mavenn.I_intrinsic(y_values=y_test, dy_values=dy_test)
print(f'I_intrinsic: {I_intr:.3f} +- {dI_intr:.3f} bits')

# Compute percent info explained
pct = 100*I_pred/I_intr
dpct = 100*np.sqrt((dI_pred/I_intr)**2 + (dI_intr*I_pred/I_intr**2)**2)
print(f'percent info explained: {pct:.1f}% +- {dpct:.1f}%')

I_var_hist = model.history['I_var']
val_I_var_hist = model.history['val_I_var']

fig, ax = plt.subplots(1,1,figsize=[4,4])
ax.plot(I_var_hist, label='I_var_train')
ax.plot(val_I_var_hist, label='I_var_val')
ax.axhline(I_var, color='C2', linestyle=':', label='I_var_test')
ax.axhline(I_pred, color='C3', linestyle=':', label='I_pred_test')
ax.axhline(I_intr, color='C4', linestyle=':', label='I_intrinsic')
ax.legend()
ax.set_xlabel('epochs')
ax.set_ylabel('bits')
ax.set_title('training hisotry')
ax.set_ylim([0, I_intr*1.2]);

In [None]:
# Predict latent phentoype values (phi) on test data
phi_test = model.x_to_phi(x_test)

# Predict measurement values (yhat) on test data
yhat_test = model.x_to_yhat(x_test)

# Set phi lims and create grid in phi space
phi_lim = [min(phi_test)-.5, max(phi_test)+.5]
phi_grid = np.linspace(phi_lim[0], phi_lim[1], 1000)

# Compute yhat each phi gridpoint
yhat_grid = model.phi_to_yhat(phi_grid)

# Compute 90% CI for each yhat
q = [0.05, 0.95] #[0.16, 0.84]
yqs_grid = model.yhat_to_yq(yhat_grid, q=q)

# Create figure
fig, ax = plt.subplots(1, 1, figsize=[4, 4])

# Illustrate measurement process with GE curve
ax.scatter(phi_test, y_test, color='C0', s=5, alpha=.2, label='test data')
ax.plot(phi_grid, yhat_grid, linewidth=2, color='C1',
        label='$\hat{y} = g(\phi)$')
ax.plot(phi_grid, yqs_grid[:, 0], linestyle='--', color='C1', label='68% CI')
ax.plot(phi_grid, yqs_grid[:, 1], linestyle='--', color='C1')
ax.set_xlim(phi_lim)
ax.set_xlabel('latent phenotype ($\phi$)')
ax.set_ylabel('measurement ($y$)')
ax.set_title('measurement process')
ax.set_ylim([-6, 3])
ax.legend()

# Fix up plot
fig.tight_layout()
plt.show()

In [None]:
# Plot pairwise parameters
if model.gpmap_type in ['pairwise', 'neighbor']:
    theta = model.get_theta()
    fig, ax = plt.subplots(1, 1, figsize=[8, 4])
    mavenn.heatmap_pairwise(values=theta['theta_lclc'],
                            alphabet=theta['alphabet'],
                            ax=ax);