In [None]:
import torch

from model import *
from data import *
from vis import *
import gc

import itertools

In [None]:
config = Config().load("config.json")

In [None]:
def splitDataset(dataset, trainSplit=0.8, shuffle=True):
    trainSize = len(dataset)
    # TODO: More stratified subsets using dataset.lengths and geographic information
    train = torch.utils.data.Subset(dataset, range(int(trainSize * trainSplit)))
    test = torch.utils.data.Subset(dataset, range(int(trainSize * trainSplit), trainSize))

    train = DataLoader(train, generator=torch.Generator(device), shuffle=shuffle)
    test = DataLoader(test, generator=torch.Generator(device), shuffle=shuffle)

    return train, test

In [None]:
def itertoolsBetter(dataIter):
    while True:
        for batch in dataIter:
            yield batch


def trainModel(config, modelClass, dataClass, objective, epochs, criterion: dict[str: nn.Module]):
    model = None
    dataset = None
    try:
        dataset = dataClass(config)

        dataset.info(dataset[0])

        train, test = splitDataset(dataset, config.dataSplit)

        dataset.info(next(iter(train)))

        model = modelClass(config)
        print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")
        print(f"Dataset has {len(dataset)} samples")
        print(next(model.parameters()).is_cuda)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        testIter = itertoolsBetter(test)

        client = Client("127.0.0.1", 12945)

        for epoch in range(epochs):
            progress = 0
            for inputs, targets in train:
                model.train()
                optimizer.zero_grad()

                metrics = {}

                history, future = targets.dischargeHistory, targets.dischargeFuture
                thresholds, deviations = targets.thresholds, targets.deviation.unsqueeze(-1)
                hindcast, forecast = model(inputs)
                loss = objective(forecast, future)

                for eval in criterion:
                    evaluated = criterion[eval](forecast, future, thresholds=thresholds, deviations=deviations)
                    metrics["Train " + eval] = evaluated

                # TODO: Determine source of nans
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"\nNAN FOUND IN LOSS")
                    dataset.info((inputs, targets))
                    print("\n")
                    continue

                trainLoss = loss.detach().item()

                loss.backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
                torch.nn.utils.clip_grad_value_(model.parameters(), 10)

                optimizer.step()

                torch.cuda.empty_cache()

                with torch.no_grad():
                    model.eval()
                    inputs1, targets1 = next(testIter)

                    history1, future1 = targets1.dischargeHistory, targets1.dischargeFuture
                    thresholds1, deviations1 = targets1.thresholds, targets1.deviation.unsqueeze(-1)
                    hindcast1, forecast1 = model(inputs1)
                    loss1 = objective(forecast1, future1)
                    testLoss = loss1.detach().item()

                    for eval in criterion:
                        evaluated = criterion[eval](forecast1, future1, thresholds=thresholds1, deviations=deviations1)
                        metrics["Test " + eval] = evaluated

                if (progress + 1) % 10 == 0:
                    gc.collect()

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)

                for metric in metrics:
                    client.send(metric, metrics[metric])

                progress += 1

                print(f"\r{progress}/{len(train)} | {(progress / len(train)) * 100:.3f}% |  Train Loss: {trainLoss} | Test Loss: {testLoss} | Memory: {torch.cuda.memory_allocated()}", end="")
            print()

        return model, dataset

    except KeyboardInterrupt:
        if model is not None:
            torch.save(model.state_dict(), "checkpoint.pt")
        return model, dataset

In [None]:
metrics = {
    "NMAE": CMALNormalizedMeanAbsolute(),
    "1 Year Flood Precision": CMALPrecision(),
    "1 Year Flood Recall": CMALRecall()
}

model, dataset = trainModel(config, FloodHub, FloodHubData, CMALLoss(), epochs=10, criterion=metrics)

In [None]:
def evalModel(config, model, dataset):
    model.eval()
    _, test = splitDataset(dataset, config.dataSplit, shuffle=False)

    metrics = {}

    progress = 0

    with torch.no_grad():
        for inputs, targets in test:
            history, future = targets.dischargeHistory, targets.dischargeFuture
            thresholds, deviations = targets.thresholds, targets.deviation.unsqueeze(-1)
            hindcast, forecast = model(inputs)

            # Consider sampling?
            medianPrediction = torch.sum(forecast[0] * forecast[3], dim=-1)
            mae = torch.abs(medianPrediction - future) / deviations

            medianPrediction = medianPrediction.unsqueeze(-1)
            future = future.unsqueeze(-1)
            thresholds = thresholds.unsqueeze(1).expand(-1, config.future, -1)
            tp = (medianPrediction > thresholds).float() * (future > thresholds).float()
            fp = (medianPrediction > thresholds).float() * (future < thresholds).float()
            fn = (medianPrediction < thresholds).float() * (future > thresholds).float()

            past, _ = inputs
            # TODO: Update for FloodHubData
            for n, name in enumerate(past.grdcID):
                if name not in metrics:
                    metrics[name] = {
                        "iter": 0, 
                        "mae": torch.zeros([config.future]), 
                        "tp": torch.zeros([config.future, thresholds.shape[-1]]),
                        "fp": torch.zeros([config.future, thresholds.shape[-1]]),
                        "fn": torch.zeros([config.future, thresholds.shape[-1]]),
                        # TODO: Calculate size of upstream basin GRDC -> PFAF -> Upstream -> Sum
                        "size": None
                    }

                metrics[name]["mae"] = (metrics[name]["mae"] * metrics[name]["iter"] + mae[n]) / (metrics[name]["iter"] + 1)
                metrics[name]["tp"] += tp[n]
                metrics[name]["fp"] += fp[n]
                metrics[name]["fn"] += fn[n]
                metrics[name]["iter"] += 1

            progress += 1
            print(f"\r{progress}/{len(test)} | {(progress / len(test)) * 100:.2f}% Complete", end="")

    nodeX, nodeY = np.zeros([len(metrics), 4]), np.zeros([len(metrics), 4])
    precisionBox, recallBox, f1Box = np.zeros([len(metrics), config.future, 4]), np.zeros([len(metrics), config.future, 4]), np.zeros([len(metrics), config.future, 4])
    for i, name in enumerate(metrics):
        nodeX[i, :] = metrics[name]["nodes"]

        tp = metrics[name]["tp"]
        fp = metrics[name]["fp"]
        fn = metrics[name]["fn"]

        recall = tp / (tp + fn)
        precision = tp / (tp + fp)

        f1 = 2 * recall * precision / (recall + precision)

        nodeY[i, :] = torch.mean(f1, dim=0).cpu().numpy()

        precisionBox = precisionBox[i] = precision.numpy()
        recallBox = recallBox[i] = recall.numpy()
        f1Box = f1Box[i] = f1.numpy()

    labels = ["1 Year Return Period", "2 Year Return Period", "5 Year Return Period", "10 Year Return Period"]
    colors = ["blue", "green", "yellow", "orange"]
    for i in range(4):
        plt.subplot(2, 4, i + 1)
        plt.title(labels[i])
        plt.scatter(nodeX[:, i], nodeY[:, i], alpha=0.3, c=colors[i])

        plt.grid()
        plt.xlabel("Total Upstream Basin Nodes")
        plt.ylabel("F1 Score")

    for i in range(4):
        plt.subplot(2, 4, i + 5)
        plt.title(labels[i])
        plt.boxplot(f1Box[:, :, i].T)

        plt.grid()
        plt.xlabel("Forecast Horizon")
        plt.ylabel("F1 Score")

    plt.show()

    # TODO: Evaluate metrics against basin size, nodes in graph, find other correlations in basin characteristics
    # TODO: Evaluate on more stratified data (other continents?)
    # TODO: Maybe evaluate at New Madrid for fun

In [None]:
evalModel(config, model, dataset)