## Training and Testing

In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import torch.optim as optim
from helpers import model_summary, set_device
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from models.implicit_neural_representations.inr_models.siren_model import SirenModel, FinerModel


In [2]:
def train(model, train_loader, optimizer, criterion, epoch):
    model.train()
    train_loss = 0
    
    # Create the tqdm progress bar for batches
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}")

    for batch_idx, (data, target) in progress_bar:
        data, target = data.to('cpu'), target.to('cpu')
        optimizer.zero_grad()
        output = model(data).to('cpu')
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # Update tqdm progress bar with current loss every 10 steps
        if batch_idx % 10 == 0:
            progress_bar.set_postfix(loss=loss.item(), refresh=True)

    train_loss /= len(train_loader)
    return train_loss


In [3]:
ds = xr.open_dataset('../data/era5_singlehour_1var.nc')
ds

In [4]:
lon = ds.longitude.values
lat = ds.latitude.values

lon_grid, lat_grid = np.meshgrid(lon, lat)  # longitude first!


In [5]:
lon_grid

array([[0.0000e+00, 2.5000e-01, 5.0000e-01, ..., 3.5925e+02, 3.5950e+02,
        3.5975e+02],
       [0.0000e+00, 2.5000e-01, 5.0000e-01, ..., 3.5925e+02, 3.5950e+02,
        3.5975e+02],
       [0.0000e+00, 2.5000e-01, 5.0000e-01, ..., 3.5925e+02, 3.5950e+02,
        3.5975e+02],
       ...,
       [0.0000e+00, 2.5000e-01, 5.0000e-01, ..., 3.5925e+02, 3.5950e+02,
        3.5975e+02],
       [0.0000e+00, 2.5000e-01, 5.0000e-01, ..., 3.5925e+02, 3.5950e+02,
        3.5975e+02],
       [0.0000e+00, 2.5000e-01, 5.0000e-01, ..., 3.5925e+02, 3.5950e+02,
        3.5975e+02]])

In [6]:
lat_grid

array([[ 90.  ,  90.  ,  90.  , ...,  90.  ,  90.  ,  90.  ],
       [ 89.75,  89.75,  89.75, ...,  89.75,  89.75,  89.75],
       [ 89.5 ,  89.5 ,  89.5 , ...,  89.5 ,  89.5 ,  89.5 ],
       ...,
       [-89.5 , -89.5 , -89.5 , ..., -89.5 , -89.5 , -89.5 ],
       [-89.75, -89.75, -89.75, ..., -89.75, -89.75, -89.75],
       [-90.  , -90.  , -90.  , ..., -90.  , -90.  , -90.  ]])

In [7]:
coords = np.stack([lat_grid, lon_grid], axis=-1).reshape(-1, 2)  # shape (N, 2)
coords

array([[ 9.0000e+01,  0.0000e+00],
       [ 9.0000e+01,  2.5000e-01],
       [ 9.0000e+01,  5.0000e-01],
       ...,
       [-9.0000e+01,  3.5925e+02],
       [-9.0000e+01,  3.5950e+02],
       [-9.0000e+01,  3.5975e+02]])

In [27]:
class GeoDataset(Dataset):
    def __init__(self, path='../data/era5_singlehour_1var.nc', shuffle=True):
        # Open dataset with Dask backend
        self.data_array = xr.open_dataset(path, chunks={'latitude': 50, 'longitude': 50})

        # Normalize longitude to [-180, 180]
        lon = self.data_array.longitude.values
        lon = ((lon + 180) % 360) - 180
        self.data_array = self.data_array.assign_coords(longitude=lon)
        self.data_array = self.data_array.sortby('longitude')  # Important for .sel

        # Store min/max for normalization
        # Compute min and max over the entire dataset (use .compute() because of Dask)
        self.min = float(self.data_array.t.min().compute())
        self.max = float(self.data_array.t.max().compute())

        # Create normalized 't' array
        self.data_array['t_norm'] = 2 * (self.data_array.t - self.min) / (self.max - self.min) - 1

        # Create [lat, lon] coordinate pairs
        lat = self.data_array.latitude.values
        lon = self.data_array.longitude.values
        lon_grid, lat_grid = np.meshgrid(lon, lat)
        coords = np.vstack([lat_grid.ravel(), lon_grid.ravel()]).T

        # Shuffle if requested
        if shuffle:
            np.random.shuffle(coords)

        # 75% train split (you can remove this if dataset should be full)
        length = len(coords)
        split = int(length * 0.75)
        self.inputs = torch.tensor(coords[:split], dtype=torch.float32)

    def __len__(self):
        return self.inputs.shape[0]

    def __getitem__(self, idx):
        coord = self.inputs[idx]
        lat, lon = float(coord[0]), float(coord[1])

        # Fetch target value using .sel with Dask, and compute the result
        target = self.data_array.t_norm.sel(latitude=lat, longitude=lon, method="nearest").values
        target = torch.tensor([target], dtype=torch.float32).squeeze().unsqueeze(0)  # Ensure target is a tensor
        return coord, target


In [31]:
dataset = GeoDataset(shuffle=True)
loader = DataLoader(dataset, batch_size=1024, shuffle=True)

  self.data_array = xr.open_dataset(path, chunks={'latitude': 50, 'longitude': 50})
  self.data_array = xr.open_dataset(path, chunks={'latitude': 50, 'longitude': 50})


In [32]:
next(iter(loader))

[tensor([[ -65.2500,   61.7500],
         [  -7.0000,  172.5000],
         [   5.5000, -107.5000],
         ...,
         [ -80.2500,  -26.5000],
         [ -55.0000,    1.2500],
         [   7.2500, -105.0000]]),
 tensor([[-0.4953],
         [ 0.5286],
         [ 0.4561],
         ...,
         [-0.6503],
         [-0.2619],
         [ 0.4634]])]

In [19]:
siren_model = SirenModel(
    in_features=2,
    out_features=1,
    hidden_layers=5,
    hidden_features=512,
    first_omega_0=30.0,
    hidden_omega_0=30.0,
    residual_net=False,
    encoding='dfs',
    r_min=0.001,
    r_max=1.0,
    scale=[10, 10]
)

In [20]:
finer_model = FinerModel(
    in_features=2,
    out_features=1,
    bias=True,
    hidden_layers=5,
    hidden_features=128,
    first_omega_0=30.,
    hidden_omega_0=30.,
    first_k=10,
    hidden_k=10,
    residual_net=False,
    encoding='dfs',
    r_min=0.001,
    r_max=1.0,
    scale=[10, 10]
)

In [33]:
train(siren_model, loader, torch.optim.Adam(siren_model.parameters(), lr=1e-3), nn.MSELoss(), 10)

Epoch 10: 100%|██████████| 761/761 [1:11:42<00:00,  5.65s/it, loss=0.188]


0.17182525107331095