In [None]:
import os

import numpy as np
import pandas as pd
import xarray as xr
import xenso
from scipy import signal

In [None]:
with open("selected_models.txt") as f:
    selected_models = f.read().splitlines()
selected_models

In [None]:
DATA_DIR = "/glade/derecho/scratch/griverat/ics_CMIP6"

In [None]:
file_list = os.listdir(DATA_DIR)
file_list.sort()
models_members = {}
for _file in file_list:
    _name, _member, _ = _file.split(".")
    models_members[_name] = models_members.get(_name, [])
    models_members[_name].append(_member)
    models_members[_name].sort()
models_members = {k: v for k, v in models_members.items() if k in selected_models}
models_members

In [None]:
def compute_coa_index(sst_anom, precip_anom):
    trop_mean_sst_anom = sst_anom.sel(lat=slice(-30, 30)).mean(dim=["lat", "lon"])
    ts_anom_adj = sst_anom - trop_mean_sst_anom

    en12 = xenso.indices.enzones(ts_anom_adj, zone="12")
    en12_fma = en12.resample(time="QS-FEB").mean()
    en12_fma = en12_fma.sel(time=en12_fma.time.dt.month.isin([2]))
    en12_fma_std = en12_fma.std("time")
    en12_fma_stdzed = (en12_fma - en12_fma.mean("time")) / en12_fma_std

    en34 = xenso.indices.enzones(ts_anom_adj, zone="34")
    en34_fma = en34.resample(time="QS-FEB").mean()
    en34_fma = en34_fma.sel(time=en34_fma.time.dt.month.isin([2]))
    en34_fma_std = en34_fma.std("time")
    en34_fma_stdzed = (en34_fma - en34_fma.mean("time")) / en34_fma_std

    # precipitation
    precip_anom = precip_anom * 86400
    en12_pr = xenso.indices.enzones(precip_anom, zone="12")
    en12_pr_fma = en12_pr.resample(time="QS-FEB").mean()
    en12_pr_fma = en12_pr_fma.sel(time=en12_pr_fma.time.dt.month.isin([2]))
    en12_pr_fma_std = en12_pr_fma.std("time")
    en12_pr_fma_stdzed = (en12_pr_fma - en12_pr_fma.mean("time")) / en12_pr_fma_std

    # coa_events_extreme = (
    #     (en12_fma_stdzed >= 1.1) & (en12_pr_fma_stdzed >= 1.1) & (en34_fma_stdzed < 0.5)
    # )  # .sum().values
    coa_events_strong = (
        # (~coa_events_extreme)
        (en12_fma_stdzed >= 0.8)
        & (en12_pr_fma_stdzed >= 0.8)
        & (en34_fma_stdzed < 0.5)
        # .sum()
        # .values
    )

    coa_events_strong["time"] = (
        coa_events_strong.indexes["time"].to_series().apply(lambda x: x.replace(day=15))
    )
    return coa_events_strong

In [None]:
BWCUT_OFF = 120

OUT_PATH = "/glade/derecho/scratch/griverat/ics_CMIP6_processed"

for model_name, members in models_members.items():
    print(f"Doing model {model_name}")
    for member in members:
        print(f"Starting member: {member}")
        xfile_path = os.path.join(DATA_DIR, f"{model_name}.{member}.nc")

        xfile = xr.open_dataset(xfile_path)
        xfile_anom = xfile.groupby("time.month").apply(lambda x: x - x.mean("time"))

        coa_events = compute_coa_index(xfile_anom.tos, xfile_anom.pr)
        coa_mask = xr.full_like(
            xfile_anom.tos.isel(lat=0, lon=0, drop=True), False, dtype=bool
        )
        coa_mask.loc[{"time": coa_events.time}] = coa_events.data
        coa_mask = (
            coa_mask.where(coa_mask)
            .ffill(dim="time", limit=1)
            .bfill(dim="time", limit=11)
            .fillna(0)
            .astype(bool)
        )

        sos = signal.butter(5, 1 / BWCUT_OFF, btype="lowpass", output="sos")
        low_signal = signal.sosfiltfilt(
            sos, xfile_anom.tos, axis=0, padtype="even", padlen=12 * 5
        )
        low_signal = xr.DataArray(low_signal, coords=xfile_anom.coords)

        # With this bias-corrected anomaly we can now compute the alpha parameter
        xfile_anom["tos"] = xfile_anom.tos - low_signal

        stacked_vars = xr.concat(
            [xfile_anom.tos, xfile_anom.zos, xfile_anom.uas, xfile_anom.vas], dim="var"
        ).transpose("time", ...)

        _target_map = xfile_anom.tos.sel(time=xfile_anom.tos.time.dt.month.isin([2]))
        target_map = xr.full_like(xfile_anom.tos, np.nan)
        target_map.loc[{"time": _target_map.time}] = _target_map.data

        target_map = target_map.bfill(dim="time", limit=11).dropna("time", how="all")
        target_map["model"] = f"{model_name}.{member}"
        target_map.to_netcdf(
            os.path.join(OUT_PATH, f"target_map_{model_name}.{member}.nc")
        )

        stacked_vars = stacked_vars.sel(time=target_map.time)
        stacked_vars["model"] = f"{model_name}.{member}"
        stacked_vars.to_netcdf(
            os.path.join(OUT_PATH, f"input_{model_name}.{member}.nc")
        )

        coa_mask = coa_mask.sel(time=target_map.time)
        coa_mask["model"] = f"{model_name}.{member}"
        coa_mask.to_netcdf(
            os.path.join(OUT_PATH, f"target_label_{model_name}.{member}.nc")
        )