In [None]:
import mlflow
import mlflow.keras
from mlflow.models import infer_signature
import numpy as np
from nightingale.model.classifier_head import ClassifierHead

# mlflow.tensorflow.autolog(
#     log_models=True,
#     log_input_examples=True,
#     log_model_signatures=True,
#     log_every_n_steps=1,
# )

model = ClassifierHead()
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")

input_data = np.random.random((128, 1024))
labels = np.random.randint(0, 2, size=(128,))

model.fit(
    x=input_data,
    y=labels,
    epochs=1
)

# Get predictions
sample_output = model.predict(input_data)
print(model.summary())

# Infer signature from data
signature = infer_signature(input_data, sample_output)

with mlflow.start_run() as run:
   model_info = mlflow.keras.log_model(model, name = "model", signature=signature, pip_requirements=['keras==3.10.0'])



In [None]:
loaded_model = mlflow.keras.load_model(model_info.model_uri)
print("Model uri: ", model_info.model_uri)
print(loaded_model.summary())

In [None]:
# Test the loaded model produces the same output for the same input as the model.
test_input = np.random.random((128, 1024))
np.testing.assert_allclose(
    model.predict(test_input),
    loaded_model.predict(test_input),
)