In [None]:
from pathlib import Path
import json
import re
import unittest
import tensorflow as tf
import tensorboard
import tf2onnx
import os
import ssl
from platform import python_version

def get_major_minor(s):
    return '.'.join(s.split('.')[:2])

def load_expected_versions() -> dict:
    lock_file = Path('./expected_versions.json')
    if not lock_file.exists():
        raise FileNotFoundError("expected_versions.json not found.")
    with open(lock_file, 'r') as file:
        return json.load(file)

def get_expected_version(dependency_name: str) -> str:
    raw_value = expected_versions.get(dependency_name)
    raw_version = re.sub(r'^\D+', '', raw_value)
    return get_major_minor(raw_version)

class TestTensorflowNotebook(unittest.TestCase):

    def test_python_version(self):
        expected_major_minor = get_expected_version('Python')
        actual_major_minor = '.'.join(python_version().split('.')[:2])
        self.assertEqual(actual_major_minor, expected_major_minor, "Incorrect Python version")

    def test_tensorflow_version(self):
        expected_major_minor = get_expected_version('TensorFlow')
        actual_major_minor = '.'.join(tf.__version__.split('.')[:2])
        self.assertEqual(actual_major_minor, expected_major_minor, "Incorrect TensorFlow version")

    def test_tf2onnx_conversion(self):
        # Workaround for known issues with tf2onnx + tf.keras
        inputs = tf.keras.Input(shape=(10,))
        flatten_layer = tf.keras.layers.Flatten()(inputs)
        outputs = tf.keras.layers.Dense(1)(flatten_layer)
        model = tf.keras.Model(inputs=inputs, outputs=outputs)
        onnx_model = tf2onnx.convert.from_keras(model, input_signature=[tf.TensorSpec(model.inputs[0].shape)])
        self.assertTrue(onnx_model is not None)

    def test_mnist_model(self):
        mnist = tf.keras.datasets.mnist
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train, x_test = x_train / 255.0, x_test / 255.0
        model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10)
        ])
        predictions = model(x_train[:1]).numpy()
        tf.nn.softmax(predictions).numpy()
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        loss_fn(y_train[:1], predictions).numpy()
        model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
        model.fit(x_train, y_train, epochs=5)
        model.evaluate(x_test, y_test, verbose=2)
        probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])
        probability_model(x_test[:5])

    def test_tensorboard(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(10, input_shape=(5,), activation='relu'),
            tf.keras.layers.Dense(1)
        ])
        model.compile(optimizer='adam', loss='mse')
        x_train = tf.random.normal((100, 5))
        y_train = tf.random.normal((100, 1))
        log_dir = './logs'
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
        model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

@unittest.skip("RHAIENG-509: TestSecurity tests all fail")
class TestSecurity(unittest.TestCase):

    def test_jupyter_password_env(self):
        self.assertIn("JUPYTER_PASSWORD", os.environ, "Missing JUPYTER_PASSWORD env variable for login protection")

    def test_ssl_files_exist(self):
        cert_file = os.environ.get("JUPYTER_SSL_CERT", "/etc/jupyter/ssl/cert.pem")
        key_file = os.environ.get("JUPYTER_SSL_KEY", "/etc/jupyter/ssl/key.pem")
        self.assertTrue(os.path.exists(cert_file), f"SSL cert not found: {cert_file}")
        self.assertTrue(os.path.exists(key_file), f"SSL key not found: {key_file}")

    def test_ssl_certificate_validity(self):
        cert_file = os.environ.get("JUPYTER_SSL_CERT", "/etc/jupyter/ssl/cert.pem")
        try:
            context = ssl.create_default_context()
            context.load_cert_chain(certfile=cert_file)
        except Exception as e:
            self.fail(f"Invalid SSL certificate: {e}")

    def test_host_not_public(self):
        host = os.environ.get("JUPYTER_HOST", "localhost")
        self.assertIn(host, ["localhost", "127.0.0.1"], f"Jupyter host is publicly exposed: {host}")

expected_versions = load_expected_versions()

suite = unittest.TestLoader().loadTestsFromTestCase(TestTensorflowNotebook)
unittest.TextTestRunner().run(suite)
