# Imports

In [None]:
import warnings

warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")
warnings.filterwarnings("ignore", ".*dubious year.*")
warnings.filterwarnings(
    "ignore", "Tried to get polar motions for times after IERS data is valid.*"
)

In [None]:
from astropy.coordinates import (
    ICRS,
    SkyCoord,
    get_body,
    UnitSphericalRepresentation,
    EarthLocation,
)
from astropy.table import QTable, vstack
from astropy_healpix import HEALPix
from astropy.time import Time
from astropy import units as u
from m4opt import fov
from m4opt.missions import uvex as mission
from m4opt.dynamics import nominal_roll
from m4opt.skygrid._geodesic import for_subdivision as geodesic_for_subdivision
from m4opt.synphot import observing
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from m4opt.utils.optimization import partition_graph, solve_tsp
import networkx as nx
from tqdm.auto import tqdm
import numpy as np
from regions import Region, Regions, CircleSkyRegion
import synphot
from ligo.skymap.plot import earth, sun, moon, mellinger
from operator import itemgetter
from astropy.visualization import make_rgb, PercentileInterval
from astropy.wcs import WCS
from reproject import reproject_adaptive
import pymetis
import itertools

# FOV

For any given target sky position, UVEX's roll angle rotates by 360° throughout the year. Plot circle inscribed within the FOV.

In [None]:
inscribed_fov = CircleSkyRegion(SkyCoord(0 * u.deg, 0 * u.deg), mission.fov.width / 2)

In [None]:
width, height = plt.rcParams["figure.figsize"]
fig = plt.figure(figsize=(width, width))
ax = fig.add_subplot(projection="astro degrees zoom", radius=3 * u.deg)

ax.add_patch(inscribed_fov.to_pixel(ax.wcs).as_artist())
ax.grid()
for key in ["ra", "dec"]:
    ax.coords[key].set_auto_axislabel(False)

blit = []
rolls = np.linspace(0, 90, endpoint=False) * u.deg
regions = fov.footprint(mission.fov, SkyCoord(0 * u.deg, 0 * u.deg), rolls)

with tqdm(total=len(regions)) as progress:

    def animate(region: Region):
        for artist in blit:
            artist.remove()
        del blit[:]
        blit.append(ax.add_patch(region.to_pixel(ax.wcs).as_artist(linewidth=2)))
        progress.update()
        return blit

    FuncAnimation(fig, animate, regions, blit=True, interval=50).save("fov.mp4")

Save the FOV footprint and the inscribed circle as DS9 regions.

In [None]:
mission.fov.copy(meta={"text": "UVEX field of view"}).write("fov.ds9", overwrite=True)
inscribed_fov.copy(meta={"text": "Circle inscribed within UVEX field of view"}).write(
    "fov-inscribed-circle.ds9", overwrite=True
)

## Sky grid

Plot existing sky grid.

In [None]:
i = np.arange(len(mission.skygrid))
fig = plt.figure(dpi=300)
ra0 = np.linspace(0, 360, 500, endpoint=False) * u.deg

with tqdm(total=len(ra0)) as progress:

    def animate(ra0):
        fig.clf()
        ax = plt.axes(projection="astro globe", center=SkyCoord(ra0, 30 * u.deg))
        plt.colorbar(
            ax.scatter(
                mission.skygrid.ra.deg,
                mission.skygrid.dec.deg,
                c=i,
                transform=ax.get_transform("world"),
                s=4,
            ),
            orientation="vertical",
        ).set_label("Field ID")
        ax.coords["ra"].set_ticklabel(exclude_overlapping=True)
        ax.grid()
        progress.update()

    FuncAnimation(fig, animate, ra0, interval=50).save("skygrid.mp4")

Check that this is the coarsest grid that covers the entire sky with that circle.

In [None]:
def skygrid_coverage_fraction(skygrid, region, nside=512):
    hpx = HEALPix(nside=nside, frame=ICRS())
    n_pix = hpx.npix
    n_covered = len(
        np.unique(np.concatenate(fov.footprint_healpix(hpx, region, skygrid)))
    )
    return n_covered / n_pix


n_lower = 4000
n_upper = 6000

best_n = np.inf

for b in tqdm(range(10, 50)):
    for c in range(0, b + 1):
        for base, base_count in {
            "icosahedron": 10,
            "octahedron": 4,
            "tetrahedron": 2,
        }.items():
            t = b * b + b * c + c * c
            n = base_count * t + 2
            if n < n_lower or n > n_upper or n >= best_n:
                continue
            points = geodesic_for_subdivision(b, c, base)
            if skygrid_coverage_fraction(points, inscribed_fov) >= 1:
                best_n = n
                best_b = b
                best_c = c
                best_base = base
                print(f"{n=} {base=} {b=} {c=}")

In [None]:
def skygrid_multiple_coverage_fraction(time, region, nside=1024):
    hpx = HEALPix(nside=nside, frame=ICRS())
    n_pix = hpx.npix
    roll = nominal_roll(
        EarthLocation.from_geocentric(0 * u.m, 0 * u.m, 0 * u.m), mission.skygrid, time
    )
    _, counts = np.unique(
        np.concatenate(fov.footprint_healpix(hpx, region, mission.skygrid, roll)),
        return_counts=True,
    )
    return np.bincount(counts, minlength=10) / n_pix


t0 = Time("2025-01-01")
coverage_fraction = skygrid_multiple_coverage_fraction(t0, mission.fov)
coverage_fraction_inscribed = skygrid_multiple_coverage_fraction(t0, inscribed_fov)

width, height = plt.rcParams["figure.figsize"]
fig, axs = plt.subplots(1, 2, tight_layout=True, figsize=(2 * height, height))
values = coverage_fraction[1:4]
axs[0].pie(
    values,
    labels=[
        f"$n = {i + 1}$\n{np.around(100 * value):g}%" for i, value in enumerate(values)
    ],
    labeldistance=0.6,
    wedgeprops=dict(lw=1, ec="black"),
)
axs[0].set_title("Instantaneous FOV")
values = coverage_fraction_inscribed[1:3]
axs[1].pie(
    values,
    labels=[
        f"$n = {i + 1}$\n{np.around(100 * value):g}%" for i, value in enumerate(values)
    ],
    labeldistance=0.6,
    wedgeprops=dict(lw=1, ec="black"),
)
axs[1].set_title("Inscribed circle")
fig.savefig("coverage-fraction.pdf")

In [None]:
center = SkyCoord("0d 0d")
radius = 10 * u.deg
width, _ = plt.rcParams["figure.figsize"]
fig = plt.figure(figsize=(width, width))
ax = fig.add_subplot(projection="astro zoom", center=center, radius=radius)
for key in ["ra", "dec"]:
    ax.coords[key].set_auto_axislabel(False)
    ax.coords[key].set_ticklabel_visible(False)
    ax.coords[key].set_ticks_visible(False)

t0 = Time("2025-01-01")
times = t0 + np.linspace(0, 1, 100, endpoint=False) * u.year
centers = mission.skygrid[
    mission.skygrid.separation(center) <= 1.1 * np.sqrt(2) * radius
]
geocenter = EarthLocation.from_geocentric(0 * u.m, 0 * u.m, 0 * u.m)

for region in fov.footprint(
    inscribed_fov,
    centers,
):
    ax.add_patch(
        region.to_pixel(ax.wcs).as_artist(
            edgecolor="black",
            alpha=0.5,
        )
    )

artists = []

with tqdm(total=len(times)) as progress:

    def animate(t):
        for artist in artists:
            artist.remove()
        del artists[:]
        rolls = nominal_roll(geocenter, centers, t)
        for region in fov.footprint(
            mission.fov,
            centers,
            rolls,
        ):
            artists.append(
                ax.add_patch(
                    region.to_pixel(ax.wcs).as_artist(
                        fill=True, facecolor="gray", edgecolor="black", alpha=0.25
                    )
                )
            )
        progress.update()

    FuncAnimation(fig, animate, times, interval=50).save("skygrid-overlap.mp4")

# Sky blocks

Partition the sky grid into disjoint groups of adjacent fields.

In [None]:
# Construct distance matrix
distance_matrix = (
    mission.skygrid[:, np.newaxis]
    .separation(mission.skygrid[np.newaxis, :])
    .to_value(u.rad)
)

# Remove loops
np.fill_diagonal(distance_matrix, np.inf)

# Remove edges between non-adjacent vertices
n = 21
m = 4
t = n * n + n * m + m * m
n_edges = 2 * 30 * t
distances = distance_matrix.ravel()
split = np.mean(np.partition(distances, n_edges)[n_edges - 1 : n_edges + 1])
distance_matrix[distance_matrix >= split] = np.inf

# Convert distance matrix to adjacency matrix
adjacency_matrix = np.isfinite(distance_matrix)

edge_weights = adjacency_matrix.astype(np.intp)
# # Construct edge weights that reward strips in right ascension.
# d_lon, d_lat = mission.skygrid[:, np.newaxis].spherical_offsets_to(
#     mission.skygrid[np.newaxis, :]
# )
# lon_weight = 1
# lat_weight = 1000
# power = 3
# edge_weights = np.ceil(
#     (lon_weight * np.abs(d_lon) ** power + lat_weight * np.abs(d_lat) ** power).value
# ).astype(int)

# Reward keeping each of the Magellanic Clouds in single, contiguous partitions.
lmc_region = CircleSkyRegion(SkyCoord.from_name("LMC"), radius=4 * u.deg)
smc_region = CircleSkyRegion(SkyCoord.from_name("SMC"), radius=2.5 * u.deg)
mc_regions = Regions([lmc_region, smc_region])
in_mc = np.logical_or.reduce(
    [
        mission.skygrid.separation(mc_region.center) <= mc_region.radius
        for mc_region in mc_regions.regions
    ]
)
edge_weights[in_mc, :] = edge_weights[:, in_mc] = 1_000_000

edge_weights[~adjacency_matrix] = 0
edge_weights

In [None]:
graph = nx.from_numpy_array(adjacency_matrix)

max_num_per_partition = 20
n_partitions_min = 265
n_partitions_max = 300
seed_min = 1
seed_max = 10
for n_partitions, seed in tqdm(
    itertools.product(
        range(n_partitions_min, n_partitions_max + 1), range(seed_min, seed_max + 1)
    ),
    total=(n_partitions_max - n_partitions_min) * (seed_max - seed_min),
):
    partition = partition_graph(edge_weights, n_partitions, seed=seed, recursive=False)

    for i in range(n_partitions):
        subgraph = graph.subgraph(np.flatnonzero(partition == i))
        if not nx.connected.is_connected(subgraph):
            contiguous = False
            break
    else:
        contiguous = True

    if not contiguous:
        continue

    _, partition_counts = np.unique(partition, return_counts=True)
    if partition_counts.max() <= max_num_per_partition:
        break
else:
    raise RuntimeError("No suitable partition found")

n_partitions, partition_counts

Visualize the partitions.

In [None]:
assert n_partitions == partition.max() + 1
partition_adjacency_matrix = np.zeros((n_partitions, n_partitions), dtype=bool)
for i in tqdm(range(n_partitions)):
    for j in range(n_partitions):
        if i != j:
            partition_adjacency_matrix[i, j] = np.logical_or.reduce(
                adjacency_matrix[np.ix_(partition == i, partition == j)].ravel()
            )
graph = nx.from_numpy_array(partition_adjacency_matrix)
partition_colors = np.asarray(
    list(
        map(
            itemgetter(1),
            sorted(
                nx.algorithms.coloring.greedy_color(
                    graph, strategy="connected_sequential", interchange=True
                ).items()
            ),
        )
    )
)
n_partition_colors = partition_colors.max() + 1
n_partition_colors

In [None]:
i = np.arange(len(mission.skygrid))
fig = plt.figure(dpi=300)
ra0 = np.linspace(0, 360, 500, endpoint=False) * u.deg

with tqdm(total=len(ra0)) as progress:

    def animate(ra0):
        fig.clf()
        ax = plt.axes(projection="astro globe", center=SkyCoord(ra0, 30 * u.deg))
        ax.scatter(
            mission.skygrid.ra.deg,
            mission.skygrid.dec.deg,
            c=partition_colors[partition],
            transform=ax.get_transform("world"),
            s=4,
            cmap="cool",
        )
        ax.coords["ra"].set_ticklabel(exclude_overlapping=True)
        ax.grid()
        progress.update()

    FuncAnimation(fig, animate, ra0, interval=50).save("skyblocks.mp4")

In [None]:
ax = plt.axes(projection="astro globe", center="0deg -90deg")
ax.scatter(
    mission.skygrid.ra.deg,
    mission.skygrid.dec.deg,
    s=10,
    c=partition_colors[partition],
    cmap="cool",
    transform=ax.get_transform("world"),
)
for region in mc_regions.regions:
    ax.add_patch(region.to_pixel(ax.wcs).as_artist())
    ax.grid()
plt.savefig("survey-blocks-magellanic-clouds.png")

Save survey grid to a data file.

In [None]:
fields = QTable(
    {
        "field_id": np.arange(len(mission.skygrid)),
        "block_id": partition,
        "target_coord": mission.skygrid,
    }
)
fields.write("fields.ecsv", overwrite=True)

In [None]:
lmc_coord, smc_coord = (SkyCoord.from_name(name) for name in ["LMC", "SMC"])
center = SkyCoord(lmc_coord.cartesian + smc_coord.cartesian)
radius = 20 * u.deg
width, _ = plt.rcParams["figure.figsize"]
fig = plt.figure(figsize=(width, width), dpi=300)
ax = fig.add_subplot(projection="astro zoom", center=center, radius=radius)
for key in ["ra", "dec"]:
    ax.coords[key].set_auto_axislabel(False)
    ax.coords[key].set_ticklabel_visible(False)
    ax.coords[key].set_ticks_visible(False)

backdrop = mellinger()
backdrop_wcs = WCS(backdrop.header).dropaxis(-1)
backdrop_reprojected = np.asarray(
    [reproject_adaptive((layer, backdrop_wcs), ax.header)[0] for layer in backdrop.data]
)
ax.imshow(
    make_rgb(*backdrop_reprojected, interval=PercentileInterval(98)), origin="lower"
)
for region in fov.footprint(
    inscribed_fov,
    mission.skygrid[mission.skygrid.separation(center) <= np.sqrt(2) * radius],
):
    ax.add_patch(region.to_pixel(ax.wcs).as_artist(color="white", alpha=0.25))

fig.savefig("magellanic-clouds.pdf")

## Survey Schedule

In [None]:
downlink_duration = 30 * u.min
observe_duration = 900 * u.s

Get the times and orientations for downlinks.

In [None]:
def get_downlinks(time: Time) -> tuple[SkyCoord, u.Quantity[u.physical.angle]]:
    """
    Get the target coordinates and roll for a downlink segment.

    Get the target coordinates and roll angle that point UVEX's high-gain
    antenna (HGA) at the Earth. This is accomplished by constructing a frame
    for pointing at the Earth while maintaining the nominal roll angle, and
    then pitching by 135°.

    Parameters
    ----------
    time:
        Time of the downlink.
    """
    observer_location = mission.observer_location(time)
    earth_coord = get_body("earth", time, observer_location)
    offset_frame = earth_coord.skyoffset_frame(
        nominal_roll(observer_location, earth_coord, time)
    )
    return SkyCoord(180 * u.deg, 45 * u.deg, frame=offset_frame).transform_to(
        earth_coord.frame
    )


downlink_cadence_days = 0.25
downlink_times = Time("2030-01-01") + np.arange(0, 102, downlink_cadence_days) * u.day
downlink_target_coords = get_downlinks(downlink_times)
downlinks = QTable(
    {
        "start_time": downlink_times,
        "target_coord": SkyCoord(
            downlink_target_coords.ra,
            downlink_target_coords.dec,
            representation_type=UnitSphericalRepresentation,
        ),  # get_downlinks(downlink_times),
        "observer_location": mission.observer_location(downlink_times),
    }
)
downlinks["action"] = "downlink"
downlinks["duration"] = downlink_duration
downlinks


Determine the blocks when each partition is fully observable.

In [None]:
block_times = downlink_times[:-1] + 0.5 * downlink_cadence_days * u.day
block_observer_locations = mission.observer_location(block_times)
observable = mission.constraints(
    block_observer_locations[np.newaxis, :],
    mission.skygrid[:, np.newaxis],
    block_times[np.newaxis, :],
)
partition_observable = np.asarray(
    [
        np.logical_and.reduce(observable[partition == i, :], axis=0)
        for i in range(partition.max() + 1)
    ]
)
partition_observable

Calculate weights for each of the blocks, equal to the total time required to reach a given depth in each field.

In [None]:
with observing(
    block_observer_locations[np.newaxis, :],
    mission.skygrid[:, np.newaxis],
    block_times[np.newaxis, :],
):
    exptime = mission.detector.get_exptime(
        5, synphot.SourceSpectrum(synphot.ConstFlux1D, amplitude=24.5 * u.ABmag), "FUV"
    ).to_value(u.s)
partition_exptime = np.asarray(
    [np.sum(exptime[partition == i, :], axis=0) for i in range(partition.max() + 1)]
)
partition_exptime = np.clip(partition_exptime, 0, 86400)
partition_exptime

Find the minimum weighted matching: map each partition to the lowest-background time to observe it.

In [None]:
graph = nx.Graph()
graph.add_weighted_edges_from(
    [
        (i, j + partition_observable.shape[0], partition_exptime[i, j])
        for i, j in zip(*np.nonzero(partition_observable))
    ]
)
matching = nx.bipartite.minimum_weight_full_matching(
    graph, np.arange(partition_observable.shape[0])
)
i = np.asarray(list(matching.keys()))
j = np.asarray(list(matching.values())) - partition_observable.shape[0]
keep = j >= 0
i = i[keep]
j = j[keep]
sort = np.argsort(j)
matching_j = j[sort]
matching_i = i[sort]

In [None]:
targets = QTable(
    {
        "field_id": fields["field_id"],
        "block_id": fields["block_id"],
        "target_coord": SkyCoord(
            fields["target_coord"].ra,
            fields["target_coord"].dec,
            representation_type=UnitSphericalRepresentation,
        ),
    }
)
targets["action"] = "observe"
targets["duration"] = observe_duration
targets

In [None]:
def plan_block(downlink, targets):
    slew_targets = vstack((downlink, targets))
    slew_targets["roll"] = nominal_roll(
        slew_targets["observer_location"][0],
        slew_targets["target_coord"],
        slew_targets["start_time"][0],
    )

    slew_time = mission.slew.time(
        slew_targets["target_coord"][:, np.newaxis],
        slew_targets["target_coord"][np.newaxis, :],
        slew_targets["roll"][:, np.newaxis],
        slew_targets["roll"][np.newaxis, :],
    )
    seq, _ = solve_tsp(slew_time.to_value(u.s), verbose=False)
    assert seq[0] == 0

    # Find optimal slew path
    slews = QTable({"duration": slew_time[seq[:-1], seq[1:]]})
    slews["action"] = "slew"

    # Interleave slews with observations
    slew_targets = slew_targets[seq[:-1]]
    slew_targets["i"] = np.arange(len(slew_targets))
    slews["i"] = np.arange(len(slews)) + 0.5
    plan = vstack((slew_targets, slews))
    plan.sort("i")
    del plan["i"]

    # Fill in observation times, observer locations
    plan["start_time"][1:] = plan["start_time"][0] + np.cumsum(plan["duration"][:-1])
    plan["observer_location"] = mission.observer_location(plan["start_time"])

    # Done!
    return plan


plan = vstack(
    [
        plan_block(downlinks[j : j + 1], targets[partition == i])
        for i, j in zip(tqdm(matching_i), matching_j)
    ]
)
plan

Sanity check: no actions overlap in time.

In [None]:
end_time = plan["start_time"] + plan["duration"]
(end_time[1:] - plan["start_time"][:-1]).min()

In [None]:
plan[
    "start_time",
    "duration",
    "observer_location",
    "action",
    "target_coord",
    "roll",
    "field_id",
    "block_id",
].write("initial-survey.ecsv", overwrite=True)

In [None]:
duration = plan["start_time"][-1] + plan["duration"][-1] - plan["start_time"][0]

observations = plan[plan["action"] == "observe"]

fig = plt.figure(dpi=300)
ax = fig.add_subplot(projection="astro mollweide")

frames = plan["start_time"][0] + np.linspace(0, 1, 10_000) * duration

observable = mission.constraints(
    mission.observer_location(frames)[:, np.newaxis],
    observations["target_coord"],
    frames[:, np.newaxis],
)

body_markers = [sun, earth, moon(-110)]
body_positions = [
    get_body(body, frames, mission.observer_location(frames))
    for body in ["sun", "earth", "moon"]
]

body_artists = [
    ax.plot(
        position[0].ra.deg,
        position[0].dec.deg,
        marker=marker,
        color="red",
        transform=ax.get_transform("world"),
    )[0]
    for marker, position in zip(body_markers, body_positions)
]

artist = ax.scatter(
    observations["target_coord"].ra.deg,
    observations["target_coord"].dec.deg,
    transform=ax.get_transform("world"),
    linewidths=0,
)

with tqdm(total=len(frames)) as progress:

    def animate(i):
        t = frames[i]
        keep = observations["start_time"] < t
        artist.set_sizes(keep * 20 + 1)
        artist.set_alpha(0.6 * (observable[i] | keep) + 0.4)
        for body_artist, position in zip(body_artists, body_positions):
            body_artist.set_data([position[i].ra.deg], [position[i].dec.deg])
        progress.update()
        return [artist, *body_artists]

    FuncAnimation(
        fig,
        animate,
        np.arange(len(frames)),
        blit=True,
        interval=25,
    ).save("initial-survey.mp4")