In [None]:
import json
import yaml

In [None]:
with open("/Users/vdk/Software/ctasoft/calibpipe/calibpipe/tests/data/throughput/optical_throughput_config.json") as json_file:
    data = json.load(json_file)
    data["EventSource"]["input_url"] = "/Users/vdk/Software/ctasoft/calibpipe/calibpipe/tests/data/throughput/simtel_run501_muon_telescope_transmission_0.8.simtel.gz"
    data["ThroughputCalibrationProducer"]["output"] = "test_file"
    data["ThroughputCalibrationProducer"]["update"] = True

In [None]:
with open("/Users/vdk/Software/ctasoft/calibpipe/doc/source/examples/throughput/configurations/throughput_configuration.yaml") as file:
    data = yaml.safe_load(file)
    data["EventSource"]["input_url"] = "/Users/vdk/Software/ctasoft/calibpipe/calibpipe/tests/data/throughput/simtel_run501_muon_telescope_transmission_0.8.simtel.gz"

In [None]:
with open("/Users/vdk/Software/ctasoft/calibpipe/doc/source/examples/throughput/configurations/throughput_configuration.yaml") as file:
    data = yaml.safe_load(file)

data

In [None]:
with open("/Users/vdk/Software/ctasoft/calibpipe/doc/source/examples/throughput/configurations/throughput_configuration.yaml") as file:
    data = yaml.safe_load(file)
    data["EventSource"]["input_url"] = "/Users/vdk/Software/ctasoft/calibpipe/calibpipe/tests/data/throughput/simtel_run501_muon_telescope_transmission_0.8.simtel.gz"
    data["ThroughputCalibrationProducer"]["output"] = "test_file"
    data["ThroughputCalibrationProducer"]["update"] = True

data


In [None]:
from ctapipe.core import run_tool
import os
import yaml

from pathlib import Path
import pytest
from traitlets.config.loader import Config

from astropy.time import Time
from astropy.table import Table
from calibpipe.core.exceptions import (
    FailedThoughputCalibration,
    CorruptedMuonData,
    DataModelVersionMismatch,
)
from calibpipe.throughput.containers import TelescopeOpticalThoughtputContainer
from calibpipe.tools.throughput_calibrator import ThroughputCalibrationProducer
from ctapipe.containers import NAN_TIME
from calibpipe.database.connections import CalibPipeDatabase
from calibpipe.database.adapter.database_containers import version_control_sql_info
from calibpipe.database.interfaces import TableHandler
import sqlalchemy as sa


data_path = Path(__file__).parent.joinpath("../../data/throughput/")
output_path = Path(__file__).parent.joinpath("output/")
config_path = (
    Path(__file__).parent.joinpath(
         "../../../../doc/source/examples/throughput/configurations"
    )
)


@pytest.mark.muon
def test_tool():
    input_file = data_path.joinpath(
        "simtel_run501_muon_telescope_transmission_0.8.simtel.gz"
    )
    config = config_path.joinpath("throughput_configuration.yaml")
    os.makedirs(output_path, exist_ok=True)
    output_file = output_path.joinpath("throughput_test.ecsv")
    if output_file.exists():
        os.remove(output_file)
    assert (
        run_tool(
            ThroughputCalibrationProducer(),
            argv=[f"-i{input_file}", f"-c{config}", f"-o{output_file}", "-u=True"],
            cwd=None,
        )
        == 0
    )


@pytest.mark.muon
def test_exception_handler():
    with open(config_path.joinpath("throughput_configuration.yaml")) as yaml_file:
        data = yaml.safe_load(yaml_file)
        output_file = output_path.joinpath("throughput_test.ecsv")
        data["EventSource"]["input_url"] = data_path.joinpath(
            "simtel_run501_muon_telescope_transmission_0.8.simtel.gz"
        )
        data["ThroughputCalibrationProducer"]["output"] = output_file
        data["ThroughputCalibrationProducer"]["update"] = True
    output_temp = output_path.joinpath("throughput_test_temp.ecsv")
    if output_temp.exists():
        os.remove(output_temp)
    calibrator = ThroughputCalibrationProducer(config=Config(data))
    calibrator.setup()
    # test exception handler
    calibrator.handle_exception(FailedThoughputCalibration())
    calibrator.handle_exception(CorruptedMuonData())
    calibrator.handle_exception(DataModelVersionMismatch())


@pytest.mark.muon
def test_throughput_storage_in_ecsv():
    with open(config_path.joinpath("throughput_configuration.yaml")) as yaml_file:
        data = yaml.safe_load(yaml_file)
        output_file = output_path.joinpath("throughput_test.ecsv")
        data["EventSource"]["input_url"] = data_path.joinpath(
            "simtel_run501_muon_telescope_transmission_0.8.simtel.gz"
        )
        data["ThroughputCalibrationProducer"]["output"] = output_file
        data["ThroughputCalibrationProducer"]["update"] = True
    if output_file.exists():
        os.remove(output_file)
    calibrator = ThroughputCalibrationProducer(config=Config(data))
    calibrator.setup()
    calibrator.throughputs.tel["1"].obs["0"] = TelescopeOpticalThoughtputContainer(
        optical_throughput_coefficient=0.5,
        optical_throughput_coefficient_std=0.2,
        validity_start=1.0,
        validity_end=2.0,
        obs_id=0,
        method="Muon Rings",
        tel_id=1,
    )
    calibrator.update_throughput_table(calibrator.throughputs)
    calibrator.finish()
    assert calibrator.output.exists()
    # Now check that the table was filled correctly
    table = Table.read(calibrator.output)
    table_entry_as_dict = dict(zip(table.colnames, table[0]))
    assert calibrator.throughputs.tel["1"].obs["0"].as_dict() == table_entry_as_dict


@pytest.mark.db
@pytest.mark.muon
def test_throughput_storage_in_db():
    with open(config_path.joinpath("throughput_configuration.yaml")) as yaml_file:
        data = yaml.safe_load(yaml_file)
        data["EventSource"]["input_url"] = data_path.joinpath(
            "simtel_run501_muon_telescope_transmission_0.8.simtel.gz"
        )
        data["ThroughputCalibrationProducer"]["write_db"] = True
    calibrator = ThroughputCalibrationProducer(config=Config(data))
    calibrator.setup()
    major = 100
    minor = 100
    calibrator.table_version = f"{major}.{minor}"
    calibrator.throughputs.tel["1"].obs["0"] = TelescopeOpticalThoughtputContainer(
        optical_throughput_coefficient=0.5,
        optical_throughput_coefficient_std=0.2,
        validity_start=1.0,
        validity_end=2.0,
        obs_id=0,
        method="Muon Rings",
        tel_id=1,
    )
    calibrator.update_throughput_db(calibrator.throughputs)
    # Now check that data have been correctly written to the DB
    with CalibPipeDatabase(
        user=calibrator.db_user,
        password=calibrator.db_pass,
        database=calibrator.db_name,
        host=calibrator.db_host,
        port=calibrator.db_port,
        autocommit=True,
    ) as connection:
        assert sa.inspect(connection.engine).has_table(calibrator.table.name)
        version_control_table = version_control_sql_info.get_table()
        assert sa.inspect(connection.engine).has_table(version_control_table.name)
        query = sa.select(version_control_table).where(
            version_control_table.c.name == calibrator.table_name,
            version_control_table.c.version == calibrator.table_version,
        )
        query_result = connection.execute(query).first()
        assert query_result is not None
        query = sa.select(calibrator.table)
        query_result = connection.execute(query).first()
        assert (
            calibrator.throughputs.tel["1"].obs["0"].as_dict() == query_result._asdict()
        )


@pytest.mark.db
@pytest.mark.muon
def test_throughput_version_compatibility_in_db():
    with open(config_path.joinpath("throughput_configuration.yaml")) as yaml_file:
        data = yaml.safe_load(yaml_file)
        db_user = data["ThroughputCalibrationProducer"]["db_user"]
        db_name = data["ThroughputCalibrationProducer"]["db_name"]
        db_host = data["ThroughputCalibrationProducer"]["db_host"]
        db_pass = data["ThroughputCalibrationProducer"]["db_pass"]
    major = 100
    minor = 100
    table_name = "optical_throughput"
    old_version = f"{major}.{minor}"
    table_full_name = f"{table_name}_v{major}_{minor}"
    minor += 1
    new_version = f"{major}.{minor}"
    version_control_table = version_control_sql_info.get_table()
    with CalibPipeDatabase(
        user=db_user,
        password=db_pass,
        database=db_name,
        host=db_host,
        autocommit=True,
    ) as connection:
        comp_version = TableHandler.get_compatible_version(
            version_control_table, table_name, new_version, connection
        )
        assert comp_version == old_version
        major += 1
        new_version = f"{major}.{minor}"
        comp_version = TableHandler.get_compatible_version(
            version_control_table, table_name, new_version, connection
        )
        assert comp_version == new_version
        # Now delete the added tables/entries
        stmt = sa.text(f"drop table {table_full_name}")
        connection.execute(stmt)
        connection.execute(
            version_control_table.delete().where(
                version_control_table.c.name == table_name,
                version_control_table.c.version >= old_version,
            )
        )
        query = sa.select(version_control_table.c.name)
        query_result = connection.execute(query).first()
        if query_result is None:
            stmt = sa.text("drop table version_control_table")
            connection.execute(stmt)
