In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from climate_learn.data.climate_dataset.args import ERA5Args
from climate_learn.data.task.args import PretrainingArgs, ForecastingArgs
from climate_learn.data.dataset import MapDatasetArgs
from climate_learn.data import DataModule

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
root_dir = "/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg"
variables = ["geopotential_500", "temperature_850"]
history = 1
subsample = 6
pred_range = 72
train_years = range(1979, 2015)
val_years = range(2015, 2017)
test_years = range(2017, 2019)

In [4]:
climate_dataset_args = ERA5Args(
    root_dir,
    variables,
    train_years,
    split="train"
)

In [5]:
pretraining_args = PretrainingArgs(variables, subsample)
train_dataset_args = MapDatasetArgs(climate_dataset_args, pretraining_args)
val_dataset_args = train_dataset_args.create_copy({
    "climate_dataset_args": {"years": val_years, "split": "val"}
})
test_dataset_args = val_dataset_args.create_copy({
    "climate_dataset_args": {"years": test_years, "split": "test"}
})

In [6]:
dm = DataModule(train_dataset_args, val_dataset_args, test_dataset_args)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:01<00:00, 31.85it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 74.47it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 55.01it/s]


In [7]:
for batch in dm.train_dataloader():
    break

In [13]:
x, times = batch[0], batch[1].flatten()
var1 = x[:,0].unsqueeze(1)
var2 = x[:,1].unsqueeze(1)

In [14]:
var1.shape

torch.Size([64, 1, 32, 64])

In [15]:
var2.shape

torch.Size([64, 1, 32, 64])

In [16]:
times.shape

torch.Size([64])

In [17]:
var1.mean()

tensor(0.0159)

In [18]:
var1.var()

tensor(0.9963)

In [19]:
var2.mean()

tensor(0.0182)

In [20]:
var2.var()

tensor(1.0058)

In [47]:
import torch
import torch.nn.functional as F
import numpy as np

In [24]:
end_of_data = torch.from_numpy(
    np.array(["2018-12-31T23:00:00"], dtype="datetime64[h]").astype(float)
)

In [50]:
times.sort()

torch.return_types.sort(
values=tensor([ 91662.,  92490.,  96192., 100854., 101382., 113484., 120060., 120852.,
        132684., 135774., 136152., 145260., 146808., 147870., 153522., 158760.,
        159930., 168660., 171360., 183012., 192150., 193254., 193662., 196704.,
        197046., 197694., 198546., 202680., 203712., 206106., 213408., 218160.,
        224316., 224448., 224952., 233280., 238566., 257490., 270078., 278256.,
        282618., 288474., 289518., 290028., 293238., 293856., 296112., 298548.,
        302760., 304746., 307722., 316200., 318486., 319080., 321678., 327630.,
        331866., 351828., 362076., 367074., 372666., 373680., 374928., 379776.],
       dtype=torch.float64),
indices=tensor([26, 54, 23,  7, 38, 31, 39, 55,  5, 58, 33,  9, 25, 21, 42, 22,  0, 28,
        32, 24, 20, 43, 19, 60, 34, 27,  3, 17, 35, 48, 18, 46,  1, 16,  8, 36,
        15,  6, 30, 45, 61, 29, 37, 10, 13, 56, 40, 14, 62, 44, 47, 52, 53, 12,
         2, 51, 57, 11, 49,  4, 41, 50, 59, 63]))

In [83]:
new_times = torch.tensor([91662, 91668, 100854], dtype=float)
n = new_times / (end_of_data * scale)
n = n.repeat((3, 1))
n = 1 - torch.abs(n - n.T)
n

tensor([[ 1.0000e+00, -3.9689e-01, -2.1390e+03],
        [-3.9689e-01,  1.0000e+00, -2.1376e+03],
        [-2.1390e+03, -2.1376e+03,  1.0000e+00]], dtype=torch.float64)

In [84]:
F.softmax(n, dim=0)

tensor([[0.8017, 0.1983, 0.0000],
        [0.1983, 0.8017, 0.0000],
        [0.0000, 0.0000, 1.0000]], dtype=torch.float64)

In [82]:
scale = 1e-5
t = times / (end_of_data * scale)
t = t.repeat((x.shape[0], 1))
t = 1 - torch.abs(t - t.T)
t

tensor([[ 1.0000e+00, -1.4989e+04, -3.7656e+04,  ..., -2.8563e+04,
         -3.3252e+04, -5.1182e+04],
        [-1.4989e+04,  1.0000e+00, -2.2666e+04,  ..., -1.3573e+04,
         -1.8262e+04, -3.6192e+04],
        [-3.7656e+04, -2.2666e+04,  1.0000e+00,  ..., -9.0927e+03,
         -4.4034e+03, -1.3525e+04],
        ...,
        [-2.8563e+04, -1.3573e+04, -9.0927e+03,  ...,  1.0000e+00,
         -4.6883e+03, -2.2619e+04],
        [-3.3252e+04, -1.8262e+04, -4.4034e+03,  ..., -4.6883e+03,
          1.0000e+00, -1.7929e+04],
        [-5.1182e+04, -3.6192e+04, -1.3525e+04,  ..., -2.2619e+04,
         -1.7929e+04,  1.0000e+00]], dtype=torch.float64)

In [77]:
F.softmax(t, dim=0)[26,26]

tensor(0.8949, dtype=torch.float64)

In [78]:
F.softmax(t, dim=0)[26,54]

tensor(0.1051, dtype=torch.float64)

In [79]:
F.softmax(t, dim=0)[26,63]

tensor(0., dtype=torch.float64)

In [33]:
import datetime

In [80]:
datetime.datetime.fromtimestamp(193254*60*60)

datetime.datetime(1992, 1, 17, 22, 0)

In [31]:
times.min()

tensor(91662., dtype=torch.float64)

In [32]:
times.max()

tensor(379776., dtype=torch.float64)

In [81]:
datetime.datetime.fromtimestamp(193662*60*60)

datetime.datetime(1992, 2, 3, 22, 0)