# **Validation Suite**

## Purpose and Setup
This notebook will allow you to test the output from WRF-GHG against surface, upper air, and satelite observations.

The next cell imports modules needed to work properly

In [None]:
import os
import wrf
import numpy as np
import pandas as pd
from datetime import datetime as dt
from netCDF4 import Dataset as Ds #type: ignore
from termcolor import cprint
import pytz
from numpy import unravel_index
import collections.abc as c
import numpy.typing as npt
import metpy as mp
from metpy.units import units
from metpy.calc import wind_components
from xarray import DataArray, Dataset
import xarray as xr
import netCDF4
from enum import IntFlag, auto
import matplotlib.pyplot as plt
units.define('ppb = 1e-9')
units.define('@alias ppm = ppmv')

### *Bit-flags for Validation failure*

This class is used for setting a bit-flag to comunicate which variable didn't validate.

In [None]:
class Bitflag(IntFlag):
    PRES = auto()
    TEMP = auto()
    DEWP = auto()
    U = auto()
    V = auto()
    T2 = auto()
    TD2 = auto()
    SLP = auto()
    U10 = auto()
    V10 = auto()
    XCO2 = auto()
    XCO = auto()
    XCH4 = auto()

### *Base class*

This class is used for all the following classes in order to give some basic location data to each class.

In [None]:
class Base_point:
    '''
    Base_point: parent class for validation. Sets location name.
    '''
    def __init__(self, loc: str, **kwargs) -> None:
        self.loc: str = loc
    @staticmethod
    def _rmse(x:npt.ArrayLike, y:npt.ArrayLike) -> float:
        error = x - y
        sqr_error = error**2
        mse = np.mean(sqr_error)
        rmse = mse**(1/2)
        return rmse
    @staticmethod
    def _mean_bias(x:npt.ArrayLike, y:npt.ArrayLike) -> float:
        return np.mean(x - y)

### *Upper air (UA) class*

This class is used to set the attributes for any data objects that cares about upper air data. This is usually the WRF data that you're testing or the radiosonde data you use for validation. It inherits from the base class.

In [None]:
class UA_point(Base_point):
    def __init__(self, loc: str, **kwargs) -> None:
        super().__init__(loc, **kwargs)
    def __eq__(self, other: object) -> bool:
        try:
            results: npt.NDArray[np.bool_] = np.empty(5, np.bool_)
            results[0] = self._rmse(self.p, other.p) #! add validation limit here
            results[1] = self._rmse(self.t, other.t) #! add validation limit here
            results[2] = self._rmse(self.td, other.td) #! add validation limit here
            results[3] = self._rmse(self.wdir, other.wdir) #! add validation limit here
            results[4] = self._rmse(self.wspd, other.wspd) #! add validation limit here
            result: bool = bool(results.all())
        except AttributeError:
            if isinstance(self, WRF_point):
                result: bool = other.__eq__(self)
            else:
                raise NotImplementedError('Compairison not implemented')
        finally:
            global flags
            flags = 0
            if not result:
                for i, test in enumerate(results):
                    if not test:
                        match i:
                            case 0:
                                flags = flags | Bitflag.PRES
                            case 1:
                                flags = flags | Bitflag.TEMP
                            case 2:
                                flags = flags | Bitflag.DEWP
                            case 3:
                                flags = flags | Bitflag.U
                            case 4:
                                flags = flags | Bitflag.V
                            case _:
                                raise RuntimeError(f'Invalid flag set: {i}')
            return result

### *Satelite (Sat) class*

This class is used to set the attributes for any data objects that cares about satelite data. This is usually the WRF data or TROPOMI or OCO-2 data. It inherits from the base class. **Note**: The respective classes for TROPOMI and OCO-2 are still in development (9-9-2024)

In [None]:
class Sat_point(Base_point):
    def __init__(self, loc: str, **kwargs) -> None:
        super().__init__(loc, **kwargs)
    def sat_loc(self, ulat: float, ulon: float, lats: c.Iterable[float], lons: c.Iterable[float]) -> tuple[int, int] | int:
        R: int = 6371000
        lat1: npt.NDArray[np.float32] | float = np.radians(ulat)
        lat2: npt.NDArray[np.float32] = np.radians(lats)
        delta_lat: npt.NDArray[np.float32] = np.radians(lats-ulat)
        delta_lon: npt.NDArray[np.float32] = np.radians(lons-ulon)
        a: npt.NDArray[np.float32] = (np.sin(delta_lat/2))*(np.sin(delta_lat/2))+(np.cos(lat1))*(np.cos(lat2))*(np.sin(delta_lon/2))*(np.sin(delta_lon/2))
        c: npt.NDArray[np.float32] = 2*np.arctan2(np.sqrt(a),np.sqrt(1-a))
        d: npt.NDArray[np.float32] = R*c
        if d.ndim == 1:
            return d.argmin()
        else:
            x: int = 0
            y: int = 0
            x, y = unravel_index(d.argmin(),d.shape)
            return x,y
    def __eq__(self, other: object) -> bool:
        result: bool | None = None
        try:
            if isinstance(self, Tropomi_point) or (isinstance(self, WRF_point) and isinstance(other, Tropomi_point)):
                results: npt.NDArray[np.bool_] = np.empty(2, np.bool_)
                results[0] = self._rmse(self.xch4, other.xch4) <= 1.e-1
                results[1] = self._rmse(self.xco, other.xco) <= 1.e-1
                result = bool(results.all())
            #! elif for OCO-2 here
        except AttributeError:
            if isinstance(self, WRF_point):
                result = other.__eq__(self)
            else:
                raise NotImplementedError('Compairison not implemented')
        else:
            if result is None:
                if isinstance(self,WRF_point):
                    result = other.__eq__(self)
                else:
                    raise NotImplementedError('Compairison not implemented')
            else:
                pass
        finally:
            global flags
            flags = 0
            if not result:
                for i, test in enumerate(results):
                    if not test:
                        if isinstance(self, Tropomi_point) or (isinstance(self, WRF_point) and isinstance(other, Tropomi_point)): 
                            match i:
                                case 0:
                                    flags = flags | Bitflag.XCH4
                                case 1:
                                    flags = flags | Bitflag.XCO
                                case _:
                                    raise RuntimeError(f'Invalid flag set: {i}')
                        else:
                            match i:
                                case 0:
                                    flags = flags | Bitflag.XCO2
                                case _:
                                    raise RuntimeError(f'Invalid flag set: {i}')
            return result

### *Surface class*

This class is used to set the attirbutes for any data object that cares about surface data. This is usually the WRF data or the ASOS observation data. It inherits from the base class.

In [None]:
class Surface_point(Base_point):
    def __init__(self, loc: str, **kwargs) -> None:
        super().__init__(loc, **kwargs)
    def __eq__(self, other: object) -> bool:
        global flags
        if isinstance(self, Obs_point) and self.met:
            try:
                results: npt.NDArray[np.bool_] = np.empty(5, np.bool_)
                results[0] = self._rmse(self.T2, other.T2) <= 5. * units.kelvin
                results[1] = self._rmse(self.td2, other.td2) <= 5. * units.kelvin
                results[2] = self._rmse(self.slp, other.slp) <= 2.0 * units.hPa
                results[3] = self._rmse(self.u10, other.u10) <= 2.24 * units('m/s')
                results[4] = self._rmse(self.v10, other.v10) <= 2.24 * units('m/s')
                result: bool = bool(results.all())
            except AttributeError:
                if isinstance(self, WRF_point):
                    result: bool = other.__eq__(self)
                else:
                    raise NotImplementedError('Comparison not implemented')
            finally:
                flags = 0
                if not result:
                    for i, test in enumerate(results):
                        if not test:
                            match i:
                                case 0:
                                    flags = flags | Bitflag.T2
                                case 1:
                                    flags = flags | Bitflag.TD2
                                case 2:
                                    flags = flags | Bitflag.SLP
                                case 3:
                                    flags = flags | Bitflag.U10
                                case 4:
                                    flags = flags | Bitflag.V10
                                case _:
                                    raise RuntimeError(f'Invalid flag set: {i}')
                return result
        if isinstance(self, Obs_point) and self.chem:
            try:
                results: npt.NDArray[np.bool_] = np.empty(2, np.bool_)
                results[0] = self._rmse(self.ch4, other.ch4[:,0]) <= 10. * units.ppb
                results[1] = self._rmse(self.co2, other.co2[:,0]) <= 10. * units.ppm #! Check values!!!
                #results[2] = self._rmse(self.xco, other.xco) <= 100. * units.ppb
                result: bool = bool(results.all())
            except AttributeError:
                if isinstance(self, WRF_point):
                    result: bool = other.__eq__(self)
                else:
                    raise NotImplementedError('Comparison not implemented')
            finally:
                flags = 0
                if not result:
                    for i, test in enumerate(results):
                        if not test:
                            match i:
                                case 0:
                                    flags = flags | Bitflag.XCH4
                                case 1:
                                    flags = flags | Bitflag.XCO2
                                #case 2:
                                 #   flags = flags | Bitflag.XCO
                                case _:
                                    raise RuntimeError(f'Invalid flag set: {i}')
                return result


### *WRF class*

This class ingests the wrfout data that you need to run the validation suite. Since it's what we're testing, it inherits from UA, Sat, and Surface classes.

In [None]:
class WRF_point(Surface_point,Sat_point,UA_point):
    '''
    WRF_point: sets up validation point using WRF data. Reads T2, TD2, SLP, and 10m Wind speed/direction. Inherits from Base_point.
    '''
    def __init__(self, wrffile: Dataset, lat: float, lon: float, loc: str, chem: bool | None = None, **kwargs) -> None:
        super().__init__(loc,**kwargs)
        wrf.omp_set_num_threads(wrf.omp_get_num_procs())
        self.lat: float = lat
        self.lon: float = lon
        self.x: float | npt.NDArray | DataArray
        self.y: float | npt.NDArray | DataArray
        self.x, self.y = wrf.ll_to_xy(wrffile,self.lat,self.lon)
        self.vars: list[str] = ['T2', 'td2', 'slp','uvmet10','p','temp','td','uvmet']
        for var in self.vars:
            if var == 'T2':
                try:
                    self.T2: DataArray = wrf.getvar(wrffile, var, wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
                except MemoryError:
                    self.T2: DataArray = self.memory_loop(wrffile, var)
            elif var == 'td2':
                try:
                    self.td2: DataArray = wrf.getvar(wrffile, var, wrf.ALL_TIMES, units='K')[..., self.y, self.x].metpy.quantify()
                except MemoryError:
                    self.td2: DataArray = wrf._memory_loop(wrffile, var, 'K')
            elif var == 'slp':
                try:
                    self.slp: DataArray = wrf.getvar(wrffile, var, wrf.ALL_TIMES, units='hPa')[..., self.y, self.x].metpy.quantify()
                except MemoryError:
                    self.slp: DataArray = self._memory_loop(wrffile, var, 'hPa')
            elif var == 'uvmet10':
                try:
                    self.u10, self.v10 = wrf.getvar(wrffile, var, wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify() #type: tuple[DataArray, DataArray]
                except MemoryError:
                    self.u10, self.v10 = self._memory_loop(wrffile, var) #type: tuple[DataArray, DataArray]
            elif var == 'p':
                try:
                    self.p: DataArray = wrf.getvar(wrffile, var, wrf.ALL_TIMES, units='hPa')[..., self.y, self.x].metpy.quantify()
                except MemoryError:
                    self.p: DataArray = self._memory_loop(wrffile, var, 'hPa')
            elif var == 'temp':
                try:
                    self.t: DataArray = wrf.getvar(wrffile, var, wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
                except MemoryError:
                    self.t: DataArray = self._memory_loop(wrffile, var)
            elif var == 'td':
                try:
                    self.td: DataArray = wrf.getvar(wrffile, var, wrf.ALL_TIMES, units='K')[..., self.y, self.x].metpy.quantify()
                except MemoryError:
                    self.td: DataArray = self._memory_loop(wrffile, var, 'K')
            elif var == 'uvmet':
                try:
                    self.u, self.v = wrf.getvar(wrffile, var, wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify() #type: tuple[DataArray, DataArray]
                except MemoryError:
                    self.u, self.v = self._memory_loop(wrffile, var) #type: tuple[DataArray, DataArray]
        ## ! Next if/elif block needs to be edited depending on WRF-GHG output structure (include converting units) ! ## 
        try:
            self.sfc_pres: DataArray = wrf.getvar(wrffile, 'PSFC',wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify() #! may need to calculate or may be in WRF output
        except MemoryError:
            self.sfc_pres: DataArray = self._memory_loop(wrffile, 'PSFC') #! may need to calculate or may be in WRF output
        if chem:
            self.xch4: npt.ArrayLike = self._extract_ghg(wrffile, 'xch4', lat=lat, lon=lon, loc=loc)
           # self.xco: npt.ArrayLike = self._extract_ghg(wrffile, 'xco')
            self.xco2: npt.ArrayLike = self._extract_ghg(wrffile, 'xco2', lat=lat, lon=lon, loc=loc)
            ch4_ant = wrf.getvar(wrffile, 'CH4_ANT', wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
            ch4_bio = wrf.getvar(wrffile, 'CH4_BIO', wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
            ch4_bck = wrf.getvar(wrffile, 'CH4_BCK', wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
            self.ch4 = ch4_ant + ch4_bio - ch4_bck
            co2_ant = wrf.getvar(wrffile, 'CO2_ANT', wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
            co2_bio = wrf.getvar(wrffile, 'CO2_BIO', wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
            co2_bck = wrf.getvar(wrffile, 'CO2_BCK', wrf.ALL_TIMES)[..., self.y, self.x].metpy.quantify()
            self.co2 = co2_ant + co2_bio - co2_bck
    @staticmethod
    def _extract_ghg(wrffile: Dataset, chem: str,*,lat=None,lon=None,loc=None) -> npt.ArrayLike:
        try:
            p = self.p
            x = self.x
            y = self.y
            sfc_pres = self.sfc_pres
        except NameError:
            assert lat
            assert lon
            assert loc
            data = WRF_point(wrffile,lat,lon,loc)
            p = data.p
            x = data.x
            y = data.y
            sfc_pres = data.sfc_pres
            del data
        if chem == 'xch4':
            _ant: DataArray = wrf.getvar(wrffile, 'CH4_ANT', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
            _bck: DataArray = wrf.getvar(wrffile, 'CH4_BCK', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
            _tst: DataArray = wrf.getvar(wrffile, 'CH4_BIO', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
        elif chem == 'xco2':
            _ant: DataArray = wrf.getvar(wrffile, 'CO2_ANT', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
            _bck: DataArray = wrf.getvar(wrffile, 'CO2_BCK', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
            _tst: DataArray = wrf.getvar(wrffile, 'CO2_BIO', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
        #elif chem == 'xco':
         #   _ant: DataArray = wrf.getvar(wrffile, 'CO_ANT', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
          #  _bck: DataArray = wrf.getvar(wrffile, 'CO_BCK', wrf.ALL_TIMES)[..., y, x].metpy.quantify()
            #_tst = wrffile['CO_BIO'][0, :, y, x] ##?
           # _tst: npt.NDArray = np.zeros_like(_bck)
        _ghg = _tst + _ant -_bck
        if ((len(_ghg) == len(p)) and (_ghg.ndim == 1)):
            pres_bound: DataArray = xr.DataArray(np.empty_like(p), p.coords, p.dims)
            for i, pres in enumerate(p):
                if i == 0:
                    pres_bound[i] = sfc_pres
                    pres_bound[i+1] = pres_bound[i] + (2*(pres-pres_bound[i]))
                else:
                    try:
                        pres_bound[i+1] = pres_bound[i] + (2*(pres-pres_bound[i]))
                    except IndexError:
                        pass
            p_layer_diff: npt.NDArray = np.array([pres_bound[i]-pres_bound[i-1] for i in range(1,len(pres_bound))]) #! This may need a value at beginning for xch4[0]
            p_diff: float = pres_bound[0] - pres_bound[-1]
            return (np.sum(_ghg*p_layer_diff)/p_diff)
        elif ((len(_ghg[0,:]) == len(p[0,:])) and (_ghg.ndim == 2)):
            pres_bound: DataArray = xr.DataArray(np.empty_like(p), p.coords, p.dims) *units.hPa
            pres_bound[:,0] = sfc_pres
            for j in range(p.shape[0]):
                for i, pres in enumerate(p[j,:]):
                    try:
                        pres_bound[j,i+1] = pres_bound[j,i] + (2*(pres-pres_bound[j,i]))
                    except IndexError:
                        pass
            p_layer_diff = xr.DataArray([pres_bound[:,i-1]-pres_bound[:,i] for i in range(1,pres_bound.shape[1])],coords=p.T.coords) * units.hPa
            p_layer_diff = np.insert(p_layer_diff,[0], sfc_pres-p[:,0],axis=0).T * units.hPa#! This may need a value at beginning for xch4[0]
            p_diff = pres_bound[:,0] - pres_bound[:,-1]
            return (np.sum(_ghg*p_layer_diff,axis=1)/p_diff)
    @staticmethod
    def _memory_loop(wrffile, var, x=None, y=None, units=None):
        length = len(wrffile)
        if x is None:
            x = self.x
        if y is None:
            y = self.y
        if length % 24 == 0:
            loops = length // 24
            if loops == 1:
                loops = 3
            elif loops == 24:
                loops = 48
            loop_seg = length // loops
            if var not in ['uvmet10','uvmet']:
                var_array = []
                for n in range(0,loops):
                    if units is not None:
                        var_array.append(wrf.getvar(wrffile[n*loop_seg:(n+1)*loop_seg],var,wrf.ALL_TIMES,units=units)[...,y,x].metpy.quantify())
                    else:
                        var_array.append(wrf.getvar(wrffile[n*loop_seg:(n+1)*loop_seg],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify())
                return xr.concat(var_array, dims='Time')
            else:
                u_array = []
                v_array = []
                for n in range(0,loops):
                    if units is not None:
                        u_temp, v_temp = wrf.getvar(wrffile[n*loop_seg:(n+1)*loop_seg],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify()
                    else:
                        u_temp, v_temp = wrf.getvar(wrffile[n*loop_seg:(n+1)*loop_seg],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify()
                    u_array.append(u_temp)
                    v_array.append(v_vemp)
                    del u_temp, v_temp
                return xr.concat(u_array, dims='Time'), xr.concat(v_array, dims='Time')
        elif (length - 1) % 24 == 0:
            loops = (length-1) // 24
            if loops == 1:
                loops = 3
            elif loops == 24:
                loops = 48
            loop_seg = (length-1) // loops
            if var not in ['uvmet10','uvmet']:
                var_array = []
                for n in range(0,loops):
                    if units is not None:
                        var_array.append(wrf.getvar(wrffile[(n*loop_seg):((n+1)*loop_seg)],var,wrf.ALL_TIMES,units=units)[...,y,x].metpy.quantify())
                    else:
                        var_array.append(wrf.getvar(wrffile[n*loop_seg:(n+1)*loop_seg],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify())
                if units is not None:
                    var_array.append(wrf.getvar(wrffile[loops*loop_seg:],var,wrf.ALL_TIMES,units=units)[...,y,x].metpy.quantify())
                else:
                    var_array.append(wrf.getvar(wrffile[loops*loop_seg:],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify())
                return xr.concat(var_array, dims='Time')
            else:
                u_array = []
                v_array = []
                for n in range(0,loops):
                    if units is not None:
                        u_temp, v_temp = wrf.getvar(wrffile[n*loop_seg:(n+1)*loop_seg],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify()
                    else:
                        u_temp, v_temp = wrf.getvar(wrffile[n*loop_seg:(n+1)*loop_seg],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify()
                    u_array.append(u_temp)
                    v_array.append(v_vemp)
                    del u_temp, v_temp
                if units is not None:
                    u_temp, v_temp = wrf.getvar(wrffile[loops*loop_seg:],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify()
                else:
                    u_temp, v_temp = wrf.getvar(wrffile[loops*loop_seg:],var,wrf.ALL_TIMES)[...,y,x].metpy.quantify()
                u_array.append(u_temp)
                v_array.append(v_vemp)
                del u_temp, v_temp
                return xr.concat(u_array, dims='Time'), xr.concat(v_array, dims='Time')
        else:
            raise RuntimeError(f'loop length not defined ({length = })')
    def __str__(self) -> str:
        return f'{self.loc} WRF Point has a temperature of {self.T2} K, a dewpoint of {self.td2} K, a slp of {self.slp} hPa, and the wind is {self.wspd10} m s^-1 at {self.wdir10} degrees.'

### *Obs_point Class*

This class ingests surface observation data from ASOS stations to test against WRF data. This inherits from Surface point.

In [None]:
class Obs_point(Surface_point):
    '''
    Obs_point: sets up validation point for observation. Reads T2, TD2, slp, and wind speed/direction from ASOS data. Inherits from Base_point.
    '''
    def __init__(self, loc: str, obsfile: str, obstime: dt | list[dt], *, met=True, chem=False, **kwargs) -> None:
        super().__init__(loc, **kwargs)
        self.met = met
        self.chem = chem
        utc = pytz.utc
        if isinstance(obstime, list):
            obstime = [utc.localize(date) for date in obstime]
        else:
            obstime = utc.localize(obstime) 
        if self.met:
            data: pd.DataFrame = pd.read_csv(obsfile,na_values='M',parse_dates=['valid'],date_format='%Y-%m-%d %H:%M')
            data.valid = pd.to_datetime(data.valid).dt.tz_localize(utc)
            if loc in ['JFK Airport','JFK','NYC/JFK', 'LaGuardia Airport', 'LGA', 'NYC/LGA', 'Central Park','NYC']:
                delta = 540
            else:
                delta = 420
            idx = []
            if not isinstance(obstime, list):
                for i, obs in enumerate(data.valid):
                    if abs(obs.timestamp() - obstime.timestamp()) == delta:
                        idx.append(i)
                        break
            else:
                for time in obstime:
                    for i, obs in enumerate(data.valid):
                        if abs(obs.timestamp() - time.timestamp()) == delta:
                            idx.append(i)
                            break
            #print(idx)
            if isinstance(obstime,list):
                obstime = [date.replace(tzinfo=None) for date in obstime]
            else:
                obstime = obstime.replace(tzinfo=None)
            ds = Dataset.from_dataframe(data.iloc[idx])
            ds = ds.rename({'index':'Time'}).assign_coords({'Time':obstime})
            del data
            self.T2: DataArray = (ds.tmpc * units.degC).metpy.convert_to_base_units()
            self.td2: DataArray = (ds.dwpc * units.degC).metpy.convert_to_base_units()
            self.slp: DataArray = ds.mslp * units.hPa
            self.u10, self.v10 = wind_components((ds.sped * 0.44704) * units('m/s'), ds.drct * units.deg)
            del ds
        if self.chem:
            cols = ['UTC_time','CH4_ppb','CO2_ppm']
            data = pd.read_csv(obsfile, na_values='NA', skiprows=6, usecols=cols,parse_dates=['UTC_time'],date_format='%Y-%m-%d %H:%M')
            data.UTC_time = pd.to_datetime(data.UTC_time).dt.tz_localize(utc)
            idx = []
            if not isinstance(obstime, list):
                for i, obs in enumerate(data.UTC_time):
                    if obs.timestamp() == obstime.timestamp():
                        idx.append(i)
                        break
            else:
                for time in obstime:
                    for i, obs in enumerate(data.UTC_time):
                        if obs.timestamp() == time.timestamp():
                            idx.append(i)
                            break
            #print(idx)
            if isinstance(obstime,list):
                obstime = [date.replace(tzinfo=None) for date in obstime]
            else:
                obstime = obstime.replace(tzinfo=None)
            ds = Dataset.from_dataframe(data.iloc[idx])
            ds = ds.rename({'index':'Time'}).assign_coords({'Time':obstime})
            del data
            self.ch4 = (ds.CH4_ppb * units.ppb).metpy.convert_units('ppm')
            self.co2 = ds.CO2_ppm * units.ppm
            #self.xco = (ds.CO_ppb * units.ppb).metpy.convert_units('ppm')
            del ds
    def __str__(self) -> str:
        return f'{self.loc} Observation Point has a temperature of {self.T2} K, a dewpoint of {self.td2} K, a slp of {self.slp} hPa, and the wind is {self.wspd10} m s^-1 at {self.wdir10} degrees.'

### *UObs_point Class*

This class ingest sounding data to test against WRF data. This inherits from UA point.

In [None]:
class UObs_point(UA_point):
    def __init__(self, loc: str, ua_file: str, wrffile: netCDF4._netCDF4.Dataset, lat: float, lon: float, **kwargs) -> None:
        super().__init__(loc, **kwargs)
        x, y = wrf.ll_to_xy(wrffile, lat, lon)#type: tuple(float, float)
        wrf_p: DataArray = wrf.getvar(wrffile, 'p', units='hPa')[:, y, x]
        data: pd.DataFrame = pd.read_csv(ua_file, na_values='M', parse_dates=['validUTC'], date_format='%Y-%m-%d %H:%M')
        p: npt.NDArray = data.pressure_mb.to_numpy()
        idx: npt.NDArray = np.digitize(wrf_p, p)
        data = data.iloc[idx]
        self.p: npt.NDArray = data.pressure_mb.to_numpy() #* mb == hPa
        self.t: npt.NDArray = data.tmpc.to_numpy() + 273.15 #* deg C -> K
        self.td: npt.NDArray = data.dwpc.to_numpy() + 273.15 #* deg C -> K
        self.wdir: npt.NDArray = data.drct.to_numpy() #* deg
        self.wspd: npt.NDArray = data.speed_kts.to_numpy() * 0.514444 #* kt -> m s^-1
        del data, p, idx, wrf_p, x, y

### *Tropomi_point class*

This class ingests TROPOMI satelite data to test against WRF data. This inherits from Sat point.

In [None]:
class Tropomi_point(Sat_point): #! Find some way to do both xch4 and xco, will probably need two files
    def __init__(self, loc: str, xch4_f: str, xco_f: str, ulat: float, ulon: float, **kwargs) -> None:
        super().__init__(loc, **kwargs)
        self.xch4: float = self._ghg(xch4_f, 'xch4', ulat, ulon) 
        self.xco: float = self._ghg(xco_f, 'xco', ulat, ulon)
    def _ghg(self, tropomi_f: str, chem: str, ulat: float, ulon: float) -> float:
        ds: netCDF4._netCDF4.Dataset = Dataset(tropomi_f, 'r')
        grp: str = 'PRODUCT'
        if chem in ['ch4','CH4','xch4','XCH4','methane','Methane','METHANE']: #!! find a way to do both
            self.xch4_sds = sds = 'methane_mixing_ratio' 
            self.xch4_unit: str = 'ppb'
        elif chem in ['co','CO','xco','XCO','carbon monoxide','Carbon Monoxide', 'CARBON MONOXIDE']:
            self.xco_sds = sds = 'carbonmonoxide_total_column_corrected' # ! is this right?
            self.xco_unit: str = 'ppb' # * will have to convert for co
        lats: netCDF4._netCDF4.Variable = ds.groups[grp].variables['latitude'][0][:][:]
        lons: netCDF4._netCDF4.Variable = ds.groups[grp].variables['longitude'][0][:][:]
        qas: npt.NDArray = np.array(ds.groups[grp].variables['qa_value'][0][:][:])
        data: netCDF4._netCDF4.Variable = ds.groups[grp].variables[sds] #units: ppb (ch4) mol m^-2 --> ppb (co)
        fv: float | netCDF4._netCDF4.Variable = data._FillValue
        dA: npt.NDArray = np.array(data[0][:][:])
        dA[(dA==fv) & (qas<=0.5)] = np.nan #? Do I want to filter for qa here or later?
        if chem in ['co','CO','xco','XCO','carbon monoxide','Carbon Monoxide', 'CARBON MONOXIDE']:
            dA = dA * 28.01 * 0.0001 * 1000 # converts (mol m^-2) * (g mol^-1) * (m^-1) == g m^-3 == ppm --> ppb
        #TODO: check about averging kernel
        min_lat: float | npt.NDArray = np.min(lats)
        max_lat: float | npt.NDArray = np.max(lats)
        min_lon: float | npt.NDArray = np.min(lons)
        max_lon: float | npt.NDArray = np.max(lons)
        if not min_lat <= ulat <= max_lat:
            raise RuntimeError(f'User Latitude is not within TROPOMI file. {ulat}, {min_lat}, {max_lat}')
        if not min_lon <= ulon <= max_lon:
            raise RuntimeError(f'User Longitude is not within TROPOMI file. {ulon}, {min_lon}, {max_lon}')
        x, y = self.sat_loc(ulat, ulon, lats, lons) #type: tuple[int, int]
        if np.isnan(dA[x,y]): #? qa check here?
            raise RuntimeError('No valid value at desired location.')
        if x < 1:
            x += 1
        if x > dA.shape[0]-2:
            x -= 2
        if y < 1:
            y += 1
        if y > dA.shape[1]-2:
            y -= 2
        t_b_t: npt.NDArray = dA[x-1:x+2,y-1:y+2].astype(float)
        # ? is this necessary? t_b_t[t_b_t==float(fv)] = np.nan
        nnan: int = np.count_nonzero(~np.isnan(t_b_t))
        if nnan == 0:
            raise RuntimeError('No valid pixels in 3x3 grid.')
        grid_avg: float = np.nanmean(t_b_t)
        grid_std: float = np.nanstd(t_b_t)
        if np.abs(grid_avg-dA[x,y]) <= grid_std:
             return grid_avg
        else:
             return dA[x,y]
    def __str__(self) -> str:
        return f'TROPOMI Satelite products at {self.loc} are {self.xch4:.2e} {self.xch4_unit} of CH4 and {self.xco:.2e} {self.xco_unit} of CO'

### *OCO_point Class*

This class ingests OCO-2/3 satelite data to compare against WRF data. This inherits from Sat_point. **NB:** Not finished yet (10/24/2024)

In [None]:
class OCO_Point(Sat_point):
    def __init__(self, loc, oco_f, ulat, ulon, **kwargs):
        super().__init__(loc, **kwargs)
        ds = Dataset(oco_f, 'r')
        lats = ds['latitude'][:]
        lons = ds['longitude'][:]
        idx = self.sat_loc(ulat, ulon, lats, lons)
        _xco2 = ds['']


### *Helper Functions*

These functions help illustrate a pass or fail of the testing suite.

In [None]:
def gprint(x: str) -> None: return cprint(x, 'white', 'on_green', attrs=['bold'],end=' ')
def rprint(x: str) -> None: return cprint(x, 'white', 'on_red', attrs=['blink','bold'], end=' ')

## Plotting

### *Common Setup*

These options are used for each or most of the plots below

In [None]:
hrs = [*range(0,24)]
golden = (1. + np.sqrt(5.))/2.
figsize = (12., 12./golden)

### *Meteorological Setup*

Adds the lats, lons and locations for the validation in question. Also sets up the observation files.

In [None]:
lats = [40.64, 40.78, 40.78]
lons = [-73.78, -73.87, -73.97]
days = [*(x for x in range(20,32)), *(x for x in range(1,13))]
mons = [*([7] * 12), *([8] * 12)]
locs = ['JFK Airport', 'LaGuardia Airport', 'Central Park']
obs_files = ['./sfc/JFK.csv','./sfc/LGA.csv','./sfc/NYC.csv']

### *Meteorological Daily Validation*

This cell loops over the simulation and plots the daily time series. It also checks the RMSE for that day's worth of data against the observation and spits out a validation result.

In [None]:
var_title = ['Temperature','Dew Point','Sea Level Pressure','U10 Wind Component','V10 Wind Component']
var_limit = [5., 5., 2., 2.24, 2.24]
var_units = ['K', 'K', 'hPa', 'm/s', 'm/s']
var_short = ['t2','td2','slp','u10','v10']
#print('before for loop')
for lat, lon, loc, obs_files in zip(lats, lons, locs, obs_files):
    t2_rmses = []
    td2_rmses = []
    slp_rmses = []
    u10_rmses = []
    v10_rmses = []
    t2_mbias = []
    td2_mbias = []
    slp_mbias = []
    u10_mbias = []
    v10_mbias = []
    print(f'{loc = }')
    for month, day in zip(mons, days):
        wrf_files = [f'./wrfout/{f}' for f in os.listdir('./wrfout') if f.startswith(f'wrfout_d02_2023-{month:02}-{day:02}')]
        wrf_files.sort()
        data_sets = [Ds(wrf_f) for wrf_f in wrf_files]
        obs_dates = [dt(2023, month, day, h, 0, 0) for h in range(0,24)]
        wrf_data = WRF_point(data_sets, lat, lon, loc)
        obs_data = Obs_point(loc, obs_file, obs_dates)
        #print('after data load')
        u10_rmses.append(wrf_data._rmse(wrf_data.u10, obs_data.u10))
        v10_rmses.append(wrf_data._rmse(wrf_data.v10, obs_data.v10))
        slp_rmses.append(wrf_data._rmse(wrf_data.slp, obs_data.slp))
        td2_rmses.append(wrf_data._rmse(wrf_data.td2, obs_data.td2))
        t2_rmses.append(wrf_data._rmse(wrf_data.T2, obs_data.T2))
        u10_mbias.append(wrf_data._mean_bias(wrf_data.u10, obs_data.u10))
        v10_mbias.append(wrf_data._mean_bias(wrf_data.v10, obs_data.v10))
        slp_mbias.append(wrf_data._mean_bias(wrf_data.slp, obs_data.slp))
        td2_mbias.append(wrf_data._mean_bias(wrf_data.td2, obs_data.td2))
        t2_mbias.append(wrf_data._mean_bias(wrf_data.T2, obs_data.T2))
        validate = obs_data == wrf_data
        if not validate:
            print(f'Failed validation: {repr(flags)}')
        else:
            print(f'Validation passed! {repr(flags)}')
        wrf_var_list = [wrf_data.T2, wrf_data.td2, wrf_data.slp, wrf_data.u10, wrf_data.v10]
        obs_var_list = [obs_data.T2, obs_data.td2, obs_data.slp, obs_data.u10, obs_data.v10]
        for wrf_var, obs_var, title, limit, unit, short in zip(wrf_var_list, obs_var_list, var_title, var_limit, var_units, var_short):
            fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)
            fig.suptitle(f'Time series of {title} at {loc} for {month:02}-{day:02}-2023')
            ax1.set_title('Values')
            ax2.set_title('Bias')
            ax1.plot(obs_dates, obs_var, 'k', label='Observed')
            ax1.fill_between(obs_dates, obs_var.metpy.dequantify()-limit, obs_var.metpy.dequantify()+limit, color='b', alpha=.5)
            ax1.plot(obs_dates, wrf_var, 'r', label='WRF Model')
            ax1.xaxis_date()
            ax1.set(ylabel=f'{title} ({unit})', xlabel='Time (UTC)')
            ax2.xaxis_date()
            ax2.set(ylabel=f'{title} Bias ({unit})', xlabel='Time (UTC)')
            ax2.plot(obs_dates, wrf_var-obs_var, 'g', label='Bias')
            ax2.plot(obs_dates, np.zeros(len(obs_dates)), 'm--')
            ax2.fill_between(obs_dates, -limit, limit, color='b', alpha=.5)
            fig.legend(loc='outside lower center', ncols=3)
            plt.savefig(f'{short}_timeseries_{loc.replace(" ","")}_2023-{month:02}-{day:02}.png')
            plt.close()


In [None]:
var_title = ['Temperature','Dew Point','Sea Level Pressure',None]
var_limit = [5., 5., 2., 2.24]
var_units = ['K', 'K', 'hPa', 'm/s']
var_short = ['t2','td2', 'slp', None]
var_wrf_getvar = ['T2', 'td2', 'slp','uvmet10']
for lat, lon, loc, obs_file in zip(lats, lons, locs, obs_files):
    print(f'{loc = }')
    for title, limit, unit, short, wrf_getvar in zip(var_title, var_limit, var_units, var_short, var_wrf_getvar):
        if title is not None:
            print(title)
            obs_hr_means = []
            wrf_hr_means = []
        else:
            print('winds')
            obs_u10_hr_means = []
            obs_v10_hr_means = []
            wrf_v10_hr_means = []
            wrf_u10_hr_means = []
        for hr in hrs:
            wrf_files = [f'./wrfout/{f}' for f in os.listdir('./wrfout/') if f'_{hr:02}:' in f and f.startswith('wrfout_d02')]
            wrf_files.sort()
            data_sets = [Ds(wrf_f) for wrf_f in wrf_files]
            x, y = wrf.ll_to_xy(data_sets, lat, lon)
            obs_dates = [dt(2023, month, day, hr, 0, 0) for month, day in zip(mons, days)]
            if title is not None:
                if title != 'Temperature':
                    try:
                        wrf_data = wrf.getvar(data_sets, wrf_getvar, wrf.ALL_TIMES,units=unit)[..., y, x].metpy.quantify()
                    except MemoryError:
                        wrf_data = WRF_point._memory_loop(data_sets, wrf_getvar, x, y, unit)
                else:
                    try:
                        wrf_data = wrf.getvar(data_sets, wrf_getvar, wrf.ALL_TIMES)[..., y, x].metpy.quantify()
                    except MemoryError:
                        wrf_data = WRF_point._memory_loop(data_sets, wrf_getvar, x, y)
            else:
                try:
                    wrf_u10, wrf_v10 = wrf.getvar(data_sets, wrf_getvar, wrf.ALL_TIMES)[..., y, x].metpy.quantify()
                except MemoryError:
                    wrf_u10, wrf_v10 = WRF_point._memory_loop(data_sets, wrf_getvar, x, y)
            obs_data = Obs_point(loc, obs_file, obs_dates)
            match title:
                case 'Temperature':
                    wrf_hr_means.append(np.nanmean(wrf_data))
                    obs_hr_means.append(np.nanmean(obs_data.T2))
                case 'Dew Point':
                    wrf_hr_means.append(np.nanmean(wrf_data))
                    obs_hr_means.append(np.nanmean(obs_data.td2))
                case 'Sea Level Pressure':
                    wrf_hr_means.append(np.nanmean(wrf_data))
                    obs_hr_means.append(np.nanmean(obs_data.slp))
                case None:
                    wrf_v10_hr_means.append(np.nanmean(wrf_v10))
                    wrf_u10_hr_means.append(np.nanmean(wrf_u10))
                    obs_v10_hr_means.append(np.nanmean(obs_data.v10))
                    obs_u10_hr_means.append(np.nanmean(obs_data.u10))
                case _:
                    raise RuntimeError('Variable not valid')
        if title is not None:
            obs_var = np.array(obs_hr_means)
            wrf_var = np.array(wrf_hr_means)
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize = figsize,sharex=True)
            fig.suptitle(f'Mean Time Series of {title} at {loc}')
            ax1.set_title('Values')
            ax2.set_title('Bias')
            ax1.plot(hrs, obs_var, 'k', label='Observed')
            ax1.fill_between(hrs, obs_var-limit, obs_var+limit, color='b', alpha=.5)
            ax1.plot(hrs, wrf_var, 'r', label='WRF Model')
            #ax1.xaxis_date()
            ax1.set(ylabel=f'{title} ({unit})', xticks=hrs)
            #ax2.xaxis_date()
            ax2.set(ylabel=f'{title} Bias ({unit})', xlabel='Time (UTC)', xticks=hrs)
            ax2.plot(hrs, wrf_var-obs_var, 'g', label='Bias')
            ax2.plot(hrs, np.zeros(len(hrs)), 'm--', label='Zero Bias')
            ax2.fill_between(hrs, -limit, limit, color='b', alpha=.5)
            fig.legend(loc='outside lower center', ncols=4)
            plt.savefig(f'{short}_timeseries_{loc.replace(" ","")}_2023-average.png')
            plt.close()
            del obs_hr_means, wrf_hr_means, obs_var, wrf_var
        else:
            obs_v10 = np.array(obs_v10_hr_means)
            wrf_v10 = np.array(wrf_v10_hr_means)
            obs_u10 = np.array(obs_u10_hr_means)
            wrf_u10 = np.array(wrf_u10_hr_means)
            wind_title = ['U10 Wind Component','V10 Wind Component']
            wind_short = ['u10', 'v10']
            obs_wind = [obs_u10, obs_v10]
            wrf_wind = [wrf_u10, wrf_v10]
            for title, short, obs_var, wrf_var in zip(wind_title, wind_short, obs_wind, wrf_wind):
                fig, (ax1, ax2) = plt.subplots(2, 1, figsize = figsize, sharex=True)
                fig.suptitle(f'Mean Time Series of {title} at {loc}')
                ax1.set_title('Values')
                ax2.set_title('Bias')
                ax1.plot(hrs, obs_var, 'k', label='Observed')
                ax1.fill_between(hrs, obs_var-limit, obs_var+limit, color='b', alpha=.5)
                ax1.plot(hrs, wrf_var, 'r', label='WRF Model')
                #ax1.xaxis_date()
                ax1.set(ylabel=f'{title} ({unit})', xticks=hrs)
                #ax2.xaxis_date()
                ax2.set(ylabel=f'{title} Bias ({unit})', xlabel='Time (UTC)',xticks=hrs)
                ax2.plot(hrs, wrf_var-obs_var, 'g', label='Bias')
                ax2.plot(hrs, np.zeros(len(hrs)), 'm--', label='Zero Bias')
                ax2.fill_between(hrs, -limit, limit, color='b', alpha=.5)
                fig.legend(loc='outside lower center', ncols=4)
                plt.savefig(f'{short}_timeseries_{loc.replace(" ","")}_2023-average.png')
                plt.close()

In [None]:
lats = [40.64, 40.78, 40.78]
lons = [-73.78, -73.87, -73.97]
locs = ['Rutgers', 'LDEO', 'ASRC']
obs_files = ['./sfc/Rutgers.txt','./sfc/LDEO.txt','./sfc/ASRC.txt']
var_title = ['CH4 Concentration','CO2 Concentration']

In [None]:
var_limit = [10. * units.ppb, 10. * units.ppm]
var_units = ['ppb', 'ppm']
var_short = ['ch4','co2']
#print('before for loop')
for lat, lon, loc, obs_file in zip(lats, lons, locs, obs_files):
    print(f'{loc = }')
    for month, day in zip(mons, days):
        wrf_files = [f'./wrfout/{f}' for f in os.listdir('./wrfout') if f.startswith(f'wrfout_d02_2023-{month:02}-{day:02}')]
        wrf_files.sort()
        data_sets = [Ds(wrf_f) for wrf_f in wrf_files]
        obs_dates = [dt(2023, month, day, h, 0, 0) for h in range(0,24)]
        wrf_data = WRF_point(data_sets, lat, lon, loc, chem=True)
        obs_data = Obs_point(loc, obs_file, obs_dates, met=False, chem=True)
        #print('after data load')
        validate = obs_data == wrf_data
        if not validate:
            print(f'Failed validation: {repr(flags)}')
        else:
            print(f'Validation passed! {repr(flags)}')
        wrf_var_list = [wrf_data.ch4[:,0], wrf_data.co2[:,0]]
        obs_var_list = [obs_data.ch4, obs_data.co2]
        for wrf_var, obs_var, title, limit, unit, short in zip(wrf_var_list, obs_var_list, var_title, var_limit, var_units, var_short):
            fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)
            fig.suptitle(f'Time series of {title} at {loc} for {month:02}-{day:02}-2023')
            ax1.set_title('Values')
            ax2.set_title('Bias')
            ax1.plot(obs_dates, obs_var.metpy.convert_units(unit), 'k', label='Observed')
            ax1.fill_between(obs_dates, obs_var.metpy.convert_units(unit)-limit, obs_var.metpy.convert_units(unit)+limit, color='b', alpha=.5)
            ax1.plot(obs_dates, wrf_var.metpy.convert_units(unit), 'r', label='WRF Model')
            ax1.xaxis_date()
            ax1.set(ylabel=f'{title} ({unit})', xlabel='Time (UTC)')
            ax2.xaxis_date()
            ax2.set(ylabel=f'{title} Bias ({unit})', xlabel='Time (UTC)')
            ax2.plot(obs_dates, wrf_var.metpy.convert_units(unit)-obs_var.metpy.convert_units(unit), 'g', label='Bias')
            ax2.plot(obs_dates, np.zeros(len(obs_dates)), 'm--')
            ax2.fill_between(obs_dates, -limit, limit, color='b', alpha=.5)
            fig.legend(loc='outside lower center', ncols=3)
            plt.savefig(f'{short}_timeseries_{loc.replace(" ","")}_2023-{month:02}-{day:02}.png')
            plt.close()

In [None]:
var_wrf_getvar = ['ch4','co2']
for lat, lon, loc, obs_file in zip(lats, lons, locs, obs_files):
    print(f'{loc = }')
    for title, limit, unit, short, wrf_getvar in zip(var_title, var_limit, var_units, var_short, var_wrf_getvar):
        print(f'{title = }')
        obs_hr_means = []
        wrf_hr_means = []
        for hr in hrs:
            wrf_files = [f'./wrfout/{f}' for f in os.listdir('./wrfout/') if f'_{hr:02}:' in f and f.startswith('wrfout_d02')]
            wrf_files.sort()
            data_sets = [Ds(wrf_f) for wrf_f in wrf_files]
            x, y = wrf.ll_to_xy(data_sets, lat, lon)
            obs_dates = [dt(2023, month, day, hr, 0, 0) for month, day in zip(mons, days)]
            wrf_data = WRF_point(data_sets, lat, lon, loc, chem=True)
            obs_data = Obs_point(loc, obs_file, obs_dates, met=False, chem=True)
            match title:
                case 'CO2 Concentration':
                    wrf_hr_means.append(wrf_data.co2[:,0].mean(skipna=True))
                    obs_hr_means.append(obs_data.co2.mean(skipna=True))
                case 'CH4 Concentration':
                    wrf_hr_means.append(wrf_data.ch4[:,0].mean(skipna=True))
                    obs_hr_means.append(obs_data.ch4.mean(skipna=True))
                case _:
                    raise RuntimeError('Variable not valid.')
        obs_var = xr.DataArray(obs_hr_means) * units.ppm
        wrf_var = xr.DataArray(wrf_hr_means) * units.ppm
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize = figsize,sharex=True)
        fig.suptitle(f'Mean Time Series of {title} at {loc}')
        ax1.set_title('Values')
        ax2.set_title('Bias')
        ax1.plot(hrs, obs_var.metpy.convert_units(unit), 'k', label='Observed')
        ax1.fill_between(hrs, obs_var.metpy.convert_units(unit)-limit, obs_var.metpy.convert_units(unit)+limit, color='b', alpha=.5)
        ax1.plot(hrs, wrf_var.metpy.convert_units(unit), 'r', label='WRF Model')
        #ax1.xaxis_date()
        ax1.set(ylabel=f'{title} ({unit})', xticks=hrs)
        #ax2.xaxis_date()
        ax2.set(ylabel=f'{title} Bias ({unit})', xlabel='Time (UTC)', xticks=hrs)
        ax2.plot(hrs, wrf_var.metpy.convert_units(unit)-obs_var.metpy.convert_units(unit), 'g', label='Bias')
        ax2.plot(hrs, np.zeros(len(hrs)), 'm--', label='Zero Bias')
        ax2.fill_between(hrs, -limit, limit, color='b', alpha=.5)
        fig.legend(loc='outside lower center', ncols=4)
        plt.savefig(f'{short}_timeseries_{loc.replace(" ","")}_2023-average.png')
        plt.close()
        del obs_hr_means, wrf_hr_means, obs_var, wrf_var