In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_probability as tfp
from tqdm import trange
from scipy import stats
tfk = tf.keras

from fishnets import *



Create training data

In [2]:
# data sizes
n_sims = 10000
n_data = 100

# fiducial parameters
theta_fid = tf.constant([0.,0.], dtype=tf.float32)
theta_fid_ = theta_fid.numpy()

# prior mean and covariance
priorCinv = tf.convert_to_tensor(np.eye(2), dtype=tf.float32)
priormu = tf.constant([0.,0.], dtype=tf.float32)

# slopes and intercepts
m_ = np.random.normal(0, 1, n_sims).astype(np.float32)
c_ = np.random.normal(0, 1, n_sims).astype(np.float32)

# x-values
x_ = np.random.uniform(0, 10, (n_sims, n_data)).astype(np.float32)

# noise std devs
sigma_ = np.random.uniform(1, 10, (n_sims, n_data)).astype(np.float32)

# simulate "data"
y_ = m_[...,np.newaxis]*x_ + c_[...,np.newaxis] + np.random.normal(0, 1, sigma_.shape)*sigma_
y_ = y_.astype(np.float32)

# stack up the data and parameters
data = tf.stack([y_, x_, 1./sigma_**2], axis=-1)
theta = tf.stack([m_, c_], axis=-1)

Create the masks if we want to train over variable N

In [6]:
# construct masks
score_mask = np.ones((n_sims, n_data, 2))
fisher_mask = np.ones((n_sims, n_data, 2, 2))

# mask or not?
masked = True

# make the masks
if masked is True:
    for i in range(n_sims):
        
        # how many points to mask?
        n_mask = np.random.randint(1, n_data-5)
        
        # choose which points to mask
        idx = np.random.choice(np.arange(n_data), n_mask, replace=False)
        
        # mask those points (set the fisher and score masks to zero for those points)
        for j in idx:
            score_mask[i,j,:] = 0
            fisher_mask[i,j,...] = 0

score_mask = tf.convert_to_tensor(score_mask, dtype=tf.float32)
fisher_mask = tf.convert_to_tensor(fisher_mask, dtype=tf.float32)

Construct the exact MLEs for comparison

In [5]:
# compute MLEs
F_ = np.sum(np.stack([x_**2 / sigma_**2, x_ / sigma_**2, x_ / sigma_**2, 1. / sigma_**2], axis=-1).reshape((n_sims, n_data, 2, 2)) * fisher_mask.numpy(), axis=1) + priorCinv.numpy()
t_ = np.sum(np.stack([x_*(y_ - (theta_fid[0]*x_ + theta_fid[1]))/ sigma_**2, (y_ - (theta_fid[0]*x_ + theta_fid[1])) / sigma_**2], axis=-1) * score_mask.numpy(), axis=1) - np.dot(priorCinv, theta_fid - priormu)
pmle_ = theta_fid_ + np.einsum('ijk,ik->ij', np.linalg.inv(F_), t_)

Make the Fishnet model

In [7]:
Model = FishnetTwin(n_parameters=2, 
                n_inputs=3, 
                n_hidden_score=[64, 64], 
                activation_score=[tf.nn.leaky_relu, tf.nn.leaky_relu],
                n_hidden_fisher=[64, 64], 
                activation_fisher=[tf.nn.leaky_relu, tf.nn.leaky_relu],
                optimizer=tf.keras.optimizers.Adam(lr=5e-4),
                theta_fid=theta_fid,
                priormu=tf.zeros(2, dtype=tf.float32),
                priorCinv=tf.eye(2, dtype=tf.float32))

Train the model

In [None]:
Model.train((data, theta, score_mask, fisher_mask), lr=5e-4, epochs=500)
Model.train((data, theta, score_mask, fisher_mask), lr=1e-4, epochs=500)
Model.train((data, theta, score_mask, fisher_mask), lr=5e-5, epochs=500)

 47%|████▋     | 234/500 [03:58<04:29,  1.01s/it, loss=-1.07] 

In [None]:
Model.lbfgs_optimize(data, theta, score_mask, fisher_mask, max_iterations=10, tolerance=1e-5)

Compute model predictions to compare to exact MLEs

In [None]:
# model MLEs
mle, F = Model.compute_mle_(data, score_mask, fisher_mask)

plt.hist(mle[:,0].numpy() - theta[:,0].numpy(), bins = 60, histtype='step', density=True, label='learned score MLE')
plt.hist(pmle_[:,0] - theta[:,0].numpy(), bins = 60, histtype='step', density=True, label='exact MLE')
std = np.std(pmle_[:,0] - theta[:,0].numpy())
x = np.linspace(-4*std, 4*std, 500)
plt.plot(x, stats.norm.pdf(x, loc=0, scale=std), color='orange')
#plt.axvline(np.mean(mle[:,0].numpy() - theta[:,0].numpy()))
#plt.axvline(np.mean(pmle_[:,0] - theta[:,0].numpy()))
plt.xlabel('$\hat{m} - m$')
plt.legend(frameon=False)
plt.show()

plt.hist(mle[:,1].numpy() - theta[:,1].numpy(), bins = 60, histtype='step', density=True, label='learned score MLE')
plt.hist(pmle_[:,1] - theta[:,1].numpy(), bins = 60, histtype='step', density=True, label='exact MLE')
std = np.std(pmle_[:,1] - theta[:,1].numpy())
x = np.linspace(-4*std, 4*std, 500)
plt.plot(x, stats.norm.pdf(x, loc=0, scale=std), color='orange')
#plt.axvline(np.mean(mle[:,1].numpy() - theta[:,1].numpy()), color='blue')
#plt.axvline(np.mean(pmle_[:,1] - theta[:,1].numpy()), color='orange')
plt.xlabel('$\hat{c} - c$')
plt.legend(frameon=False)
plt.show()

In [None]:
# predicted (blue) vs true (orange) MLEs
x = np.linspace(-4,4,100)
plt.scatter(theta[:,0].numpy(), mle[:,0].numpy(), s = 0.1)
plt.scatter(theta[:,0].numpy(), pmle_[:,0], s = 0.1)
plt.plot(x,x)
plt.xlabel('m true')
plt.ylabel('MLE')
plt.show()

plt.scatter(theta[:,1].numpy(), mle[:,1].numpy(), s = 0.1)
plt.scatter(theta[:,1].numpy(), pmle_[:,1], s = 0.1)
plt.plot(x,x)
plt.xlabel('c true')
plt.ylabel('MLE')
plt.show()