In [1]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

os.chdir("../")

In [2]:
import os
import glob
import datetime as dt
import numpy as np
import torch

import neural_lam.constants as constants

In [3]:
# load time step data
# sample_dir_path = os.path.join(BASE_PATH, "data/era5_uk/samples/train")
sample_dir_path = "data/era5_uk/samples/train"

sample_files = glob.glob(f'{sample_dir_path}/*.npy')
sample_files.sort()
init_states = torch.tensor(np.load(sample_files[0]), dtype=torch.float32).unsqueeze(0)
sample_files = [os.path.basename(f) for f in sample_files]
sample_times = [dt.datetime.strptime(f, '%Y%m%d%H%M%S.npy') for f in sample_files]

In [25]:
init_states.shape

torch.Size([1, 3705, 48])

In [4]:
sample_files

['20220101000000.npy',
 '20220101060000.npy',
 '20220101120000.npy',
 '20220101180000.npy',
 '20220102000000.npy',
 '20220102060000.npy',
 '20220102120000.npy',
 '20220102180000.npy',
 '20220103000000.npy',
 '20220103060000.npy',
 '20220103120000.npy',
 '20220103180000.npy',
 '20220104000000.npy',
 '20220104060000.npy',
 '20220104120000.npy',
 '20220104180000.npy',
 '20220105000000.npy',
 '20220105060000.npy',
 '20220105120000.npy',
 '20220105180000.npy',
 '20220106000000.npy',
 '20220106060000.npy',
 '20220106120000.npy',
 '20220106180000.npy',
 '20220107000000.npy',
 '20220107060000.npy',
 '20220107120000.npy',
 '20220107180000.npy',
 '20220108000000.npy',
 '20220108060000.npy',
 '20220108120000.npy',
 '20220108180000.npy',
 '20220109000000.npy',
 '20220109060000.npy',
 '20220109120000.npy',
 '20220109180000.npy',
 '20220110000000.npy',
 '20220110060000.npy',
 '20220110120000.npy',
 '20220110180000.npy',
 '20220111000000.npy',
 '20220111060000.npy',
 '20220111120000.npy',
 '202201111

In [7]:
sample_times

[datetime.datetime(2022, 1, 1, 0, 0),
 datetime.datetime(2022, 1, 1, 6, 0),
 datetime.datetime(2022, 1, 1, 12, 0),
 datetime.datetime(2022, 1, 1, 18, 0),
 datetime.datetime(2022, 1, 2, 0, 0),
 datetime.datetime(2022, 1, 2, 6, 0),
 datetime.datetime(2022, 1, 2, 12, 0),
 datetime.datetime(2022, 1, 2, 18, 0),
 datetime.datetime(2022, 1, 3, 0, 0),
 datetime.datetime(2022, 1, 3, 6, 0),
 datetime.datetime(2022, 1, 3, 12, 0),
 datetime.datetime(2022, 1, 3, 18, 0),
 datetime.datetime(2022, 1, 4, 0, 0),
 datetime.datetime(2022, 1, 4, 6, 0),
 datetime.datetime(2022, 1, 4, 12, 0),
 datetime.datetime(2022, 1, 4, 18, 0),
 datetime.datetime(2022, 1, 5, 0, 0),
 datetime.datetime(2022, 1, 5, 6, 0),
 datetime.datetime(2022, 1, 5, 12, 0),
 datetime.datetime(2022, 1, 5, 18, 0),
 datetime.datetime(2022, 1, 6, 0, 0),
 datetime.datetime(2022, 1, 6, 6, 0),
 datetime.datetime(2022, 1, 6, 12, 0),
 datetime.datetime(2022, 1, 6, 18, 0),
 datetime.datetime(2022, 1, 7, 0, 0),
 datetime.datetime(2022, 1, 7, 6, 0),


In [26]:
i = 100
sample_length = 6 + 2

# times = sample_times[i:i+sample_length]

hour_inc = torch.arange(sample_length) * 6 # (N_t,)
init_hour = sample_times[i].hour
hour_of_day = init_hour + hour_inc
print("hour of day")
print(hour_of_day)

start_of_year = dt.datetime(sample_times[i].year, 1, 1)
print("start of year")
print(start_of_year)

init_seconds_into_year = (sample_times[i] - start_of_year).total_seconds()
seconds_into_year = init_seconds_into_year + hour_inc * 3600
print("seconds into year")
print(seconds_into_year)

hour_angle = (hour_of_day / 24) * 2 * torch.pi 
year_angle = (seconds_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi
print("hour angle")
print(hour_angle)
print("year angle")
print(year_angle)

hour of day
tensor([ 0,  6, 12, 18, 24, 30, 36, 42])
start of year
2022-01-01 00:00:00
seconds into year
tensor([2160000., 2181600., 2203200., 2224800., 2246400., 2268000., 2289600.,
        2311200.])
hour angle
tensor([ 0.0000,  1.5708,  3.1416,  4.7124,  6.2832,  7.8540,  9.4248, 10.9956])
year angle
tensor([0.4304, 0.4347, 0.4390, 0.4433, 0.4476, 0.4519, 0.4562, 0.4605])


In [27]:
datetime_forcing = torch.stack(
    (
        torch.sin(hour_angle),
        torch.cos(hour_angle),
        torch.sin(year_angle),
        torch.cos(year_angle),
    ),
    dim=1,
)  # (N_t, 4)
print(datetime_forcing.shape)
print(datetime_forcing)
datetime_forcing = (datetime_forcing + 1) / 2 # Normalize to [0,1]
print(datetime_forcing)

datetime_forcing = datetime_forcing.unsqueeze(1).expand(
    -1, init_states.shape[1], -1
)  # (sample_len, N_grid, 4)
print(datetime_forcing.shape)

forcing = torch.cat(
    (
        datetime_forcing[:-2],
        datetime_forcing[1:-1],
        datetime_forcing[2:],
    ),
    dim=2,
)

print(forcing.shape)
print(forcing)

torch.Size([8, 4])
tensor([[ 0.0000e+00,  1.0000e+00,  4.1719e-01,  9.0882e-01],
        [ 1.0000e+00, -4.3711e-08,  4.2110e-01,  9.0701e-01],
        [-8.7423e-08, -1.0000e+00,  4.2500e-01,  9.0519e-01],
        [-1.0000e+00,  1.1925e-08,  4.2889e-01,  9.0336e-01],
        [ 1.7485e-07,  1.0000e+00,  4.3278e-01,  9.0150e-01],
        [ 1.0000e+00, -3.3777e-07,  4.3665e-01,  8.9963e-01],
        [-2.3850e-08, -1.0000e+00,  4.4052e-01,  8.9774e-01],
        [-1.0000e+00,  6.6361e-07,  4.4438e-01,  8.9584e-01]])
tensor([[0.5000, 1.0000, 0.7086, 0.9544],
        [1.0000, 0.5000, 0.7106, 0.9535],
        [0.5000, 0.0000, 0.7125, 0.9526],
        [0.0000, 0.5000, 0.7144, 0.9517],
        [0.5000, 1.0000, 0.7164, 0.9508],
        [1.0000, 0.5000, 0.7183, 0.9498],
        [0.5000, 0.0000, 0.7203, 0.9489],
        [0.0000, 0.5000, 0.7222, 0.9479]])
torch.Size([8, 3705, 4])
torch.Size([6, 3705, 12])
tensor([[[0.5000, 1.0000, 0.7086,  ..., 0.0000, 0.7125, 0.9526],
         [0.5000, 1.0000, 0.708