In [237]:
import cv2
import pandas as pd
%load_ext autoreload
%autoreload
%matplotlib widget
%matplotlib inline

import os
import torch
import zarr
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

from torch.utils.data import DataLoader
from tqdm import tqdm
from src.era5_dataset import ERA5Dataset, TimeMode
from src.fuxi_ligthning import FuXi
from tqdm import tqdm
from ipywidgets import interact, IntSlider


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
model = FuXi.load_from_checkpoint('/Users/ksoll/git/DL4WeatherAndClimate/models/epoch=13-step=613533.ckpt',
                                  map_location=torch.device("cpu"))
model: FuXi
model.set_autoregression_steps(12)
dataset = test_ds = ERA5Dataset(
    "/Users/ksoll/git/DL4WeatherAndClimate/data/era5_6hourly.zarr",
    TimeMode.BETWEEN,
    start_time="2011-01-01T00:00:00",
    end_time="2011-01-31T18:00:00",
    max_autoregression_steps=12,
    zarr_col_names='lessig'
)
dl = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count() // 2, pin_memory=True)

In [240]:
total_params = sum(p.numel() for p in model.parameters())
total_params

26975248

In [9]:
# range muss noch durch den dataloader ersetzt werden
for idx, batch in tqdm(enumerate(dl)):
    out = model.forward(batch)[0, :, :, :, :]
    preds = torch.reshape(out, (12, 5, 5, 121, 240)).numpy()
    temp_pred = preds[:, 0, 2:4, :, :]
    humid_pred = preds[None, :, 1, 2:4, :, :]
    uwind_pred = preds[None, :, 2, 2:4, :, :]
    vwind_pred = preds[None, :, 3, 2:4, :, :]
    geo_pred = preds[None, :, 4, 2:4, :, :]
    break

0it [00:04, ?it/s]


<zarr.hierarchy.Group '/'>

In [233]:
ds_forecast = xr.open_dataset("/Users/ksoll/git/DL4WeatherAndClimate/data/preds_2020_more_steps.zarr")
mins = np.array([193.48901, -3.3835982e-05, -65.45247, -96.98215, -6838.8906])
maxs = np.array([324.80637, 0.029175894, 113.785934, 89.834595, 109541.625])
plot_mins = [-40 + 273.15, 0, -10, -30, -5000]
plot_maxs = [30 + 273.15, 0.015, 30, 30, 100_000]
maxs_minus_mins = maxs - mins
var_to_idx = {
    "temperature": 0,
    "t": 0,
    "specific_humidity": 1,
    "q": 1,
    "u_component_of_wind": 2,
    "u": 2,
    "v_component_of_wind": 3,
    "v": 3,
    "geopotential": 4,
    "z": 4
}

In [234]:
def save_forecast_timestep(variable, time_idx, pred_time_idx):
    var_idx = var_to_idx[variable]
    forecast = ds_forecast[variable].isel(time=time_idx).isel(prediction_timedelta=pred_time_idx).isel(level=1)
    forecast = forecast * maxs_minus_mins[var_idx] + mins[var_idx]
    fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.coastlines()
    im = forecast.plot(ax=ax, add_colorbar=False, transform=ccrs.PlateCarree(), vmin=plot_mins[var_idx], vmax=plot_maxs[var_idx])
    cbar = plt.colorbar(im, ax=ax, orientation='vertical', shrink=0.5)
    cbar.set_label(variable)
    im.colorbar.set_label(variable)
    time = np.datetime64("2020-01-01T00:00:00") + np.timedelta64(time_idx, "D") + 6*np.timedelta64(pred_time_idx, "h")
    ax.set_title(f'Prediction of {variable} on {time}')
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    data = cv2.cvtColor(data, cv2.COLOR_RGBA2BGR)
    plt.close(fig)
    return data


fps = 12
frame_size = (1920, 1080)

# Verwendungszweck:
for variable in ["temperature", "specific_humidity", "u_component_of_wind", "v_component_of_wind"]:
    for time_idx in range(1):
        path = f"/Users/ksoll/git/DL4WeatherAndClimate/figures/2020-01/{variable}_pred_{time_idx}"
        out = cv2.VideoWriter(path + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, frame_size)
        for pred_time_idx in tqdm(range(40)):
            data = save_forecast_timestep(variable, time_idx, pred_time_idx)
            bild_resized = cv2.resize(data, frame_size)
            cv2.imwrite(path + f'_{pred_time_idx}.png', bild_resized)
            out.write(bild_resized)
        out.release()

100%|██████████| 40/40 [00:13<00:00,  2.88it/s]
100%|██████████| 40/40 [00:14<00:00,  2.74it/s]
100%|██████████| 40/40 [00:13<00:00,  2.87it/s]
100%|██████████| 40/40 [00:13<00:00,  2.99it/s]


In [223]:
store = zarr.DirectoryStore("/Users/ksoll/git/DL4WeatherAndClimate/data/era5_6hourly_cpy.zarr")
sources = zarr.group(store=store)

In [224]:
les_to_gcloud = {
    "t": "temperature",
    "q": "specific_humidity",
    "u": "u_component_of_wind",
    "v": "v_component_of_wind",
    "z": "geopotential"
}


def save_obs_timestep(variable, time_idx, pred_time_idx):
    # Daten laden
    var_idx = var_to_idx[variable]
    obs = np.array(sources[variable])[time_idx + pred_time_idx, 1, :, :]
    obs = np.flipud(obs)
    obs = np.array(cv2.resize(obs, dsize=(360, 180), interpolation=cv2.INTER_CUBIC))
    obs = xr.DataArray(obs, dims=['latitude', 'longitude'])
    obs['latitude'] = np.linspace(-90, 90, 180)
    obs['longitude'] = np.linspace(0, 360, 360)
    fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.coastlines()
    im = obs.plot(ax=ax, add_colorbar=False, transform=ccrs.PlateCarree(), vmin=plot_mins[var_idx], vmax=plot_maxs[var_idx])
    cbar = plt.colorbar(im, ax=ax, orientation='vertical', shrink=0.5)
    cbar.set_label(les_to_gcloud[variable])
    im.colorbar.set_label(les_to_gcloud[variable])
    ax.set_title(f'Vorhersage von {les_to_gcloud[variable]}')
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    data = cv2.cvtColor(data, cv2.COLOR_RGBA2BGR)
    plt.close(fig)
    return data


fps = 12
frame_size = (1920, 1080)

for variable in ["t", "q", "u", "v"]:
    for time_idx in range(1):
        video_path = f"/Users/ksoll/git/DL4WeatherAndClimate/figures/{les_to_gcloud[variable]}_obs_{time_idx}.mp4"
        out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, frame_size)
        for pred_time_idx in tqdm(range(40)):
            data = save_obs_timestep(variable, time_idx, pred_time_idx)
            bild_resized = cv2.resize(data, frame_size)
            out.write(bild_resized)
        out.release()

100%|██████████| 40/40 [00:35<00:00,  1.13it/s]
100%|██████████| 40/40 [00:36<00:00,  1.11it/s]
100%|██████████| 40/40 [00:36<00:00,  1.10it/s]
100%|██████████| 40/40 [00:34<00:00,  1.15it/s]


In [ ]:
from src.score_torch import *

In [230]:
def rmse_plots(time_idx):
    var_idx = var_to_idx[variable]
    day_rmse = []
    rmse_sum = 0
    weights = np.cos(np.deg2rad(np.linspace(-90, 90, 121)))
    for pred_time_idx in range(40):
        obs = np.array(sources[variable])[time_idx + pred_time_idx, 1, :, :]
        obs = np.flipud(obs)
        
        forecast = ds_forecast[variable].isel(time=time_idx).isel(prediction_timedelta=pred_time_idx).isel(level=1)
        forecast = np.array(forecast * maxs_minus_mins[var_idx] + mins[var_idx])
        
        rmse = compute_weighted_rmse(forecast, obs, weights)
        rmse_sum += rmse
        print(rmse)
        
        if (pred_time_idx+1)%4 == 0:
            day_rmse.append(rmse_sum)
            rmse_sum = 0

rmse_plots(0)

KeyError: 'v_component_of_wind'

In [238]:
metrics = xr.open_dataset("/Users/ksoll/git/DL4WeatherAndClimate/data/deterministic.nc")
metrics

ValueError: found the following matches with the input file in xarray's IO backends: ['netcdf4', 'scipy']. But their dependencies may not be installed, see:
https://docs.xarray.dev/en/stable/user-guide/io.html 
https://docs.xarray.dev/en/stable/getting-started-guide/installing.html