<h1> Understand Batch Size Dimensions </h1>

In [34]:
from utils import TrajectoryDataset
from torch.utils.data import DataLoader
from pathlib import Path
# See Train Dimensions
PATH = Path('.')
window= 10
batch_size = 5


trainset = TrajectoryDataset(PATH / 'data/train.h5', window=window, flatten=True)
validset = TrajectoryDataset(PATH / 'data/valid.h5', window=window, flatten=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
validloader = DataLoader(validset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)

for i, (batch, _) in  enumerate(trainloader):
    print(f"{i} : {batch.shape}")

0 : torch.Size([5, 40, 128, 128])
1 : torch.Size([5, 40, 128, 128])
2 : torch.Size([5, 40, 128, 128])
3 : torch.Size([5, 40, 128, 128])
4 : torch.Size([4, 40, 128, 128])


<h1> Understand UNET dimensions </h1>

In [31]:

import torch
import torch.nn as nn
from nn import UNet

in_channels = 4
out_channels = 4
mod_features = 66
hidden_channels = [32, 64, 128]
hidden_blocks = [2, 3, 5]
spatial = 2

# UNET Forward : x (B, C, H, W)
unet = UNet(
    in_channels=in_channels,
    out_channels=out_channels,
    mod_features=mod_features,
    hidden_channels=hidden_channels,
    hidden_blocks=hidden_blocks,
    spatial=spatial,
)

)
batch_size = 4
x = torch.randn(batch_size, in_channels, 128, 128)


y = torch.randn(batch_size, mod_features)


output = unet(x, y)


print(output.shape)

torch.Size([4, 4, 128, 128])


<h1> Understand ScoreUNET dimensions </h1>

In [35]:
import torch
import torch.nn as nn
from score import ScoreUNet
batch_size = 5
MAR_channels = 4 # Example ['RF', 'T2m', 'U10m', MASK]
window = 10
y_dim = 128
x_dim = 128

CONFIG = { 'hidden_channels' : [32, 64, 128],  \
'hidden_blocks' : [2, 3, 5],  \
'spatial' : 2, \
'channels' : window*MAR_channels, \
'context' : 0,\
'embedding' : 64 }


score_unet = ScoreUNet(**CONFIG)




x = torch.randn([batch_size, window*MAR_channels, y_dim, x_dim]) # because of flatten
t = torch.rand(x.shape[0], dtype=x.dtype, device=x.device)

print(f"x:  {x.shape} , t: {t.shape}")
c = None

# Forward pass through the ScoreUNet
output = score_unet(x, t, c)

# Print the output shape
print(f"Output Shape : {output.shape}")  # Should be (batch_size, channels, x, y)

x:  torch.Size([5, 40, 128, 128]) , t: torch.Size([5])
Output Shape : torch.Size([5, 40, 128, 128])


<h1> Understand VPSDE dimensions </h1>

In [38]:
import torch
import torch.nn as nn
from score import ScoreUNet, VPSDE

batch_size = 5
MAR_channels = 4 # Example ['RF', 'T2m', 'U10m', MASK]
window = 10
y_dim = 128
x_dim = 128

CONFIG = { 'hidden_channels' : [32, 64, 128],  \
'hidden_blocks' : [2, 3, 5],  \
'spatial' : 2, \
'channels' : window*MAR_channels, \
'context' : 0,\
'embedding' : 64 }


score_unet = ScoreUNet(**CONFIG)
vpsde = VPSDE(score_unet, shape=(MAR_channels*window, y_dim, x_dim))



x = torch.randn([batch_size, window*MAR_channels, y_dim, x_dim]) # because of flatten

vpsde.loss(x)

tensor(1.3164, grad_fn=<MeanBackward0>)

<h1> Loop-Alike (Ensure Architecture design is well shaped)</h1>

In [42]:
import torch
import torch.nn as nn
from score import ScoreUNet, VPSDE

# Common dimensions
batch_size = 5
MAR_channels = 4 # Example ['RF', 'T2m', 'U10m', MASK]
window = 10
y_dim = 128
x_dim = 128

# Define the network
CONFIG = { 'hidden_channels' : [32, 64, 128],  \
'hidden_blocks' : [2, 3, 5],  \
'spatial' : 2, \
'channels' : window*MAR_channels, \
'context' : 0,\
'embedding' : 64 }

# Denoiser and Scheduler
score_unet = ScoreUNet(**CONFIG)
vpsde = VPSDE(score_unet, shape=(MAR_channels*window, y_dim, x_dim))

trainset = TrajectoryDataset(PATH / "data/train.h5", window=10, flatten=True)
validset = TrajectoryDataset(PATH / "data/valid.h5", window=10, flatten=True)


# Batch loop
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
validloader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=1, persistent_workers=True)
for i, (batch, _) in enumerate(trainloader):
        print(f"{i} batch : {batch.shape}")
        loss = vpsde.loss(batch)

0 batch : torch.Size([5, 40, 128, 128])
1 batch : torch.Size([5, 40, 128, 128])
2 batch : torch.Size([5, 40, 128, 128])
3 batch : torch.Size([5, 40, 128, 128])
4 batch : torch.Size([4, 40, 128, 128])


In [13]:
import os
import h5py
import math
import torch
#import wandb
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import trange
from pathlib import Path
from utils import TrajectoryDataset
from score import ScoreUNet, MCScoreWrapper, VPSDE
from score import VPSDE

PATH = Path('.')

with h5py.File(PATH / "data/mask.h5", "r") as f:
    mask = torch.tensor(f["dataset"][:], dtype=torch.float32).unsqueeze(0)  # Shape

def masked_vpsde_loss(sde, x, mask):
    w = mask.expand_as(x)
    return sde.loss(x, w=w)

CONFIG = {
    "epochs": 100,
    "batch_size": 5,
    "learning_rate": 1e-3,
    "weight_decay": 1e-3,
    "scheduler": "linear",  # Can be "cosine" or "exponential"
    "embedding": 32,
    "hidden_channels": (64,),
    "hidden_blocks": (3,),
    "activation": "SiLU",
}

trainset = TrajectoryDataset(PATH / "data/train.h5", window=10)
validset = TrajectoryDataset(PATH / "data/valid.h5", window=10)

trainloader = DataLoader(trainset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=1, persistent_workers=True)
validloader = DataLoader(validset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=1, persistent_workers=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

score_model = MCScoreWrapper(
    ScoreUNet(
        channels=4,
        embedding=CONFIG["embedding"],
        hidden_channels=CONFIG["hidden_channels"],
        hidden_blocks=CONFIG["hidden_blocks"],
        activation=nn.SiLU,
    )
).to(device)

sde = VPSDE(score_model, shape=(10, 4, 55, 66)).to(device)

optimizer = optim.AdamW(sde.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"])

# Define Learning Rate Scheduler
if CONFIG["scheduler"] == "linear":
    lr_lambda = lambda t: 1 - (t / CONFIG["epochs"])
elif CONFIG["scheduler"] == "cosine":
    lr_lambda = lambda t: (1 + math.cos(math.pi * t / CONFIG["epochs"])) / 2
elif CONFIG["scheduler"] == "exponential":
    lr_lambda = lambda t: math.exp(-7 * (t / CONFIG["epochs"]) ** 2)
else:
    raise ValueError("Invalid scheduler type")

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

for epoch in (bar := trange(CONFIG["epochs"], ncols=88)):
    losses_train = []
    losses_valid = []

    ## Train
    sde.train()
    print('-------')
    print('[TRAIN LOOP]')
    for i, (batch, _) in enumerate(trainloader):
        print(f"{i}")
        batch = batch.to(device)
        optimizer.zero_grad()

        # **Apply Mask**
        mask_batch = mask.to(device).expand_as(batch)  # Expand mask to match batch size
        w = mask_batch.float()  # Convert mask to weight format

        # **Compute VPSDE Loss**
        loss = sde.loss(batch, w=w)
        loss.backward()
        optimizer.step()

        losses_train.append(loss.detach())
    print('[\TRAIN LOOP]')
    print('-------')
    ## Validation
    sde.eval()
    with torch.no_grad():
        print('-------')
        print('[VALID LOOP]')
        for batch, _ in validloader:
            batch = batch.to(device)
            mask_batch = mask.to(device).expand_as(batch)
            w = mask_batch.float()

            loss = sde.loss(batch, w=w)
            losses_valid.append(loss)
        print('-------')
        print('[VALID LOOP]')
    ## Compute Loss Stats
    loss_train = torch.stack(losses_train).mean().item()
    loss_valid = torch.stack(losses_valid).mean().item()
    lr = optimizer.param_groups[0]['lr']

    ## Step Scheduler
    scheduler.step()

    # Save model periodically
    if (epoch + 1) % 10 == 0:
        torch.save(sde.state_dict(), PATH / f"checkpoints/model_epoch{epoch+1}.pth")

  0%|                                                           | 0/100 [00:00<?, ?it/s]

-------
[TRAIN LOOP]


  0%|                                                           | 0/100 [00:00<?, ?it/s]


0


RuntimeError: Given groups=1, weight of size [64, 4, 3, 3], expected input[20, 10, 55, 66] to have 4 channels, but got 10 channels instead