/home/dfulu/repos/climateTranslation/UNIT


In [4]:
import xarray as xr
import numpy as np
import dask
import argparse
import sys
import os

from climatetranslation.unit.utils import get_config
from climatetranslation.unit.data import get_all_data_loaders
from climatetranslation.unit.trainer import UNIT_Trainer

import torch


# load post-processing - opposite of preprocessing
def post_process_constructor(config):
    if config['preprocess_method']=='zeromean':
        
        ab = 'b' if args.a2b else 'a'
        ds_agg = xr.load_dataset(config[f'agg_data_{ab}']).isel(height=0)
        
        def undo_zeromean(x):
            return x + ds_agg.sel(variable='mean').to_array()
        return undo_zeromean
    
    else:
        def celcius_to_kelvin(x):
            return x + 273
        return celcius_to_kelvin
    
    
def network_translate_constructor(config, checkpoint, a2b):
    
    # load model
    state_dict = torch.load(checkpoint)

    trainer = UNIT_Trainer(config)
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
    trainer.eval().cuda()
    encode = trainer.gen_a.encode if a2b else trainer.gen_b.encode # encode function
    decode = trainer.gen_b.decode if a2b else trainer.gen_a.decode # decode function
    
    def network_translate(x):
        x = np.array(x)[np.newaxis, ...]
        x = torch.from_numpy(x).cuda()
        x, noise = encode(x)
        x = decode(x)
        x = x.cpu().detach().numpy()
        return x
    return network_translate
    


def complete_translate_constructor(config, checkpoint, a2b):
    
    network_translate = network_translate_constructor(config, checkpoint, a2b)
    post_process = post_process_constructor(config)
    
    def translate(x):
        x = network_translate(x)
        x = post_process(x)
        return x
    
    return translate

ModuleNotFoundError: No module named 'networks'

In [3]:
class Args:
    def __init__(self, config, output_zarr, checkpoint, a2b, seed):
        self.config = config
        self.output_zarr = output_zarr
        self.checkpoint = checkpoint
        self.a2b = a2b
        self.seed = seed

args=Args(
    config=os.path.expanduser('~/model_outputs/outputs/hadgem3_to_cam5_nat-hist/config.yaml'),
    output_zarr="/datastore/cam5/nat_hist_to_hadgem3_zarr",
    checkpoint=os.path.expanduser('~/model_outputs/outputs/hadgem3_to_cam5_nat-hist/checkpoints/gen_00160000.pt'),
    a2b=0,
    seed=32213,
)


torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# Load experiment setting
config = get_config(args.config)

# Setup model and data loader
# By constructing loader and extracting dataset from this we make sure all preprocessing is
# consistent
loaders = get_all_data_loaders(config, downscale_consolidate=True)
ds = loaders[0].dataset.ds if args.a2b else loaders[2].dataset.ds
da = ds.to_array().transpose('run', 'time', 'variable', 'lat', 'lon')

# append number of variables
config['input_dim_a'] = loaders[0].dataset.shape[1]
config['input_dim_b'] = loaders[2].dataset.shape[1]
del loaders


translate = complete_translate_constructor(config, args.checkpoint, args.a2b)

Create weight file: bilinear_324x432_192x288_peri.nc
Remove file bilinear_324x432_192x288_peri.nc
using dimensions ('lat', 'lon') from data variable tas as the horizontal dimensions for this dataset.


In [4]:
ds_translated = xr.apply_ufunc(translate, 
                               da.isel(run=slice(0,1), time=slice(0,10)),
                                vectorize=True,
                                dask='parallelized', 
                                output_dtypes=['float'],
                                input_core_dims=[['variable', 'lat', 'lon']],
                                output_core_dims=[['variable', 'lat', 'lon']],)



In [5]:
from dask.diagnostics import ProgressBar
with ProgressBar():
    ds_translated.to_dataset(dim='variable').to_zarr(args.output_zarr, consolidated=True, mode='w')

[########################################] | 100% Completed |  0.5s
