In [1]:
import numpy as np
import xarray as xr
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
from src.score import *
import re

import generate_data as gd

In [2]:
# resnet/whole_mm_indiv_data.py using 9 blocks

In [3]:
DATADIR = '/rds/general/user/mc4117/home/WeatherBench/data/'

z500_valid = load_test_data(f'{DATADIR}geopotential_500', 'z')
t850_valid = load_test_data(f'{DATADIR}temperature_850', 't')
valid = xr.merge([z500_valid, t850_valid])


In [4]:
# For the data generator all variables have to be merged into a single dataset.
var_dict = {
    'geopotential': ('z', [500, 850]),
    'temperature': ('t', [500, 850]),
    'specific_humidity': ('q', [850]),
    '2m_temperature': ('t2m', None),
    'potential_vorticity': ('pv', [50, 100]),
    'constants': ['lsm', 'orography']
}

# For the data generator all variables have to be merged into a single dataset.
ds = [xr.open_mfdataset(f'{DATADIR}/{var}/*.nc', combine='by_coords') for var in var_dict.keys()]
ds_whole = xr.merge(ds, compat = 'override')

# load all training data
ds_train = ds_whole.sel(time=slice('1979', '2016'))
ds_test = ds_whole.sel(time=slice('2017', '2018'))

class DataGenerator(keras.utils.Sequence):
    def __init__(self, ds, var_dict, lead_time, batch_size=32, shuffle=True, load=True,
                 mean=None, std=None, output_vars=None):
        """
        Data generator for WeatherBench data.
        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
        Args:
            ds: Dataset containing all variables
            var_dict: Dictionary of the form {'var': level}. Use None for level if data is of single level
            lead_time: Lead time in hours
            batch_size: Batch size
            shuffle: bool. If True, data is shuffled.
            load: bool. If True, datadet is loaded into RAM.
            mean: If None, compute mean from data.
            std: If None, compute standard deviation from data.
        """

        self.ds = ds
        self.var_dict = var_dict
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.lead_time = lead_time

        data = []
        level_names = []
        generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
        for long_var, params in var_dict.items():
            if long_var == 'constants':
                for var in params:
                    data.append(ds[var].expand_dims(
                        {'level': generic_level, 'time': ds.time}, (1, 0)
                    ))
                    level_names.append(var)
            else:
                var, levels = params
                try:
                    data.append(ds[var].sel(level=levels))
                    level_names += [f'{var}_{level}' for level in levels]
                except ValueError:
                    data.append(ds[var].expand_dims({'level': generic_level}, 1))
                    level_names.append(var)

        self.data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
        self.data['level_names'] = xr.DataArray(
            level_names, dims=['level'], coords={'level': self.data.level})
        if output_vars is None:
            self.output_idxs = range(len(dg_valid.data.level))
        else:
            self.output_idxs = [i for i, l in enumerate(self.data.level_names.values)
                                if any([bool(re.match(o, l)) for o in output_vars])]

        # Normalize
        self.mean = self.data.mean(('time', 'lat', 'lon')).compute() if mean is None else mean
#         self.std = self.data.std('time').mean(('lat', 'lon')).compute() if std is None else std
        self.std = self.data.std(('time', 'lat', 'lon')).compute() if std is None else std
        self.data = (self.data - self.mean) / self.std

        self.n_samples = self.data.isel(time=slice(0, -lead_time)).shape[0]
        self.init_time = self.data.isel(time=slice(None, -lead_time)).time
        self.valid_time = self.data.isel(time=slice(lead_time, None)).time

        self.on_epoch_end()

        # For some weird reason calling .load() earlier messes up the mean and std computations
        if load: print('Loading data into RAM'); self.data.load()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(self.n_samples / self.batch_size))

    def __getitem__(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
        X = self.data.isel(time=idxs).values
        y = self.data.isel(time=idxs + self.lead_time, level=self.output_idxs).values
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.n_samples)
        if self.shuffle == True:
            np.random.shuffle(self.idxs)        
            
import re

bs=32
lead_time=72
output_vars = ['z_500', 't_850']

# Create a training and validation data generator. Use the train mean and std for validation as well.
dg_train = DataGenerator(
    ds_train.sel(time=slice('1979', '2013')), var_dict, lead_time, batch_size=bs, load=True, output_vars = output_vars)

#dg_valid2 = DataGenerator(
#    ds_train.sel(time=slice('2015', '2016')), var_dict, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, shuffle=False, output_vars = output_vars)

dg_valid = DataGenerator(
    ds_train.sel(time=slice('2015', '2016')), var_dict, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, shuffle=False, output_vars = output_vars)

# Now also a generator for testing. Impartant: Shuffle must be False!
dg_test = DataGenerator(ds_test, var_dict, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std,
                         shuffle=False, output_vars=output_vars)

Loading data into RAM
Loading data into RAM
Loading data into RAM


In [6]:
pred_ensemble_temp = np.load('/rds/general/user/mc4117/ephemeral/saved_pred/whole_res_valid_do_9_temp_[300, 400, 500, 600, 700, 850].npy')
pred_ensemble_geo = np.load('/rds/general/user/mc4117/ephemeral/saved_pred/whole_res_valid_do_9_geo_[300, 400, 500, 600, 700, 850].npy')

In [11]:
samples = 30
preds_temp = xr.Dataset({
    'z': xr.DataArray(
        pred_ensemble_temp[1, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_valid.data.time[72:], 'lat': dg_valid.data.lat, 'lon': dg_valid.data.lon, 'ens': 60+ np.arange(samples), 
                },
    ),
    't': xr.DataArray(
        pred_ensemble_temp[0, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_valid.data.time[72:], 'lat': dg_valid.data.lat, 'lon': dg_valid.data.lon, 'ens': 60 + np.arange(samples), 
                },
    )
})

preds_geo = xr.Dataset({
    'z': xr.DataArray(
        pred_ensemble_geo[0, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_valid.data.time[72:], 'lat': dg_valid.data.lat, 'lon': dg_valid.data.lon, 'ens': 80+ np.arange(samples), 
                },
    ),
    't': xr.DataArray(
        pred_ensemble_geo[1, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_valid.data.time[72:], 'lat': dg_valid.data.lat, 'lon': dg_valid.data.lon, 'ens': 80 + np.arange(samples), 
                },
    )
})

In [13]:
ens_4 = preds_temp.isel(ens = 0).copy()

for i in range(1, len(preds_temp.ens)):
    ens_4 += preds_temp.isel(ens = i).copy()
    
ens_5 = preds_geo.isel(ens = 0).copy()

for i in range(1, len(preds_geo.ens)):
    ens_5 += preds_geo.isel(ens = i).copy()
    

In [15]:
temp_levels = xr.open_dataset('/rds/general/user/mc4117/ephemeral/saved_pred/9_temp_[300, 400, 500, 600, 700, 850]_preds_newval.nc')
geo_levels = xr.open_dataset('/rds/general/user/mc4117/ephemeral/saved_pred/9_geo_[300, 400, 500, 600, 700, 850]_preds_newval.nc')

In [8]:
stop
ens_1_avg = ens_1/20
ens_1_avg.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/sh_5_preds_newval.nc')

ens_2_avg = ens_2/20
ens_2_avg.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/pv_5_preds_newval.nc')

ens_3_avg = ens_3/20
ens_3_avg.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/const_5_preds_newval.nc')

ens_4_avg = ens_4/20
ens_4_avg.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/temp_5_preds_newval.nc')

ens_5_avg = ens_5/20
ens_5_avg.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/geo_5_preds_newval.nc')

## test outputs

In [9]:
pred_ensemble_sh_test = np.load('/rds/general/user/mc4117/ephemeral/saved_pred/whole_res_testing_do_5_specific_humidity.npy')
pred_ensemble_pv_test = np.load('/rds/general/user/mc4117/ephemeral/saved_pred/whole_res_testing_do_5_pot_vort.npy')
pred_ensemble_const_test = np.load('/rds/general/user/mc4117/ephemeral/saved_pred/whole_res_testing_do_5_const.npy')
pred_ensemble_temp_test = np.load('/rds/general/user/mc4117/ephemeral/saved_pred/whole_res_testing_do_5_temp.npy')
pred_ensemble_geo_test = np.load('/rds/general/user/mc4117/ephemeral/saved_pred/whole_res_testing_do_5_geo.npy')

In [14]:
samples = 20
preds_sh_test = xr.Dataset({
    'z': xr.DataArray(
        pred_ensemble_sh_test[0, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': np.arange(samples), 
                },
    ),
    't': xr.DataArray(
        pred_ensemble_sh_test[1, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': np.arange(samples), 
                },
    )
})

preds_pv_test = xr.Dataset({
    'z': xr.DataArray(
        pred_ensemble_pv_test[0, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 20+ np.arange(samples), 
                },
    ),
    't': xr.DataArray(
        pred_ensemble_pv_test[1, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 20+ np.arange(samples), 
                },
    )
})

preds_const_test = xr.Dataset({
    'z': xr.DataArray(
        pred_ensemble_const_test[0, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 40+ np.arange(samples), 
                },
    ),
    't': xr.DataArray(
        pred_ensemble_const_test[1, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 40 + np.arange(samples), 
                },
    )
})

preds_temp_test = xr.Dataset({
    'z': xr.DataArray(
        pred_ensemble_temp_test[0, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 60+ np.arange(samples), 
                },
    ),
    't': xr.DataArray(
        pred_ensemble_temp_test[1, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 60 + np.arange(samples), 
                },
    )
})

preds_geo_test = xr.Dataset({
    'z': xr.DataArray(
        pred_ensemble_geo_test[1, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 80+ np.arange(samples), 
                },
    ),
    't': xr.DataArray(
        pred_ensemble_geo_test[0, ...],
        dims=['time', 'lat', 'lon', 'ens'],
        coords={'time':dg_test.data.time[72:], 'lat': dg_test.data.lat, 'lon': dg_test.data.lon, 'ens': 80 + np.arange(samples), 
                },
    )
})

In [15]:
ens_1_test = preds_sh_test.isel(ens = 0).copy()

for i in range(1, len(preds_sh_test.ens)):
    ens_1_test += preds_sh_test.isel(ens = i).copy()
    
ens_2_test = preds_pv_test.isel(ens = 0).copy()

for i in range(1, len(preds_pv_test.ens)):
    ens_2_test += preds_pv_test.isel(ens = i).copy()

ens_3_test = preds_const_test.isel(ens = 0).copy()

for i in range(1, len(preds_const_test.ens)):
    ens_3_test += preds_const_test.isel(ens = i).copy()
    
ens_4_test = preds_temp_test.isel(ens = 0).copy()

for i in range(1, len(preds_temp_test.ens)):
    ens_4_test += preds_temp_test.isel(ens = i).copy()

ens_5_test = preds_geo_test.isel(ens = 0).copy()

for i in range(1, len(preds_geo_test.ens)):
    ens_5_test += preds_geo_test.isel(ens = i).copy()

In [16]:
ens_1_avg_test = ens_1_test/20
ens_1_avg_test.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/sh_5_preds_test.nc')

ens_2_avg_test = ens_2_test/20
ens_2_avg_test.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/pv_5_preds_test.nc')

ens_3_avg_test = ens_3_test/20
ens_3_avg_test.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/const_5_preds_test.nc')

ens_4_avg_test = ens_4_test/20
ens_4_avg_test.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/temp_5_preds_test.nc')

ens_5_avg_test = ens_5_test/20
ens_5_avg_test.to_netcdf('/rds/general/user/mc4117/ephemeral/saved_pred/geo_5_preds_test.nc')