In [3]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

os.chdir("../")

# Number of files in train vs val

In [4]:
import torch
import math

In [5]:
samples_dir = "./data/era5_uk/samples/train"
num_files = len(os.listdir(samples_dir))
total_days = 31 * 7 + 28 + 30 * 4 # Jan, Mar, May, Jul, Aug, Oct, Dec

total_time_steps = total_days * 4
pred_length = 1
sample_length = pred_length + 2
total_samples = total_time_steps - sample_length + 1
total_batches = math.ceil(total_samples / 4) 

print("Training:")
print("num_files:", num_files)
print("total_time_steps:", total_time_steps)
print("total_days:", total_days)
print("total_samples:", total_samples)
print("total_batches:", total_batches)

Training:
num_files: 1460
total_time_steps: 1460
total_days: 365
total_samples: 1458
total_batches: 365


In [6]:
# Finding the validation files 

samples_dir = "./data/era5_uk/samples/val"
dirs = os.listdir(samples_dir)
num_files = len(dirs)
print(dirs)
print(num_files)

['10', '04', '01', '07']
4


In [7]:
# Verify number of validation samples
total_files = 0
for m in ['01', '04', '07', '10']:
    total_files += len(os.listdir(os.path.join(samples_dir, m)))
    

total_days = 31 + 30 + 31 + 31
total_time_steps = total_days * 4
pred_length = 28
sample_length = pred_length + 2
total_samples = total_time_steps - sample_length + 1
total_batches = math.ceil(total_samples / 4) 

print("Validation:")
print("total_files:", total_files)
print("total_time_steps:", total_time_steps)
print("total_days:", total_days)
print("total_samples:", total_samples)

Validation:
total_files: 492
total_time_steps: 492
total_days: 123
total_samples: 463


In [8]:
def steps(total_days):
    total_time_steps = total_days * 4
    pred_length = 28
    sample_length = pred_length + 2
    total_samples = total_time_steps - sample_length + 1
    return total_samples

days = [31, 30, 31, 31]
total_samples = sum([steps(d) for d in days])
total_batches = math.ceil(total_samples / 4)
print("total_batches:", total_batches)

total_batches: 94


In [9]:
import torch
from neural_lam.era5_dataset import ERA5UKDataset

In [10]:
dataset_name = "era5_uk"
batch_size = 4
n_workers = 4

train_set = ERA5UKDataset(
    dataset_name,
    pred_length=1,
    split="train",
    standardize=False,
)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("training_batches:", len(train_loader))

training_batches: 365


In [12]:
dataiter = iter(train_loader)
init_states, target_states, forcing = next(dataiter)
init_states

tensor([[[[ 2.0544e+05,  1.3444e+05,  1.0096e+05,  ..., -2.0289e-01,
           -1.5528e-02,  3.0334e-02],
          [ 2.0545e+05,  1.3444e+05,  1.0093e+05,  ..., -1.7996e-01,
           -8.8169e-03,  3.5368e-02],
          [ 2.0545e+05,  1.3444e+05,  1.0091e+05,  ..., -1.5256e-01,
            2.0826e-02,  3.8724e-02],
          ...,
          [ 2.0565e+05,  1.3685e+05,  1.0477e+05,  ...,  2.5573e-01,
           -2.7840e-01, -6.5797e-03],
          [ 2.0565e+05,  1.3686e+05,  1.0476e+05,  ...,  1.0696e-01,
           -2.9741e-01, -2.0562e-02],
          [ 2.0564e+05,  1.3686e+05,  1.0475e+05,  ..., -2.2974e-01,
           -4.0424e-01,  2.0826e-02]],

         [[ 2.0540e+05,  1.3449e+05,  1.0143e+05,  ...,  4.9909e-02,
            8.9620e-02,  7.0044e-02],
          [ 2.0541e+05,  1.3449e+05,  1.0142e+05,  ...,  3.1453e-02,
            9.6891e-02,  6.8926e-02],
          [ 2.0541e+05,  1.3449e+05,  1.0139e+05,  ...,  1.1877e-02,
            1.0640e-01,  6.6688e-02],
          ...,
     

In [13]:
val_set = ERA5UKDataset(
    dataset_name,
    pred_length=28,
    split="val",
    standardize=False,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("validation_set batches:", len(val_loader))

validation_set batches: 94


In [14]:
init_states, target_states, forcing = val_set[0]

In [15]:
print("init_states:", init_states.shape)
print("target_states:", target_states.shape)
print("forcing:", forcing.shape)

init_states: torch.Size([2, 3705, 48])
target_states: torch.Size([28, 3705, 48])
forcing: torch.Size([28, 3705, 12])


# Test Time Resolution Dataset

In [19]:
dataset_name = "era5_uk"
batch_size = 4
n_workers = 4
subsample_step = 2

train_set = ERA5UKDataset(
    dataset_name,
    pred_length=1,
    subsample_step=subsample_step,
    split="train",
    standardize=False,
)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("training_batches:", len(train_loader))

training_batches: 182


In [20]:
val_set = ERA5UKDataset(
    dataset_name,
    pred_length=28,
    subsample_step=subsample_step,
    split="val",
    standardize=False,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("validation_set batches:", len(val_loader))

validation_set batches: 33


# Test Multi Resolution Dataset + Dataloader

In [2]:
import torch 
from neural_lam.era5_dataset import ERA5MultiResolutionDataset

In [9]:
datasets = ["era5_uk_small", "era5_uk_big_coarse"]
train_set = ERA5MultiResolutionDataset(
    datasets,
    split="train",
)

batch_size = 4
n_workers = 4
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=n_workers,
)
dataiter = iter(train_loader)

In [10]:
init_states, target_states, forcing = next(dataiter)

In [13]:
init_states[0].shape

torch.Size([4, 2, 1840, 48])