In [1]:
import numpy as np
import networkx as nx
import data, network, torch, pickle
import matplotlib.pyplot as plt
import torch_geometric.utils as utils
import torch.nn.functional as F
from torch_geometric.utils import to_networkx
from torch_geometric.loader import DataLoader as batchLoader

device = "cpu"
for modelIdx in range(33, 49):

    path = "./trainedModel/"
    metaPath = f"{modelIdx}_metadata.pkl"
    modelPath = f"{modelIdx}.pt"
    paraPath = f"{modelIdx}_para.pkl"

    with open(path + metaPath, "rb") as f:
        metadata = pickle.load(f)
    with open(path + paraPath, "rb") as f:
        modelPara, paraDict = pickle.load(f)
    paraDict["device"] = torch.device(device)
    checkPoint = torch.load(path + modelPath, map_location=torch.device(device))
    model = network.modelSelect(modelPara).to(device)
    model.load_state_dict(checkPoint["model_state_dict"])
    model.eval()
    print("Model loaded.")

    FTG = data.FTGenerator(paraDict, metadata)
    Dataset, metadata = FTG.createDataset()
    _, location = FTG._constructRawData()
    testSet = Dataset[-730:]
    # testSet = [Dataset[-1]]
    nClass = metadata["nClass"]
    nNodes = testSet[0].x.shape[0]
    print(f"Model {modelIdx}. Test dataset loaded.")

    testLoader = batchLoader(testSet, batch_size=len(testSet), shuffle=False)

    for batch in testLoader:
        _yHat, _weights = model(batch.x, batch.edge_index, batch.edge_attr)
        # Pick the last three days only
        _yHat = _yHat[:, -6:]
        _truth = batch.y[:, -3:].reshape(-1)
        _yHat = _yHat.reshape(-1, nClass)
        _pred = torch.argmax(_yHat, axis=-1)
        _correct = _pred == _truth

        _accu = int(_correct.sum()) / len(_truth)

        _tP = (_truth == 1).reshape(-1, nNodes)
        _tN = (_truth == 0).reshape(-1, nNodes)
        _pP = (_pred == 1).reshape(-1, nNodes)
        _pN = (_pred == 0).reshape(-1, nNodes)

        _TP = torch.sum(_tP & _pP, axis=0)
        _FN = torch.sum(_tP & _pN, axis=0)
        _FP = torch.sum(_tN & _pP, axis=0)
        _TN = torch.sum(_tN & _pN, axis=0)

        _recall = _TP / (_TP + _FN)
        _prec = _TP / (_TP + _FP)
        _f1 = 2 * _TP / (2 * _TP + _FP + _FN)
        print(_accu, _recall.mean().item(), _prec.mean().item(), _f1.mean().item())
        with open("record2", "a") as f:
            print(
                _accu,
                _recall.mean().item(),
                _prec.mean().item(),
                _f1.mean().item(),
                file=f,
            )

Model loaded.
Model 33. Test dataset loaded.
0.9370861872146119 0.5921064019203186 0.6516919732093811 0.6179358959197998
Model loaded.
Model 34. Test dataset loaded.
0.9082429604261796 0.48991861939430237 0.47854506969451904 0.483335942029953
Model loaded.
Model 35. Test dataset loaded.
0.8934693683409437 0.3521733283996582 0.3831295669078827 0.36521780490875244
Model loaded.
Model 36. Test dataset loaded.
0.8844796423135465 0.24855542182922363 0.3049909174442291 0.2716773450374603
Model loaded.
Model 37. Test dataset loaded.
0.8754899162861491 0.2441893219947815 0.2721286714076996 0.25608769059181213
Model loaded.
Model 38. Test dataset loaded.
0.8535625951293759 0.3007812201976776 0.23998399078845978 0.2666744291782379
Model loaded.
Model 39. Test dataset loaded.
0.8530679223744292 0.290447860956192 0.23485024273395538 0.2593574821949005
Model loaded.
Model 40. Test dataset loaded.
0.8488489345509893 0.3183199167251587 0.23895250260829926 0.27240967750549316
Model loaded.
Model 41. T