In [None]:
from scipy.optimize import least_squares
from datetime import datetime
from scipy.interpolate import interp1d
from astropy.io import fits
from typing import Any, Tuple
import casatools
import numpy as np
import pandas as pd
import xarray as xr

In [None]:
# @title ATM model
def get_tau(
    f_min: float,
    f_max: float,
    f_step: float,
    pwv: float,
    **atm_params: Any,
) -> Tuple[np.ndarray, np.ndarray]:
    """Compute of zenith opacities at given frequencies.

    Args:
        f_min: Minimum frequency (in units of Hz).
        f_max: Maximum frequency (in units of Hz).
        f_step: Frequency step (in units of Hz).
        pwv: Precipitable water vapor (in units of mm).
        atm_params: Parameters fo the ATM model.

    Returns:
        freq: Array of frequencies (in units of Hz).
        tau: Array of zenith opacities at the frequencies.

    """
    at = casatools.atmosphere()
    qa = casatools.quanta()

    f_cent = qa.quantity((f_min + f_max) / 2, "Hz")
    f_width = qa.quantity(f_max - f_min + f_step, "Hz")
    f_step = qa.quantity(f_step, "Hz")
    pwv = qa.quantity(pwv, "mm")

    at.initAtmProfile(**atm_params)
    at.initSpectralWindow(1, f_cent, f_width, f_step)
    at.setUserWH2O(pwv)

    freq = qa.convert(at.getSpectralWindow(), "Hz")["value"]
    tau = at.getDryOpacitySpec()[1] + at.getWetOpacitySpec()[1]['value']

    return freq, tau

In [None]:
# @title Main function

def ATM_fit(da, dbb, pwv0, dt, w, *ranges): # Obligatory inputs : DEMS (xarray.DataArray), DBB (.fits)
# Optional inputs : initial PWV value (float), time step (int), fitting weight (float), weighted frequency ranges (float)


    ## ERROR HANDLING ##

    if not isinstance(da, xr.DataArray):
        raise TypeError(f"Argument '{da}' must be an xarray DataArray, got {type(da).__name__} instead.")

    if not isinstance(dbb, fits.HDUList):
        raise TypeError(f"Argument '{dbb}' must be a .fits, got {type(dbb).__name__} instead.")

    for k,frange in enumerate(ranges):
        for arg in [frange[0], frange[1], pwv0, w]:
            if arg is not None and not isinstance(arg, (int, float)):
                raise TypeError(f"Argument '{arg}' must be a number, got {type(arg).__name__} instead.")
            if arg is not None and arg is not pwv0 and arg is not w and arg <= 0:
                raise ValueError(f"Argument '{arg}' must be a strictly positive number.")
            if (arg is pwv0 or arg is w) and arg is not None and arg < 0:
                raise ValueError(f"Argument '{arg}' must be a positive number or null.")

        if k != 0 and k!= len(ranges)-1:
            if frange[0] is None or frange[1] is None:
                raise ValueError(f"Both fmin and fmax must be specified.")

        if frange[0] is not None and frange[1] is not None:
            if frange[0] > frange[1]:
                raise ValueError(f"fmin must be less than or equal to fmax.")

    if dt is not None and dt >= len(da.time.values):
        print("Warning: dt is greater than time series length, which might not be intended.")

    if dt is not None and ( not isinstance(dt, int) or dt<=0 ) :
        raise TypeError(f"dt must be a positive integer, got {type(dt).__name__} instead.")


    ## CREATE ATM PARAMETERS ##

    def get_first_non_nan(arr, default):
        for value in arr:
            if not np.isnan(value):
                return value
        return default

    default_pressure = 570  # mbar
    default_humidity = 20  # percent
    default_temperature = 0  # C
    default_airmass = 1

    pressure_value = get_first_non_nan(da.pressure.values, default_pressure)
    humidity_value = get_first_non_nan(da.humidity.values, default_humidity)
    temperature_value = get_first_non_nan(da.temperature.values, default_temperature) + 273.15  # convert to Kelvin
    airmass_value = get_first_non_nan(da.secz.values, default_airmass)

    atm_params = {
        'atmType': 1,
        'humidity': 20,
        'temperature': f'{temperature_value:.0f} K',
        'altitude': '4800 m',
        'pressure': f'{pressure_value:.0f} mbar',
        'h0': '2.0 km',
    }


    ## LOAD DATA ##

    T_atm = temperature_value
    airmass = airmass_value
    freq = da.d2_mkid_frequency.values
    freq_unsorted = freq

    kidfilt = dbb["KIDFILT"] # fits file
    masterid = kidfilt.data["masterid"] # channel IDs of filter response functions (more than in DEMS)
    nu = kidfilt.data["Raw Toptica F"] # corresponding frequencies (same for all IDs normally)
    R = kidfilt.data["Raw df resp."] # corresponding responses (y data)

    common_indices = np.where(np.isin(masterid, da.chan.values))[0] # find common channels
    nu = nu[common_indices]
    R = R[common_indices]

    sorted_indices = np.argsort(freq) # sort frequencies (just for simplicity in the following lines)
    freq = freq[sorted_indices]
    nu = nu[sorted_indices]
    R = R[sorted_indices]

    for range in ranges: # check given ranges are subsets of global range
        fmin = range[0]
        fmax = range[1]
        if (fmin is not None and fmin < freq[0]) :
            raise ValueError(f"fmin must be within the frequency range.")

        if (fmax is not None and fmax > freq[-1]) :
            raise ValueError(f"fmax must be within the frequency range.")

    N = len(freq)-1
    m = freq[0]
    M = freq[-1]
    step = (freq[-1] - freq[0])/N # we want a step as close to the data as possible

    # set default values to arguments
    if pwv0 is None:
        init_PWV = 2.0
    else :
        init_PWV = pwv0
    if dt is None:
        time_step = 1
    else :
        time_step = dt
    if w is None:
        W = 1
    else :
        W = w


    ## FITTING FUNCTION ##

    def Tb_fit(t,freq):


        ## LOAD DATA AND HANDLE NANS ##

        Tb = da[t,:].values
        Tb= Tb[sorted_indices]
        valid_indices = np.where(~np.isnan(Tb))
        Tb_valid = Tb[valid_indices]
        freq_valid = freq[valid_indices]

        # create an empty array with the same shape as Tb : useful for ouput (we want to keep NaNs to match input)
        empty_Tb = np.empty_like(Tb, dtype = 'float')
        empty_Tb[np.isnan(Tb)] = np.nan

        # R is an array of response functions (themselves arrays with x values being frequencies)
        # so we iterate over the responses and interpolate to match frequencies
        def interpolate_responses(frequencies, responses, freq_valid):
            return np.array([interp1d(freq, resp, kind='linear', fill_value="extrapolate")(freq_valid) for freq, resp in zip(frequencies, responses)])

        filter_responses = interpolate_responses(nu[valid_indices], R[valid_indices], freq_valid)


        ## CONVERT OPACITY GIVEN BY MODEL TO SKY BRIGHTNESS TEMPERATURE ##

        def tau_to_T(pwv):
            temp, tau = get_tau(m*1e9, M*1e9, step*1e9, pwv, **atm_params)
            tau = interp1d(temp[valid_indices], tau[valid_indices], kind='linear', bounds_error=False, fill_value="extrapolate")(freq_valid*1e9)
            T_model = T_atm * (1 - np.exp(-tau*airmass))

            norm = np.sum(filter_responses, axis=0)
            filtered_T_model = np.sum([T_model * resp for resp in filter_responses], axis=0) / norm

            return filtered_T_model


        ## RESIDUAL FUNCTION ##

        def diff(pwv, w):
            return abs(Tb_valid - tau_to_T(pwv)) * w

        weights = np.ones_like(Tb_valid)
        w_indices = []

        for r in ranges:
            if r[0] is None:
                r[0] = freq_valid[0]
            if r[1] is None:
                r[1] = freq_valid[-1]
            mask = (freq_valid >= r[0]) & (freq_valid <= r[1])
            w_indices.append(np.where(mask))

        weights[w_indices] = W


        ## LEAST SQUARES FIT ##

        result = least_squares(diff,init_PWV,loss='huber', kwargs={'w': weights})
        optim_PWV = result.x[0]

        Tmodel = tau_to_T(optim_PWV)

        # Re-insert NaN positions
        k=0
        for i,el in enumerate(empty_Tb):
              if not np.isnan(el):
                  empty_Tb[i] = Tmodel[k]
                  k+=1

        unsorted_indices = np.argsort(sorted_indices)
        empty_Tb = empty_Tb[unsorted_indices] # unsort indices to match input data array?

        return empty_Tb, optim_PWV # return unsorted and invalid (with NaNs) Tb_model and optimized PWV


    ## LOOP OVER TIME ##

    time_values=pd.Series(da.time.values[0:len(da.time):time_step])

    Tb_model_array = []
    optim_PWV_array = np.array([])

    for time_str in time_values:
        t = np.where(da.time.values == time_str)[0][0]
        T, PWV = Tb_fit(t,freq)
        Tb_model_array.append(T)
        optim_PWV_array = np.append(optim_PWV_array,PWV)


    ## CREATE THE OUTPUT DATA ARRAY ##

    new_da = da
    new_da.values = np.array(Tb_model_array)
    new_da.assign_coords(PWV=("time",optim_PWV_array))
    new_da['PWV'].attrs['long name'] = 'Precipitable Water Vapor derived from ATM model'
    new_da['PWV'].attrs['units'] = 'mm'

    return new_da