In [None]:
%config InlineBackend.figure_format='retina'

from functools import partial
from json import load
from multiprocessing import Pool
from pathlib import Path

import logging
logging.basicConfig(level=logging.WARNING)

import matplotlib.pyplot as plt
import hvplot.xarray
import xarray as xr

from process_cris import find_moon_intrusions, cris_wavenumbers

N_CPUS = 16

In [None]:
def get_scans_with_maximum(df, fieldofregard=0):
    """Return list of spectra with maximum in field of regard"""
    return [
        df[f"Rad_{fieldofregard}_{maxind.values[0]}_{maxind.values[1]}"]
        for maxind in df[f"rad_for{fieldofregard}_maxind"]
    ]

# Load CrIS file list

In [None]:
with open('cris_picks.json', 'r') as infile:
    filelist = load(infile)

# Find moon intrusions

In [None]:
find_moon_intrusions_partial = partial(find_moon_intrusions,
                                       wavelen_id=99,
                                       threshold=20)
files = [
    f"../SDR_data_npp/{f}/Lunar_fsr_spectra_{f}_npp.nc" for f in filelist
]
with Pool(N_CPUS) as pool:
    results = pool.map(find_moon_intrusions_partial, files)

#%%
valid_results = [
    r for r in results
    if r is not None and "rad_for0_maxind" in r and "rad_for1_maxind" in r
]

no_results = [filelist[i] for i, r in enumerate(results) if r is None]
print(f"Files without results: {len(no_results)} out of {len(filelist)}")

#%%
max_scans = [
    get_scans_with_maximum(df, fieldofregard=0) for df in valid_results
]
max_scans += [
    get_scans_with_maximum(df, fieldofregard=1) for df in valid_results
]
max_scans = sum(max_scans, [])
combined_scans = xr.concat(max_scans, dim="scans").rename("combined_scans")
combined_scans.attrs[
    "description"] = "Combination of all scans that fulfill selection criteria"
wn = 10000 / cris_wavenumbers()
combined_scans = combined_scans.assign_coords(n_Channels=wn)
combined_scans.n_Channels.attrs["units"] = "µm"

# Plot mean spectra

In [None]:
fig, ax = plt.subplots(figsize=(12,8))
combined_scans.mean(dim="scans").sel(n_Channels=slice(9.0, 7.0)).plot(ax=ax)

# Point plot to visualize gaps in the frequency grid
# combined_scans.mean(dim="scans").sel(n_Channels=slice(10.0, 7.0)).plot(
#     ax=ax, marker='.', markersize=0.75, ls='None')

fig.savefig("mean_scans.pdf")
#combines_scans.to_netcdf("combined_scans.nc")

In [None]:
combined_scans.mean(dim="scans").hvplot(width=1000, height=500)