In [None]:
import numpy as np
import xarray as xr
import holoviews as hv
from dask.diagnostics import ProgressBar
from datetime import datetime
import cftime
import dask
import matplotlib.pyplot as plt

hv.extension('bokeh')

### Animation comparison

In [None]:
ds_cam = xr.open_zarr('/datadrive/cam5/nat_hist_zarr').isel(height=0)

In [None]:
ds_cam2had = xr.open_zarr('/datastore/cam5/nat_hist_to_hadgem3_zarr', consolidated=True)

In [None]:
ds_had = xr.open_zarr('/datastore/hadgem3/nat_hist_zarr', consolidated=True)

In [None]:
def get_images(da_list, **kwargs):
    hv_list = []
    for da in da_list:
        hv_ds = hv.Dataset(da)
        hv_list.append(hv_ds.to(hv.Image, ['lon', 'lat']).options(**kwargs))
    return hv_list
        
images1 = get_images(
    [ds.isel(run=0, time=slice(0, 30)).tas.load()-273 for ds in [ds_cam2had, ds_cam]], 
    height=180,
    width=360,
    cmap='viridis',
    colorbar=True
)

# colorbars are matched if the variables have the same names
images2 = get_images(
    [(ds_cam-ds_cam2had).isel(run=0, time=slice(0, 30)).tas.load().rename('tas_diff')], 
    height=180,
    width=360,
    cmap='bwr',
    colorbar=True
)
    
image  = images1[0].opts(title='CAM2Had') \
       + images1[1].opts(title="CAM") \
       + images2[0].opts(title="CAM - CAM2Had").redim.range(tas_diff=(-20, 20))

In [None]:
%%output holomap='scrubber'
# ['gif', ]
image

The image sequences look like they maintain spatial structures, which is what we were hoping for.

### Basic statistics comparison

In [None]:
agg_had = xr.open_dataset("/datastore/hadgem3/nat_hist_agg.nc").load().isel(height=0)
agg_cam = xr.open_dataset("/datastore/cam5/nat_hist_agg.nc").load().isel(height=0)
agg_cam2had = xr.open_dataset("/datastore/cam5/nat_hist_to_hadgem3_agg.nc").load()

In [None]:
cd ../../UNIT

In [None]:
from data import construct_regridders

In [None]:
cd ../notebooks/data_analysis

In [None]:
rg_a, rg_b = construct_regridders(agg_had, agg_cam)

agg_had = agg_had if rg_a is None else rg_a(agg_had).astype(np.float32)
agg_cam = agg_cam if rg_b is None else rg_b(agg_cam).astype(np.float32)

In [None]:
ds_list = [ds for ds in [agg_had, agg_cam2had, agg_cam]]

In [None]:
ds_allmod = xr.concat(ds_list, dim='source').assign_coords(source=['had', 'cam2had', 'cam'])
ds_allmod_diff = ds_allmod - ds_allmod.sel(source='cam2had')

In [None]:
def plot_comp(stat):
    fig = ds_allmod.tas.sel(aggregate_statistic=stat).plot(col='source')
    plt.suptitle(f'{stat} for each data source', y=1.05)
    ds_allmod_diff.tas.sel(aggregate_statistic=stat).plot(col='source')
    plt.suptitle(f'{stat} for each x = data source - cam2had', y=1.05)
    return fig

In [None]:
plot_comp('mean')

The mean of the cam5$\rightarrow$hadgem3 translated output has a mean very similar to the raw hadgem output. CAM has colder poles.

In [None]:
plot_comp('std')

The std of the cam5$\rightarrow$hadgem3 translated output is also very similar to the raw hadgem output. CAM has more variant poles.

In [None]:
plot_comp('min')

In [None]:
plot_comp('max')

The mins and maxes of cam5$\rightarrow$hadgem3 also better match HadGem3 tahn CAM5. This really highlights the extremes on land.

### Trend

In [None]:
def calculate_mean_trend(da):
    # cos latitude weights
    weight = np.cos(np.deg2rad(da.lat))
    weight /= weight.sum()
    return (da.mean(dim='lon')*weight).sum(dim='lat')

In [None]:
def _to_datetime(dt):
    return datetime(dt.year, dt.month,dt.day,dt.hour)

def to_compatible_datetime(dt, like_dt):
    if isinstance(like_dt, cftime.DatetimeNoLeap):
        return cftime.DatetimeNoLeap(dt.year, dt.month,dt.day,dt.hour)
    elif isinstance(like_dt, cftime.Datetime360Day):
        return cftime.Datetime360Day(dt.year, dt.month,dt.day,dt.hour)
    raise ValueError(f"Not valid datetype {like_dt}")

def _time_overlap(ds_list):
    dt_min = max([_to_datetime(ds.time.values.min()) for ds in ds_list])
    dt_max = min([_to_datetime(ds.time.values.max()) for ds in ds_list])
    return dt_min, dt_max

def filter_to_time_overlap(ds_list):
    dt_min, dt_max = _time_overlap(ds_list)
    ds_filtered=[]
    for ds in ds_list:
        dt_min_ = to_compatible_datetime(dt_min, ds.time.values[0])
        dt_max_ = to_compatible_datetime(dt_max, ds.time.values[0])
        ds_filtered.append(ds.sel(time=slice(dt_min_, dt_max_)))
    return ds_filtered

In [None]:
ds_trend_list = [calculate_mean_trend(ds) for ds in filter_to_time_overlap([ds_cam, ds_cam2had, ds_had])]

In [None]:
with ProgressBar(dt=10):
    ds_trend_list = dask.compute(ds_trend_list, scheduler='processes')

In [None]:
ds_trend_list_ = ds_trend_list[0]

In [None]:
for ds, source in zip(ds_trend_list_, ['cam', 'cam2had', 'had']):
    plt.plot(ds.tas.isel(run=0).values, label=source, alpha=0.7)
plt.legend()

I don't know what I expected to learn from this, but you can see the extremes line up better for cam5$\rightarrow$hadgem3 and hadgem3.

In [None]:
for ds, source in zip(ds_trend_list_[:-1], ['cam', 'cam2had',]):
    y = ds.tas.isel(run=0, time=slice(2000, 5000)).values
    plt.plot(y-y.mean(), label=source, alpha=0.7)
plt.legend()

Looks like we still get some of the extreme global events mapped across. The outlying points are preserved but are shifted to better allign with hadegm3

In [None]:
colours = plt.rcParams['axes.prop_cycle'].by_key()['color'][:3]
bins=100
for ds, source, c in zip(ds_trend_list_, ['cam', 'cam2had', 'had'], colours):
    ds.tas.plot.hist(label=source, color=c, alpha=0.4, bins=bins)
    ds.tas.plot.hist(bins=bins, color=c, histtype='step')
plt.legend()