# Simulation notebook

In [None]:
# Parameters cells
WITNESS_NAME = "CHSH"
SIMULATION_PATH = "./simulated_data"
MLFLOW_URL = "http://localhost:5000"
AIRFLOW_DAG_RUN_ID = "test-dm-chsh"
AWS_ACCESS_KEY_ID="minio123"
AWS_SECRET_ACCESS_KEY="minio123"
MLFLOW_S3_ENDPOINT_URL="http://localhost:9990"

In [None]:
from os import environ

environ["AWS_ACCESS_KEY_ID"] = AWS_ACCESS_KEY_ID
environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY

In [None]:
from qutip import basis, tensor, rand_ket
import numpy as np
from entanglement_witnesses import witnesses
import mlflow

In [None]:
if MLFLOW_URL is not None:
    environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL
    mlflow.set_tracking_uri(MLFLOW_URL)

In [None]:
from simulation_utils import get_simulation_methods_for_witness

def get_simulated_training_data(entanglement_witness, witness_name, samples_nb=500):
    samples_states = []
    samples_is_entangled = []
    for simulation_method in get_simulation_methods_for_witness(witness_name):
        for _ in range(0, samples_nb):
            state_dm = simulation_method()
            samples_states.append(state_dm)
            is_entangled = entanglement_witness(state_dm)
            samples_is_entangled.append(is_entangled)

    return samples_states, samples_is_entangled

In [None]:
simulated_data = {}

if WITNESS_NAME is not None:
    parameter_witnesses = {WITNESS_NAME: witnesses[WITNESS_NAME]}
else: 
    parameter_witnesses = witnesses

for name, witness in parameter_witnesses.items():
    samples_states, samples_is_entangled =  get_simulated_training_data(witness, name)

    simulated_data[name] = {
        "states": samples_states,
        "entanglement": samples_is_entangled
    }

In [None]:
from simulation_utils import flatten_density_matrix

for name, data in simulated_data.items():
    states = data["states"]
    labels = np.array(data["entanglement"])

    flatten_states = [flatten_density_matrix(state) for state in states]

    file_path = "{}/simulation-{}.npz".format(SIMULATION_PATH, name)
    np.savez(file_path, states=np.array(flatten_states), labels=labels)

    if MLFLOW_URL is not None:
        mlflow.set_experiment('ML Quantum Entanglement')
        with mlflow.start_run() as run:
            mlflow.set_tag("airflow_dag_run_id", AIRFLOW_DAG_RUN_ID)
            mlflow.set_tag("pipeline_step", "simulation")
            mlflow.log_param("witness", WITNESS_NAME)
            mlflow.log_artifact(file_path, artifact_path="simulated_data")