In [None]:
import os, sys
import argparse
from datetime import datetime as dt
import numpy as np
import xarray as xr
from unet_model import build_unet
from wgan_model import WGAN 
from wgan_model import critic_model
from handle_data_unet import HandleUnetData

In [None]:
import importlib
importlib.reload(sys.modules['wgan_model'])
from wgan_model import WGAN 

In [None]:
# set diretcories and (hyper-)parameters for WGAN
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_ifs/netcdf_data/all_files/"
outdir = "/p/project/deepacf/maelstrom/langguth1/downscaling_jsc_repo/downscaling_unet/trained_models"

z_branch = True

lr_gen = 5.e-05
lr_gen_end = lr_gen/10.
lr_critic = 1.e-06
lr_decay = True
nepochs = 1
d_steps = 6


# Read training and validation data
ds_train, ds_val = xr.open_dataset(os.path.join(datadir, "era5_to_ifs_train_corrected.nc")), \
                   xr.open_dataset(os.path.join(datadir, "era5_to_ifs_val_corrected.nc"))

print("Datasets for trining, validation and testing loaded.")

wgan_model = WGAN(build_unet, critic_model,
                  {"lr_decay": lr_decay, "lr_gen": lr_gen, "lr_critic": lr_critic, "lr_gen_end": lr_gen_end,
                   "train_epochs": nepochs, "d_steps": d_steps, "z_branch": z_branch})

In [None]:
# prepare data
def reshape_ds(ds):
    da = ds.to_array(dim="variables")  # .squeeze()
    da = da.transpose(..., "variables")
    return da

da_train, da_val = reshape_ds(ds_train), reshape_ds(ds_val)

norm_dims = ["time", "lat", "lon"]
da_train, mu_train, std_train = HandleUnetData.z_norm_data(da_train, dims=norm_dims, return_stat=True)
da_val = HandleUnetData.z_norm_data(da_val, mu=mu_train, std=std_train)
print("Datat prepared successfully!")

In [None]:
print("Start compiling WGAN-model.")
train_iter, val_iter = wgan_model.compile(da_train.astype(np.float32), da_val.astype(np.float32))

In [None]:
# train model
print("Start training of WGAN...")
history = wgan_model.fit(train_iter, val_iter)

In [None]:
model_name = "wgan_lr1e-05_epochs1_opt_split_era5_ifs"

savedir = os.path.join("../downscaling_unet/trained_models/", model_name)
os.makedirs(savedir, exist_ok=True)

In [None]:
#import copy
#wgan_model_save = copy.copy(wgan_model)

print(train_iter)

Solution to save model was posted here: https://www.reddit.com/r/tensorflow/comments/szqsgd/keras_how_to_save_the_vae_from_the_official/

In [None]:
wgan_model.generator.save(os.path.join(savedir, "wgan_lr1e-05_epochs30_opt_split_era5_ifs_gen"))
wgan_model.critic.save(os.path.join(savedir, "wgan_lr1e-05_epochs30_opt_split_era5_ifs_gen_critic"))