In [None]:
from pathlib import Path
import json
import re
import unittest
import tensorflow as tf
import tensorboard
import tf2onnx
from platform import python_version

def get_major_minor(s):
    return '.'.join(s.split('.')[:2])

def load_expected_versions() -> dict:
    lock_file = Path('./expected_versions.json')
    data = {}

    with open(lock_file, 'r') as file:
        data = json.load(file)

    return data

def get_expected_version(dependency_name: str) -> str:
    raw_value = expected_versions.get(dependency_name)
    raw_version = re.sub(r'^\D+', '', raw_value)
    return get_major_minor(raw_version)

class TestTensorflowNotebook(unittest.TestCase):

    def test_python_version(self):
        expected_major_minor = get_expected_version('Python')
        actual_major_minor = '.'.join(python_version().split('.')[:2])
        self.assertEqual(actual_major_minor, expected_major_minor, "incorrect version")

    def test_tensorflow_version(self):
        expected_major_minor = get_expected_version('TensorFlow')
        actual_major_minor = '.'.join(tf.__version__.split('.')[:2])
        self.assertEqual(actual_major_minor, expected_major_minor, "incorrect version")

    def test_tf2onnx_conversion(self):
        # Sometime around TF 2.17 - some weird issue was introduced w.r.t the interplay between TF Keras and tf2onnx
        # - naively defining a Sequential model doesn't seem to work
        #   - https://github.com/tensorflow/tensorflow/issues/63867
        #   - https://github.com/onnx/tensorflow-onnx/issues/2319
        # - input_signature required on from_keras function
        #    https://github.com/onnx/tensorflow-onnx/issues/2329

        # Define the input layer
        inputs = tf.keras.Input(shape=(10,))

        # Define the model layers
        flatten_layer = tf.keras.layers.Flatten()(inputs)
        outputs = tf.keras.layers.Dense(1)(flatten_layer)

        # Create the model
        model = tf.keras.Model(inputs=inputs, outputs=outputs)

        # Export the model to ONNX format
        onnx_model = tf2onnx.convert.from_keras(model, input_signature=[tf.TensorSpec(model.inputs[0].shape)])

        self.assertTrue(onnx_model is not None)

    def test_mnist_model(self):
        # A basic test from the official tensorflow webpage https://www.tensorflow.org/tutorials/quickstart/beginner
        mnist = tf.keras.datasets.mnist
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train, x_test = x_train / 255.0, x_test / 255.0
        model = tf.keras.models.Sequential([
          tf.keras.layers.Flatten(input_shape=(28, 28)),
          tf.keras.layers.Dense(128, activation='relu'),
          tf.keras.layers.Dropout(0.2),
          tf.keras.layers.Dense(10)
        ])
        predictions = model(x_train[:1]).numpy()
        predictions
        tf.nn.softmax(predictions).numpy()
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        loss_fn(y_train[:1], predictions).numpy()
        model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
        model.fit(x_train, y_train, epochs=5)
        model.evaluate(x_test,  y_test, verbose=2)
        probability_model = tf.keras.Sequential([
          model,
          tf.keras.layers.Softmax()
        ])
        probability_model(x_test[:5])

    def test_tensorboard(self):
        # Create a simple model
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(10, input_shape=(5,), activation='relu'),
            tf.keras.layers.Dense(1)
        ])
        # Compile the model
        model.compile(optimizer='adam', loss='mse')
        # Generate some example data
        x_train = tf.random.normal((100, 5))
        y_train = tf.random.normal((100, 1))
        # Create a TensorBoard callback
        log_dir = './logs'
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
        # Train the model
        model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

expected_versions = load_expected_versions()

suite = unittest.TestLoader().loadTestsFromTestCase(TestTensorflowNotebook)
unittest.TextTestRunner().run(suite)
