In [1]:
%load_ext autoreload
%autoreload 2
from functools import partial

import keras

from datasets import example_datasets, to_numpy
from models import mixture_poissons,poisson_glm
from metrics import mixture_poi_loss, get_bpr_loss_func, mix_bpr, get_penalized_bpr_loss_func_mix
from experiments import training_loop
from plotting_funcs import plot_losses, plot_frontier

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

2024-05-08 12:07:49.960732: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-08 12:07:50.015171: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-08 12:07:50.015211: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-08 12:07:50.016504: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-08 12:07:50.027134: I tensorflow/core/platform/cpu_feature_guar

In [2]:
seed=360
num_components=4
learning_rate = 0.005
epochs=1250
outdir = '/cluster/home/kheuto01/testdir'
penalty = 0
threshold = 0.45
K=4
do_only=True
# tracts/distributions
S=12
# history/features
H = 3
# total timepoints
T= 500
perturbed_sigma=0.1

In [3]:
train_dataset, val_dataset, test_dataset = example_datasets(H, T, seed=seed)
train_X_THS, train_y_TS = to_numpy(train_dataset)
val_X_THS, val_y_TS = to_numpy(val_dataset)

input_shape = (H,S)

negative_bpr_K = get_bpr_loss_func(K, sigma=perturbed_sigma)
loss_func=negative_bpr_K

In [4]:
model, mix_weights  = mixture_poissons(poisson_glm, input_shape, num_components=num_components)

$$\nabla_\phi \theta^*(\phi) = \frac{1}{M}\sum_{m=1}^M f(y) \nabla_\phi \log p(y)$$
$$\nabla_\phi \mathcal{L}(\theta^*(\phi),y) = \nabla_\phi \theta^*(\phi) \nabla_{\theta^*}\mathcal{L}$$

In [5]:
num_epochs = 1
num_samples = 1

In [6]:
def cross_ratio_decision(predicted_y, location_axis =-1):
    denominator = tf.reduce_sum(predicted_y, axis=location_axis, keepdims=True)
    return predicted_y/denominator

decision_func = cross_ratio_decision

In [7]:
num_score_func_samples = 2

In [None]:
def score_function_trick(jacobian_MBSp, decision_MBS):
    num_param_dims = tf.rank(jacobian_MBSp)-3
    # expand decision to match jacobian
    decision_MBSp = tf.reshape(decision_MBS, decision_MBS.shape + [1]*num_param_dims.numpy())

    scaled_jacobian_MBSp = jacobian_MBSp*decision_MBSp

    # average over sample dims
    param_gradient_BSp = tf.reduce_mean(scaled_jacobian_MBSp, axis=0)

    return param_gradient_BSp

In [99]:
def overall_gradient_calculation(gradient_BSp, decision_gradient_BS):
    num_param_dims = tf.rank(gradient_BSp)-2

    decision_gradient_BSp = tf.reshape(decision_gradient_BS, decision_gradient_BS.shape + [1]*num_param_dims.numpy())

    overall_gradient_BSp = gradient_BSp*decision_gradient_BSp

    # sum over batch and location
    overall_gradient = tf.reduce_sum(overall_gradient_BSp, axis=[0,1])
    return overall_gradient

In [100]:
tf.debugging.enable_check_numerics()
for epoch in range(num_epochs):
    print(f'Epoch: {epoch}')
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        print(f'Step:{step}')


        with tf.GradientTape() as jacobian_tape, tf.GradientTape() as loss_tape:
            prob_params_BSK, mixture_weights_KS = model(x_batch_train, training=True)
            # for decision
            # dims are ( batch, location, component) and (component, location)
            # output is ( batch, location)
            model_prediction_BS = tf.einsum('ijk,kj->ij', prob_params_BSK, mixture_weights_KS)
            model_decision_BS = decision_func(model_prediction_BS)

            model_loss = loss_func(y_batch_train, model_decision_BS)
            

            poisson = tfp.distributions.Poisson(rate=prob_params_BSK)

            # add constant to avoid log 0
            sample_ys = poisson.sample(num_score_func_samples)+1e-13
            
            # for decision
            # dims are (samples, batch, location, component) and (component, location)
            # output is (samples, batch, location)
            sample_mixture_preds = tf.einsum('hijk,kj->hij', sample_ys, mixture_weights_KS)

            sample_decisions = decision_func(sample_mixture_preds)
            
            # expand for sample dimension
            # add constant to avoid log 0
            prob_params_like_sample = tf.ones_like(sample_ys)*tf.expand_dims(prob_params_BSK, axis=0) + 1e-13
            

            
            log_probs_MBSK =tf.nn.log_poisson_loss(sample_ys,
                                                tf.math.log(prob_params_like_sample),
                                                compute_full_loss=True)

            # swap 2 axes of mixture weights
            log_mixture_weights_SK = tf.math.log(tf.transpose(mixture_weights_KS, perm=[1,0]))
            log_probs_MBSK = log_probs_MBSK + log_mixture_weights_SK

            log_probs_MBS = tf.reduce_sum(log_probs_MBSK, axis=-1)

        jacobian_MBSp = jacobian_tape.jacobian(log_probs_MBS, model.trainable_weights)
        param_gradient_BSp = [score_function_trick(j, sample_decisions) for j in jacobian_MBSp]



        #decisions = decision_func(sample_mixture_preds) 
        loss_gradients_BS = loss_tape.gradient(model_loss, model_decision_BS)

        overall_gradient = [overall_gradient_calculation(g, loss_gradients_BS) for g in param_gradient_BSp]

INFO:tensorflow:Enabled check-numerics callback in thread MainThread
Epoch: 0
Step:0


In [101]:
overall_gradient

[<tf.Tensor: shape=(1, 3, 1), dtype=float32, numpy=
 array([[[0.02932102],
         [0.00203432],
         [0.02529929]]], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.00073522], dtype=float32)>,
 <tf.Tensor: shape=(1, 3, 1), dtype=float32, numpy=
 array([[[-0.00080013],
         [-0.01321443],
         [ 0.00188697]]], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.00027942], dtype=float32)>,
 <tf.Tensor: shape=(1, 3, 1), dtype=float32, numpy=
 array([[[ 0.01771369],
         [-0.02450514],
         [ 0.03534801]]], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.00049132], dtype=float32)>,
 <tf.Tensor: shape=(1, 3, 1), dtype=float32, numpy=
 array([[[-0.00034262],
         [-0.01006141],
         [-0.00036543]]], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.00051176], dtype=float32)>,
 <tf.Tensor: shape=(4, 12), dtype=float32, numpy=
 array([[-1.17291738e-05,  1.07722524e-04,  1.1

In [None]:
overall_grad

In [89]:
param_gradient_BSp = [score_function_trick(j, sample_decisions) for j in jacobian_MBSP]

In [63]:
sample_decisions.shape+[1,]

TensorShape([2, 300, 12, 1, 1, 1])

In [28]:
flattener = keras.layers.Flatten()
flat_jacobian = [flattener(j) for j in flat_jacobian]
flat_jacobian

[<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[  0.3403473, -37.81227  ,  -7.9126587]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-3.5017319]], dtype=float32)>,
 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-57.39868 , -30.255272, -13.497282]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.0436039]], dtype=float32)>,
 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-111.691956, -708.1121  , -400.58032 ]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-67.77326]], dtype=float32)>,
 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[ -11.833933, -154.68434 ,   11.667574]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-20.540833]], dtype=float32)>,
 <tf.Tensor: shape=(4, 12), dtype=float32, numpy=
 array([[  1.6635876 , -12.89687   , -10.895559  , -11.756701  ,
           4.4507465 ,  16.808748  ,   8.907337  ,   1.311276  ,
          -3.2423413 

In [17]:
jacobian[1]

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-3.5017319], dtype=float32)>

In [11]:
loss_gradients

<tf.Tensor: shape=(300, 12), dtype=float32, numpy=
array([[ 7.2934316e-04,  1.1512380e-03,  3.1422477e-04, ...,
        -4.1794814e-03,  3.2868035e-04, -3.5045113e-04],
       [ 7.2769716e-04,  7.0758373e-04,  5.4643239e-04, ...,
        -3.5533882e-04, -4.1297339e-03,  3.2435090e-04],
       [-1.1449133e-03, -4.1354785e-04, -4.5500961e-04, ...,
         8.2411140e-04,  4.1348182e-04,  1.1702622e-03],
       ...,
       [ 3.6757754e-04,  1.3102517e-04, -2.8178840e-06, ...,
        -5.0894036e-03,  2.2923536e-04,  5.0476985e-04],
       [ 6.7625387e-04,  3.7381738e-05, -2.6694720e-04, ...,
         5.3252059e-04,  8.1657775e-04,  2.1537433e-04],
       [-3.7140571e-04, -6.3346478e-04,  5.0142215e-04, ...,
         9.8311272e-04,  2.5420348e-04, -5.5796891e-03]], dtype=float32)>

In [9]:
model_loss

<tf.Tensor: shape=(), dtype=float32, numpy=-0.48294663>

In [89]:
tape.batch_jacobian(log_probs, model.trainable_weights)

AttributeError: 'list' object has no attribute 'shape'

In [94]:
gradient = tape.gradient(log_probs, model.trainable_weights)

In [93]:
jacobian = tape.jacobian(log_probs, model.trainable_weights)

In [96]:
jacobian[0].shape

TensorShape([1, 300, 12, 1, 3, 1])

In [100]:
np.allclose(gradient[0].numpy(), tf.reduce_sum(jacobian[0],axis=[0,1,2]).numpy(), rtol=1e-4)

True

In [99]:
gradient[0].numpy()

array([[[ 55.352844],
        [-27.209534],
        [-50.069534]]], dtype=float32)

In [98]:
tf.reduce_sum(jacobian[0],axis=[0,1,2]).numpy()

array([[[ 55.35452],
        [-27.2091 ],
        [-50.07021]]], dtype=float32)

In [53]:
tf.debugging.disable_check_numerics()
log_probs = poisson.log_prob(sample_ys)

INFO:tensorflow:Disabled check-numerics callback in thread MainThread


In [80]:
tf.math.log(tf.expand_dims(prob_params, axis=0)+1e-13)

<tf.Tensor: shape=(1, 300, 12, 4), dtype=float32, numpy=
array([[[[ -2.7427397 ,   3.1057425 ,   2.2696664 , -11.362669  ],
         [ -2.7427397 ,   3.1057425 ,   2.2696664 , -11.362669  ],
         [ -2.7427397 ,   3.1057425 ,   2.2696664 , -11.362669  ],
         ...,
         [-24.726921  ,   4.753423  ,   4.560614  , -29.933605  ],
         [ -0.36651292,  -0.36651292,  -0.36651292,  -0.36651292],
         [ -0.36651292,  -0.36651292,  -0.36651292,  -0.36651292]],

        [[ -2.7427397 ,   3.1057425 ,   2.2696664 , -11.362669  ],
         [ -2.7427397 ,   3.1057425 ,   2.2696664 , -11.362669  ],
         [ -2.7427397 ,   3.1057425 ,   2.2696664 , -11.362669  ],
         ...,
         [  3.4848926 ,   4.772131  ,   3.844227  , -29.933605  ],
         [ -0.36651292,  -0.36651292,  -0.36651292,  -0.36651292],
         [ -0.36651292,  -0.36651292,  -0.36651292,  -0.36651292]],

        [[ -2.7427397 ,   3.1057425 ,   2.2696664 , -11.362669  ],
         [ -2.7427397 ,   3.1057425 ,   

In [74]:
my_log_probs =tf.nn.log_poisson_loss(sample_ys, tf.math.log(tf.expand_dims(prob_params, axis=0)), compute_full_loss=True)

InvalidArgumentError: {{function_node __wrapped__CheckNumericsV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} 

!!! Detected Infinity or NaN in output 0 of eagerly-executing op "Log" (# of outputs: 1) !!!
  dtype: <dtype: 'float32'>
  shape: (1, 300, 12, 4)
  # of -Inf elements: 6787

  Input tensor: tf.Tensor(
[[[[  0.  22.   7.   0.]
   [  0.  18.   9.   0.]
   [  0.  25.   3.   0.]
   ...
   [  0. 112.  84.   0.]
   [  0.   1.   0.   0.]
   [  3.   0.   1.   0.]]

  [[  0.  26.  12.   0.]
   [  0.  24.  10.   0.]
   [  0.  18.  11.   0.]
   ...
   [ 35. 119.  40.   0.]
   [  0.   0.   1.   2.]
   [  1.   0.   0.   1.]]

  [[  0.  23.  14.   0.]
   [  0.  22.   8.   0.]
   [  0.  23.  10.   0.]
   ...
   [  3.   1.   4.   0.]
   [  0.   1.   1.   1.]
   [  1.   0.   0.   1.]]

  ...

  [[  0.  26.   8.   0.]
   [  0.  26.  12.   0.]
   [  0.  27.   8.   0.]
   ...
   [ 34. 118.  44.   0.]
   [  5. 237. 147.   0.]
   [  0. 103. 101.   0.]]

  [[  0.  28.   7.   0.]
   [  0.  24.  10.   0.]
   [  0.  18.   7.   0.]
   ...
   [  1.   0.   0.   0.]
   [  1.   0.   0.   1.]
   [  1.   0.   0.   0.]]

  [[  0.  20.  13.   0.]
   [  0.  25.   7.   0.]
   [  0.  22.  11.   0.]
   ...
   [  2.   0.   0.   0.]
   [  0.   0.   0.   0.]
   [  2.   1.   1.   0.]]]], shape=(1, 300, 12, 4), dtype=float32)

 : Tensor had -Inf values [Op:CheckNumericsV2] name: 

In [60]:
prob_params

<tf.Tensor: shape=(300, 12, 4), dtype=float32, numpy=
array([[[6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
        [6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
        [6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
        ...,
        [3.2618927e+01, 1.1817080e+02, 4.6722557e+01, 0.0000000e+00],
        [5.7442770e-21, 8.4788452e+01, 1.5882159e-02, 4.4143806e+01],
        [1.8148809e-11, 1.1598061e+02, 9.5642197e+01, 0.0000000e+00]],

       [[6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
        [6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
        [6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
        ...,
        [6.9314718e-01, 6.9314718e-01, 6.9314718e-01, 6.9314718e-01],
        [6.9314718e-01, 6.9314718e-01, 6.9314718e-01, 6.9314718e-01],
        [1.8148809e-11, 1.1598061e+02, 9.5642197e+01, 0.0000000e+00]],

       [[6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
      

In [59]:
log_probs

<tf.Tensor: shape=(1, 300, 12, 4), dtype=float32, numpy=
array([[[[-6.4393684e-02, -2.6565647e+00, -2.0510015e+00,
          -1.1621323e-05],
         [-6.4393684e-02, -2.6858273e+00, -2.0510015e+00,
          -1.1621323e-05],
         [-6.4393684e-02, -2.6565647e+00, -2.0510015e+00,
          -1.1621323e-05],
         ...,
         [-3.1611977e+00, -4.6970673e+00, -3.0369034e+00,
           0.0000000e+00],
         [-5.7442770e-21, -3.1557007e+00, -1.5882157e-02,
          -2.8735771e+00],
         [-1.8148809e-11, -3.2962418e+00, -3.2509613e+00,
           0.0000000e+00]],

        [[-6.4393684e-02, -2.8178692e+00, -2.1234436e+00,
          -1.1621323e-05],
         [-6.4393684e-02, -3.2547436e+00, -2.0839205e+00,
          -1.1621323e-05],
         [-6.4393684e-02, -3.0332394e+00, -2.7226753e+00,
          -1.1621323e-05],
         ...,
         [-6.9314718e-01, -6.9314718e-01, -1.0596601e+00,
          -6.9314718e-01],
         [-1.0596601e+00, -6.9314718e-01, -6.9314718e-01,
     

In [65]:
my_log_probs[0,0,:,:]

<tf.Tensor: shape=(12, 4), dtype=float32, numpy=
array([[6.4393684e-02, 2.6521797e+00, 2.0417461e+00, 1.1621323e-05],
       [6.4393684e-02, 2.6824989e+00, 2.0417461e+00, 1.1621323e-05],
       [6.4393684e-02, 2.6521797e+00, 2.0417461e+00, 1.1621323e-05],
       [6.4393684e-02, 3.2517700e+00, 2.4204483e+00, 1.1621323e-05],
       [2.0603862e-02, 5.5735474e+00, 2.2872677e+00, 8.9200782e-08],
       [8.0944866e-02, 2.1682796e+00, 2.0345354e+00, 3.9605020e-05],
       [2.0603862e-02, 2.9631653e+00, 2.3217087e+00, 8.9200782e-08],
       [2.0603862e-02, 3.1296272e+00, 2.4280634e+00, 8.9200782e-08],
       [6.9314718e-01, 6.9314718e-01, 6.9314718e-01, 1.0596601e+00],
       [3.1590118e+00, 4.6962585e+00, 3.0349274e+00,           nan],
       [5.7442770e-21, 3.1547241e+00, 1.5882157e-02, 2.8717499e+00],
       [1.8148809e-11, 3.2955322e+00, 3.2500610e+00,           nan]],
      dtype=float32)>

In [66]:
prob_params[0,:,:]

<tf.Tensor: shape=(12, 4), dtype=float32, numpy=
array([[6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
       [6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
       [6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
       [6.4393684e-02, 2.2325790e+01, 9.6761732e+00, 1.1621321e-05],
       [2.0603864e-02, 3.1893986e+01, 1.3823015e+01, 8.9200782e-08],
       [8.0944851e-02, 1.1598071e+01, 9.5642900e+00, 3.9604991e-05],
       [2.0603864e-02, 3.1893986e+01, 1.3823015e+01, 8.9200782e-08],
       [2.0603864e-02, 3.1893986e+01, 1.3823015e+01, 8.9200782e-08],
       [6.9314718e-01, 6.9314718e-01, 6.9314718e-01, 6.9314718e-01],
       [3.2618927e+01, 1.1817080e+02, 4.6722557e+01, 0.0000000e+00],
       [5.7442770e-21, 8.4788452e+01, 1.5882159e-02, 4.4143806e+01],
       [1.8148809e-11, 1.1598061e+02, 9.5642197e+01, 0.0000000e+00]],
      dtype=float32)>

In [24]:
# Define your tensors as Python variables
A = tf.Variable(tf.random.normal([3, 1]))  # Example random initialization
B = tf.Variable(tf.random.normal([2, 200, 12]))  # Example random initialization

# Expand dimensions of A to match the shape of B
# The new shape will be (3, 1, 1)
A_expanded = tf.expand_dims(A, axis=-1)

# Multiply A_expanded by B
# Result shape will be (2, 200, 3, 12)
result = A_expanded * B
result.shape

InvalidArgumentError: {{function_node __wrapped__Mul_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [3,1,1] vs. [2,200,12] [Op:Mul] name: 

In [29]:
loss_gradients

<tf.Tensor: shape=(300, 12), dtype=float32, numpy=
array([[ 9.5758314e-04, -1.4900744e-04, -4.4442251e-04, ...,
         1.3997008e-03,  1.5834172e-03,  8.3316190e-05],
       [ 4.2887675e-04,  1.8113095e-04,  3.8881294e-04, ...,
        -2.5158492e-04, -5.4394361e-03,  6.2560267e-04],
       [-1.3380859e-03, -4.1314977e-04,  2.1057956e-04, ...,
         2.4359687e-03,  1.0059469e-03,  5.8637129e-04],
       ...,
       [-1.6016867e-04,  6.7898765e-04,  3.1120621e-04, ...,
        -4.1304049e-03,  7.4906932e-04,  1.1425787e-03],
       [ 3.8880884e-04,  5.5720104e-04, -2.1608907e-04, ...,
         9.8572788e-04, -4.6086917e-03,  1.0958280e-03],
       [-1.2019272e-03, -1.9123746e-04, -5.3372973e-04, ...,
         3.8452941e-04,  1.2894482e-03,  1.0369413e-03]], dtype=float32)>