In [None]:
import datetime
import enum
import itertools
import logging
import re
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable

import astropy
import humanize
import matplotlib
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy.time import Time
from matplotlib.ticker import FuncFormatter

from lsst.utils.iteration import ensure_iterable

import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astro_metadata_translator import ObservationInfo
from astropy.time import Time
import lsst.daf.butler as dafButler
import lsst.summit.utils.butlerUtils as butlerUtils
from lsst.summit.utils.utils import dayObsIntToString
from lsst.summit.utils.efdUtils import getEfdData, makeEfdClient
from lsst.summit.utils.tmaUtils import filterBadValues
from matplotlib.ticker import FuncFormatter
from lsst.ts.observing.utilities.decorated_logger import DecoratedLogger
logger = DecoratedLogger.get_decorated_logger()

In [None]:
def getAzElRotDataForExposure(client, expRecord, prePadding=0, postPadding=0):

    azimuthData = getEfdData(
        client,
        "lsst.sal.MTMount.azimuth",
        expRecord=expRecord,
        prePadding=prePadding,
        postPadding=postPadding
    )
    elevationData = getEfdData(
        client,
        "lsst.sal.MTMount.elevation",
        expRecord=expRecord,
        prePadding=prePadding,
        postPadding=postPadding
    )
    rotationData = getEfdData(
        client,
        "lsst.sal.MTRotator.rotation",
        expRecord=expRecord,
        prePadding=prePadding,
        postPadding=postPadding
    )
    rotationTorques = getEfdData(
        client,
        "lsst.sal.MTRotator.motors",
        expRecord=expRecord,
        prePadding=prePadding,
        postPadding=postPadding
    )

    azValues = azimuthData["actualPosition"].values
    elValues = elevationData["actualPosition"].values
    rotValues = rotationData["actualPosition"].values
    azDemand = azimuthData["demandPosition"].values
    elDemand = elevationData["demandPosition"].values
    rotDemand = rotationData["demandPosition"].values

    azError = (azValues - azDemand) * 3600
    elError = (elValues - elDemand) * 3600
    rotError = (rotValues - rotDemand) * 3600

    azimuthData["azError"] = azError
    elevationData["elError"] = elError
    rotationData["rotError"] = rotError

    return azimuthData, elevationData, rotationData, rotationTorques
    

In [None]:
def tickFormatter(value: float, tick_number: float) -> str:
    # Convert the value to a string without subtracting large numbers
    # tick_number is unused.
    return f"{value:.2f}"

def getPlotTime(time: pd.Timestamp | Time | datetime.datetime) -> datetime.datetime:
    """Get the right time to plot a point from the various time formats."""
    match time:
        case pd.Timestamp():
            return time.to_pydatetime()
        case astropy.time.Time():
            return time.utc.datetime
        case datetime.datetime():
            return time
        case _:
            raise ValueError(f"Unknown type for commandTime: {type(time)}")


In [None]:
client = makeEfdClient()
butler = dafButler.Butler('/repo/embargo_new', collections=["LSSTComCam/raw/all", "LSSTComCam/calib"])
expId = 2024112000272
dataId = {'exposure': expId, 'detector': 4, 'instrument': 'LSSTComCam'}
expRecord = butlerUtils.getExpRecordFromDataId(butler, dataId)

## Test the data retrieval only

In [None]:
azimuthData, elevationData, rotationData, rotationTorques = \
        getAzElRotDataForExposure(client, expRecord= expRecord)

In [None]:
rotationData['rotError'].plot()

In [None]:
def calculateMountErrors(dataId, butler, client, figure=None, maxDelta=0.1, \
                         saveFilename=None, logger=None, doFilterResiduals=True):
    """Queries EFD for a given exposure and calculates the RMS errors in the                                                                     
    axes during the exposure, optionally plotting and saving the data.                                                                           
    """

    NON_TRACKING_IMAGE_TYPES = [
        "BIAS",
        "FLAT",
    ]
    
    COMCAM_ANGLE_TO_EDGE_OF_FIELD_ARCSEC = 1800.0
    LSSTCAM_ANGLE_TO_EDGE_OF_FIELD_ARCSEC = 8500.0
    MOUNT_IMAGE_WARNING_LEVEL = 0.05  # this determines the colouring of the cells in the table, yellow for this                                     
    MOUNT_IMAGE_BAD_LEVEL = 0.10  # and red for this                                                                                                 
    
    start = time.time()
    expRecord = butlerUtils.getExpRecordFromDataId(butler, dataId)
    dayString = dayObsIntToString(expRecord.day_obs)
    seqNumString = str(expRecord.seq_num)
    dataIdString = f"{dayString} - seqNum {seqNumString}"
    imgType = expRecord.observation_type.upper()
    if imgType in NON_TRACKING_IMAGE_TYPES:
        logger.info(f"Skipping mount torques for non-tracking image type {imgType} for {dataIdString}")
        return False

    exptime = expRecord.exposure_time
    if exptime < 1.99:
        logger.info("Skipping sub 2s expsoure")
        return False

    azimuthData, elevationData, rotationData, rotationTorques = \
        getAzElRotDataForExposure(client, expRecord= expRecord)

    elevation = 90 - expRecord.zenith_angle
    azimuth = expRecord.azimuth
    logger.debug(f"dataId={dataIdString}, imgType={imgType}")

    azError = azimuthData["azError"].values
    elError = elevationData["elError"].values
    rotError = rotationData["rotError"].values
    if doFilterResiduals:
        # Filtering out bad values
        nReplacedAz = filterBadValues(azError, maxDelta)
        nReplacedEl = filterBadValues(elError, maxDelta)
        azimuthData["azError"] = azError
        elevationData["elError"] = elError
    az_rms = np.sqrt(np.mean(azError * azError))
    el_rms = np.sqrt(np.mean(elError * elError))
    rot_rms = np.sqrt(np.mean(rotError * rotError))

    # Calculate Image impact RMS                                                                                                                 
    image_az_rms = az_rms * np.cos(elevation * np.pi / 180.0)
    image_el_rms = el_rms
    image_rot_rms = rot_rms * COMCAM_ANGLE_TO_EDGE_OF_FIELD_ARCSEC * np.pi / 180.0 / 3600.0
    image_impact_rms = np.sqrt(image_az_rms**2 + image_el_rms**2 + image_rot_rms**2)

    end = time.time()
    elapsed = end - start
    logger.debug(f"Elapsed time for butler and EFD queries = {elapsed}")

    start = time.time()
    if figure is not None:
        [[ax1, ax4], [ax2, ax5], [ax3, ax6]] = figure.subplots(3,2,sharex='col', sharey=False, \
                   gridspec_kw={"wspace": 0.25, "hspace": 0, "height_ratios": [2.5, 1, 1], \
                  "width_ratios":[1.5, 1]})
        # Use the native color cycle for the lines. Because they're on different
        # axes they don't cycle by themselves
        axs = [ax1, ax2, ax3, ax4, ax5, ax6]
        lineColors = [p["color"] for p in plt.rcParams["axes.prop_cycle"]]
        nColors = len(lineColors)
        colorCounter = 0
    
        ax1.plot(azimuthData["actualPosition"], label="Azimuth position", c=lineColors[colorCounter % nColors])
        colorCounter += 1
        ax1.yaxis.set_major_formatter(FuncFormatter(tickFormatter))
        ax1.set_ylabel("Azimuth (degrees)")
    
        ax1_twin = ax1.twinx()
        ax1_twin.plot(
            elevationData["actualPosition"], label="Elevation position", c=lineColors[colorCounter % nColors]
        )
        colorCounter += 1
        
        ax2.plot(
            azimuthData["azError"],
            label="Azimuth tracking error",
            c=lineColors[colorCounter % nColors],
        )
        colorCounter += 1
        ax2.plot(
            elevationData["elError"],
            label="Elevation tracking error",
            c=lineColors[colorCounter % nColors],
        )
        colorCounter += 1
        ax2.axhline(0.01, ls="-.", color="black")
        ax2.axhline(-0.01, ls="-.", color="black")
        ax2.yaxis.set_major_formatter(FuncFormatter(tickFormatter))
        ax2.set_ylabel("Tracking error (arcsec)")
        ax2.set_xticks([])  # remove x tick labels on the hidden upper x-axis
        ax2.set_ylim(-0.05, 0.05)
        ax2.set_yticks([-0.04, -0.02, 0.0, 0.02, 0.04])
        ax2.legend()
        ax2.text(0.1, 0.9, f"Image impact RMS = {image_impact_rms:.3f} arcsec (with rot).", transform=ax2.transAxes)
        if doFilterResiduals:
            ax2.text(
                0.1,
                0.8,
                f"{nReplacedAz} bad az values and {nReplacedEl} bad el values were replaced",
                transform=ax2.transAxes,
            )
        ax3_twin = ax3.twinx()
        ax3.plot(azimuthData["actualTorque"], label="Azimuth torque", c=lineColors[colorCounter % nColors])
        colorCounter += 1
        ax3_twin.plot(elevationData["actualTorque"], label="Elevation torque", c=lineColors[colorCounter % nColors])
        colorCounter += 1
        ax3.set_ylabel("Azimuth torque (Nm)")
        ax3_twin.set_ylabel("Elevation torque (Nm)")
        ax3.set_xlabel("Time (UTC)")  # yes, it really is UTC, matplotlib converts this automatically!
    
        # put the ticks at an angle, and right align with the tick marks
        ax3.set_xticks(ax3.get_xticks())  # needed to supress a user warning
        xlabels = ax3.get_xticks()
        ax3.set_xticklabels(xlabels)
        ax3.tick_params(axis='x', rotation=45)
        print(ax3.get_xticklabels())
        print(ax3.get_xticks())
        ax3.xaxis.set_major_locator(mdates.AutoDateLocator())
        ax3.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
        
        ax4.plot(rotationData["actualPosition"], label="Rotator position", c=lineColors[colorCounter % nColors])
        colorCounter += 1
        ax4.yaxis.set_major_formatter(FuncFormatter(tickFormatter))
        ax4.yaxis.tick_right()
        ax4.set_ylabel("Rotator angle (degrees)")
        ax4.yaxis.set_label_position("right")
        ax5.plot(
            rotationData["rotError"],
            c=lineColors[colorCounter % nColors],
        )
        
        colorCounter += 1
        ax5.axhline(0.1, ls="-.", color="black")
        ax5.axhline(-0.1, ls="-.", color="black")
        ax5.yaxis.set_major_formatter(FuncFormatter(tickFormatter))
        ax5.set_ylabel("Tracking error (arcsec)")
        ax5.set_xticks([])  # remove x tick labels on the hidden upper x-axis
        ax5.set_ylim(-1.0, 1.0)
        ax5.set_yticks([-0.5, 0.0, 0.5])
        ax5.yaxis.tick_right()
        ax5.yaxis.set_label_position("right")
        
        ax6.plot(rotationTorques["torque0"], label="Torque0", c=lineColors[colorCounter % nColors])
        colorCounter += 1
        ax6.plot(rotationTorques["torque1"], label="Torque1", c=lineColors[colorCounter % nColors])
        ax6.set_ylabel("Rotator torque (Nm)")
        ax6.set_xlabel("Time (UTC)")  # yes, it really is UTC, matplotlib converts this automatically!
        # put the ticks at an angle, and right align with the tick marks
        print(ax6.get_xticks())
        ax6.set_xticks(ax6.get_xticks())  # needed to supress a user warning
        xlabels = ax6.get_xticks()
        print(xlabels)
        ax6.set_xticklabels(xlabels)
        ax6.tick_params(axis='x', rotation=45)
        ax6.xaxis.set_major_locator(mdates.AutoDateLocator())
        ax6.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
        ax6.yaxis.tick_right()
        ax6.yaxis.set_label_position("right")
        print(ax6.get_xticklabels())
        ax6.legend()
        
        ax1_twin.yaxis.set_major_formatter(FuncFormatter(tickFormatter))
        ax1_twin.set_ylabel("Elevation (degrees)")
        ax1.set_xticks([])  # remove x tick labels on the hidden upper x-axis    
        # combine the legends and put inside the plot
        handles1a, labels1a = ax1.get_legend_handles_labels()
        handles1b, labels1b = ax1_twin.get_legend_handles_labels()
        handles2a, labels2a = ax3.get_legend_handles_labels()
        handles2b, labels2b = ax3_twin.get_legend_handles_labels()
        handles = handles1a + handles1b + handles2a + handles2b
        labels = labels1a + labels1b + labels2a + labels2b
        # ax2 is "in front" of ax1 because it has the vlines plotted on it, and
        # vlines are on ax2 so that they appear at the bottom of the legend, so
        # make sure to plot the legend on ax2, otherwise the vlines will go on top
        # of the otherwise-opaque legend.
        ax1_twin.legend(handles, labels, facecolor="white", framealpha=1)

        ax1.set_title("Azimuth and Elevation")
        ax4.set_title("Rotator")
        title = f"ComCam - {expId} - Exposure time = {expRecord.exposure_time:.1f}s"
        plt.suptitle(title, fontsize=18, y=0.95)

        # Add exposure start and end:
        for ax in axs:
            ax.axvline(getPlotTime(expRecord.timespan.begin), ls='--', color='green')
            ax.axvline(getPlotTime(expRecord.timespan.end), ls='--', color='red')

    

    return dict(
        az_rms=az_rms,
        el_rms=el_rms,
        rot_rms=rot_rms,
        image_az_rms=image_az_rms,
        image_el_rms=image_el_rms,
        image_rot_rms=image_rot_rms,
        image_impact_rms=image_impact_rms        
    ) 


## Test the plot routine with no figure supplied

In [None]:
# Run it with no figure.
expId = 2024111900364
dataId = {'exposure': expId, 'detector': 4, 'instrument': 'LSSTComCam'}
myDict = calculateMountErrors(dataId, butler, client, logger=logger)
myDict

## Now make a plot

In [None]:
# Now make a plot
%matplotlib inline
figure = plt.figure(figsize=(10,8))
expId = 2024111900364
dataId = {'exposure': expId, 'detector': 4, 'instrument': 'LSSTComCam'}
myDict = calculateMountErrors(dataId, butler, client, figure=figure, logger=logger)
plt.savefig(f"/home/c/cslage/u/MTMount/mount_plots/ComCam_Mount_Plot_Old_{expId}.png")