# CMIP6 on JASMIN

In [None]:
import xarray as xr
import numpy as np
import glob
import re
import geopandas as gpd
import regionmask
import os

os.chdir("/home/users/clairb/00_notebooks/24-04_sahel-heat")

import dask; dask.config.set(**{'array.slicing.split_large_chunks': True})

# broader region for spatial pattern
Xn, Xx, Yn, Yx = [-18,52,0,35]

# region for West Africa time series analysis
xn,xx,yn,yx = [-10,20,10,17]
box_str = "_".join([str(i) for i in [xn,xx,yn,yx]])

# shapefile
sf = gpd.read_file("sf_malibf")
sf_str = "malibf"

In [None]:
def wrap_lon(ds):
    
    # method to wrap longitude from (0,360) to (-180,180)
    
    if "longitude" in ds.coords:
        lon = "longitude"
        lat = "latitude"
    elif "lon" in ds.coords:
        lon = "lon"
        lat = "lat"
    else: 
        # can only wrap longitude
        return ds
    
    if ds[lon].max() > 180:
        ds[lon] = (ds[lon].dims, (((ds[lon].values + 180) % 360) - 180), ds[lon].attrs)
        
    if lon in ds.dims:
        ds = ds.reindex({ lon : np.sort(ds[lon]) })
        ds = ds.reindex({ lat : np.sort(ds[lat]) })
    return ds

## Identify models with both historical & SSP data

In [None]:
varnm = "tasmax"

# list all models with available historical data
fl_hist = glob.glob('/badc/cmip6/data/CMIP6/CMIP/*/*/historical/*/day/'+varnm)
mdl_hist = [re.sub("_historical", "", "_".join(fnm.split("/")[6:10])) for fnm in fl_hist]

# list all models with available SSP585 data
fl_ssp = glob.glob('/badc/cmip6/data/CMIP6/ScenarioMIP/*/*/ssp585/*/day/'+varnm)
mdl_ssp = [re.sub("_ssp585", "", "_".join(fnm.split("/")[6:10])) for fnm in fl_ssp]

# list all model variants for which both historical & SSP are available
model_vars = sorted([m for m in mdl_hist if m in mdl_ssp])

# identify unique models & get first ensemble member for each
models = list(set(["_".join(m.split("_")[:2]) for m in model_vars]))
models = [[v for v in model_vars if m in v][0] for m in models]

In [None]:
# cut out time series & spatial pattern for individual files
for mdl in models:

    print(mdl, end = "")
    inst, gcm, em = mdl.split("_")

    sp_fnm = "spatial/"+varnm+"_"+mdl+"_"+box_str+"_spatial-monthly.nc"
    box_fnm = "daily/"+varnm+"_"+mdl+"_"+box_str+"_daily.nc"
    sf_fnm = "daily/"+varnm+"_"+mdl+"_"+sf_str+"_daily.nc"

    # if os.path.exists(sp_fnm) and os.path.exists(box_fnm) and os.path.exists(sf_fnm): 
    if os.path.exists(sf_fnm): 
        print("Already processed")
        continue

    # list all relevant files (not including anything past 2050 at the moment)
    fl_h = glob.glob("/badc/cmip6/data/CMIP6/CMIP/"+inst+"/"+gcm+"/historical/"+em+"/day/"+varnm+"/*/latest/*.nc")
    fl_s = glob.glob("/badc/cmip6/data/CMIP6/ScenarioMIP/"+inst+"/"+gcm+"/ssp585/"+em+"/day/"+varnm+"/*/latest/*.nc")
    fl_s = [fnm for fnm in fl_s if int(fnm[-20:-16]) <= 2050]

    fl = fl_h + fl_s
    if(len(fl) == 0): continue

    for fnm in fl:
        
        # load the data
        da = wrap_lon(xr.open_dataset(fnm)).reset_coords(drop = True)[varnm].sel(lon = slice(Xn,Xx), lat = slice(Yn,Yx))
    
        # can't convert units or calendar on Jasmin so will have to handle that later

        # # monthly spatial pattern
        # sp = da.groupby("time.month").mean()
        # sp.to_netcdf("cmip6/"+re.sub("day", "spatial", fnm.split("/")[-1]))

        # # get daily time series over rectangular region
        # ts_box = da.sel(lon = slice(xn,xx), lat = slice(yn,yx)).mean(["lat", "lon"])
        # ts_box.to_netcdf("cmip6/"+re.sub("day", "daily-box", fnm.split("/")[-1]))

        # get daily time series over shapefile region
        rm = regionmask.mask_3D_geopandas(sf, da.lon, da.lat).sum("region")
        ts_sf = da.where(rm == 1).mean(["lat", "lon"])
        ts_sf.to_netcdf("cmip6/"+re.sub("day", "daily-sf", fnm.split("/")[-1]))

        print(".", end = "")
    print("")