In [2]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
import numpy as np
from utils.visualization import plot_prediction_vs_real
from experiments.graphs.graph_experiments import get_pignn_config, get_dataset, create_data_loaders
from architecture.pignn.pignn import FlowPIGNN
from architecture.pignn.deconv import DeConvNet, FCDeConvNet
import os
import json

# Evaluation methods for non-temporal methods
def load_config(config_path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

criterion = torch.nn.MSELoss()

def calculate_test_loss(model, test_loader, plot_examples=False):
    with torch.no_grad():
        test_losses = []
        for i, batch in enumerate(test_loader):
            batch = batch.to(device)
            x = batch.x.to(device).float()
            pos = batch.pos.to(device).float()
            ef = batch.edge_attr.to(device).float()
            gf = batch.global_feats.to(device).float()
            batch_size = gf.size(0)

            if isinstance(model, FCDeConvNet):
                x_cat = torch.cat((
                    x.reshape(batch_size, -1),
                    pos.reshape(batch_size, -1),
                    ef.reshape(batch_size, -1),
                    gf.reshape(batch_size, -1)
                ), dim=-1)

                pred = model(x_cat).float()
                target = batch.y.to(device).reshape(-1, pred.size(1))
            else:
                pred = model(batch, torch.cat((x, pos), dim=-1), ef, gf)
                target = batch.y.to(device).reshape(-1, pred.size(1))
            test_loss = criterion(pred, target)

            if plot_examples:
                predictions = [
                    pred[0, :].reshape(128, 128).cpu(),
                    pred[16, :].reshape(128, 128).cpu(),
                    pred[32, :].reshape(128, 128).cpu()
                ]

                targets = [
                    target[0, :].reshape(128, 128).cpu(),
                    target[16, :].reshape(128, 128).cpu(),
                    target[32, :].reshape(128, 128).cpu()
                ]
                for i in range(3):
                    plot_prediction_vs_real(predictions[i], targets[i], number=i+3)
            test_losses.append(test_loss.item())

        return np.mean(test_losses), np.std(test_losses)

def evaluate_model(experiment_dir):
    config_path = os.path.join(experiment_dir, 'config.json')
    model_config = get_pignn_config()
    config = load_config(config_path)
    model = FlowPIGNN(**model_config, deconv_model=DeConvNet(1, [64, 128, 256, 1], output_size=128)).to(device) if config['use_graph'] else FCDeConvNet(212, 650, 656, 500).to(device)
    model_path = os.path.join(experiment_dir, 'pignn_best.pt')
    model.load_state_dict(torch.load(model_path))

    dataset = get_dataset(config['dataset_dirs'], False, 1)
    _, _, test_loader = create_data_loaders(dataset, config['batch_size'], 1)
    return calculate_test_loss(model, test_loader)

In [None]:
# Evaluate non-temporal methods
base_dir = "results"

for experiment_name in os.listdir(base_dir):
    experiment_dir = os.path.join(base_dir, experiment_name)

    if os.path.isdir(experiment_dir):  # Check if it's a directory
        try:
            mse, std = evaluate_model(experiment_dir)
            print(f"Loaded model from {experiment_name} has MSE on test set: {mse} +- {std}")
        except Exception as e:
            print(f"Failed to load model from {experiment_name}: {e}")

In [10]:
from architecture.windspeedLSTM.windspeedLSTM import WindSpeedLSTMDeConv, WindspeedLSTM

# Evaluation methods for temporal methods
def evaluate_temporal_model(experiment_dir):
    config_path = os.path.join(experiment_dir, 'config.json')
    model_config = get_pignn_config()
    config = load_config(config_path)
    is_direct_lstm = config['direct_lstm']

    deconv_model = DeConvNet(1, [64, 128, 256, 1], output_size=image_size) if not is_direct_lstm else None
    graph_model = FlowPIGNN(**model_config, deconv_model=deconv_model).to(device)
    graph_model_path = os.path.join(experiment_dir, 'pignn_best.pt')
    graph_model.load_state_dict(torch.load(graph_model_path))

    temporal_model = WindSpeedLSTMDeConv(seq_length, [64, 128, 256, 1], image_size).to(
        device) if is_direct_lstm else WindspeedLSTM(seq_length).to(device)
    temporal_model_path = os.path.join(experiment_dir, 'unet_lstm_best.pt')
    temporal_model.load_state_dict(torch.load(temporal_model_path))
    embedding_size = (50, 10) if is_direct_lstm else (image_size, image_size)

    dataset = get_dataset(config['dataset_dirs'], True, seq_length)
    _, _, test_loader = create_data_loaders(dataset, config['batch_size'], seq_length)
    return calculate_temporal_test_loss(test_loader, graph_model, temporal_model, embedding_size)

def calculate_temporal_test_loss(test_loader, graph_model, temporal_model, embedding_size, output_size=(128, 128), plot_examples=False):
    with torch.no_grad():
        test_losses = []
        for j, batch in enumerate(test_loader):
            generated_img = []
            target_img = []
            for i, seq in enumerate(batch[0]):
                # Process graphs in parallel at each timestep for the entire batch
                seq = seq.to(device)
                nf = torch.cat((seq.x.to(device), seq.pos.to(device)), dim=-1).float()
                ef = seq.edge_attr.to(device).float()
                gf = seq.global_feats.to(device).float()
                graph_output = graph_model(seq, nf, ef, gf).reshape(-1, embedding_size[0], embedding_size[1])
                generated_img.append(graph_output)
                target_img.append(batch[1][i].y.to(device).reshape(-1, output_size[0], output_size[1]))

            temporal_img = torch.stack(generated_img, dim=1)
            output = temporal_model(temporal_img).flatten()
            target = torch.stack(target_img, dim=1).flatten()
            test_loss = criterion(output, target)

            if plot_examples:
                plot_prediction_vs_real(output[0, seq_length - 1].cpu(), target[0, seq_length - 1].cpu(), number=j+6)
            test_losses.append(test_loss.item())
    return np.mean(test_losses), np.std(test_losses)

In [11]:
# Evaluation for temporal methods
image_size = 128
seq_length = 50
base_dir = "results/temporal"

# Iterate through each experiment folder and load the model and config
for experiment_name in os.listdir(base_dir):
    experiment_dir = os.path.join(base_dir, experiment_name)

    if os.path.isdir(experiment_dir):
        try:
            mse, std = evaluate_temporal_model(experiment_dir)
            print(f"Loaded model from {experiment_name} has MSE on test set: {mse} +- {std}")
        except Exception as e:
            print(f"Failed to load model from {experiment_name}: {e}")

  graph_model.load_state_dict(torch.load(graph_model_path))
  temporal_model.load_state_dict(torch.load(temporal_model_path))


Loaded datasets, 2300 samples
Loaded model from 20241026180328_Case01_False_pignn_unet_lstm_30_50_case_01_sliding has MSE on test set: 0.25863130403601603 +- 0.021885890681178796
Loaded datasets, 2300 samples
Loaded model from 20241026232226_Case01_False_pignn_lstm_deconv_30_50_case_01_sliding has MSE on test set: 0.26188334498716437 +- 0.023512369145551625
