# Read the results of the retrained models

In [1]:
import seml
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
import torch
import seaborn as sns
import json
import os
import sys
import time
from torch_geometric.loader import DataLoader
from statistics import mean, median
from tqdm import tqdm
import dgl
import networkx as nx
import torch_geometric
sys.path.append(os.path.dirname(os.getcwd()))

from src.baseline.model_gcn import GIN
from src.ppgn.ppgn import PPGN
from src.I2GNN.I2GNN import I2GNN
from src.baseline.dataset_gcn import GraphDataset
from src.I2GNN.I2GNN_dataset import I2GNNDataset, I2GNNDataLoader, I2GNNDatasetRobustness
from src.metrics.L1_based import L1LossCount, L1LossStd
from src.I2GNN.utils import create_subgraphs2

In [2]:
old_dataset = 'er_10'
new_dataset = 'er_2'
arch = "PPGN"
device = 'cuda'
batch_size = 32
retrain_experiment = f'retrain_{arch}_{old_dataset}_{new_dataset}'
results: pd.DataFrame = seml.get_results(retrain_experiment, to_data_frame=True)
results.index = results["_id"]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  parsed = pd.io.json.json_normalize(parsed, sep='.')


In [3]:
def evaluate_epoch(dataloader, gnn, loss_fn, device)->torch.Tensor:
    gnn.eval()
    with torch.no_grad():
        num_batches = len(dataloader)
        loss = torch.zeros(len(l)).to(device)
        for data in dataloader:
            data = data.to(device)
            y = data.y
            pred = gnn(data)
            for i, loss_fn in enumerate(l):
                loss[i] += loss_fn(pred, y)
        
        loss = loss / num_batches
    return loss

hops = {
    "Triangle": 1,
    "2-Path": 2,
    "4-Clique": 1, 
    "Chordal cycle": 2,
    "Tailed triangle": 2,
    "3-Star": 2,
    "4-Cycle": 2,
    "3-Path": 3,
    "3-Star not ind.": 2,
}

## Loss of the original model on the original dataset

In [8]:
# load original models
model_folder = results["config.model_folder"][1]
model = results["config.model"][1]
dataset = results["config.test_original_dataset"][1]
n_seeds = results["config.n_seeds"][1]
subgraphs = results['config.subgraph'].unique()
result_dataset = pd.DataFrame(columns=["L1 avg", "L1 std avg", "L1 count avg", "L1", "L1 std", "L1 count"], index=subgraphs)
for subgraph in subgraphs:
    start = time.time()
    gnns = []
    for i in range(n_seeds):
        model_dict = f"{model_folder}/{model}_{subgraph}_{i}.pth"
        model_params = f"{model_folder}/{model}_{subgraph}_{i}.json"
        with open(model_params, 'r') as fp:
            h_params = json.load(fp)
        if model == 'GIN':
            gnns.append(GIN(**h_params).to(device))
        elif model == 'PPGN':
            gnns.append(PPGN(**h_params).to(device))
        elif model == 'I2GNN':
            gnns.append(I2GNN(**h_params).to(device))
        else:
            raise ValueError("The architecture is not supported!")
        gnns[-1].load_state_dict(torch.load(model_dict, map_location=torch.device(device)))

    # load original dataset
    if model == 'GIN' or model == 'PPGN':
        test = GraphDataset(dataset, subgraph, in_channels=1)
        dataloader = DataLoader(dataset=test, batch_size=batch_size, shuffle=False)
        std = torch.std(test.labels)

    elif model == 'I2GNN':
        def pre_transform(g, hops):
            return create_subgraphs2(g, hops)
        test = I2GNNDataset(root=os.path.dirname(dataset),dataset=os.path.basename(dataset),  subgraph_type=subgraph, pre_transform=pre_transform, hops=hops[subgraph])
        dataloader = I2GNNDataLoader(dataset=test, batch_size=batch_size, shuffle=False)
        std = torch.std(test.data.y)

    l1 = torch.nn.L1Loss()
    l1_std = L1LossStd(std)
    l1_count = L1LossCount()
    mse = torch.nn.MSELoss()

    # evaluate
    l = [l1, l1_std, l1_count]
    l1_err = []
    l1_std_err = []
    l1_count_err = []
    for gnn in gnns:
        err = evaluate_epoch(dataloader, gnn, l, device)
        l1_err.append(err[0].item())
        l1_std_err.append(err[1].item())
        l1_count_err.append(err[2].item())
    result_dataset.loc[subgraph] = [mean(l1_err), mean(l1_std_err), mean(l1_count_err), l1_err,l1_std_err, l1_count_err]
    print(f'{subgraph}: {mean(l1_err)}, time: {time.time() - start}')
display(result_dataset)

Triangle: 0.005825826153159142, time: 3.3609139919281006
2-Path: 0.01601270940154791, time: 3.570786952972412
4-Clique: 0.009579430520534515, time: 3.4671690464019775
Chordal cycle: 0.10421011745929717, time: 3.4854495525360107
Tailed triangle: 0.27803733944892883, time: 3.461744785308838
3-Star: 0.11732048988342285, time: 3.345783233642578
4-Cycle: 0.05844668671488762, time: 3.460584878921509
3-Path: 0.28193793892860414, time: 3.4192745685577393
3-Star not ind.: 0.03771090656518936, time: 3.5073368549346924


Unnamed: 0,L1 avg,L1 std avg,L1 count avg,L1,L1 std,L1 count
Triangle,0.005826,0.001218,0.000783,"[0.0064351060427725315, 0.005726082716137171, ...","[0.0013453572755679488, 0.001197125413455069, ...","[0.0007883830112405121, 0.0007761426386423409,..."
2-Path,0.016013,0.00173,0.000472,"[0.01598486490547657, 0.014843310229480267, 0....","[0.0017273830017074943, 0.001604021992534399, ...","[0.00047385680954903364, 0.000443336961325258,..."
4-Clique,0.009579,0.005716,0.002665,"[0.008784443140029907, 0.009193134494125843, 0...","[0.005241316743195057, 0.005485165398567915, 0...","[0.0024962821044027805, 0.002486540237441659, ..."
Chordal cycle,0.10421,0.014371,0.011043,"[0.09791170805692673, 0.10903941839933395, 0.1...","[0.013502541929483414, 0.0150371128693223, 0.0...","[0.010211536660790443, 0.011264696717262268, 0..."
Tailed triangle,0.278037,0.022938,0.010512,"[0.25291532278060913, 0.2878391146659851, 0.29...","[0.020865697413682938, 0.02374693751335144, 0....","[0.009684931486845016, 0.010799292474985123, 0..."
3-Star,0.11732,0.020817,0.011555,"[0.10288278758525848, 0.12160579115152359, 0.1...","[0.01825486496090889, 0.021576952189207077, 0....","[0.010355697944760323, 0.011695235967636108, 0..."
4-Cycle,0.058447,0.013447,0.010366,"[0.06346682459115982, 0.05979559198021889, 0.0...","[0.014601532369852066, 0.013756902888417244, 0...","[0.012175307609140873, 0.010429250076413155, 0..."
3-Path,0.281938,0.026432,0.009002,"[0.2753278315067291, 0.280529260635376, 0.2613...","[0.025811880826950073, 0.02629951946437359, 0....","[0.008840369060635567, 0.008810547180473804, 0..."
3-Star not ind.,0.037711,0.001171,0.00101,"[0.03888349235057831, 0.03946198895573616, 0.0...","[0.001206965302117169, 0.0012249229475855827, ...","[0.0008388894493691623, 0.0010124592809006572,..."


## Loss of original model on new dataset

In [10]:
# load original models
model_folder = results["config.model_folder"][1]
model = results["config.model"][1]
dataset = results["config.test_dataset"][1]
n_seeds = results["config.n_seeds"][1]
subgraphs = results['config.subgraph'].unique()
results_dataset = pd.DataFrame(columns=["L1 avg", "L1 std avg", "L1 count avg", "L1", "L1 std", "L1 count"], index=subgraphs)
for subgraph in subgraphs:
    start = time.time()
    gnns = []
    for i in range(n_seeds):
        model_dict = f"{model_folder}/{model}_{subgraph}_{i}.pth"
        model_params = f"{model_folder}/{model}_{subgraph}_{i}.json"
        with open(model_params, 'r') as fp:
            h_params = json.load(fp)
        if model == 'GIN':
            gnns.append(GIN(**h_params).to(device))
        elif model == 'PPGN':
            gnns.append(PPGN(**h_params).to(device))
        elif model == 'I2GNN':
            gnns.append(I2GNN(**h_params).to(device))
        else:
            raise ValueError("The architecture is not supported!")
        gnns[-1].load_state_dict(torch.load(model_dict, map_location=torch.device(device)))

    # load original dataset
    if model == 'GIN' or model == 'PPGN':
        test = GraphDataset(dataset, subgraph, in_channels=1)
        dataloader = DataLoader(dataset=test, batch_size=batch_size, shuffle=False)
        std = torch.std(test.labels)

    elif model == 'I2GNN':
        def pre_transform(g, hops):
            return create_subgraphs2(g, hops)
        test = I2GNNDataset(root=os.path.dirname(dataset),dataset=os.path.basename(dataset),  subgraph_type=subgraph, pre_transform=pre_transform, hops=hops[subgraph])
        dataloader = I2GNNDataLoader(dataset=test, batch_size=batch_size, shuffle=False)
        std = torch.std(test.data.y)

    l1 = torch.nn.L1Loss()
    l1_std = L1LossStd(std)
    l1_count = L1LossCount()
    mse = torch.nn.MSELoss()

    # evaluate
    l = [l1, l1_std, l1_count]
    l1_err = []
    l1_std_err = []
    l1_count_err = []
    for gnn in gnns:
        err = evaluate_epoch(dataloader, gnn, l, device)
        l1_err.append(err[0].item())
        l1_std_err.append(err[1].item())
        l1_count_err.append(err[2].item())
    results_dataset.loc[subgraph] = [mean(l1_err), mean(l1_std_err), mean(l1_count_err), l1_err,l1_std_err, l1_count_err]
    print(f'{subgraph}: {mean(l1_err)}, time: {time.time() - start}')
display(results_dataset)



  arr_value = np.asarray(value)


Triangle: 2.9661218166351317, time: 3.3656005859375
2-Path: 5.370835685729981, time: 3.4094550609588623
4-Clique: 7.8150018692016605, time: 3.3911688327789307
Chordal cycle: 19.98771381378174, time: 3.4019222259521484
Tailed triangle: 28.441026878356933, time: 3.3941831588745117
3-Star: 5.4437649011611935, time: 3.317873954772949
4-Cycle: 5.4002622127532955, time: 3.541440725326538
3-Path: 13.611257457733155, time: 3.396723747253418
3-Star not ind.: 11.468862771987915, time: 3.4053115844726562


Unnamed: 0,L1 avg,L1 std avg,L1 count avg,L1,L1 std,L1 count
Triangle,2.966122,0.207382,0.040729,"[3.4586639404296875, 3.656454086303711, 2.1099...","[0.24181875586509705, 0.2556476593017578, 0.14...","[0.04715724289417267, 0.04874846339225769, 0.0..."
2-Path,5.370836,0.622969,0.15233,"[1.3477202653884888, 1.8963526487350464, 5.546...","[0.15632350742816925, 0.2199600338935852, 0.64...","[0.04128839075565338, 0.0609329454600811, 0.15..."
4-Clique,7.815002,0.286641,0.116855,"[9.962668418884277, 11.018967628479004, 4.4120...","[0.3654138147830963, 0.4041573405265808, 0.161...","[0.14211338758468628, 0.16380423307418823, 0.0..."
Chordal cycle,19.987714,1.512661,0.244781,"[37.8895149230957, 20.984554290771484, 12.6142...","[2.8674604892730713, 1.5881016254425049, 0.954...","[0.448074072599411, 0.2517739236354828, 0.1630..."
Tailed triangle,28.441027,1.966026,2.052349,"[11.80980396270752, 30.221813201904297, 44.406...","[0.8163692951202393, 2.0891246795654297, 3.069...","[0.7736030220985413, 1.745348572731018, 3.5882..."
3-Star,5.443765,1.290124,4.121051,"[3.2707149982452393, 13.54274845123291, 4.8272...","[0.7751302123069763, 3.209510326385498, 1.1440...","[2.1512773036956787, 10.553454399108887, 3.874..."
4-Cycle,5.400262,0.91326,1.173153,"[8.442290306091309, 3.373133897781372, 3.61650...","[1.4277104139328003, 0.5704444646835327, 0.611...","[1.8547978401184082, 0.5130177140235901, 0.782..."
3-Path,13.611257,1.735703,5.198219,"[6.362536907196045, 16.700332641601562, 21.288...","[0.8113481998443604, 2.129619598388672, 2.7146...","[1.2426958084106445, 8.534914016723633, 7.2176..."
3-Star not ind.,11.468863,0.114679,0.023371,"[11.919388771057129, 3.3888866901397705, 5.699...","[0.11918392777442932, 0.033886030316352844, 0....","[0.024615278467535973, 0.0068615577183663845, ..."


## Loss of retrained model on the new dataset

In [5]:
# load original models
model = results["config.model"][1]
dataset = results["config.test_dataset"][1]
n_seeds = results["config.n_seeds"][1]
subgraphs = results['config.subgraph'].unique()
results_dataset = pd.DataFrame(columns=["L1 avg", "L1 std avg", "L1 count avg", "L1", "L1 std", "L1 count"], index=subgraphs)
for _, line in results.iterrows():
    dict_models = line['result.model_paths']
    params_models = line['result.h_param_paths']
    subgraph = line['config.subgraph']
    start = time.time()
    gnns = []
    for dict_model, params_model in zip(dict_models, params_models):
        with open(params_model, 'r') as f:
            h_params = json.load(f)
        if arch == "GIN":
            gnn = GIN(**h_params).to(device)
        if arch == "PPGN":
            gnn = PPGN(**h_params).to(device)
        elif arch == 'I2GNN':
            gnn = I2GNN(**h_params).to(device)
        gnn.load_state_dict(torch.load(dict_model, map_location=torch.device(device)))
        gnns.append(gnn)

    # load dataset
    if model == 'GIN' or model == 'PPGN':
        test = GraphDataset(dataset, subgraph, in_channels=1)
        dataloader = DataLoader(dataset=test, batch_size=batch_size, shuffle=False)
        std = torch.std(test.labels)

    elif model == 'I2GNN':
        def pre_transform(g, hops):
            return create_subgraphs2(g, hops)
        test = I2GNNDataset(root=os.path.dirname(dataset),dataset=os.path.basename(dataset),  subgraph_type=subgraph, pre_transform=pre_transform, hops=hops[subgraph])
        dataloader = I2GNNDataLoader(dataset=test, batch_size=batch_size, shuffle=False)
        std = torch.std(test.data.y)

    l1 = torch.nn.L1Loss()
    l1_std = L1LossStd(std)
    l1_count = L1LossCount()
    mse = torch.nn.MSELoss()

    # evaluate
    l = [l1, l1_std, l1_count]
    l1_err = []
    l1_std_err = []
    l1_count_err = []
    for gnn in gnns:
        err = evaluate_epoch(dataloader, gnn, l, device)
        l1_err.append(err[0].item())
        l1_std_err.append(err[1].item())
        l1_count_err.append(err[2].item())
    results_dataset.loc[subgraph] = [mean(l1_err), mean(l1_std_err), mean(l1_count_err), l1_err,l1_std_err, l1_count_err]
    print(f'{subgraph}: {mean(l1_err)}, time: {time.time() - start}')
display(results_dataset)
print(results["result.count_mean"])

  arr_value = np.asarray(value)


Triangle: 0.05817360281944275, time: 2.4045417308807373
2-Path: 0.11299102753400803, time: 2.3890392780303955
4-Clique: 0.3135445713996887, time: 2.3078620433807373
Chordal cycle: 1.0700857400894166, time: 2.422668933868408
Tailed triangle: 1.040301263332367, time: 2.369826316833496
3-Star: 0.24227949976921082, time: 2.401632785797119
4-Cycle: 0.2899557322263718, time: 2.4262750148773193
3-Path: 0.7225114226341247, time: 2.2972583770751953
3-Star not ind.: 0.6967951059341431, time: 2.379390239715576


Unnamed: 0,L1 avg,L1 std avg,L1 count avg,L1,L1 std,L1 count
Triangle,0.058174,0.004067,0.000977,"[0.05962110683321953, 0.04445173591375351, 0.0...","[0.004168517421931028, 0.0031079244799911976, ...","[0.0010356515413150191, 0.0007291014189831913,..."
2-Path,0.112991,0.013106,0.002682,"[0.11834253370761871, 0.07676807790994644, 0.1...","[0.013726676814258099, 0.008904412388801575, 0...","[0.002628588117659092, 0.001702932990156114, 0..."
4-Clique,0.313545,0.0115,0.006685,"[0.244696706533432, 0.2504911422729492, 0.3125...","[0.008975066244602203, 0.009187594056129456, 0...","[0.00577513175085187, 0.005842967424541712, 0...."
Chordal cycle,1.070086,0.080984,0.013391,"[0.9687739014625549, 1.1567442417144775, 1.205...","[0.07331638038158417, 0.0875418558716774, 0.09...","[0.012509196996688843, 0.01468413881957531, 0...."
Tailed triangle,1.040301,0.071912,0.047402,"[0.8548049330711365, 0.9839093685150146, 1.128...","[0.05908958986401558, 0.06801410764455795, 0.0...","[0.03359950706362724, 0.054910220205783844, 0...."
3-Star,0.242279,0.057418,0.07934,"[0.15281909704208374, 0.3498614430427551, 0.31...","[0.03621676564216614, 0.08291404694318771, 0.0...","[0.03194578364491463, 0.12903477251529694, 0.1..."
4-Cycle,0.289956,0.049036,0.042455,"[0.37575170397758484, 0.2742007374763489, 0.28...","[0.06354491412639618, 0.0463712103664875, 0.04...","[0.05259224399924278, 0.042126964777708054, 0...."
3-Path,0.722511,0.092134,0.100972,"[0.5731582045555115, 0.8539143204689026, 0.718...","[0.07308891415596008, 0.1088908240199089, 0.09...","[0.0668974369764328, 0.11965062469244003, 0.10..."
3-Star not ind.,0.696795,0.006967,0.001591,"[1.4685271978378296, 0.3790729343891144, 0.306...","[0.014684036374092102, 0.003790411865338683, 0...","[0.003306184196844697, 0.0008619489963166416, ..."


KeyError: 'resilt.count_mean'