In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pygrib
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [2]:
grib_file = '../data2021-2022.grib'
grbs = pygrib.open(grib_file)
grbs.seek(0)
grb = grbs[1]
lats,lons = grb.latlons()

In [None]:
def draw_poland(ax, X, matrix_name, cmap):
    data_crs = ccrs.PlateCarree()
    ax.set_extent([14, 25, 49, 55])
    ax.add_feature(cfeature.COASTLINE.with_scale('50m'))
    ax.add_feature(cfeature.BORDERS)
    contour_plot = ax.contourf(lons, lats, X, transform=data_crs, cmap=cmap)
    ax.set_title(matrix_name)
    plt.colorbar(contour_plot)
    
def plot_matrices(features, fh, y_test, y_hat):
    max_samples = 1

    for i in range(max_samples):
        y_test_sample, y_hat_sample = y_test[i], y_hat[i]
        fig, axs = plt.subplots(
            features,
            3 * fh,
            figsize=(10 * fh, 3 * features),
            subplot_kw={'projection': map_crs}
        )

        for j in range(features):
            cur_feature = f'f{j}'
            for k in range(3 * fh):
                ts = k // 3
                ax = axs[j, k]
                if k % 3 == 0:
                    title = rf"$X_{{{cur_feature},t+{ts+1}}}$"
                    value = y_test[i, ..., ts, j]
                    cmap = plt.cm.coolwarm
                elif k % 3 == 1:
                    title = rf"$\hat{{X}}_{{{cur_feature},t+{ts+1}}}$"
                    value = y_hat[i, ..., ts, j]
                    cmap = plt.cm.coolwarm
                else:
                    title = rf"$|X - \hat{{X}}|_{{{cur_feature},t+{ts+1}}}$"
                    value = np.abs(y_test[i, ..., ts, j] - y_hat[i, ..., ts, j])
                    cmap = "binary"

                draw_poland(ax, value, title, cmap)

        plt.tight_layout()
        plt.show()
    return axs

features = 6
fh = 1

latitude = 32
longitude = 48

y_hat = np.random.rand(1, latitude, longitude, fh, features)
y_test = np.random.rand(1, latitude, longitude, fh, features)

plot_matrices(features, fh, y_test, y_hat)