In [1]:
import sys
sys.path.append('/home/mksoll/DL4WeatherAndClimate')

In [2]:
import os

import zarr
import numpy as np
import xarray as xr
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

from src.era5_dataset import ERA5Dataset, TimeMode
from src.fuxi_ligthning import FuXi
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload
# Batch Size darf nicht größer als 1 gewählt werden, sonst funktioniert die Logik unten beim schreiben nicht

In [3]:
bs = 1
autoregression_steps = 10*4
timesteps_cnt = 2*4
levels_cnt = 2
vars_cnt = 5
lats_cnt = 121
lons_cnt = 240
start_time = "2019-12-31T12:00:00"
end_time = "2020-12-31T18:00:00"
model_path = "/home/mksoll/DL4WeatherAndClimate/models/epoch=13-step=613533.ckpt"

In [4]:
model = FuXi.load_from_checkpoint(model_path)
model: FuXi
model.set_autoregression_steps(autoregression_steps)
dataset = test_ds = ERA5Dataset(
    "/home/mksoll/DL4WeatherAndClimate/data/1959-2022-6h-240x121_equiangular_with_poles_conservative.zarr",
    TimeMode.BETWEEN,
    start_time=start_time,
    end_time=end_time,
    max_autoregression_steps=autoregression_steps,
    zarr_col_names='gcloud'
)
dl = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=os.cpu_count() // 2, pin_memory=True)


In [5]:
store = zarr.DirectoryStore('./preds.zarr')
root = zarr.group(store=store, overwrite=True)

latitude = root.create_dataset('latitude', shape=(0,), chunks=(lats_cnt,), dtype=np.float64, fill_value=-99999)
levels = root.create_dataset('level', shape=(0,), chunks=(levels_cnt,), dtype=np.int32, fill_value=-99)
longitude = root.create_dataset('longitude', shape=(0,), chunks=(lons_cnt,), dtype=np.float64, fill_value=-99999)
pred_timedelta = root.create_dataset('prediction_timedelta', shape=(0,), chunks=(autoregression_steps,),
                                     dtype='timedelta64[ns]')
# die Zeit muss noch richtig gesetzt werden, wahrscheinlich dann über das Dataset
time = root.create_dataset('time', shape=(0,), chunks=(timesteps_cnt,), dtype='datetime64[ns]', fill_value=-99999)

temp = root.create_dataset('temperature', shape=(0, autoregression_steps, levels_cnt, lats_cnt, lons_cnt),
                           dtype=np.float64,
                           chunks=(16, autoregression_steps, levels_cnt, lats_cnt, lons_cnt), fill_value=-99999)
temp.attrs['_ARRAY_DIMENSIONS'] = ['time', 'prediction_timedelta', 'level', 'latitude', 'longitude']
humid = root.create_dataset('specific_humidity', shape=(0, autoregression_steps, levels_cnt, lats_cnt, lons_cnt),
                            dtype=np.float64,
                            chunks=(16, autoregression_steps, levels_cnt, lats_cnt, lons_cnt), fill_value=-99999)
humid.attrs['_ARRAY_DIMENSIONS'] = ['time', 'prediction_timedelta', 'level', 'latitude', 'longitude']
uwind = root.create_dataset('u_component_of_wind', shape=(0, autoregression_steps, levels_cnt, lats_cnt, lons_cnt),
                            dtype=np.float64,
                            chunks=(16, autoregression_steps, levels_cnt, lats_cnt, lons_cnt), fill_value=-99999)
uwind.attrs['_ARRAY_DIMENSIONS'] = ['time', 'prediction_timedelta', 'level', 'latitude', 'longitude']
vwind = root.create_dataset('v_component_of_wind', shape=(0, autoregression_steps, levels_cnt, lats_cnt, lons_cnt),
                            dtype=np.float64,
                            chunks=(16, autoregression_steps, levels_cnt, lats_cnt, lons_cnt), fill_value=-99999)
vwind.attrs['_ARRAY_DIMENSIONS'] = ['time', 'prediction_timedelta', 'level', 'latitude', 'longitude']
geo = root.create_dataset('geopotential', shape=(0, autoregression_steps, levels_cnt, lats_cnt, lons_cnt),
                          dtype=np.float64,
                          chunks=(16, autoregression_steps, levels_cnt, lats_cnt, lons_cnt), fill_value=-99999)
geo.attrs['_ARRAY_DIMENSIONS'] = ['time', 'prediction_timedelta', 'level', 'latitude', 'longitude']


In [6]:
latitude.append(np.linspace(-90, 90, lats_cnt))
latitude.attrs['_ARRAY_DIMENSIONS'] = ['latitude']

levels.append([500, 850])
levels.attrs['_ARRAY_DIMENSIONS'] = ['level']

longitude.append(np.linspace(0, 358.5, lons_cnt))
longitude.attrs['_ARRAY_DIMENSIONS'] = ['longitude']

timedelta = [np.timedelta64(6 * i, 'h') for i in range(autoregression_steps)]
pred_timedelta.append(np.array(timedelta))
pred_timedelta.attrs['_ARRAY_DIMENSIONS'] = ['prediction_timedelta']

times = [np.datetime64(start_time) + 2* timedelta[1] + i * timedelta[1] for i in range(timesteps_cnt)]
time.append(times)
time.attrs['_ARRAY_DIMENSIONS'] = ['time']

In [11]:
mins = torch.Tensor([193.48901, -3.3835982e-05, -65.45247, -96.98215, -6838.8906])
maxs = torch.Tensor([324.80637, 0.029175894, 113.785934, 89.834595, 109541.625])
max_minus_min = maxs - mins
mins = mins[:, None, None]
max_minus_min = max_minus_min[:, None, None]

In [12]:
# range muss noch durch den dataloader ersetzt werden
for idx, batch in tqdm(enumerate(dl)):
    if idx > timesteps_cnt-1:
        break
    x,y = batch
    batch = x.cuda(), y.cuda()
    out = model.forward(batch)[0, :, :, :, :]
    preds = torch.reshape(out, (autoregression_steps, vars_cnt, out.shape[1] // vars_cnt, lats_cnt, lons_cnt))
    preds = (preds*max_minus_min+mins).numpy()
    temp.append(preds[None, :, 2:4, 0, :, :], axis=0)
    humid.append(preds[None, :, 2:4, 1, :, :], axis=0)
    uwind.append(preds[None, :, 2:4, 2, :, :], axis=0)
    vwind.append(preds[None, :, 2:4, 3, :, :], axis=0)
    geo.append(preds[None, :, 2:4, 4, :, :], axis=0)

8it [00:27,  3.38s/it]


In [13]:
zarr.consolidate_metadata(store)

<zarr.hierarchy.Group '/'>

In [14]:
import shutil
shutil.make_archive('preds_2020_more_steps.zarr', 'zip', './preds.zarr')

'/home/mksoll/DL4WeatherAndClimate/notebooks/preds_2020_more_steps.zarr.zip'