In [None]:
cd ..

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

In [None]:
import torch
import pandas as pd
import numpy as np
import pickle
import pytorch_lightning as pl
from tqdm.auto import tqdm
from rga.data.diag_repr_graph_data_module import DiagonalRepresentationGraphDataModule
from rga.data.graph_loaders import RealGraphLoader, SyntheticGraphLoader
from rga.models.autoencoder_components import GraphEncoder
from rga.models.edge_encoders import MemoryEdgeEncoder
from rga.util.load_model import *
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from sklearn.metrics import *
from sklearn.neural_network import MLPClassifier
from rga.util.adjmatrix.diagonal_block_representation import diagonal_block_to_adj_matrix_representation

In [None]:
pl.seed_everything(0)

In [None]:
checkpoints_folder = ''
datasets = [
    'IMDB-BINARY',
    'IMDB-MULTI',
    'COLLAB',
    'REDDIT-BINARY',
    'REDDIT-MULTI-5K',
    'REDDIT-MULTI-12K',
]
dataset_folder = ''

In [None]:
class RealSaver(DiagonalRepresentationGraphDataModule):
    graphloader_class = RealGraphLoader
    
def prepare_model(model_path, hparams):
    encoder = GraphEncoder(edge_encoder_class = MemoryEdgeEncoder, **hparams)

    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    encoder_checkpoint = {
        k.replace("encoder.edge_encoder.", "edge_encoder."): v
        for (k, v) in checkpoint["state_dict"].items()
        if "encoder" in k
    }
    encoder.load_state_dict(encoder_checkpoint)
    
    return encoder

def prepare_dataset(dataset_path, hparams):
    return RealSaver(
        pickled_dataset_path=dataset_path,
        use_labels=True,
        bfs=True,
        deduplicate_train = False,
        deduplicate_val_test = False,
        batch_size=32,
        batch_size_val=32,
        batch_size_test=32,
        workers=0,
        block_size=hparams['block_size'],
        subgraph_scheduler_name='none',
        subgraph_scheduler_params={}
    )

def get_embeddings(model, dataloader, add_features = False):
    data_iterator = iter(dataloader)
    X = []
    Y = []
    if add_features:
        addidional_features = []
    
    for i, batch in enumerate(tqdm(data_iterator, desc='Embeddings')):
        X.append(model(batch).detach().numpy())
        Y.append(batch[3])
        if add_features:
            features = []
            for j in range(batch[0].shape[0]):
                A = diagonal_block_to_adj_matrix_representation(
                    batch[0][j], batch[2][j]
                )[:, :, 0].clamp(min = 0)
                node_degress = (A + A.T).sum(axis = 0)
                features.append([    
                    batch[2][j],
                    np.power(node_degress.sqrt().mean().detach().numpy(), 2),
                    node_degress.mean().detach().numpy(),
                    np.sqrt(node_degress.square().mean().detach().numpy())
                ])

            addidional_features.append(features)
        
    return np.concatenate(X), np.concatenate(Y), np.concatenate(addidional_features) if add_features else None

In [None]:
def process_model(hparams_path, model_path, dataset_path, PCA_dim = None, add_features = False):
    hparams = load_hparams(hparams_path)

    model = prepare_model(model_path, hparams)
    dataset = prepare_dataset(dataset_path, hparams)

    train_X, train_Y, train_addidional_features = get_embeddings(
        model, dataset.train_dataloader(shuffle=False), add_features = add_features
    )
    val_X, val_Y, val_addidional_features = get_embeddings(
        model, dataset.val_dataloader(shuffle=False)[0], add_features = add_features
    )
    test_X, test_Y, test_addidional_features = get_embeddings(
        model, dataset.test_dataloader(shuffle=False)[0], add_features = add_features
    )

    if PCA_dim is not None:
        pca = PCA(n_components=PCA_dim)
        pca.fit(train_X)
        train_X = pca.transform(train_X)
        val_X = pca.transform(val_X)
        test_X = pca.transform(test_X)

    if add_features:
        mean_additional_features = train_addidional_features.mean(axis = 0)
        train_X_augmented = np.concatenate([train_X, train_addidional_features/mean_additional_features], axis = 1)
        val_X_augmented = np.concatenate([val_X, val_addidional_features/mean_additional_features], axis = 1)
        test_X_augmented = np.concatenate([test_X, test_addidional_features/mean_additional_features], axis = 1)
        
        train_X_only_stats = train_addidional_features/mean_additional_features
        val_X_only_stats = val_addidional_features/mean_additional_features
        test_X_only_stats = test_addidional_features/mean_additional_features
        
    sklearn_models = {
        'NB': GaussianNB(),
        'SVM': SVC(),
        'Logistic regression': LogisticRegression(), 
        'xgboost': GradientBoostingClassifier(min_samples_leaf=20, verbose=False), 
        'Random forset': RandomForestClassifier(min_samples_leaf=20, verbose=False), 
    }

    stats = {
        'train': {},
        'val': {},
        'test': {}
    }
    
    
    for mode, train_features, val_features, test_features in tqdm([
        (' normal', train_X, val_X, test_X),
        (' only_stats', train_X_only_stats, val_X_only_stats, test_X_only_stats),
        (' augmented', train_X_augmented, val_X_augmented, test_X_augmented),

    ], desc='Modes'):
        for name, sklearn_model in tqdm(sklearn_models.items(), desc='Models', leave=True):
            sklearn_model.fit(train_features, train_Y)
            train_preds = sklearn_model.predict(train_features)
            val_preds = sklearn_model.predict(val_features)
            test_preds = sklearn_model.predict(test_features)
        
            stats['train'].update({
                name + mode: accuracy_score(train_Y, train_preds),
            })
            stats['val'].update({
                name + mode: accuracy_score(val_Y, val_preds),
            })
            stats['test'].update({
                name + mode: accuracy_score(test_Y, test_preds),
            })
    return stats

In [None]:
classification_results = {
    'train':{dataset:{} for dataset in datasets}, 
    'val':{dataset:{} for dataset in datasets}, 
    'test':{dataset:{} for dataset in datasets}, 
}

In [None]:
dataset_name = 'IMDB-BINARY'
add_features = True
for pickle_file in tqdm([0, 1, 2, 3, 4]):
    hparams_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '_hparams.yaml'
    model_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '.ckpt'
    dataset_path = dataset_folder + dataset_name + '/' + str(pickle_file) + '.pkl'
    
    stats = process_model(hparams_path, model_path, dataset_path, add_features=add_features)
    
    for dset in ['train', 'val', 'test']:
        selected_stats = {(pickle_file, k):v for (k,v) in stats[dset].items()}
        classification_results[dset][dataset_name].update(selected_stats)


In [None]:
dataset_name = 'IMDB-MULTI'
add_features = True
for pickle_file in tqdm([0, 1, 2, 3, 4]):
    hparams_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '_hparams.yaml'
    model_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '.ckpt'
    dataset_path = dataset_folder + dataset_name + '/' + str(pickle_file) + '.pkl'
    
    stats = process_model(hparams_path, model_path, dataset_path, add_features=add_features)
    
    selected_stats = {(pickle_file, k):v for (k,v) in stats['train'].items()}
    classification_results['train'][dataset_name].update(selected_stats)
    
    selected_stats = {(pickle_file, k):v for (k,v) in stats['val'].items()}
    classification_results['val'][dataset_name].update(selected_stats)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['test'].items()}
    classification_results['test'][dataset_name].update(selected_stats)


In [None]:
dataset_name = 'REDDIT-BINARY'
add_features = True
PCA_dim = None
for pickle_file in tqdm([0, 1, 2, 3, 4]):
    hparams_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '_hparams.yaml'
    model_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '.ckpt'
    dataset_path = dataset_folder + dataset_name + '/' + str(pickle_file) + '.pkl'
    
    stats = process_model(hparams_path, model_path, dataset_path, add_features=add_features, PCA_dim=PCA_dim)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['train'].items()}
    classification_results['train'][dataset_name].update(selected_stats)
    
    selected_stats = {(pickle_file, k):v for (k,v) in stats['val'].items()}
    classification_results['val'][dataset_name].update(selected_stats)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['test'].items()}
    classification_results['test'][dataset_name].update(selected_stats)


In [None]:
dataset_name = 'COLLAB'
add_features = True
PCA_dim = None
for pickle_file in tqdm([0, 1, 2, 3, 4]):
    hparams_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '_hparams.yaml'
    model_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '.ckpt'
    dataset_path = dataset_folder + dataset_name + '/' + str(pickle_file) + '.pkl'
    
    stats = process_model(hparams_path, model_path, dataset_path, add_features=add_features, PCA_dim=PCA_dim)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['train'].items()}
    classification_results['train'][dataset_name].update(selected_stats)
    
    selected_stats = {(pickle_file, k):v for (k,v) in stats['val'].items()}
    classification_results['val'][dataset_name].update(selected_stats)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['test'].items()}
    classification_results['test'][dataset_name].update(selected_stats)


In [None]:
dataset_name = 'REDDIT-MULTI-5K'
add_features = True
PCA_dim = None
for pickle_file in tqdm([0, 1, 2, 3, 4]):
    hparams_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '_hparams.yaml'
    model_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '.ckpt'
    dataset_path = dataset_folder + dataset_name + '/' + str(pickle_file) + '.pkl'
    
    stats = process_model(hparams_path, model_path, dataset_path, add_features=add_features, PCA_dim=PCA_dim)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['train'].items()}
    classification_results['train'][dataset_name].update(selected_stats)
    
    selected_stats = {(pickle_file, k):v for (k,v) in stats['val'].items()}
    classification_results['val'][dataset_name].update(selected_stats)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['test'].items()}
    classification_results['test'][dataset_name].update(selected_stats)

In [None]:
dataset_name = 'REDDIT-MULTI-12K'
add_features = True
PCA_dim = None
for pickle_file in tqdm([0, 1, 2, 3, 4]):
    hparams_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '_hparams.yaml'
    model_path = checkpoints_folder + dataset_name + '/' + str(pickle_file) + '.ckpt'
    dataset_path = dataset_folder + dataset_name + '/' + str(pickle_file) + '.pkl'
    
    stats = process_model(hparams_path, model_path, dataset_path, add_features=add_features, PCA_dim=PCA_dim)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['train'].items()}
    classification_results['train'][dataset_name].update(selected_stats)
    
    selected_stats = {(pickle_file, k):v for (k,v) in stats['val'].items()}
    classification_results['val'][dataset_name].update(selected_stats)

    selected_stats = {(pickle_file, k):v for (k,v) in stats['test'].items()}
    classification_results['test'][dataset_name].update(selected_stats)

In [None]:
for dset in ['train', 'val', 'test']:
    print('='*40+'   ' + dset + '   ' + '='*40)
    display(
        pd.DataFrame(
            classification_results[dset]
        ).reset_index().groupby('level_1').agg(['mean', 'std']).drop('level_0', axis = 1).round(3)
    )

In [None]:
with open('classification_results.backup', 'wb') as f:
    pickle.dump(classification_results, f)