In [2]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

os.chdir("../")

# Number of files in train vs val

In [4]:
import torch
import math

In [5]:
samples_dir = "./data/era5_uk/samples/train"
num_files = len(os.listdir(samples_dir))
total_days = 31 * 7 + 28 + 30 * 4 # Jan, Mar, May, Jul, Aug, Oct, Dec

total_time_steps = total_days * 4
pred_length = 1
sample_length = pred_length + 2
total_samples = total_time_steps - sample_length + 1
total_batches = math.ceil(total_samples / 4) 

print("Training:")
print("num_files:", num_files)
print("total_time_steps:", total_time_steps)
print("total_days:", total_days)
print("total_samples:", total_samples)
print("total_batches:", total_batches)

Training:
num_files: 1460
total_time_steps: 1460
total_days: 365
total_samples: 1458
total_batches: 365


In [6]:
# Finding the validation files 

samples_dir = "./data/era5_uk/samples/val"
dirs = os.listdir(samples_dir)
num_files = len(dirs)
print(dirs)
print(num_files)

['10', '04', '01', '07']
4


In [7]:
# Verify number of validation samples
total_files = 0
for m in ['01', '04', '07', '10']:
    total_files += len(os.listdir(os.path.join(samples_dir, m)))
    

total_days = 31 + 30 + 31 + 31
total_time_steps = total_days * 4
pred_length = 28
sample_length = pred_length + 2
total_samples = total_time_steps - sample_length + 1
total_batches = math.ceil(total_samples / 4) 

print("Validation:")
print("total_files:", total_files)
print("total_time_steps:", total_time_steps)
print("total_days:", total_days)
print("total_samples:", total_samples)

Validation:
total_files: 492
total_time_steps: 492
total_days: 123
total_samples: 463


In [8]:
def steps(total_days):
    total_time_steps = total_days * 4
    pred_length = 28
    sample_length = pred_length + 2
    total_samples = total_time_steps - sample_length + 1
    return total_samples

days = [31, 30, 31, 31]
total_samples = sum([steps(d) for d in days])
total_batches = math.ceil(total_samples / 4)
print("total_batches:", total_batches)

total_batches: 94


In [9]:
import torch
from neural_lam.era5_dataset import ERA5UKDataset

In [10]:
dataset_name = "era5_uk"
batch_size = 4
n_workers = 4

train_set = ERA5UKDataset(
    dataset_name,
    pred_length=1,
    split="train",
    standardize=False,
)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("training_batches:", len(train_loader))

training_batches: 365


In [12]:
dataiter = iter(train_loader)
init_states, target_states, forcing = next(dataiter)
init_states

tensor([[[[ 2.0544e+05,  1.3444e+05,  1.0096e+05,  ..., -2.0289e-01,
           -1.5528e-02,  3.0334e-02],
          [ 2.0545e+05,  1.3444e+05,  1.0093e+05,  ..., -1.7996e-01,
           -8.8169e-03,  3.5368e-02],
          [ 2.0545e+05,  1.3444e+05,  1.0091e+05,  ..., -1.5256e-01,
            2.0826e-02,  3.8724e-02],
          ...,
          [ 2.0565e+05,  1.3685e+05,  1.0477e+05,  ...,  2.5573e-01,
           -2.7840e-01, -6.5797e-03],
          [ 2.0565e+05,  1.3686e+05,  1.0476e+05,  ...,  1.0696e-01,
           -2.9741e-01, -2.0562e-02],
          [ 2.0564e+05,  1.3686e+05,  1.0475e+05,  ..., -2.2974e-01,
           -4.0424e-01,  2.0826e-02]],

         [[ 2.0540e+05,  1.3449e+05,  1.0143e+05,  ...,  4.9909e-02,
            8.9620e-02,  7.0044e-02],
          [ 2.0541e+05,  1.3449e+05,  1.0142e+05,  ...,  3.1453e-02,
            9.6891e-02,  6.8926e-02],
          [ 2.0541e+05,  1.3449e+05,  1.0139e+05,  ...,  1.1877e-02,
            1.0640e-01,  6.6688e-02],
          ...,
     

In [13]:
val_set = ERA5UKDataset(
    dataset_name,
    pred_length=28,
    split="val",
    standardize=False,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("validation_set batches:", len(val_loader))

validation_set batches: 94


In [14]:
init_states, target_states, forcing = val_set[0]

In [15]:
print("init_states:", init_states.shape)
print("target_states:", target_states.shape)
print("forcing:", forcing.shape)

init_states: torch.Size([2, 3705, 48])
target_states: torch.Size([28, 3705, 48])
forcing: torch.Size([28, 3705, 12])


# Test Multi Time Resolution Dataset

In [2]:
import numpy as np
import torch
from neural_lam.era5_dataset import ERA5MultiTimeDataset

In [3]:
subsample_steps = [2, 1]
dataset_name = "era5_uk"
pred_length = 28
train_set = ERA5MultiTimeDataset(
    dataset_name,
    subsample_steps=subsample_steps,
    split="train",
)

batch_size = 4
n_workers = 4
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=n_workers,
)
dataiter = iter(train_loader)

In [4]:
init_states, target_states, forcing = next(dataiter)

In [5]:
print(len(train_set.datasets[0]))
print(len(train_set.datasets[1]))

1401
1431


In [6]:
print(len(train_set.datasets[0].sample_names))
print(len(train_set.datasets[1].sample_names))

1460
1460


In [7]:
pred_length = 8
max_subsample_step = 2
sample_length = pred_length + 2 * max_subsample_step

total_step_samples = 50
total_samples = total_step_samples - sample_length + 1

print(sample_length)
print(total_samples)

12
39


In [8]:
n_samples = 12
pred_length = 8
subsample_step = 2

sample_length = (pred_length + 2) * subsample_step
assert sample_length <= n_samples


items = n_samples - sample_length + 1
print("items:", items)

AssertionError: 

In [None]:
items

3

In [None]:
sample_paths = train_set.datasets[1].sample_names
sample_paths

['20220101000000.npy',
 '20220101060000.npy',
 '20220101120000.npy',
 '20220101180000.npy',
 '20220102000000.npy',
 '20220102060000.npy',
 '20220102120000.npy',
 '20220102180000.npy',
 '20220103000000.npy',
 '20220103060000.npy',
 '20220103120000.npy',
 '20220103180000.npy',
 '20220104000000.npy',
 '20220104060000.npy',
 '20220104120000.npy',
 '20220104180000.npy',
 '20220105000000.npy',
 '20220105060000.npy',
 '20220105120000.npy',
 '20220105180000.npy',
 '20220106000000.npy',
 '20220106060000.npy',
 '20220106120000.npy',
 '20220106180000.npy',
 '20220107000000.npy',
 '20220107060000.npy',
 '20220107120000.npy',
 '20220107180000.npy',
 '20220108000000.npy',
 '20220108060000.npy',
 '20220108120000.npy',
 '20220108180000.npy',
 '20220109000000.npy',
 '20220109060000.npy',
 '20220109120000.npy',
 '20220109180000.npy',
 '20220110000000.npy',
 '20220110060000.npy',
 '20220110120000.npy',
 '20220110180000.npy',
 '20220111000000.npy',
 '20220111060000.npy',
 '20220111120000.npy',
 '202201111

In [None]:
meps_sample_path = "data/meps_example/samples/train/nwp_2022040100_mbr000.npy"
full_sample = torch.tensor(
    np.load(meps_sample_path), dtype=torch.float32
)  # (N_t', dim_x, dim_y, d_features')
full_sample.shape

torch.Size([65, 268, 238, 18])

# Test Time Resolution Dataset

In [2]:
dataset_name = "era5_uk"
batch_size = 4
n_workers = 4
subsample_step = 2

train_set = ERA5UKDataset(
    dataset_name,
    pred_length=1,
    subsample_step=subsample_step,
    split="train",
    standardize=False,
)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("training_batches:", len(train_loader))

NameError: name 'ERA5UKDataset' is not defined

In [None]:
val_set = ERA5UKDataset(
    dataset_name,
    pred_length=28,
    subsample_step=subsample_step,
    split="val",
    standardize=False,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size,
    shuffle=True,
    num_workers=n_workers,
)

print("validation_set batches:", len(val_loader))

validation_set batches: 33


# Test Multi Resolution Dataset + Dataloader

In [2]:
import torch 
from neural_lam.era5_dataset import ERA5MultiResolutionDataset

In [9]:
datasets = ["era5_uk_small", "era5_uk_big_coarse"]
train_set = ERA5MultiResolutionDataset(
    datasets,
    split="train",
)

batch_size = 4
n_workers = 4
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=n_workers,
)
dataiter = iter(train_loader)

In [10]:
init_states, target_states, forcing = next(dataiter)

In [13]:
init_states[0].shape

torch.Size([4, 2, 1840, 48])

# Test new dataset

In [3]:
import os
import glob
import torch
import numpy as np
import datetime as dt
import math

from neural_lam import utils, constants

In [4]:
class TestERA5Dataset(torch.utils.data.Dataset):
    """
    ERA5 UK dataset
    
    N_t' = 65
    N_t = 65//subsample_step (= 21 for 3h steps)
    N_x = 268 (width)
    N_y = 238 (height)
    N_grid = 268x238 = 63784 (total number of grid nodes)
    d_features = 17 (d_features' = 18)
    d_forcing = 5
    """
    def __init__(
        self,
        dataset_name,
        subsample_steps=[2, 1],
        pattern="*.npy",
        pred_length=28, 
        split="train", 
        year=2022,
        month=None,
        subsample_step=1,
        standardize=False,
        subset=False,
        control_only=False,
        args=None,
    ):
        super().__init__()
        assert split in ("train", "val", "test"), "Unknown dataset split"
        self.sample_dir_path = os.path.join("data", dataset_name, "samples", split)
        self.args = args
        self.split = split
        
        pattern = f"{year}{pattern}"
        if self.split == "train":
            sample_paths = glob.glob(os.path.join(self.sample_dir_path, pattern))
            # example name: '20200101000000.npy'
            self.sample_names = [os.path.basename(path) for path in sample_paths] 
            self.sample_names.sort()
            self.sample_times = [dt.datetime.strptime(n, '%Y%m%d%H%M%S.npy') for n in self.sample_names]

        else:
            assert month is not None, "Month must be specified for validation/test dataset"
            month_dir = os.path.join(self.sample_dir_path, month)
            sample_paths = glob.glob(os.path.join(month_dir, pattern))
            self.sample_names = [os.path.join(month, os.path.basename(path)) for path in sample_paths]
            self.sample_names.sort()
            self.sample_times = [dt.datetime.strptime(n[3:], '%Y%m%d%H%M%S.npy') for n in self.sample_names]

        if subset:
            self.sample_names = self.sample_names[:50] # Limit to 50 samples
        
        # 2 init states, pred_length target states
        self.subsample_steps = [4, 2, 1]
        self.pred_length = pred_length
        self.sample_length = pred_length + 2 * self.subsample_steps[0]
        self.length = len(self.sample_names) - self.sample_length + 1
        
        print("pred_length:", self.pred_length)
        print("sample_length", self.sample_length)
        print("length:", self.length)
        
        assert (
            self.length > 0
        ), "Requesting too long time series samples"

        # Set up for standardization
        self.standardize = standardize
        if standardize:
            ds_stats = utils.load_dataset_stats(dataset_name, "cpu")
            self.data_mean, self.data_std = (
                ds_stats["data_mean"],
                ds_stats["data_std"],
            )
        

    def __len__(self):
        return self.length

    def _get_sample(self, sample_name):
        sample_path = os.path.join(self.sample_dir_path, f"{sample_name}")
        try:
            full_sample = torch.tensor(np.load(sample_path),
                    dtype=torch.float32) # (N_lon*N_lat, N_vars*N_levels)
        except ValueError:
            print(f"Failed to load {sample_path}")
        
        return full_sample
    
    def __getitem__(self, idx):
        _init_states = []
        _target_states = []
        _forcing_features = []
        
        _start_idx = idx * self.subsample_steps[0]
        _end_idx = _start_idx + self.sample_length
        for i in range(len(self.subsample_steps)):
            subsample_step = self.subsample_steps[i]
            
            if i == 0:
                start_idx = _start_idx
            else:
                start_idx = _start_idx + 2 * self.subsample_steps[0] - 2 * subsample_step
            
            # print("idx:", idx)
            # print("subsample_step:", subsample_step)
            # print("start_idx:", start_idx)
            # print()
            
            # === Sample ===
            prev_prev_state = self._get_sample(self.sample_names[start_idx])
            prev_state = self._get_sample(self.sample_names[start_idx+subsample_step])        

            # N_grid = N_x * N_y; d_features = N_vars * N_levels
            init_states = torch.stack((prev_prev_state, prev_state), dim=0) # (2, N_grid, d_features)
            
            target_states = []
            for i in range(start_idx + 2 * subsample_step, _end_idx, subsample_step):
                target_states.append(self._get_sample(self.sample_names[start_idx + i]))
            target_states = torch.stack(target_states, dim=0) # (sample_len-2, N_grid, d_features)
            
            if self.standardize:
                # Standardize sample
                init_states = (init_states - self.data_mean) / self.data_std
                target_states = (target_states - self.data_mean) / self.data_std
            
            # === Forcing features ===
            # Each step is 6 hours long
            hour_inc = torch.arange(len(target_states) + 2) * 6 * subsample_step # (sample_len,)
            init_dt = self.sample_times[start_idx]
            
            init_hour = init_dt.hour
            hour_of_day = init_hour + hour_inc

            start_of_year = dt.datetime(init_dt.year, 1, 1)
            init_seconds_into_year = (init_dt - start_of_year).total_seconds()
            seconds_into_year = init_seconds_into_year + hour_inc * 3600

            hour_angle = (hour_of_day / 24) * 2 * torch.pi 
            year_angle = (seconds_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi
            
            datetime_forcing = torch.stack(
                (
                    torch.sin(hour_angle),
                    torch.cos(hour_angle),
                    torch.sin(year_angle),
                    torch.cos(year_angle),
                ),
                dim=1,
            )  # (sample_len, 4)
            datetime_forcing = (datetime_forcing + 1) / 2 # Normalize to [0,1]
            datetime_forcing = datetime_forcing.unsqueeze(1).expand(
                -1, init_states.shape[1], -1
            )  # (sample_len, N_grid, 4)

            forcing = torch.cat(
                (
                    datetime_forcing[:-2],
                    datetime_forcing[1:-1],
                    datetime_forcing[2:],
                ),
                dim=2,
            ) # (sample_len-2, N_grid, 12)
            
            if self.args and self.args.no_forcing:
                forcing = torch.zeros(target_states.shape[0], target_states.shape[1], 0) # (sample_len-2, N_grid, d_forcing)
                
            _init_states.append(init_states)
            _target_states.append(target_states)
            _forcing_features.append(forcing)
        
        return _init_states, _target_states, _forcing_features
    
def era5_dataset(
    dataset_name,
    pattern="*.npy",
    pred_length=28, 
    split="train", 
    year=2022,
    subsample_step=1,
    standardize=False,
    subset=False,
    control_only=False,
    args=None,
):
    if split == "train":
        return TestERA5Dataset(
            dataset_name,
            pattern=pattern,
            pred_length=pred_length, 
            split=split, 
            year="2022",
            subsample_step=subsample_step,
            standardize=standardize,
            subset=subset,
            control_only=control_only,
            args=args,
        )
    else:
        datasets = []
        for month in constants.ERA5UKConstants.VAL_MONTHS:
            datasets.append(
                TestERA5Dataset(
                    dataset_name,
                    pattern=pattern,
                    pred_length=pred_length, 
                    split=split, 
                    year="2023",
                    month=month,
                    subsample_step=subsample_step,
                    standardize=standardize,
                    subset=subset,
                    control_only=control_only,
                    args=args,
                )
            )
        return torch.utils.data.ConcatDataset(datasets)
    
import neural_lam.era5_dataset as nl_era5 

In [5]:
from neural_lam.era5_dataset import ERA5MultiTimeDataset

dataset_name = "era5_uk"
batch_size = 2
n_workers = 4
split = "train"


train_set = ERA5MultiTimeDataset(
    dataset_name,
    subsample_steps=[4, 2, 1],
    pred_length=28,
    split=split,
)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=n_workers,
)

# init_states, target_states, forcing = train_set[0]

print(f"{split} batches :", len(train_loader))

dataiter = iter(train_loader)
init_states, target_states, forcing = next(dataiter)

# print("init_states:", init_states.shape)
# print("target_states:", target_states.shape)
# print("forcing:", forcing.shape)


pred_length: 28
sample_length 36
length: 1425
train batches : 713


In [14]:
print(init_states[0].shape)
print(init_states[1].shape)
print(init_states[2].shape)

print(target_states[0].shape)
print(target_states[1].shape)
print(target_states[2].shape)

print(forcing[0].shape)
print(forcing[1].shape)
print(forcing[2].shape)

torch.Size([2, 2, 3705, 48])
torch.Size([2, 2, 3705, 48])
torch.Size([2, 2, 3705, 48])
torch.Size([2, 7, 3705, 48])
torch.Size([2, 14, 3705, 48])
torch.Size([2, 28, 3705, 48])
torch.Size([2, 7, 3705, 12])
torch.Size([2, 14, 3705, 12])
torch.Size([2, 28, 3705, 12])


In [15]:
(init_states[1][:, 1] == init_states[2][:, 0]).all()

(target_states[0][:, 1] == target_states[1][:, 0]).all()

(target_states[0][:, -1] == target_states[1][:, -4]).all()

tensor(True)

In [67]:
dataset_name = "era5_uk"
batch_size = 8
n_workers = 4

subsample_step = 1
# split, year = "val", "2022"
split = "val"

train_set = nl_era5.era5_dataset(
    dataset_name,
    pred_length=28,
    subsample_step=subsample_step,
    split=split,
    standardize=False,
)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size,
    shuffle=False,
    num_workers=n_workers,
)

print(f"{split} batches :", len(train_loader))

dataiter = iter(train_loader)
init_states, target_states, forcing = next(dataiter)
print("init_states:", init_states.shape)
print("target_states:", target_states.shape)
print("forcing:", forcing.shape)

val batches : 47


init_states: torch.Size([8, 2, 3705, 48])
target_states: torch.Size([8, 28, 3705, 48])
forcing: torch.Size([8, 28, 3705, 12])


In [69]:
subsample_step = 2
# split, year = "val", "2022"
split = "val"

train_set = nl_era5.era5_dataset(
    dataset_name,
    pred_length=28,
    subsample_step=subsample_step,
    split=split,
    standardize=False,
)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size,
    shuffle=False,
    num_workers=n_workers,
)

print(f"{split} batches :", len(train_loader))

dataiter = iter(train_loader)
subsampled_init_states, subsampled_target_states, subsampled_forcing = next(dataiter)
print("init_states:", subsampled_init_states.shape)
print("target_states:", subsampled_target_states.shape)
print("forcing:", subsampled_forcing.shape)

val batches : 32


init_states: torch.Size([8, 2, 3705, 48])
target_states: torch.Size([8, 28, 3705, 48])
forcing: torch.Size([8, 28, 3705, 12])


In [70]:
# (subsampled_init_states[:, 1] == target_states[:, 0]).all()
(subsampled_target_states[:, 0] == target_states[:, 2]).all()

tensor(True)

In [60]:
sub_sampled_init_states

tensor([[[[ 1.9002e+05,  1.2574e+05,  9.5088e+04,  ..., -1.7236e-01,
            1.0889e-01, -5.6854e-02],
          [ 1.9003e+05,  1.2575e+05,  9.5097e+04,  ..., -1.4230e-01,
            1.4538e-01, -5.3848e-02],
          [ 1.9004e+05,  1.2576e+05,  9.5110e+04,  ..., -1.5991e-01,
            1.1404e-01, -4.6549e-02],
          ...,
          [ 1.9874e+05,  1.3382e+05,  1.0307e+05,  ..., -2.2074e-02,
           -6.2407e-01,  2.8299e-03],
          [ 1.9878e+05,  1.3383e+05,  1.0311e+05,  ...,  1.5282e-02,
           -7.3656e-01,  1.6570e-02],
          [ 1.9882e+05,  1.3384e+05,  1.0313e+05,  ...,  4.7915e-02,
           -3.9607e-01, -6.2436e-02]],

         [[ 1.9083e+05,  1.2599e+05,  9.5139e+04,  ...,  2.3899e-01,
            2.8021e-01,  3.7610e-02],
          [ 1.9084e+05,  1.2600e+05,  9.5177e+04,  ...,  1.7759e-01,
            2.5359e-01,  4.9632e-02],
          [ 1.9085e+05,  1.2601e+05,  9.5212e+04,  ...,  1.2606e-01,
            2.0679e-01,  4.8344e-02],
          ...,
     