In [1]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from climatetranslation.unit.data import get_all_data_loaders, CustomTransformer
from climatetranslation.unit.utils import get_config

In [2]:
conf = get_config("/home/dfulu/model_outputs/outputs/hadgem3_to_cam5_nat-hist-v6/config.yaml")
conf

{'image_save_iter': 10000,
 'image_display_iter': 100,
 'display_size': 16,
 'snapshot_save_iter': 10000,
 'log_iter': 10,
 'max_iter': 1000000,
 'batch_size': 1,
 'weight_decay': 0.0001,
 'beta1': 0.5,
 'beta2': 0.999,
 'init': 'kaiming',
 'lr': 0.0001,
 'lr_policy': 'step',
 'step_size': 100000,
 'gamma': 0.5,
 'gan_w': 4,
 'recon_x_w': 10,
 'recon_h_w': 0,
 'recon_kl_w': 0.01,
 'recon_x_cyc_w': 10,
 'recon_kl_cyc_w': 0.01,
 'vgg_w': 0,
 'gen': {'dim': 64,
  'mlp_dim': 256,
  'style_dim': 8,
  'activ': 'relu',
  'n_downsample': 2,
  'n_res': 4,
  'pad_type': 'zero',
  'upsample': 'bilinear',
  'output_activ': ['relu', 'none', 'none', 'none']},
 'dis': {'dim': 64,
  'norm': 'none',
  'activ': 'lrelu',
  'n_layer': 4,
  'gan_type': 'lsgan',
  'num_scales': 3,
  'pad_type': 'reflect'},
 'num_workers': 5,
 'data_zarr_a': '/datadrive/hadgem3/nat_hist_zarr',
 'data_zarr_b': '/datadrive/cam5/nat_hist_zarr',
 'agg_data_a': '/datadrive/hadgem3/nat_hist_agg.nc',
 'agg_data_b': '/datadrive/cam5

# Check for reasonable looking results when loaded
- [x] units
- [x] zeromean
- [x] normalise
- [x] custom_nofield
- [x] custom_allfield
- [x] custom_tasfield
- [ ] custom_prfield

In [3]:
# units
#conf['preprocess_method'] = 'custom_nofield' # [units, zeromean, normalise, custom_nofield, custom_allfield, custom_tasfield, custom_prfield]
print(conf['preprocess_method'])

custom_nofield


In [4]:
loaders = get_all_data_loaders(conf)

KeyError: 'split_at'

Have a look at individual samples of preprocessed data and the time means

In [None]:
loaders[0].dataset.ds.lat.shape, loaders[2].dataset.ds.lat.shape

In [None]:
plt.figure(figsize=(16,8))
loaders[0].dataset.ds.tas.isel(run=0, time=0).plot()
plt.title(f"train dataset {conf['data_zarr_a']}")
plt.show()

In [None]:
plt.figure(figsize=(16,8))
ds = loaders[2].dataset.ds.tas.isel(run=0, time=0).plot()
plt.title(f"dataset {conf['data_zarr_b']} train")
plt.show()

In [None]:
ds = loaders[2].dataset.ds.isel(run=0, time=slice(0, 700, 11)).compute()

In [None]:
import matplotlib.colors as colors
plt.figure(figsize=(12,4))
plt.subplot(121)
(ds.tas.isel(time=0)+70).plot(norm=colors.LogNorm())
plt.subplot(122)
ds.tas.isel(time=0).plot(norm=colors.Normalize(vmin=-5, vmax=5))

For the temperature above data, processed as Temp_K$_{ij}$ -> (Temp_K$_{ij}$ -273.15)/$\sigma_{ij}$, it looks like a bad transform. However, you have to remember that the network will be looking for differences between the predicted and real values. The real values have unit variance even if the means are very different. This method of preprocessing means 0 degress Celcius is given a value of 0 after preprocessing globally and stops the poles, which have higher variance, dominating the errors.

On the other hand, the convolutional filters may have some trouble dealing with how spatially inhomogeneous the transform makes this. So either of `custom_tasfield` or `custom_nofield` could be best.

In [None]:
plt.figure(figsize=(12,6))
ds.pr.isel(time=1).plot()

In [None]:
ds.pr.mean(dim='time').plot()

In [None]:
ds.pr.std(dim='time').plot()

In [None]:
ds.tas.std(dim='time').plot()

# Check for ability to translate back

In [None]:
# units
#conf['preprocess_method'] = 'custom_tasfield' # [units, zeromean, normalise, custom_nofield, custom_allfield, custom_tasfield, custom_prfield]
#loaders = get_all_data_loaders(conf)

if conf['preprocess_method']=='zeromean':
    trans = ZeroMeaniser(conf, downscale_consolidate=True)
elif conf['preprocess_method']=='normalise':
    trans = Normaliser(conf, downscale_consolidate=True)
elif conf['preprocess_method']=='units':
    trans = UnitModifier(conf, downscale_consolidate=True)
elif conf['preprocess_method']=='custom_allfield':
    trans = CustomTransformer(conf, downscale_consolidate=True, tas_field_norm=True, pr_field_norm=True)
elif conf['preprocess_method']=='custom_tasfield':
    trans = CustomTransformer(conf, downscale_consolidate=True, tas_field_norm=True, pr_field_norm=False)
elif conf['preprocess_method']=='custom_prfield':
    trans = CustomTransformer(conf, downscale_consolidate=True, tas_field_norm=False, pr_field_norm=True)
elif conf['preprocess_method']=='custom_nofield':
    trans = CustomTransformer(conf, downscale_consolidate=True, tas_field_norm=False, pr_field_norm=False)
else:
    raise ValueError(f"Unrecognised preprocess_method : {conf['preprocess_method']}")

ds_a = loaders[0].dataset.ds.isel(time=1, run=0).compute()
ds_b = loaders[2].dataset.ds.isel(time=1, run=0).compute()
    
trans.fit(trans.ds_agg_a, trans.ds_agg_b)

In [None]:
# preprocessed data
plt.figure(figsize=(24,10))
ax = plt.subplot(121)
ds_a.tas.plot(ax=ax)
ax = plt.subplot(122)
ds_b.tas.plot(ax=ax)

In [None]:
# de-processed data
trans.inverse_a(ds_a).tas.plot()

In [None]:
# original data
xr.open_zarr(conf['data_zarr_a']).isel(time=1, run=0, height=1).tas.plot()

## Test full train loading scheme

In [None]:
from climatetranslation.unit.utils import prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
from climatetranslation.unit.data import get_all_data_loaders

import os
import sys
import shutil

import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import tensorboardX


# replace the argparser in original
class blank:
    pass

opts = blank()
opts.config = '/home/dfulu//repos/climateTranslation/climatetranslation/unit/configs/hadgem3_to_cam5_nat-hist-v6.yaml'
opts.output_path = '/home/dfulu/tmp'
opts.resume = False

cudnn.benchmark = True

# Load experiment setting
config = get_config(opts.config)

display_size = 2 # config['display_size']
config['batch_size'] = 2
#config['preprocess_method'] = 'custom_nofield' # works : 'normalise', 'units' no-works: 'custom_allfield'

# data loaders
train_loader_a, test_loader_a, train_loader_b, test_loader_b = get_all_data_loaders(config, downscale_consolidate=True)

In [None]:
# Selection of climate fields to display after a number of updates
def generate_n(generator, n):
    return torch.cat([img for _, img in zip(range((n-1)//generator.batch_size + 1), generator)])[:n]

def generate_batch(generator, *args):
    return [img for _, img in zip(range(1), generator)][0]

train_display_images_a = generate_batch(train_loader_a, display_size)#.cuda()
train_display_images_b = generate_batch(train_loader_b, display_size)#.cuda()
test_display_images_a  = generate_batch(test_loader_a, display_size)#.cuda()
test_display_images_b  = generate_batch(test_loader_b, display_size)#.cuda()

In [None]:
def examine(imgs):
    return torch.isnan(imgs).any().item(), imgs.max().item(), imgs.min().item()

print('set    ', 'isnan, max, min')
print('train a', examine(train_display_images_a))
print('train b', examine(train_display_images_b))
print('test  a', examine(test_display_images_a))
print('test  b', examine(test_display_images_b))

In [None]:
import torchvision.utils as vutils
import matplotlib.pyplot as plt
plt.figure(figsize=(24,8))
image_outputs = [images[:,:3].expand(-1, 3, -1, -1) for images in [train_display_images_a, train_display_images_b]]
image_tensor = torch.cat([images[:display_size] for images in image_outputs], 0)
image_grid = vutils.make_grid(image_tensor.data, nrow=config['batch_size'], padding=0, normalize=True)
plt.imshow(image_grid.permute(1, 2, 0))

In [None]:

# A small amount of datetimes have all NaN data
def all_nan_last_two_axis_any_channel(x):
    #return torch.any(torch.all(torch.all(torch.isnan(x), axis=-1), axis=-1), axis=-1)
    return torch.isnan(x).all(dim=-1).all(dim=-1).any()

# Start training
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0


for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
    # Skip NaN fields
    if all_nan_last_two_axis_any_channel(images_a) or all_nan_last_two_axis_any_channel(images_b):
        print('Skipped on it = {}'.format(it))
        continue

    images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()
    break    

In [None]:
images_a.shape, images_b.shape

In [None]:
import torchvision.utils as vutils
import matplotlib.pyplot as plt
plt.figure(figsize=(24,8))
image_outputs = [images[:,:3].cpu().expand(-1, 3, -1, -1) for images in [images_a, images_b]]
image_tensor = torch.cat([images for images in image_outputs], 0)
image_grid = vutils.make_grid(image_tensor.data, nrow=config['batch_size'], padding=0, normalize=True)
plt.imshow(image_grid.permute(1, 2, 0))

## Test translation

In [None]:
%cd ../../climatetranslation/unit

In [None]:
ls -l /home/dfulu/model_outputs/outputs/hadgem3_to_cam5_nat-hist-v6/checkpoints

In [None]:
import xarray as xr
import numpy as np
import progressbar
import torch
import matplotlib.pyplot as plt
import holoviews as hv

hv.extension('bokeh')

from translate import network_translate_constructor
from utils import get_config
from data import (get_dataset, 
                  CustomTransformer, 
                  UnitModifier, 
                  ZeroMeaniser, 
                  Normaliser
)

In [None]:
# stand in for arg-parser
class argsob:
    pass

def get_translation(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Load experiment setting
    config = get_config(args.config)
    config['gen']['upsample']='nearest'

    # load the datasets
    ds_a = get_dataset(config['data_zarr_a'], config['level_vars'])
    ds_b = get_dataset(config['data_zarr_b'], config['level_vars'])

    # load pre/post processing transformer
    if config['preprocess_method']=='zeromean':
        prepost_trans = ZeroMeaniser(config, downscale_consolidate=True)
    elif config['preprocess_method']=='normalise':
        prepost_trans = Normaliser(config, downscale_consolidate=True)
    elif config['preprocess_method']=='units':
        prepost_trans = UnitModifier(config, downscale_consolidate=True)
    elif config['preprocess_method']=='custom_allfield':
        prepost_trans = CustomTransformer(config, downscale_consolidate=True, tas_field_norm=True, pr_field_norm=True)
    elif config['preprocess_method']=='custom_tasfield':
        prepost_trans = CustomTransformer(config, downscale_consolidate=True, tas_field_norm=True, pr_field_norm=False)
    elif config['preprocess_method']=='custom_prfield':
        prepost_trans = CustomTransformer(config, downscale_consolidate=True, tas_field_norm=False, pr_field_norm=True)
    elif config['preprocess_method']=='custom_nofield':
        prepost_trans = CustomTransformer(config, downscale_consolidate=True, tas_field_norm=False, pr_field_norm=False)
    else:
        raise ValueError(f"Unrecognised preprocess_method : {conf['preprocess_method']}")
    prepost_trans.fit(ds_a, ds_b)

    pre_trans = prepost_trans.transform_a if args.x2x[0]=='a' else prepost_trans.transform_b
    post_trans = prepost_trans.inverse_a if args.x2x[-1]=='a' else prepost_trans.inverse_b

    # load model 
    config['input_dim_a'] = len(ds_a.keys())
    config['input_dim_b'] = len(ds_b.keys())
    net_trans = network_translate_constructor(config, args.checkpoint, args.x2x)

    ds = ds_a if args.x2x[0]=='a' else ds_b    

    n_times=100
    ds_sample = ds.isel(time=slice(0, 11*n_times, 11), run=slice(0,1))

    # pre-rocess and convert to array
    da_pre = (
        pre_trans(ds_sample)
        .to_array()
        .transpose('run', 'time', 'variable', 'lat', 'lon')
    )

    # transform through network 
    da_post = xr.apply_ufunc(
        net_trans, 
        da_pre,
        vectorize=True,
        dask='allowed',
        output_dtypes=['float'],
        input_core_dims=[['variable', 'lat', 'lon']],
        output_core_dims=[['variable', 'lat', 'lon']]
    )

    # fix chunking
    da_post = da_post.chunk(dict(run=1, time=1, lat=-1, lon=-1))

    # post-process
    ds_translated = post_trans(da_post.to_dataset(dim='variable')).compute()
    
    da_prepost = post_trans(da_pre.to_dataset(dim='variable')).compute()
    
    return ds_translated, da_prepost

In [None]:
def plot2(tn, var='tas', vmin=None, vmax=None, **kwargs):
    plt.figure(figsize=(12, 21))
    plt.subplot(311)
    if vmin is None:
        vmin = min([
            float(ds_translated_v5.isel(time=tn)[var].min()),
            float(ds_translated_v6.isel(time=tn)[var].min()),
            float(da_prepost.isel(time=tn)[var].min())
        ])
    if vmax is None:
        vmax = max([
            float(ds_translated_v5.isel(time=tn)[var].max()),
            float(ds_translated_v6.isel(time=tn)[var].max()),
            float(da_prepost.isel(time=tn)[var].max())
        ])
    
    ds_translated_v5.isel(time=tn).sel(**kwargs)[var].plot(vmin=vmin, vmax=vmax)
    plt.title("v5")
    plt.subplot(312)

    ds_translated_v6.isel(time=tn).sel(**kwargs)[var].plot(vmin=vmin, vmax=vmax)
    plt.title("v6")
    plt.subplot(313)

    da_prepost.isel(time=tn).sel(**kwargs)[var].plot(vmin=vmin, vmax=vmax)
    plt.title("pre then post")
    plt.tight_layout()
    plt.show()

In [7]:
!head /home/dfulu/model_outputs/outputs/hadgem3_to_cam5_nat-hist-v6/checkpoints/gen_00110000.pt

��
   little_endianq�Xq (X   protocol_versionqM�X
   type_sizesq}q(X   shortqKX   intqKX   longqKuu.�}q (X   aqccollections
OrderedDict
q)Rq(X   enc.model.0.conv.weightqctorch._utils
_rebuild_tensor_v2
q((X   storageqctorch
FloatStorage
qX   94301127180880X   cuda:0q	M 1Ntq
tqRqX   enc.model.0.conv.biasqh((hhX   94300235451152qX   cuda:0qK@NtqQK K@�qK�q�h)RqtqRqX   enc.model.1.conv.weightqh((hhX   94300027945792qX   cuda:0qJ   NtqQK (K�K@KKtq(M KKKtq�h)Rqtq Rq!X   enc.model.1.conv.biasq"h((hhX   94300125215184q#X   cuda:0q$K�Ntq%QK K��q&K�q'�h)Rq(tq)Rq*X   enc.model.2.conv.weightq+h((hhX   94300027946816q,X   cuda:0q-J  Ntq.QK (M K�KKtq/(MKKKtq0�h)Rq1tq2Rq3X   enc.model.2.conv.biasq4h((hhX   94300001134944q5X   cuda:0q6M Ntq7QK M �q8K�q9�h)Rq:tq;Rq<X'   enc.model.3.model.0.model.0.conv.weightq=h((hhX   94300027967328q>X   cuda:0q?J  	 Ntq@QK (M M KKtqA(M 	K	KKtqB�h)RqCtqDRqEX

In [None]:
model_n = '00110000'
args = argsob()
args.config = '/home/dfulu//repos/climateTranslation/climatetranslation/unit/configs/hadgem3_to_cam5_nat-hist-v6.yaml'
args.checkpoint = f'/home/dfulu/model_outputs/outputs/hadgem3_to_cam5_nat-hist-v6/checkpoints/gen_{model_n}.pt'
args.x2x = 'b2a'
args.seed=1

ds_translated_v6, da_prepost = get_translation(args)

In [None]:
model_n = '00069000'
args = argsob()
args.config = '/home/dfulu//repos/climateTranslation/climatetranslation/unit/configs/hadgem3_to_cam5_nat-hist-v5.yaml'
args.checkpoint = f'/home/dfulu/model_outputs/outputs/hadgem3_to_cam5_nat-hist-v5/checkpoints/gen_{model_n}.pt'
args.x2x = 'b2a'
args.seed=1

ds_translated_v5, da_prepost = get_translation(args)

In [None]:
#ds_translated_v6 = ds_translated_v6.mean(dim='time', keepdims=True)
#ds_translated_v5 = ds_translated_v5.mean(dim='time', keepdims=True)
#da_prepost = da_prepost.mean(dim='time', keepdims=True)

In [None]:
plot2(0, var='tasmax',  lat=slice(-90, 90), lon=slice(0, 360))

In [None]:
plot2(10, lat=slice(-90, 90), lon=slice(None, None))

In [None]:
tn=10

tmin = float(ds_translated.isel(time=tn).tas.min())
tmax = float(ds_translated.isel(time=tn).tas.max())

plt.figure(figsize=(20, 10))
ds_translated.isel(time=tn).sel(lat=slice(-90, 90)).tas.plot(vmin=tmin, vmax=tmax)
plt.show()

In [None]:
tn=10

tmin = min(float(ds_sample_regrid.isel(time=tn).tas.min()), float(ds_translated.isel(time=tn).tas.min()))
tmax = max(float(ds_sample_regrid.isel(time=tn).tas.max()), float(ds_translated.isel(time=tn).tas.max()))

plt.figure(figsize=(20, 6))
plt.subplot(121)
ds_sample_regrid.isel(time=tn).tas.plot(vmin=tmin, vmax=tmax)
plt.title("regridded sample")
plt.subplot(122)
ds_translated.isel(time=tn).tas.plot(vmin=tmin, vmax=tmax)
plt.title("translated sample")
plt.tight_layout()
plt.show()

In [None]:
tn = 0
pmin = min(float(ds_sample_regrid.isel(time=tn).pr.min()), float(ds_translated.isel(time=tn).pr.min()))
pmax = max(float(ds_sample_regrid.isel(time=tn).pr.max()), float(ds_translated.isel(time=tn).pr.max()))

plt.figure(figsize=(14, 4))
plt.subplot(121)
ds_sample_regrid.isel(time=tn).pr.plot(vmin=pmin, vmax=pmax)
plt.subplot(122)
ds_translated.isel(time=tn).pr.plot(vmin=pmin, vmax=pmax)
plt.tight_layout()
plt.show()

In [None]:
tn = 0

p4min = min(float(da_pre.isel(time=tn, variable=0).min()), float(da_post.isel(time=tn, variable=0).min()))
p4max = max(float(da_pre.isel(time=tn, variable=0).max()), float(da_post.isel(time=tn, variable=0).max()))

plt.figure(figsize=(14, 4))
plt.subplot(121)
da_pre.isel(time=tn, variable=tn).plot(vmin=p4min, vmax=p4max)
plt.subplot(122)
da_post.isel(time=tn, variable=tn).plot(vmin=p4min, vmax=p4max)
plt.tight_layout()
plt.show()

In [None]:
(ds_sample_regrid.tas-ds_translated.tas).isel(time=0).plot()

In [None]:
def get_images(da_list, **kwargs):
    hv_list = []
    for da in da_list:
        hv_ds = hv.Dataset(da)
        hv_list.append(hv_ds.to(hv.Image, ['lon', 'lat']).options(**kwargs))
    return hv_list

def animate_compare(da_a, da_b, name_a=None, name_b=None):
        
    images1 = get_images(
        [da_a, da_b], 
        height=180,
        width=360,
        cmap='viridis',
        colorbar=True
    )
    
    # colorbars are matched if the variables have the same names
    images2 = get_images(
        [(da_a - da_b).rename('diff')], 
        height=180,
        width=360,
        cmap='bwr',
        colorbar=True
    )
    
    name_a = 'a' if name_a is None else name_a
    name_b = 'b' if name_b is None else name_b
    
    vmax = max(abs(float((da_a - da_b).min())), abs(float((da_a - da_b).min())))
    vmin = -vmax
    

    image  = images1[0].opts(title=name_a) \
           + images1[1].opts(title=name_b) \
           + images2[0].opts(title=f"{name_a} - {name_b}").redim.range(diff=(vmin, vmax))
    
    return image


image = animate_compare(ds_sample_regrid.isel(time=slice(0, 30)).tas, 
                        ds_translated.isel(time=slice(0, 30)).tas, 
                        'original', 
                        'translated')

In [None]:
%%output holomap='scrubber'
# ['gif', ]
image

In [None]:
image = animate_compare(ds_sample_regrid.isel(time=slice(0, 30)).pr, 
                        ds_translated.isel(time=slice(0, 30)).pr, 
                        'original', 
                        'translated')

In [None]:
%%output holomap='scrubber'
image

In [None]:
image = animate_compare(da_pre.isel(time=slice(0, 30), variable=0),
                        da_post.isel(time=slice(0, 30), variable=0),
                        'original', 
                        'translated')

In [None]:
%%output holomap='scrubber'
image