In [1]:
# !pip install tensorflow_probability


In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import numpy as np
import pickle
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import RobustScaler
from predict import extract_features
os.sys.path.append('../../evaluation/')
import metrics


In [3]:
# training parameters
epochs = 30
batch_size = 16
n_dimensions = 2
opt = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.99)
input_path = '../../data/simulated_2d/'
output_path = './model/'


In [4]:
# load data
x_train = pickle.load(open(input_path+'/train/input_data.pkl', 'rb'))
y_train = pickle.load(open(input_path+'/train/target_data.pkl', 'rb'))
sizes_train = pickle.load(open(input_path+'/train/sizes.pkl', 'rb'))
x_val = pickle.load(open(input_path+'/val/input_data.pkl', 'rb'))
y_val = pickle.load(open(input_path+'/val/target_data.pkl', 'rb'))
sizes_val = pickle.load(open(input_path+'/val/sizes.pkl', 'rb'))


In [5]:
# nr = 8
# c = 0
# plt.figure(figsize=(3, 12), dpi=200)
# for i in range(0, nr*2, 2):
#     idx = np.random.choice(len(x_train))
#     plt.subplot(nr, 2, i+1)
#     plt.scatter(x_train[idx][:,0], x_train[idx][:,1], s=0.01, c='k', alpha=0.3)
#     plt.xticks([])
#     plt.yticks([])
#     plt.xlim([-4, 4])
#     plt.ylim([-4, 4])
#     plt.subplot(nr, 2, i+2)
#     plt.pcolormesh(extract_features(x_train[idx]).T)
#     plt.xticks([])
#     plt.yticks([])
# plt.tight_layout()


In [6]:
# extract features
dtype = tf.float32
x_train = [extract_features(i) for i in x_train]
x_train = tf.cast(np.array(x_train).squeeze(), dtype)
x_val = [extract_features(i) for i in x_val]
x_val = tf.cast(np.array(x_val).squeeze(), dtype)


In [7]:
# # optionally add reference fraction prediction

# y_train = np.array(y_train)
# y_train = np.hstack([
#     y_train, 
#     np.array([i[0]/sum(i) for i in sizes_train]).reshape(-1, 1)
# ])

# y_val = np.array(y_val)
# y_val = np.hstack([
#     y_val, 
#     np.array([i[0]/sum(i) for i in sizes_val]).reshape(-1, 1)
# ])


In [8]:
# # standardize outputs
# scalery = RobustScaler()

# y_train = scalery.fit_transform(y_train)
# y_val = scalery.transform(y_val)


In [9]:
x_train = x_train[..., np.newaxis]
x_val = x_val[..., np.newaxis]


In [10]:
x_train.shape


TensorShape([13000, 100, 100, 1])

In [11]:
y_train = tf.cast(np.array(y_train), dtype)
y_val = tf.cast(np.array(y_val), dtype)


In [12]:
def kl_divergence(mean_true, cov_true, mean_pred, cov_pred):
    """ Computes KL divergence between two multivariate Gaussians
    """
    # Get dimensionality
    num_features = tf.cast(tf.shape(mean_pred)[-1], tf.float32)

    # Compute the inverse of cov_true
    inv_cov_true = tf.linalg.inv(cov_true)

    # Compute the trace term: trace(inv_cov_true @ cov_pred)
    trace_term = tf.linalg.trace(tf.linalg.matmul(inv_cov_true, cov_pred, transpose_a=False, transpose_b=True))

    # Compute the Mahalanobis term
    diff_mean = mean_true - mean_pred
    diff_mean_expanded = tf.expand_dims(diff_mean, axis=-1)
    mahalanobis_term = tf.reduce_sum(tf.linalg.matmul(inv_cov_true, diff_mean_expanded) * diff_mean_expanded, axis=-2)

    # Compute the log-determinants
    log_det_cov_pred = tf.linalg.logdet(cov_pred)
    log_det_cov_true = tf.linalg.logdet(cov_true)

    # KL divergence computation
    kl = 0.5 * (trace_term + tf.squeeze(mahalanobis_term) - num_features + log_det_cov_true - log_det_cov_pred)
    
    # Return the average KL divergence over the batch dimension
    return tf.reduce_mean(kl)

# reshape correlations into an upper triangular matrix
corr_indices = [(i, j) for i in range(n_dimensions) for j in range(i + 1, n_dimensions)]
corr_indices = tf.constant(corr_indices, dtype=tf.int64)
def output_to_stats(batch_vectors, n_dimensions=n_dimensions):
    """ Converts a batch of arrays each consisting of means, std devs, and pairwise correlations
        to separate arrays of mean vectors and covariance matrices
    """
    batch_size = tf.shape(batch_vectors)[0]
    # number of unique correlations in the upper triangular part
    num_correlations = (n_dimensions* (n_dimensions- 1)) // 2
    # extract statistics
    means = batch_vectors[:, :n_dimensions]
    std_devs = batch_vectors[:, n_dimensions:(n_dimensions*2)]
    correlations = batch_vectors[:, (n_dimensions*2):]
    correlations = tf.maximum(-1.0, tf.minimum(1.0, correlations))
    # expand indices for batch
    batch_indices = tf.reshape(tf.range(batch_size, dtype=tf.int64), (-1, 1, 1))
    batch_indices = tf.tile(batch_indices, (1, tf.shape(corr_indices)[0], 1))
    expanded_indices = tf.concat([batch_indices, tf.tile(tf.expand_dims(corr_indices, 0), (batch_size, 1, 1))], axis=-1)
    # scatter correlations into the upper triangular part
    upper_triangular = tf.scatter_nd(expanded_indices, correlations, shape=(batch_size, n_dimensions, n_dimensions))
    upper_triangular = upper_triangular + tf.transpose(upper_triangular, perm=[0, 2, 1]) - tf.linalg.diag(tf.linalg.diag_part(upper_triangular))
    correlation_matrix = tf.linalg.set_diag(upper_triangular, tf.ones((batch_size, n_dimensions), dtype=batch_vectors.dtype))
    # compute the covariance matrix
    std_devs = tf.expand_dims(std_devs, axis=-1)  # shape (batch_size, n_dimensions, 1)
    covariance_matrices = correlation_matrix * (std_devs @ tf.transpose(std_devs, perm=[0, 2, 1]))
    return means, covariance_matrices

def kl_divergence_loss(y_true, y_pred):
    """ Wrapper for kl_divergence that formats target and output arrays
    """
#     y_true = tf.cast(scalery.inverse_transform(y_true), dtype)
#     y_pred = tf.cast(scalery.inverse_transform(y_pred), dtype)
    mean_true, cov_true = output_to_stats(y_true)
    mean_pred, cov_pred = output_to_stats(y_pred)
    return kl_divergence(mean_true, cov_true, mean_pred, cov_pred)


In [13]:
# define model

inputs = tf.keras.layers.Input(shape=x_train.shape[1:])
x = tf.keras.layers.Conv2D(32, kernel_size=10, activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D(pool_size=2)(x)
x = tf.keras.layers.Conv2D(64, kernel_size=10, activation='relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=2)(x)
x = tf.keras.layers.Conv2D(64, kernel_size=10, activation='relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=2)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(y_train.shape[1], activation='linear')(x)
model = tf.keras.models.Model(inputs, outputs)

# define loss and optimizer
model.compile(loss = kl_divergence_loss,
              optimizer = opt)

model.summary()


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 100, 100, 1)]     0         
                                                                 
 conv2d (Conv2D)             (None, 91, 91, 32)        3232      
                                                                 
 max_pooling2d (MaxPooling2  (None, 45, 45, 32)        0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 36, 36, 64)        204864    
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 18, 18, 64)        0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 9, 9, 64)          409664

In [14]:
# plot model

# !pip install pydot
# !pip install graphviz

# tf.keras.utils.plot_model(
#     model,
#     to_file="model.png",
#     show_shapes=True,
#     show_dtype=False,
#     show_layer_names=False,
# )


In [19]:
y_pred = model.predict(x_train)
y_pred.shape




(13000, 5)

In [20]:
kl_loss = kl_divergence_loss(y_train, y_pred)
kl_loss


<tf.Tensor: shape=(), dtype=float32, numpy=nan>

In [23]:
def check_kl_nan(y_true, y_pred):
    
    mean2, cov2 = output_to_stats(y_true)
    mean1, cov1 = output_to_stats(y_pred)
    
    # Get dimensionality
    num_features = tf.cast(tf.shape(mean1)[-1], tf.float32)
    print('num_features: ', num_features)
    
    # Compute the inverse of cov2
    inv_cov2 = tf.linalg.inv(cov2)
    print('inv_cov2: ',inv_cov2)
    
    # Compute the trace term: trace(inv_cov2 @ cov1)
    trace_term = tf.linalg.trace(tf.linalg.matmul(inv_cov2, cov1, transpose_a=False, transpose_b=True))
    print('trace_term: ',trace_term)
    
    # Compute the Mahalanobis term
    diff_mean = mean2 - mean1
    diff_mean_expanded = tf.expand_dims(diff_mean, axis=-1)
    mahalanobis_term = tf.reduce_sum(tf.linalg.matmul(inv_cov2, diff_mean_expanded) * diff_mean_expanded, axis=-2)
    print('mahalanobis_term: ',mahalanobis_term)
    
    # Compute the log-determinants
    log_det_cov1 = tf.linalg.logdet(cov1)
    log_det_cov2 = tf.linalg.logdet(cov2)
    print('log_det_cov1: ',log_det_cov1)
    print('log_det_cov2: ',log_det_cov2)
    print('\n')
    
check_kl_nan(y_train, y_pred)


num_features:  tf.Tensor(2.0, shape=(), dtype=float32)
inv_cov2:  tf.Tensor(
[[[ 7.3513975   3.844594  ]
  [ 3.844594    3.9993315 ]]

 [[ 3.7192318  -0.32302928]
  [-0.32302928  2.2530615 ]]

 [[ 3.774963    3.2884834 ]
  [ 3.2884834   5.8780155 ]]

 ...

 [[ 5.7599864  -3.2129445 ]
  [-3.2129445   4.2565007 ]]

 [[ 4.5998893  -2.4814246 ]
  [-2.4814246   3.593445  ]]

 [[ 1.8620318  -0.50413394]
  [-0.50413394  2.0677228 ]]], shape=(13000, 2, 2), dtype=float32)
trace_term:  tf.Tensor([nan nan nan ... nan nan nan], shape=(13000,), dtype=float32)
mahalanobis_term:  tf.Tensor(
[[nan]
 [nan]
 [nan]
 ...
 [nan]
 [nan]
 [nan]], shape=(13000, 1), dtype=float32)
log_det_cov1:  tf.Tensor([nan nan nan ... nan nan nan], shape=(13000,), dtype=float32)
log_det_cov2:  tf.Tensor([-2.6823754 -2.1132765 -2.4314327 ... -2.6528459 -2.339108  -1.2798263], shape=(13000,), dtype=float32)




In [18]:
# define best model checkpoint
checkpoint_path = output_path+'/model_checkpoint'
checkpoint_callback = ModelCheckpoint(
    checkpoint_path,
    save_weights_only=False,
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    save_format='tf',
    verbose=1
)

# train
history = model.fit(
    x_train, 
    y_train,
    validation_data = (x_val, y_val),
    epochs = epochs,
    batch_size = batch_size,
    callbacks = [checkpoint_callback],
)

with open(output_path+'/model_history.npy', 'wb') as file_pi:
    pickle.dump(history.history, file_pi)
    

Epoch 1/30


I0000 00:00:1722400644.813940   29314 service.cc:145] XLA service 0x7f5552cbf820 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1722400644.813986   29314 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1722400644.846263   29314 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 1: val_loss did not improve from inf
Epoch 2/30
Epoch 2: val_loss did not improve from inf
Epoch 3/30
Epoch 3: val_loss did not improve from inf
Epoch 4/30
Epoch 4: val_loss did not improve from inf
Epoch 5/30
Epoch 5: val_loss did not improve from inf
Epoch 6/30
Epoch 6: val_loss did not improve from inf
Epoch 7/30
Epoch 7: val_loss did not improve from inf
Epoch 8/30
Epoch 8: val_loss did not improve from inf
Epoch 9/30
Epoch 9: val_loss did not improve from inf
Epoch 10/30
Epoch 10: val_loss did not improve from inf
Epoch 11/30
Epoch 11: val_loss did not improve from inf
Epoch 12/30
Epoch 12: val_loss did not improve from inf
Epoch 13/30
Epoch 13: val_loss did not improve from inf
Epoch 14/30
Epoch 14: val_loss did not improve from inf
Epoch 15/30
Epoch 15: val_loss did not improve from inf
Epoch 16/30
Epoch 16: val_loss did not improve from inf
Epoch 17/30
Epoch 17: val_loss did not improve from inf
Epoch 18/30
Epoch 18: val_loss did not improve from inf
Epoch 19/30
Epoch 19:


KeyboardInterrupt



In [None]:
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Test')
plt.legend();
