In [None]:
import torch

from model import *
from data import *
from vis import *
import gc

import itertools

In [None]:
config = Config().load("config.json")

In [None]:
def trainModel(config, modelClass, dataClass, objective, epochs, criterion: dict[str: nn.Module]):
    model = None
    try:
        dataset = dataClass(config)

        dataset.info(dataset[0])

        trainSize = len(dataset)
        # TODO: More stratified subsets using dataset.lengths and geographic information
        train = torch.utils.data.Subset(dataset, range(int(trainSize * config.dataSplit)))
        test = torch.utils.data.Subset(dataset, range(int(trainSize * config.dataSplit), trainSize))

        trainSampler = GraphSizeSampler(train, nodesPerBatch=config.nodesPerBatch)
        testSampler = GraphSizeSampler(test, nodesPerBatch=config.nodesPerBatch)

        train = DataLoader(train, batch_sampler=trainSampler, generator=torch.Generator(device))
        test = DataLoader(test, batch_sampler=testSampler, generator=torch.Generator(device))

        dataset.info(next(iter(train)))

        model = modelClass(config)
        print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")
        print(f"Dataset has {len(dataset)} samples")
        print(next(model.parameters()).is_cuda)

        precision, recall, mae = CMALPrecision(), CMALRecall(), CMALNormalizedMeanAbsolute()

        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

        testIter = itertools.cycle(test)

        client = Client("127.0.0.1")

        for epoch in range(epochs):
            progress = 0
            for inputs, targets in train:
                model.train()
                optimizer.zero_grad()

                history, future = targets.dischargeHistory, targets.dischargeFuture
                thresholds, deviations = targets.thresholds, targets.deviation.unsqueeze(1)
                hindcast, forecast = model(inputs)
                loss = objective(forecast, future)
                trainLoss = loss.detach().item()

                # trainPrecision, trainRecall = precision(forecast, future, thresholds), recall(forecast, future, thresholds)
                trainMAE = mae(forecast, future, deviations).detach().item()

                loss.backward()
                optimizer.step()

                # For the love of god
                del loss, history, future, thresholds, deviations, hindcast, forecast, inputs, targets

                torch.cuda.empty_cache()

                with torch.no_grad():
                    model.eval()
                    inputs, targets = next(testIter)

                    history, future = targets.dischargeHistory, targets.dischargeFuture
                    thresholds, deviations = targets.thresholds, targets.deviation.unsqueeze(1)
                    hindcast, forecast = model(inputs)
                    loss = objective(forecast, future)
                    testLoss = loss.detach().item()

                    # testPrecision, testRecall = precision(forecast, future, thresholds), recall(forecast, future, thresholds)
                    testMAE = mae(forecast, future, deviations).detach().item()

                    del loss, history, future, thresholds, deviations, hindcast, forecast, inputs, targets

                torch.cuda.empty_cache()

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)
                client.send("Train NMAE", trainMAE)
                client.send("Test NMAE", testMAE)

                progress += 1
                # print(f"\r{progress}/{len(train)} | {(progress / len(train)) * 100:.3f}% |  Train Loss: {trainLoss} | Train Precision: {precision} | Train Recall: {recall}")
                print(f"\r{progress}/{len(train)} | {(progress / len(train)) * 100:.3f}% |  Train Loss: {trainLoss} | Test Loss: {testLoss}", end="")
                gc.collect()

            print()

        return model

    except KeyboardInterrupt:
        if model is not None:
            torch.save(model.state_dict(), "checkpoint.pt")
        return model

In [None]:
metrics = {
    "Precision": CMALLoss()
}

model = trainModel(config, InundationStation, InundationData, CMALLoss(), epochs=10, criterion=metrics)