In [43]:
import numpy as np
import pandas as pd
import os
import sys
sys.path.append("..")
import torch
import gcc
from gcc.datasets import (
    GRAPH_CLASSIFICATION_DSETS,
    GraphClassificationDataset,
    GraphClassificationDatasetLabeled,
    #LoadBalanceGraphDataset,
    NodeClassificationDataset,
    NodeClassificationDatasetLabeled,
    worker_init_fn,
)
from gcc.datasets.data_util import batcher
from gcc.models import GraphEncoder
from gcc.datasets import data_util
from collections import defaultdict, namedtuple


In [2]:
def print_model_args(args):
    for arg in vars(args):
        print(arg , " ", vars(args)[arg])
        
def test_moco(train_loader, model, opt):
    """
    one epoch training for moco
    """

    model.eval()

    emb_list = []
    for idx, batch in enumerate(train_loader):
        graph_q, graph_k = batch
        bsz = graph_q.batch_size
        graph_q.to(opt.device)
        graph_k.to(opt.device)

        with torch.no_grad():
            feat_q = model(graph_q)
            feat_k = model(graph_k)

        assert feat_q.shape == (bsz, opt.hidden_size)
        emb_list.append(((feat_q + feat_k) / 2).detach().cpu())
    return torch.cat(emb_list)

Loading and training models

In [3]:
checkpoint = torch.load('../saved/Pretrain_moco_False_dgl_gin_layer_5_lr_0.005_decay_1e-05_bsz_256_hid_64_samples_2000_nce_t_0.07_nce_k_32_rw_hops_256_restart_prob_0.8_aug_1st_ft_False_deg_16_pos_32_momentum_0.999/current.pth', map_location="cpu")
#print_model_args(checkpoint["opt"])
args = checkpoint["opt"]
args.device = torch.device("cpu")
model = GraphEncoder(
        positional_embedding_size=args.positional_embedding_size,
        max_node_freq=args.max_node_freq,
        max_edge_freq=args.max_edge_freq,
        max_degree=args.max_degree,
        freq_embedding_size=args.freq_embedding_size,
        degree_embedding_size=args.degree_embedding_size,
        output_dim=args.hidden_size,
        node_hidden_dim=args.hidden_size,
        edge_hidden_dim=args.hidden_size,
        num_layers=args.num_layer,
        num_step_set2set=args.set2set_iter,
        num_layer_set2set=args.set2set_lstm_layer,
        gnn_model=args.model,
        norm=args.norm,
        degree_input=True,
    )
model.load_state_dict(checkpoint["model"])

<All keys matched successfully>

In [None]:
save_results_to = '/media/nedooshki/f4f0aea6-900a-437f-82e1-238569330477/GRL-course-project/results'
def create_dataframe_save_to_csv(embeddings, labels, dataset_name, model_name, save_path):
    filename = '{}_{}.csv'.format(dataset_name, model_name)
    emb_df = pd.DataFrame(np.array(embeddings))
    emb_df.columns = ['emb' + str(e+1) for e in range(emb_df.shape[1])]
    emb_df['label'] = labels
    emb_df.to_csv(os.path.join(save_path, filename), sep='\t', index=False)

loading dataset

In [35]:
imdb_binary_train_dataset = GraphClassificationDataset(
            dataset='imdb-binary',
            rw_hops=args.rw_hops,
            subgraph_size=args.subgraph_size,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
        )
args.batch_size = len(imdb_binary_train_dataset)
imdb_binary_train_loader = torch.utils.data.DataLoader(
    dataset=imdb_binary_train_dataset,
    batch_size=args.batch_size,
    collate_fn=batcher(),
    shuffle=False,
    num_workers=args.num_workers,
)
imdb_binary_emb = test_moco(imdb_binary_train_loader, model, args)
imdb_binary_labels = data_util.create_graph_classification_dataset("imdb-binary").graph_labels
create_dataframe_save_to_csv(imdb_binary_emb,imdb_binary_labels,'imdbb','GCC',save_results_to)

In [34]:
imdb_multi_train_dataset = GraphClassificationDataset(
            dataset='imdb-multi',
            rw_hops=args.rw_hops,
            subgraph_size=args.subgraph_size,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
        )
args.batch_size = len(imdb_multi_train_dataset)
imdb_multi_train_loader = torch.utils.data.DataLoader(
    dataset=imdb_multi_train_dataset,
    batch_size=args.batch_size,
    collate_fn=batcher(),
    shuffle=False,
    num_workers=args.num_workers,
)
imdb_multi_emb = test_moco(imdb_multi_train_loader, model, args)
imdb_multi_labels = data_util.create_graph_classification_dataset("imdb-multi").graph_labels
create_dataframe_save_to_csv(imdb_multi_emb,imdb_multi_labels,'imdbm','GCC',save_results_to)

Downloading /home/nedooshki/.dgl/tu_IMDB-MULTI.zip from https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/IMDB-MULTI.zip...
Extracting file to /home/nedooshki/.dgl/tu_IMDB-MULTI


In [6]:
bbbp_train_dataset = GraphClassificationDataset(
            dataset='bbbp',
            rw_hops=args.rw_hops,
            subgraph_size=args.subgraph_size,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
        )
args.batch_size = len(bbbp_train_dataset)
bbbp_train_loader = torch.utils.data.DataLoader(
    dataset=bbbp_train_dataset,
    batch_size=args.batch_size,
    collate_fn=batcher(),
    shuffle=False,
    num_workers=args.num_workers,
)
bbbp_emb = test_moco(bbbp_train_loader, model, args)
bbbp_labels = data_util.create_graph_classification_dataset("bbbp").graph_labels
create_dataframe_save_to_csv(bbbp_emb,bbbp_labels,'bbbp','GCC',save_results_to)

In [36]:
bace_train_dataset = GraphClassificationDataset(
            dataset='bace',
            rw_hops=args.rw_hops,
            subgraph_size=args.subgraph_size,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
        )
args.batch_size = len(bace_train_dataset)
bace_train_loader = torch.utils.data.DataLoader(
    dataset=bace_train_dataset,
    batch_size=args.batch_size,
    collate_fn=batcher(),
    shuffle=False,
    num_workers=args.num_workers,
)
bace_emb = test_moco(bace_train_loader, model, args)
bace_labels = data_util.create_graph_classification_dataset("bace").graph_labels
create_dataframe_save_to_csv(bace_emb,bace_labels,'bace','GCC',save_results_to)

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv
Processing...
Done!
