In [1]:
import torch

from model import *
from data import *

import itertools

In [2]:
config = Config().load("config.json")

In [3]:
# Hyperparameters for tuning
#
# History steps
# Future steps
#
# Discrete projection dims (Basin and River)
# Projection dimension (Basin and River)
#
# Graph attention hidden dimension
# Graph attention layers
#
# LSTM layers
# LSTM hidden dimension
#
# CMAL hidden dimension
# CMAL mixtures
#
# Dropout (Probably not)
# Batch size (Probably not)
#
# x2 for encoder and decoder?

In [4]:
def trainModel(config, modelClass, dataClass, objective, epochs, criterion: dict[str: nn.Module]):
    dataset = dataClass(config)

    dataset.info(dataset[0])

    trainSize = len(dataset)
    train = torch.utils.data.Subset(dataset, range(int(trainSize * config.dataSplit)))
    test = torch.utils.data.Subset(dataset, range(int(trainSize * config.dataSplit), trainSize))

    train = DataLoader(train, batch_size=config.batchSize, shuffle=True, generator=torch.Generator(device))
    test = DataLoader(test, batch_size=config.batchSize, shuffle=True, generator=torch.Generator(device))

    print("Dataloader Shuffled")

    dataset.info(next(iter(train)))

    model = modelClass(config)
    print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")
    print(f"Dataset has {len(dataset)} samples")

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    testIter = itertools.cycle(test)

    for epoch in range(epochs):
        progress = 0
        for inputs, targets in train:
            model.train()
            optimizer.zero_grad()

            history, future, thresholds = targets.dischargeHistory, targets.dischargeFuture, targets.thresholds
            hindcast, forecast = model(inputs)
            loss = objective(hindcast, history[:, -1].unsqueeze(1))
            trainLoss = loss.item()

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                model.eval()
                inputs, targets = next(testIter)

                history, future, thresholds = targets.dischargeHistory, targets.dischargeFuture, targets.thresholds
                hindcast, forecast = model(inputs)
                loss = objective(hindcast, history[:, -1].unsqueeze(1))
                testLoss = loss.item()

            progress += 1
            print(f"\r{progress}/{len(train)} | {(progress / len(train)) * 100:.3f}% |  Train Loss: {trainLoss} | Test Loss: {testLoss}")
        print()

    return model

In [None]:
metrics = {
    "Precision": CMALLoss()
}

trainModel(config, InundationStation, InundationData, CMALLoss(), epochs=10, criterion=metrics)

Loading GeoPandas...
GeoPandas Loaded
2544/2544 GRDC files loaded
9640/9640 ERA5 files loaded
Total empty basins: 37
57622/57646 Basin Structures Appended to Graph
Upstream Basins Compiled | 1.0 | 17.43485477178423
Upstream Structures Compiled
Structure Tensors Complete
Index Mapping Complete
Static Input Scaling Complete

        Total Samples: 28683059
        Era5 History: torch.Size([1, 365, 7]) torch.float32
        Era5 Future: torch.Size([1, 7, 7]) torch.float32
        Basin Continuous: torch.Size([1, 277]) torch.float32
        Basin Discrete: torch.Size([1, 10]) torch.int64
        Structure: torch.Size([2, 1]) torch.int64
        River Continuous: torch.Size([258]) torch.float32
        River Discrete: torch.Size([14]) torch.int64
        Discharge History: torch.Size([365]) torch.float32
        Discharge Future: torch.Size([7]) torch.float32
        Thresholds: torch.Size([4]) torch.float32
        
Dataloader Shuffled

        Total Samples: 28683059
        Era5 History: