In [2]:
CONFIG = {
    # Architecture
    'window': 5,
    'embedding': 64,
    'hidden_channels': (96, 192, 384),
    'hidden_blocks': (3, 3, 3),
    'kernel_size': 3,
    'activation': 'SiLU',
    # Training
    'epochs': 4096,
    'batch_size': 32,
    'optimizer': 'AdamW',
    'learning_rate': 2e-4,
    'weight_decay': 1e-3,
    'scheduler': 'linear',
}

In [3]:
from utils import TrajectoryDataset
from torch.utils.data import DataLoader
from pathlib import Path
# See Train Dimensions
PATH = Path('.')
trainset = TrajectoryDataset(PATH / 'data/train.h5', window=10)
validset = TrajectoryDataset(PATH / 'data/valid.h5', window=10)
batch_size = 5
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
validloader = DataLoader(validset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)

for i, (batch, _) in  enumerate(trainloader):
    print(f"{i} : {batch.shape}")

0 : torch.Size([5, 10, 4, 55, 66])
1 : torch.Size([5, 10, 4, 55, 66])
2 : torch.Size([5, 10, 4, 55, 66])
3 : torch.Size([5, 10, 4, 55, 66])
4 : torch.Size([4, 10, 4, 55, 66])


In [13]:
import os
import h5py
import math
import torch
#import wandb
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import trange
from pathlib import Path
from utils import TrajectoryDataset
from score import ScoreUNet, MCScoreWrapper
from score import VPSDE

PATH = Path('.')

with h5py.File(PATH / "data/mask.h5", "r") as f:
    mask = torch.tensor(f["dataset"][:], dtype=torch.float32).unsqueeze(0)  # Shape

def masked_vpsde_loss(sde, x, mask):
    w = mask.expand_as(x)
    return sde.loss(x, w=w)

CONFIG = {
    "epochs": 100,
    "batch_size": 5,
    "learning_rate": 1e-3,
    "weight_decay": 1e-3,
    "scheduler": "linear",  # Can be "cosine" or "exponential"
    "embedding": 32,
    "hidden_channels": (64,),
    "hidden_blocks": (3,),
    "activation": "SiLU",
}

trainset = TrajectoryDataset(PATH / "data/train.h5", window=10)
validset = TrajectoryDataset(PATH / "data/valid.h5", window=10)

trainloader = DataLoader(trainset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=1, persistent_workers=True)
validloader = DataLoader(validset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=1, persistent_workers=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

score_model = MCScoreWrapper(
    ScoreUNet(
        channels=4,
        embedding=CONFIG["embedding"],
        hidden_channels=CONFIG["hidden_channels"],
        hidden_blocks=CONFIG["hidden_blocks"],
        activation=nn.SiLU,
    )
).to(device)

sde = VPSDE(score_model, shape=(10, 4, 55, 66)).to(device)

optimizer = optim.AdamW(sde.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"])

# Define Learning Rate Scheduler
if CONFIG["scheduler"] == "linear":
    lr_lambda = lambda t: 1 - (t / CONFIG["epochs"])
elif CONFIG["scheduler"] == "cosine":
    lr_lambda = lambda t: (1 + math.cos(math.pi * t / CONFIG["epochs"])) / 2
elif CONFIG["scheduler"] == "exponential":
    lr_lambda = lambda t: math.exp(-7 * (t / CONFIG["epochs"]) ** 2)
else:
    raise ValueError("Invalid scheduler type")

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

for epoch in (bar := trange(CONFIG["epochs"], ncols=88)):
    losses_train = []
    losses_valid = []

    ## Train
    sde.train()
    print('-------')
    print('[TRAIN LOOP]')
    for i, (batch, _) in enumerate(trainloader):
        print(f"{i}")
        batch = batch.to(device)
        optimizer.zero_grad()

        # **Apply Mask**
        mask_batch = mask.to(device).expand_as(batch)  # Expand mask to match batch size
        w = mask_batch.float()  # Convert mask to weight format

        # **Compute VPSDE Loss**
        loss = sde.loss(batch, w=w)
        loss.backward()
        optimizer.step()

        losses_train.append(loss.detach())
    print('[\TRAIN LOOP]')
    print('-------')
    ## Validation
    sde.eval()
    with torch.no_grad():
        print('-------')
        print('[VALID LOOP]')
        for batch, _ in validloader:
            batch = batch.to(device)
            mask_batch = mask.to(device).expand_as(batch)
            w = mask_batch.float()

            loss = sde.loss(batch, w=w)
            losses_valid.append(loss)
        print('-------')
        print('[VALID LOOP]')
    ## Compute Loss Stats
    loss_train = torch.stack(losses_train).mean().item()
    loss_valid = torch.stack(losses_valid).mean().item()
    lr = optimizer.param_groups[0]['lr']

    ## Step Scheduler
    scheduler.step()

    # Save model periodically
    if (epoch + 1) % 10 == 0:
        torch.save(sde.state_dict(), PATH / f"checkpoints/model_epoch{epoch+1}.pth")

  0%|                                                           | 0/100 [00:00<?, ?it/s]

-------
[TRAIN LOOP]


  0%|                                                           | 0/100 [00:00<?, ?it/s]


0


RuntimeError: Given groups=1, weight of size [64, 4, 3, 3], expected input[20, 10, 55, 66] to have 4 channels, but got 10 channels instead

In [28]:
import score
x_test = torch.randn(8, 10, 3, 64, 64)  # Batch de 8, longueur 10, 3 channels, 64x64
t_test = torch.tensor(0.5)  # Exemple de temps
c_test = torch.randn(3, 64, 64)  # Contexte

model = score.MCScoreNet(features=3, context=3, order=1, spatial=2)
output = model(x_test, t_test, c_test)



In [37]:
from torchsummaryX import summary
from nn import UNet

CONFIG = {
    "epochs": 100,
    "batch_size": 5,
    "learning_rate": 1e-3,
    "weight_decay": 1e-3,
    "scheduler": "linear",  # Can be "cosine" or "exponential"
    "embedding": 32,
    "hidden_channels": (64,),
    "hidden_blocks": (3,),
    "activation": "SiLU",
}
myUnet = UNet(3,3,0)
summary(myUnet, torch.zeros((4, 64, 64)), torch.zeros((4,32)) )



RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 4, 64, 64] to have 3 channels, but got 4 channels instead

<h1> Working of UNET </h1>

In [53]:

import torch
import torch.nn as nn
from nn import UNet
# Define the input dimensions
in_channels = 3  # Number of input channels (e.g., RGB image)
out_channels = 1  # Number of output channels (e.g., grayscale mask)
mod_features = 66  # Embedding dimension for the modulation vector y
hidden_channels = [32, 64, 128]  # Number of hidden channels at each depth
hidden_blocks = [2, 3, 5]  # Number of residual blocks at each depth
spatial = 2  # 2D spatial dimensions (x, y)

# Create the UNet model
unet = UNet(
    in_channels=in_channels,
    out_channels=out_channels,
    mod_features=mod_features,
    hidden_channels=hidden_channels,
    hidden_blocks=hidden_blocks,
    spatial=spatial,
)

# Define the input tensor x with shape (batch_size, channel, x, y)
batch_size = 4
x = torch.randn(batch_size, in_channels, 128, 128)  # Example input tensor (64x64 image)

# Define the modulation tensor y with shape (batch_size, mod_features)
y = torch.randn(batch_size, mod_features)  # Example modulation tensor

# Forward pass through the UNet
output = unet(x, y)

# Print the output shape
print(output.shape)  # Should be (batch_size, out_channels, x, y)

torch.Size([4, 1, 128, 128])
