In [1]:
cd ..

/home/adam/phd/recurrent-graph-autoencoder


In [2]:
import gc
import torch
import random
import numpy as np
from tqdm.auto import tqdm

from graph_nn_vae.data.diag_repr_graph_data_module import DiagonalRepresentationGraphDataModule
from graph_nn_vae.data.graph_loaders import RealGraphLoader, SyntheticGraphLoader
from graph_nn_vae.experiments.decorators import add_graphloader_args
from graph_nn_vae.models.autoencoder_components import GraphEncoder
from graph_nn_vae.models.edge_encoders import MemoryEdgeEncoder

In [3]:
class RealSaver(DiagonalRepresentationGraphDataModule):
    graphloader_class = RealGraphLoader

In [4]:
train_val_test_split = [0.7, 0.15, 0.15]
train_val_test_permutation_split = [1, 0, 0.0]
num_dataset_graph_permutations = 10

In [19]:
dataset = RealSaver(
    datasets_dir='/home/adam/phd/recurrent-graph-autoencoder/datasets',
    dataset_name='IMDB-MULTI',
    use_labels=True,
    max_graph_size=None,
    num_dataset_graph_permutations=1, 
    train_val_test_split=train_val_test_split, 
    train_val_test_permutation_split=train_val_test_permutation_split,
    # save_dataset_to_pickle=to_save_path+'/'+dataset_name+'/'+str(i)+'.pkl',
    bfs=True,
    deduplicate_train = False,
    deduplicate_val_test = False,
    batch_size=500000,
    batch_size_val=500000,
    batch_size_test=500000,
    workers=0,
    block_size=6,
    subgraph_scheduler_name='none',
    subgraph_scheduler_params={}
)

reading edges: 0it [00:00, ?it/s]

Statistic of set:  Full original dataset
             Dataset size : 1500
                   Labels : True
           Min node count : 7
       Average node count : 13.0
           Max node count : 89
           Min edge count : 12.0
       Average edge count : 65.94
           Max edge count : 1467.0
     Min filling fraction : 0.13
 Average filling fraction : 0.77
     Max filling fraction : 1.0
          Label "1" count : 500
          Label "2" count : 500
          Label "3" count : 500
----------------------------------------------------------------
Statistic of set:  Train dataset
             Dataset size : 1050
                   Labels : True
           Min node count : 7
       Average node count : 12.99
           Max node count : 89
           Min edge count : 12.0
       Average edge count : 67.04
           Max edge count : 1467.0
     Min filling fraction : 0.14
 Average filling fraction : 0.78
     Max filling fraction : 1.0
          Label "1" count : 345
          La

preparing dataset train for autoencoder:   0%|          | 0/1050 [00:00<?, ?it/s]

preparing dataset val 0 for autoencoder:   0%|          | 0/225 [00:00<?, ?it/s]

preparing dataset test 0 for autoencoder:   0%|          | 0/225 [00:00<?, ?it/s]

In [20]:
dataset

<__main__.RealSaver at 0x7ff1f643d4f0>

In [21]:
checkpoint_path = '/home/adam/phd/recurrent-graph-autoencoder/tb_logs/RecurrentGraphAutoencoder/IMDB-MULTI/version_0/checkpoints/epoch=329-step=14189-v1.ckpt'

encoder = GraphEncoder(
    edge_encoder_class = MemoryEdgeEncoder,
    embedding_size = 104,
    edge_size = 1,
    block_size= 6,
    loss_function = 'BCEWithLogits',
    loss_weight = None,
    learning_rate = '0.0001',
    optimizer = 'AdamWAMSGrad',
    lr_scheduler_name = 'NoSched',
    lr_scheduler_params = {},
    lr_scheduler_metric = 'loss/train_avg',
    metrics = [],
    encoder_hidden_layer_sizes=[1024, 768],
    encoder_activation_function='ELU'
)

checkpoint = torch.load(checkpoint_path)
encoder_checkpoint = {
    k.replace("encoder.edge_encoder.", "edge_encoder."): v
    for (k, v) in checkpoint["state_dict"].items()
    if "encoder" in k
}
encoder.load_state_dict(encoder_checkpoint)

<All keys matched successfully>

In [36]:
train_batch = next(iter(dataset.train_dataloader()))
train_batch_labels = train_batch[3]
print(len(train_batch_labels))

val_batch = next(iter(dataset.val_dataloader()[0]))
val_batch_labels = val_batch[3]
print(len(val_batch_labels))

test_batch = next(iter(dataset.test_dataloader()[0]))
test_batch_labels = test_batch[3]
print(len(test_batch_labels))


1050
225
225


In [44]:
train_batch_X = encoder(train_batch).detach().numpy()
val_batch_X = encoder(val_batch).detach().numpy()
test_batch_X = encoder(test_batch).detach().numpy()

In [45]:
train_batch_X.shape

(1050, 104)

In [42]:
train_batch_labels

tensor([3, 1, 2,  ..., 3, 1, 2])

In [63]:
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import classification_report

In [65]:
model = RandomForestClassifier(n_estimators=500, min_samples_leaf=5, min_samples_split=4)
# model = GradientBoostingClassifier()
model.fit(train_batch_X, train_batch_labels)
train_batch_labels_pred = model.predict(train_batch_X)
print(classification_report(train_batch_labels, train_batch_labels_pred))

val_batch_labels_pred = model.predict(val_batch_X)
print(classification_report(val_batch_labels, val_batch_labels_pred))

              precision    recall  f1-score   support

           1       0.86      0.47      0.61       345
           2       0.62      0.74      0.67       352
           3       0.60      0.74      0.66       353

    accuracy                           0.65      1050
   macro avg       0.69      0.65      0.65      1050
weighted avg       0.69      0.65      0.65      1050

              precision    recall  f1-score   support

           1       0.44      0.40      0.42        80
           2       0.46      0.42      0.44        69
           3       0.49      0.58      0.53        76

    accuracy                           0.47       225
   macro avg       0.46      0.47      0.46       225
weighted avg       0.46      0.47      0.46       225

