In [None]:
import sys
import os
import warnings

from absl import app
from absl import flags
import matplotlib
matplotlib.use('Agg')
from matplotlib import figure
from matplotlib.backends import backend_agg

import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import tf_keras
from safetensors.tensorflow import save_file

In [None]:
tf.enable_v2_behavior

warnings.simplefilter(action='ignore')

sys.argv = sys.argv[:1]

In [None]:
try:
    import seaborn as sns
    HAS_SEABORN = True
except ImportError:
    HAS_SEABORN = False

tfd = tfp.distributions

IMAGE_SHAPE = [28, 28, 1]
NUM_TRAIN_EXAMPLES = 60000
NUM_HELDOUT_EXAMPLES = 10000
NUM_CLASSES = 10

flags.DEFINE_float('learning_rate',
                   default = 0.001,
                   help = 'Initital learning rate.')

flags.DEFINE_integer('num_epochs',
                     default = 2,
                     help = 'Number of training steps to run.')

flags.DEFINE_integer('batch_size',
                     default = 10,
                     help = 'Batch size.')

flags.DEFINE_string('data_dir',
                    default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'),
                                         'bayesian_neural_network/data'),
                                         help='Directory where data is stored (if using real data).')

flags.DEFINE_string(
    'model_dir',
    default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'),
                         'bayesian neural network/'),
    help="Directory to put the model's fit.")

flags.DEFINE_integer('viz_steps',
                     default=400,
                     help='Frequency at which save visualizations.')

flags.DEFINE_integer('num_monte_carlo',
                     default=50,
                     help='Network draws to compute predictive probabilities.')

flags.DEFINE_bool('fake_data',
                  default=False,
                  help='If true, use fake data. Defaults to real data.')

flags.DEFINE_string('checkpoint_path', 'model_checkpoint.weights.h5', help= 'Name of the file to save model weights.')

flags.DEFINE_string('final_path', 'final_model.keras', help= 'path storing the model.')

FLAGS = flags.FLAGS

In [None]:
def plot_weight_posteriors(names, qm_vals, qs_vals, fname):
    """Save a PNG plot with histograms of weight means and stddevs.
    
    Args:
        names: A python `iterable` of `str` variable names.
          qm_vals: A python `iterable`, the same length as `names`,
            whose elements are Numpy `array`s, of any shape, containing
            posterior means of weight variables.
          qs_vals: A python `iterable`, the same length as `names`,
            whose elements are Numpy `array`s, of any shape, containing
            posterior standard deviations of weight variables.
        fname: Python `str` filename to save the plot to.
        
    """
    fig = figure.Figure(figsize=(6,3))
    canvas = backend_agg.FigureCanvasAgg(fig)

    ax = fig.add_subplot(1,2,1)
    for n, qm in zip(names, qm_vals):
        sns.displot(qm.reshape([-1]), ax=ax, label=n)
    ax.set_title('weight means')
    ax.set_xlim([-1.5, 1.5])
    ax.legend()

    ax = fig.add_subplot(1,2,2)
    for n, qs in zip(names, qs_vals):
        sns.displot(qs.reshape([-1]), ax=ax)
    ax.set_title('weight stddevs')
    ax.set_xlim([0, 1.])

    fig.tight_layout()
    canvas.print_figure(fname, format='png')
    print('saved {}'.format(fname))


In [None]:
def plot_heldout_prediction(input_vals, probs,
                            fname, n=10, title=''):
    """Save a PNG plot visualizing posterior uncertainity heldout data.
    
    Args:
        input_vals: A `float`-like Numpy `array` of shape
          `[num_heldout] + IMAGE_SHAPE`, containing heldout input images.
        probs: A `float`-line Numpy array of shape `[num_monte_carlo,
          num_heldout, num_classes]` containing Monte Carlo samples of
          class probabilities for each heldout sample.
        fname: Python `str` filename to save the plot to.
        n: Python `int` number of datapoints to visualize.
        title: Python `str` title for the plot.
    """
    fig = figure.Figure(figsize=(9, 3*n))
    canvas = backend_agg.FigureCanvasAgg(fig)
    for i in range(n):
        ax = fig.add_subplot(n, 3, 3*i +1)
        ax.imshow(input_vals[i, :].reshape(IMAGE_SHAPE[:-1]), interpolation='None')

        ax = fig.add_subplot(n, 3, 3*i +2)
        for prob_sample in probs:
            sns.barplot(x=np.arange(10), y=prob_sample[i, :], alpha=0.1, ax=ax)
            ax.set_ylim([0,1])
        ax.set_title('posterior samples')

        ax = fig.add_subplot(n, 3, 3*i +3)
        sns.barplot(x=np.arange(10), y=np.mean(probs[:, i, :], axis=0), ax=ax)
        ax.set_ylim([0, 1])
        ax.set_title('predictive probs')
    fig.suptitle(title)
    fig.tight_layout()

    canvas.print_figure(fname, format='png')
    print('saved {}'.format(fname))


In [None]:
def create_model():
    """Creates a Keras model using LeNet-5 architecture.
    
    Returns:
        model: Compiled Keras model.
    """

    kl_divergence_function = (lambda q, p, _: tfd.kl_divergence(q, p) /
                              tf.cast(NUM_TRAIN_EXAMPLES, dtype=tf.float32))
    
    model = tf_keras.models.Sequential([
        tfp.layers.Convolution2DFlipout(
            6, kernel_size=5, padding='SAME',
            kernel_divergence_fn=kl_divergence_function,
            activation=tf.nn.relu),
        tf_keras.layers.MaxPooling2D(
            pool_size=[2, 2], strides=[2,2],
            padding='SAME'),
        tfp.layers.Convolution2DFlipout(
            16, kernel_size=5, padding='SAME',
            kernel_divergence_fn=kl_divergence_function,
            activation=tf.nn.relu),
        tf_keras.layers.MaxPooling2D(
            pool_size=[2, 2], strides=[2, 2],
            padding='SAME'),
        tfp.layers.Convolution2DFlipout(
            120, kernel_size=5, padding='SAME',
            kernel_divergence_fn=kl_divergence_function,
            activation=tf.nn.relu),
        tf_keras.layers.Flatten(),
        tfp.layers.DenseFlipout(
            84, kernel_divergence_fn=kl_divergence_function,
            activation=tf.nn.relu),
        tfp.layers.DenseFlipout(
            NUM_CLASSES, kernel_divergence_fn=kl_divergence_function,
            activation=tf.nn.softmax)
    ])
    optimizer = tf_keras.optimizers.legacy.Adam(lr=FLAGS.learning_rate)
    model.compile(optimizer, loss= 'categorical_crossentropy',
                  metrics=['accuracy'], experimental_run_tf_function=False)
    return model

In [None]:
class MNISTSequence(tf_keras.utils.Sequence):
    """Produces a sequence of MNIST digits with labels."""
    def __init__(self, data=None, batch_size=128, fake_data_size=None):
        """Initializes the sequence.
        
        Args:
          data: Tuple of numpy `array` instances, the first representing images and
            the second labels.
          batch_size: Integer, number of elements in each training batch.
          fake_data_size: Optional integer number of fake datapoints to generate.
        """
        if data:
            images, labels = data
        else:
            images, labels = MNISTSequence.__generate_fake_data(
                num_images=fake_data_size, num_classes=NUM_CLASSES)
        self.images, self.labels = MNISTSequence.__preprocessing(
            images, labels)
        self.batch_size = batch_size
    
    @staticmethod
    def __generate_fake_data(num_images, num_classes):
        """Generates fake data in the shape of the MNIST dataset for unittest.
        
        Args:
          num_images: Integer, the number of fake images to be generated.
          num_classes: Integer, the number of classes to be generated.
        Returns:
          images: Numpy `array` representing the fake image data. The
            shape of the array will be (num_images, 28, 28).
          labels: Numpy `array` of integers, where each entry will be
            assigned a unique integer.
        """
        images = np.random.randint(low=0, high=256,
                                   size=(num_images, IMAGE_SHAPE[0],
                                         IMAGE_SHAPE[1]))
        labels = np.random.randint(low=0, high=num_classes,
                                   size=num_images)
        return images, labels
    
    @staticmethod
    def __preprocessing(images, labels):
        """Preprocesses image and labels data.
        
        Args:
          images: Numpy `array` representing the image data.
          labels: Numpy `array` representing the labels data (range 0-9).
        
        Returns:
          images: Numpy `array` representing the image data, normalized
            and expanded for convolutional network input.
          labels: Numpy `array` representing the labels data (range 0-9),
            as one-hot (categorical) values.
        """
        images = 2 * (images / 255.) -1.
        images = images[..., tf.newaxis]

        labels = tf_keras.utils.to_categorical(labels)
        return images, labels
    
    def __len__(self):
        return int(tf.math.ceil(len(self.images) / self.batch_size))
    
    def __getitem__(self, idx):
        batch_x = self.images[idx * self.batch_size: (idx +1) * self.batch_size]
        batch_y = self.labels[idx * self.batch_size: (idx +1) * self.batch_size]
        return batch_x, batch_y
        

In [22]:

def main(argv):
        del argv

        full_checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.checkpoint_path)

        if not tf.io.gfile.exists(FLAGS.model_dir):
            tf.compat.v1.logging.info(
                'Creating new log directory at {}'.format(FLAGS.model_dir))
            tf.io.gfile.makedirs(FLAGS.model_dir)
        else:
            tf.compat.v1.logging.info(
                'Directory {} already exists. Resuming if weights found.'.format(FLAGS.model_dir)
            )

        if FLAGS.fake_data:
            train_seq = MNISTSequence(batch_size=FLAGS.batch_size,
                                      fake_data_size=NUM_TRAIN_EXAMPLES)
            heldout_seq = MNISTSequence(batch_size = FLAGS.batch_size,
                                        fake_data_size=NUM_HELDOUT_EXAMPLES)
        else:
            train_set, heldout_set = tf_keras.datasets.mnist.load_data()
            train_seq = MNISTSequence(data=train_set, batch_size=FLAGS.batch_size)
            heldout_seq = MNISTSequence(data=heldout_set, batch_size=FLAGS.batch_size)
        
        model = create_model()
        model.build(input_shape=[None, 28,28, 1])

        if tf.io.gfile.exists(full_checkpoint_path):
            tf.compat.v1.logging.info('>>> Resuming from last checkpoint: {}'.format(full_checkpoint_path))
            model.load_weights(full_checkpoint_path)

        print(' ... Training convolutional neural network')
        best_loss= float('inf')

        for epoch in range(FLAGS.num_epochs):
            epoch_accuracy, epoch_loss = [], []
            # for step, (batch_x, batch_y) in enumerate(train_seq):
            #     batch_loss, batch_accuracy = model.train_on_batch(
            #         batch_x, batch_y)
            #     epoch_accuracy.append(batch_accuracy)
            #     epoch_loss.append(batch_loss)

            #     if step % 100 == 0:
            #       print('Epoch: {}, Batch index: {}, '
            #      'Loss: {:.3f}, Accuracy: {:.3f}'.format(
            #       epoch, step,
            #       tf.reduce_mean(epoch_loss),
            #       tf.reduce_mean(epoch_accuracy)))
                  
            #     if (step+1) % FLAGS.viz_steps == 0:
            #         print(' ... Running monte carlo inference')
            #         probs = tf.stack([model.predict(heldout_seq, verbose=1)
            #                           for _ in range(FLAGS.num_monte_carlo)], axis = 0)
            #         mean_probs = tf.reduce_mean(probs, axis=0)
            #         heldout_log_prob = tf.reduce_mean(tf.math.log(mean_probs))
            #         print(' ... Held-out nats: {:.3f}'.format(heldout_log_prob))

            #         if HAS_SEABORN:
            #             names = [layer.name for layer in model.layers
            #                      if 'flipout' in layer.name]
            #             qm_values = [layer.kernel_posterior.mean().numpy()
            #                          for layer in model.layers
            #                          if 'flipout' in layer.name]
            #             qs_values = [layer.kernel_posterior.stddev().numpy()
            #                         for layer in model.layers
            #                         if 'flipout' in layer.name]
                        
            #             plot_weight_posteriors(names, qm_values, qs_values,
            #                                    fname=os.path.join(
            #                                        FLAGS.model_dir,
            #                                        'epoch{}_step{:05d}_weights.png'.format(
            #                                            epoch, step)
            #                                    ))
                        
            #             plot_heldout_prediction(heldout_seq.images, probs.numpy(),
            #                                     fname=os.path.join(
            #                                         FLAGS.model_dir,
            #                                         'epoch{}_step{}_pred.png'.format(
            #                                             epoch, step)
            #                                     ), title = 'mean heldout logprob {:.2f}'
            #                                     .format(heldout_log_prob))
        
        current_epoch_loss = tf.reduce_mean(epoch_loss)
        if current_epoch_loss < best_loss:
            best_loss = current_epoch_loss
            model.save_weights(full_checkpoint_path)
            tf.compat.v1.logging.info('... Epoch {}: Model improved. Saved weights to {}'.format(epoch, full_checkpoint_path))

        model.save(FLAGS.final_path)
        tf.compat.v1.logging.info(FLAGS.final_path)
        weights = {}
        for layer in model.layers:
            w = layer.get_weights()
            
            # CASE: Bayesian Kernel + Bayesian Bias (4 weights)
            if len(w) == 4:
                weights[f"{layer.name}.kernel_mu"] = tf.convert_to_tensor(w[0])
                weights[f"{layer.name}.kernel_rho"] = tf.convert_to_tensor(w[1])
                weights[f"{layer.name}.bias_mu"] = tf.convert_to_tensor(w[2])
                weights[f"{layer.name}.bias_rho"] = tf.convert_to_tensor(w[3])
                
            # CASE: Bayesian Kernel + Deterministic Bias (3 weights)
            elif len(w) == 3:
                weights[f"{layer.name}.kernel_mu"] = tf.convert_to_tensor(w[0])
                weights[f"{layer.name}.kernel_rho"] = tf.convert_to_tensor(w[1])
                weights[f"{layer.name}.bias"] = tf.convert_to_tensor(w[2])
                
            # CASE: Standard Layer (2 weights: Kernel + Bias)
            elif len(w) == 2:
                weights[f"{layer.name}.weight"] = tf.convert_to_tensor(w[0])
                weights[f"{layer.name}.bias"] = tf.convert_to_tensor(w[1])

        save_file(weights, "model.safetensors")
if __name__ == '__main__':
    try:
        app.run(main)
    except SystemExit:
        print('Done!')

INFO:tensorflow:Directory tmp/bayesian neural network/ already exists. Resuming if weights found.


I0221 07:40:32.777131 8434347072 839659448.py:12] Directory tmp/bayesian neural network/ already exists. Resuming if weights found.


INFO:tensorflow:>>> Resuming from last checkpoint: tmp/bayesian neural network/model_checkpoint.weights.h5


I0221 07:40:33.274456 8434347072 839659448.py:30] >>> Resuming from last checkpoint: tmp/bayesian neural network/model_checkpoint.weights.h5


 ... Training convolutional neural network
INFO:tensorflow:final_model.keras


I0221 07:40:33.363446 8434347072 839659448.py:91] final_model.keras


Done!


In [15]:
# from safetensors.tensorflow import load_file
# file_path = "fixed_model.safetensors"


# import keras
# from safetensors.tensorflow import save_file

# # 1. Load your model
# model = keras.models.load_model("final_model.keras")

# # 2. Extract weights as a dictionary of tensors
# # We convert them to a dict that safetensors understands
# weights_dict = {v.name: v.value() for v in model.weights}

# # 3. Save to a new file
# save_file(weights_dict, "fixed_model.safetensors")
# print("Successfully converted .keras weights to fixed_model.safetensors")




# try:
#     # This returns a dictionary of {name: tf.Tensor}
#     tensors = load_file(file_path)
    
#     print(f"--- Found {len(tensors)} tensors in {file_path} ---")
    
#     # Sort keys to make it easier to find the 'conv2d_flipout' group
#     for name in sorted(tensors.keys()):
#         tensor = tensors[name]
#         print(f"Name: {name:<50} | Shape: {tensor.shape}")

# except Exception as e:
#     print(f"Error reading safetensors: {e}")