# Processed data

In [None]:
%reload_ext autoreload

%autoreload 2

import os

import iris
import matplotlib.pyplot as plt
import metpy
import numpy as np
import xarray as xr

from ml_downscaling_emulator import UKCPDatasetMetadata
from ml_downscaling_emulator.utils import cp_model_rotated_pole, platecarree

## Vorticity

In [None]:
moose_dir = os.path.join(os.getenv("DERIVED_DATA"), "moose")
vort_meta = UKCPDatasetMetadata(moose_dir, frequency="day", domain="london", resolution="2.2km-coarsened-4x", variable="vorticity850")

vort_ds = xr.open_mfdataset(vort_meta.existing_filepaths())
vort_ds

In [None]:
ax = plt.axes(projection=cp_model_rotated_pole)
vort_ds.isel(time=0)['vorticity850'].plot(ax=ax)
ax.coastlines()

In [None]:
hist = vort_ds['vorticity850'].plot.hist(bins=50, density=True)

## Target Pr

In [None]:
pr_meta = UKCPDatasetMetadata(moose_dir, frequency="day", domain="london", resolution="2.2km", variable="pr")

pr_ds = xr.open_mfdataset(pr_meta.existing_filepaths()).rename({pr_meta.variable: f'target_{pr_meta.variable}'})
pr_ds

In [None]:
pr_ds['target_pr'].plot(bins=50)

In [None]:
ax = plt.axes(projection=cp_model_rotated_pole)
pr_ds.isel(time=0)['target_pr'].plot(ax=ax)
ax.coastlines()

In [None]:
combined_ds = xr.combine_by_coords([vort_ds, pr_ds], compat='no_conflicts', combine_attrs="drop_conflicts", coords="all", join="inner", data_vars="all")
combined_ds = combined_ds.assign_coords(season=(('time'), (combined_ds['time.month'].data % 12 // 3)))
combined_ds

## Dataset (from sample)

In [None]:
splits = ["train", "val", "test"] #, "extreme_val", "extreme_test"]

data_splits = {split: xr.open_dataset(os.path.join(os.getenv("DERIVED_DATA"), "moose", "nc-datasets", "2.2km-coarsened-8x_london_random_london_8x_vorticity850_random", f"{split}.nc")) for split in splits}

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(12, 15), subplot_kw=dict(projection=cp_model_rotated_pole))

for i, split in enumerate(splits):
    ts = np.random.choice(data_splits[split].time.values)
    
    ax=axes[i][0]
    data_splits[split].sel(time=ts)['target_pr'].plot(ax=ax)
    ax.coastlines()
    
    ax=axes[i][1]
    data_splits[split].sel(time=ts)['vorticity850'].plot(ax=ax)
    ax.coastlines()


In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(24, 5))

for split in splits:
    data_splits[split]['target_pr'].plot(ax=axes[0], bins=50, density=True, alpha=0.3, label=split)
    data_splits[split]['vorticity850'].plot(ax=axes[1], bins=50, density=True, alpha=0.3, label=split)

axes[0].set_title("Target pr")
axes[1].set_title("Vorticity@850")
axes[0].legend()
axes[1].legend()