In [1]:
import numpy as np
import pandas as pd
import torch
from model import DrBC
from utils import prepare_test, prepare_synthetic, preprocessing_data, validate
MODEL_SAVED_PATH = "saved_model/"

### Experiment

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
model200 = DrBC()
model200.load_state_dict(torch.load(f'{MODEL_SAVED_PATH}DrBC_G1000_N200_E10000.pth'))

model500 = DrBC()
model500.load_state_dict(torch.load(f'{MODEL_SAVED_PATH}DrBC_G1000_N500_E10000.pth'))

model5000 = DrBC()
model5000.load_state_dict(torch.load(f'{MODEL_SAVED_PATH}DrBC_G1000_N5000_E5000.pth'))

models = [(200, model200), (500, model500), (5000, model5000)]




scales = [5000, 10000, 20000]
ls_metrics = []
for scale in scales:
    print('-'*15, scale)
    g_list, dg_list, bc_list = prepare_synthetic(30, (scale, scale+1), parallel=True)
    for i in range(len(g_list)):
        test_X, test_y, test_edge_index = preprocessing_data([g_list[i]], [dg_list[i]], [bc_list[i]])
        t_data = [[test_X, test_y, test_edge_index]]
        for level, model in models:
            model = model.to(device)
            _acc1, _acc5, _acc10, _kendall, _time = validate(model, t_data)
            ls_metrics.append([scale, level, i, _acc1, _acc5, _acc10, _kendall, _time])


t_data = prepare_test('y')
for level, model in models:
    model = model.to(device)
    _acc1, _acc5, _acc10, _kendall, _time = validate(model, t_data)
    ls_metrics.append(['youtube', level, 0, _acc1, _acc5, _acc10, _kendall, _time])

--------------- 5000


[Generating new training graph]:   0%|          | 0/30 [00:00<?, ?it/s]

--------------- 10000


[Generating new training graph]:   0%|          | 0/30 [00:00<?, ?it/s]

--------------- 20000


[Generating new training graph]:   0%|          | 0/30 [00:00<?, ?it/s]

  (2 * xtie * ytie) / m + x0 * y0 / (9 * m * (size - 2)))


In [4]:
df = pd.DataFrame(ls_metrics, columns=['scale', 'model', 'test_graph_id', 'test_acc1', 'test_acc5', 'test_acc10', 'test_kendall', 'time'])
df.to_csv('test_scale_diff_result.csv', index=False)