### Experiment.py Tests

In [None]:
from experiments.Experiment import Experiment, ExperimentLogger

import yaml
import os
import numpy as np

import torch
from torch_geometric.data import Data
from collections.abc import Iterable

In [2]:
def test_experiment_class_init():
    RESULT_DIR = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/result/'
    config_file = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/config/test_config.yml'
    config_name = os.path.basename(config_file)[:-5]
    with open(config_file) as file:
        config = yaml.safe_load(file)

    exp = Experiment(config, RESULT_DIR + config_name)
    assert exp
    assert exp.device == torch.device('cpu')
    

In [3]:
test_experiment_class_init()

In [21]:
def test_experiment_prep_data():
    RESULT_DIR = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/result/'
    config_file = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/config/test_config.yml'
    config_name = os.path.basename(config_file)[:-5]
    with open(config_file) as file:
        config = yaml.safe_load(file)

    exp = Experiment(config, RESULT_DIR + config_name)
    dataset, kfold = exp.prep_data()
    assert len(dataset) == 2914
    assert type(dataset[0]) == Data
    assert isinstance(dataset, Iterable)
    assert kfold.get_n_splits() == 5

In [23]:
test_experiment_prep_data()

In [4]:
def test_experiment_run_e2e():
    RESULT_DIR = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/result/'
    config_file = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/config/test_config.yml'
    config_name = os.path.basename(config_file)[:-4]
    with open(config_file) as file:
        config = yaml.safe_load(file)

    exp = Experiment(config, RESULT_DIR + config_name)
    exp.run()

In [5]:
test_experiment_run_e2e()

100%|███████████████████████████████████████████| 20/20 [00:46<00:00,  2.33s/it]
100%|███████████████████████████████████████████| 20/20 [00:41<00:00,  2.06s/it]
100%|███████████████████████████████████████████| 20/20 [00:38<00:00,  1.92s/it]
100%|███████████████████████████████████████████| 20/20 [00:39<00:00,  1.97s/it]
100%|███████████████████████████████████████████| 20/20 [00:37<00:00,  1.88s/it]
training model GIN on wico over 5 folds: 100%|████| 5/5 [03:23<00:00, 40.80s/it]

log saved to:  <_io.TextIOWrapper name='/Users/maxperozek/GNN-research/GNN-exp-pipeline/result/test_config_result.yaml' mode='w' encoding='UTF-8'>





In [6]:
def test_random_seed():
    pass

### Test Transforms

In [60]:
from transforms.wico_transforms import WICOTransforms

In [108]:
def test_wico_5g_vs_non_conspiracy_transform():
    t = getattr(WICOTransforms, 'wico_5g_vs_non_conspiracy')
    DATA_DIR = '/Users/maxperozek/GNN-research/data_pro/data/'
    full_wico_pyg = 'full_wico.pt'
    wico = torch.load(DATA_DIR + full_wico_pyg)
    wico_2_class = t(wico)
    
    assert len(wico) == 3511
    assert len(wico_2_class) == 2914
    assert (np.unique(np.array(wico_2_class, dtype=object)[:,2,1].astype(int)) == np.arange(2)).all() # assert there are exactly 2 classes; 0 and 1

In [109]:
test_wico_5g_vs_non_conspiracy_transform()

In [2]:
RESULT_DIR = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/result/'
config_file = '/Users/maxperozek/GNN-research/GNN-exp-pipeline/config/decent_config.yml'
config_name = os.path.basename(config_file)[:-4]
with open(config_file) as file:
    config = yaml.safe_load(file)

exp = Experiment(config, RESULT_DIR + config_name)
exp.run()

100%|███████████████████████████████████████████| 20/20 [01:22<00:00,  4.13s/it]
100%|███████████████████████████████████████████| 20/20 [01:24<00:00,  4.20s/it]
100%|███████████████████████████████████████████| 20/20 [01:23<00:00,  4.19s/it]
100%|███████████████████████████████████████████| 20/20 [01:23<00:00,  4.16s/it]
100%|███████████████████████████████████████████| 20/20 [01:23<00:00,  4.18s/it]
training model GIN on wico over 5 folds: 100%|████| 5/5 [06:59<00:00, 83.81s/it]

log saved to:  <_io.TextIOWrapper name='/Users/maxperozek/GNN-research/GNN-exp-pipeline/result/decent_config_result.yaml' mode='w' encoding='UTF-8'>





In [24]:
import itertools