# Predict

In [None]:
# Parameters cells
WITNESS_NAME = "CHSH"
SIMULATION_PATH = "./simulated_data"
MLFLOW_URL = "http://localhost:5000"
AIRFLOW_DAG_RUN_ID = "test-dm-chsh"
MLFLOW_RUN_ID = "1c81e50024c54dd69d8e48bc406c4dad"
AWS_ACCESS_KEY_ID="minio123"
AWS_SECRET_ACCESS_KEY="minio123"
MLFLOW_S3_ENDPOINT_URL="http://localhost:9990"

In [None]:
from os import environ

environ["AWS_ACCESS_KEY_ID"] = AWS_ACCESS_KEY_ID
environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY
environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL
environ["MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR"] = "false"

In [None]:
import mlflow

if MLFLOW_URL is not None:
    environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL
    mlflow.set_tracking_uri(MLFLOW_URL)

In [None]:
import mlflow
import mlflow.tensorflow
import tensorflow as tf

if MLFLOW_RUN_ID is None:
    client = mlflow.tracking.MlflowClient()

    experiments = client.search_experiments(
        filter_string = "name = 'ML Quantum Entanglement'"
    )
    experiment_id = experiments[0].experiment_id

    runs = client.search_runs(
        experiment_ids=[experiment_id], 
        filter_string=f"tags.airflow_dag_run_id = '{AIRFLOW_DAG_RUN_ID}'"
    )
    run_id = runs[0].info.run_id
else :
    run_id = MLFLOW_RUN_ID

model_uri = f"runs:/{run_id}/model"
model = mlflow.tensorflow.load_model(model_uri)

In [None]:
from simulation_utils import flatten_density_matrix, create_random_separable, create_bell_states

# Test maximally entangled pure bell states
for bell_dm in create_bell_states():
    flatten_bell = flatten_density_matrix(bell_dm)
    predictions = model.predict(tf.constant([flatten_bell]))
    print(f"Bell state \n{bell_dm}: \nentanglement prediction {predictions}")
    assert predictions[0][0] > 0.9999

# Test separable pure random states
for _ in range(0, 4):
    separable_dm = create_random_separable()
    flatten_separable = flatten_density_matrix(separable_dm)
    predictions = model.predict(tf.constant([flatten_separable]))
    print(f"Separable state {flatten_separable}: \nentanglement prediction {predictions}")
    assert predictions[0][0] > 0 and predictions[0][0] < 1e-5
