In [1]:
from data_loader import *
from data_util import *

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# Example: loading as in `gcc`

In [2]:
import torch

# Example TUDataset

In [3]:
from dgl.data import TUDataset

In [4]:
dataset = TUDataset("REDDIT-BINARY")

In [5]:
graph, _ = dataset[0]

In [6]:
dataset

Dataset("REDDIT-BINARY", num_graphs=2000, save_path=/Users/nmm/.dgl/REDDIT-BINARY)

## From GraphMAE

In [7]:
from collections import namedtuple, Counter

def load_graph_classification_dataset(dataset_name, deg4feat=False):
    dataset_name = dataset_name.upper()
    dataset = TUDataset(dataset_name)
    graph, _ = dataset[0]

    if "attr" not in graph.ndata:
        if "node_labels" in graph.ndata and not deg4feat:
            print("Use node label as node features")
            feature_dim = 0
            for g, _ in dataset:
                feature_dim = max(feature_dim, g.ndata["node_labels"].max().item())
            
            feature_dim += 1
            for g, l in dataset:
                node_label = g.ndata["node_labels"].view(-1)
                feat = F.one_hot(node_label, num_classes=feature_dim).float()
                g.ndata["attr"] = feat
        else:
            print("Using degree as node features")
            feature_dim = 0
            degrees = []
            for g, _ in dataset:
                feature_dim = max(feature_dim, g.in_degrees().max().item())
                degrees.extend(g.in_degrees().tolist())
            MAX_DEGREES = 400

            oversize = 0
            for d, n in Counter(degrees).items():
                if d > MAX_DEGREES:
                    oversize += n
            # print(f"N > {MAX_DEGREES}, #NUM: {oversize}, ratio: {oversize/sum(degrees):.8f}")
            feature_dim = min(feature_dim, MAX_DEGREES)

            feature_dim += 1
            for g, l in dataset:
                degrees = g.in_degrees()
                degrees[degrees > MAX_DEGREES] = MAX_DEGREES
                
                feat = F.one_hot(degrees, num_classes=feature_dim).float()
                g.ndata["attr"] = feat
    else:
        print("******** Use `attr` as node features ********")
        feature_dim = graph.ndata["attr"].shape[1]

    labels = torch.tensor([x[1] for x in dataset])
    
    num_classes = torch.max(labels).item() + 1
    dataset = [(g.remove_self_loop().add_self_loop(), y) for g, y in dataset]

    print(f"******** # Num Graphs: {len(dataset)}, # Num Feat: {feature_dim}, # Num Classes: {num_classes} ********")

    return dataset, (feature_dim, num_classes)

In [8]:
mae_data, _ = load_graph_classification_dataset("REDDIT-BINARY")

Using degree as node features
******** # Num Graphs: 2000, # Num Feat: 401, # Num Classes: 2 ********


In [10]:
mae_data[1]#[1]

(Graph(num_nodes=717, num_edges=4041,
       ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(401,), dtype=torch.float32)}
       edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}),
 tensor([1]))

# `GraphDataset` from `graph-tool` library

In [11]:
pretext_keys = ["ceo_club", "bison", "copenhagen/calls"]
fake_labels = ["label2", "label1", "label2"]

In [12]:
gtdata = GraphDataset(dgl_graphs=graphtoolkeys2dglgraphs(pretext_keys)
                      , graphs_labels=fake_labels
                      , verbosity=True
                     )

🦦 converting 3 graphs from graph-tool library to DGL format.


100%|█████████████████████████████████████████████| 3/3 [00:00<00:00,  6.75it/s]

┌------------------------------------------------------------┐
|                        GraphDataset                        |
├------------------------------------------------------------┤
|number of graphs                       |                   3|
|nodes — tot                            |                 602|
|nodes — mean                           |  200.66666666666666|
|nodes — median                         |                40.0|
|nodes — min                            |                  26|
|nodes — max                            |                 536|
|edges — tot                            |                1428|
|edges — mean                           |               476.0|
|edges — median                         |               314.0|
|edges — min                            |                 190|
|edges — max                            |                 924|
|number of labels                       |                   2|
|   - label2                            |          2 (6




# `GCCDataset`

In [13]:
pretext_keys = ["ceo_club", "bison", "copenhagen/calls"]
fake_labels = ["label2", "label1", "label2"]

In [14]:
test_gcc = GCCDataset(dgl_graphs=graphtoolkeys2dglgraphs(pretext_keys)
                      , graphs_labels=fake_labels
                      , verbosity=True
                     )

🦦 converting 3 graphs from graph-tool library to DGL format.


100%|███████████████████████████████████████████| 3/3 [00:00<00:00, 1258.92it/s]

┌------------------------------------------------------------┐
|                         GCCDataset                         |
├------------------------------------------------------------┤
|number of graphs                       |                   3|
|nodes — tot                            |                 602|
|nodes — mean                           |  200.66666666666666|
|nodes — median                         |                40.0|
|nodes — min                            |                  26|
|nodes — max                            |                 536|
|edges — tot                            |                1428|
|edges — mean                           |               476.0|
|edges — median                         |               314.0|
|edges — min                            |                 190|
|edges — max                            |                 924|
|number of labels                       |                   2|
|   - label2                            |          2 (6




In [15]:
test_gcc.display_statistics()

┌------------------------------------------------------------┐
|                         GCCDataset                         |
├------------------------------------------------------------┤
|number of graphs                       |                   3|
|nodes — tot                            |                 602|
|nodes — mean                           |  200.66666666666666|
|nodes — median                         |                40.0|
|nodes — min                            |                  26|
|nodes — max                            |                 536|
|edges — tot                            |                1428|
|edges — mean                           |               476.0|
|edges — median                         |               314.0|
|edges — min                            |                 190|
|edges — max                            |                 924|
|number of labels                       |                   2|
|   - label2                            |          2 (6

## GCC dataloader

In [16]:
torch.utils.data.DataLoader(test_gcc)

<torch.utils.data.dataloader.DataLoader at 0x17e05b5b0>

In [17]:
def batcher():
    def batcher_dev(batch):
        graph_q, graph_k = zip(*batch)
        graph_q, graph_k = dgl.batch(graph_q), dgl.batch(graph_k)
        return graph_q, graph_k

    return batcher_dev

def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    dataset.graphs, _ = dgl.data.utils.load_graphs(
        dataset.dgl_graphs_file, dataset.jobs[worker_id]
    )
    dataset.length = sum([g.number_of_nodes() for g in dataset.graphs])
    np.random.seed(worker_info.seed % (2 ** 32))

In [166]:
train_loader = torch.utils.data.DataLoader(
    dataset=test_gcc,#gtdata,
    batch_size=16,
    collate_fn=labeled_batcher() if False else batcher(),
    shuffle=True if False else False,
    num_workers=0,
    worker_init_fn=None if False or False else worker_init_fn,
)

In [167]:
gcc_graph_loader = GraphDataLoader(
     dataset=test_gcc,#gtdata,
    batch_size=16,
    collate_fn=labeled_batcher() if False else batcher(),
    shuffle=True if False else False,
    num_workers=0,
    worker_init_fn=None if False or False else worker_init_fn,
)

In [168]:
for idx, batch in enumerate(train_loader):
    graph_q, graph_k = batch

In [171]:
idx, batch

(0,
 (Graph(num_nodes=602, num_edges=1428,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'pos_undirected': Scheme(shape=(32,), dtype=torch.float32), 'seed': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}),
  Graph(num_nodes=602, num_edges=1428,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'pos_undirected': Scheme(shape=(32,), dtype=torch.float32), 'seed': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})))

In [172]:
for idx, batch in enumerate(gcc_graph_loader):
    graph_q, graph_k = batch
    
idx, batch

(0,
 (Graph(num_nodes=602, num_edges=1428,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'pos_undirected': Scheme(shape=(32,), dtype=torch.float32), 'seed': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}),
  Graph(num_nodes=602, num_edges=1428,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'pos_undirected': Scheme(shape=(32,), dtype=torch.float32), 'seed': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})))

# `GraphMAE` dataset

In [173]:
pretext_keys = ["ceo_club", "bison", "copenhagen/calls"]
fake_labels = ["label2", "label1", "label2"]

In [174]:
test_graphmae = GraphMAEDataset(dgl_graphs=graphtoolkeys2dglgraphs(pretext_keys)
                                , graphs_labels=fake_labels
                                , verbosity=True
                               )

🦦 converting 3 graphs from graph-tool library to DGL format.


100%|████████████████████████████████████████████| 3/3 [00:00<00:00, 225.95it/s]

┌-----------------------------------------------------------┐
|                      GraphMAEDataset                      |
├-----------------------------------------------------------┤
|number of graphs                       |                  3|
|nodes — tot                            |                602|
|nodes — mean                           | 200.66666666666666|
|nodes — median                         |               40.0|
|nodes — min                            |                 26|
|nodes — max                            |                536|
|edges — tot                            |               1428|
|edges — mean                           |              476.0|
|edges — median                         |              314.0|
|edges — min                            |                190|
|edges — max                            |                924|
|number of labels                       |                  2|
|   - label2                            |         2 (66.7 %)|
|   - la




In [179]:
test_graphmae.graphs[0]

Graph(num_nodes=40, num_edges=190,
      ndata_schemes={'attr': Scheme(shape=(22,), dtype=torch.float32)}
      edata_schemes={})

In [180]:
attrs = vars(test_graphmae)
attrs.keys()

dict_keys(['name', 'graphs', 'num_labels', 'graphs_labels', 'degrees', 'feature_dim', 'oversize'])

## GraphMAE dataloader

In [181]:
from torch.utils.data.sampler import SubsetRandomSampler

In [178]:
def collate_fn(batch):
    # graphs = [x[0].add_self_loop() for x in batch]
    graphs = [x[0] for x in batch]
    labels = [x[1] for x in batch]
    batch_g = dgl.batch(graphs)
    labels = torch.cat(labels, dim=0)
    return batch_g, labels

In [145]:
train_idx = torch.arange(len(test_graphmae))
train_sampler = SubsetRandomSampler(train_idx)

In [146]:
batch_size = 32

In [161]:
from dgl.dataloading.dataloader import GraphCollator

In [163]:
test_train_loader = torch.utils.data.DataLoader(
    test_graphmae
    , sampler=train_sampler
    , collate_fn = GraphCollator().collate#, collate_fn=collate_fn
    , batch_size=batch_size
    , pin_memory=True
)

mae_train_loader = GraphDataLoader(mae_data#test_graphmae
                               , sampler=SubsetRandomSampler(torch.arange(len(mae_data)))
                               , collate_fn=collate_fn
                               , batch_size=batch_size
                               , pin_memory=True
                              )

In [164]:
for idx, batch in enumerate(mae_train_loader):
    #print(batch)
    batch = batch
    
batch, (idx+1)*batch_size

((Graph(num_nodes=8789, num_edges=50353,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(401,), dtype=torch.float32)}
        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}),
  tensor([1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0])),
 2016)

In [165]:
for idx, batch in enumerate(test_train_loader):
    #print(batch)
    batch = batch
    
batch, (idx+1)*batch_size

([Graph(num_nodes=602, num_edges=1428,
        ndata_schemes={'attr': Scheme(shape=(22,), dtype=torch.float32)}
        edata_schemes={}),
  ('label2', 'label1', 'label2')],
 32)

# 🧪 Integration experiments

## GCC

### Using pretrained models

In [186]:
import sys  
sys.path.insert(0, '../../dist_sim/')

from gcc.generate import *

In [187]:
PATH_MODEL = "../../dist_sim/gcc/pretrained/Pretrain_moco_True_dgl_gin_layer_5_lr_0.005_decay_1e-05_bsz_32_hid_64_samples_2000_nce_t_0.07_nce_k_16384_rw_hops_256_restart_prob_0.8_aug_1st_ft_False_deg_16_pos_32_momentum_0.999/current.pth"

In [188]:
checkpoint = torch.load(PATH_MODEL, map_location="cpu")
args = checkpoint["opt"]

# create model and optimizer
model = GraphEncoder(
    positional_embedding_size=args.positional_embedding_size,
    max_node_freq=args.max_node_freq,
    max_edge_freq=args.max_edge_freq,
    max_degree=args.max_degree,
    freq_embedding_size=args.freq_embedding_size,
    degree_embedding_size=args.degree_embedding_size,
    output_dim=args.hidden_size,
    node_hidden_dim=args.hidden_size,
    edge_hidden_dim=args.hidden_size,
    num_layers=args.num_layer,
    num_step_set2set=args.set2set_iter,
    num_layer_set2set=args.set2set_lstm_layer,
    gnn_model=args.model,
    norm=args.norm,
    degree_input=True,
)


model.load_state_dict(checkpoint["model"])

del checkpoint

In [191]:
parser = argparse.ArgumentParser("argument for training")
# fmt: off
parser.add_argument("--load-path", type=str, help="path to load model")
parser.add_argument("--dataset", type=str
                    , default="rdt-b"
                    , choices=["dgl", "wikipedia", "blogcatalog", "usa_airport", "brazil_airport", "europe_airport", "cora", "citeseer", "pubmed", "kdd", "icdm", "sigir", "cikm", "sigmod", "icde", "h-index-rand-1", "h-index-top-1", "h-index"] + GRAPH_CLASSIFICATION_DSETS)
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")

args_test = parser.parse_args(args=[])

args.gpu = 0
args.num_workers = 0
args.num_copies = 1
args.gpu, args.num_workers, args.num_copies

(0, 0, 1)

In [190]:
gcc_graph_loader = GraphDataLoader(
     dataset=test_gcc,#gtdata,
    batch_size=16,
    collate_fn=labeled_batcher() if False else batcher(),
    shuffle=True if False else False,
    num_workers=0,
    worker_init_fn=None if False or False else worker_init_fn,
)

In [196]:
%%time 
emb = test_moco(gcc_graph_loader, model, args)

1it [00:00,  4.29it/s]

CPU times: user 1.04 s, sys: 434 ms, total: 1.47 s
Wall time: 235 ms





In [197]:
emb

tensor([[-1.4525e-02,  1.0907e-01, -3.2119e-03, -1.2183e-01,  9.7231e-02,
         -6.3740e-02,  8.0004e-02, -2.0031e-01,  1.5975e-01, -2.3301e-01,
          8.9989e-02, -1.4166e-01, -3.2889e-02, -1.0339e-01, -3.1904e-01,
         -9.6499e-02,  1.3998e-01, -4.1098e-02, -3.7734e-02, -2.4100e-01,
          2.8044e-01, -1.3764e-02, -8.6480e-02,  3.2001e-03, -1.9461e-01,
         -1.0765e-01,  2.9378e-02, -1.4074e-01,  2.7304e-01,  7.3826e-02,
          1.7877e-01,  1.3722e-02,  6.5938e-02, -7.9772e-02,  1.5075e-02,
         -1.0509e-01, -9.4612e-02,  1.0878e-01,  1.3448e-01,  1.5198e-01,
         -3.0072e-02, -1.0517e-01,  1.4195e-01,  2.3302e-01,  3.2797e-02,
          9.0370e-02, -7.7570e-02, -5.1710e-02, -4.8910e-02, -1.6310e-02,
          9.4236e-03, -3.9813e-02, -3.9938e-02,  1.1781e-03, -3.1850e-02,
          2.7156e-01,  3.3454e-02,  1.4332e-01, -8.5373e-02,  1.1330e-01,
         -6.1122e-02,  6.7042e-02,  5.3721e-02,  1.0419e-01],
        [-1.3897e-01, -1.9041e-01,  1.8808e-02, -1

## GraphMAE

In [137]:
def pretrain(model, pooler, dataloaders, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob=True, logger=None):
    train_loader, eval_loader = dataloaders

    epoch_iter = tqdm(range(max_epoch))
    for epoch in epoch_iter:
        model.train()
        loss_list = []
        for batch in train_loader:
            batch_g, _ = batch
            batch_g = batch_g.to(device)

            feat = batch_g.ndata["attr"]
            model.train()
            loss, loss_dict = model(batch_g, feat)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_list.append(loss.item())
            if logger is not None:
                loss_dict["lr"] = get_current_lr(optimizer)
                logger.note(loss_dict, step=epoch)
        if scheduler is not None:
            scheduler.step()
        epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}")

    return model

In [138]:
# simulate the instantiation of GraphMAEDataste from a list of graphs.

tudata = TUDataset("REDDIT-BINARY")


tu_graphs = [tu[0] for tu in tudata]
tu_labs = [tu[1][0] for tu in tudata]

In [139]:
graphmae = GraphMAEDataset(dgl_graphs=tu_graphs
                           , graphs_labels=tu_labs
                           , verbosity=True
                          )

┌-----------------------------------------------------------┐
|                      GraphMAEDataset                      |
├-----------------------------------------------------------┤
|number of graphs                       |               2000|
|nodes — tot                            |             859254|
|nodes — mean                           |            429.627|
|nodes — median                         |              304.5|
|nodes — min                            |                  6|
|nodes — max                            |               3782|
|edges — tot                            |            3982032|
|edges — mean                           |           1991.016|
|edges — median                         |             1516.0|
|edges — min                            |                 16|
|edges — max                            |              16284|
|number of labels                       |               2000|
|   - 1                                 |        1000 (50 %)|
|   - 0 

In [140]:
batch_size=32

train_idx = torch.arange(len(graphmae))
train_sampler = SubsetRandomSampler(train_idx)

test_train_loader = GraphDataLoader(graphmae
                                    #, sampler=train_sampler
                                    #, collate_fn=collate_fn
                                    , batch_size=batch_size
                                    , pin_memory=True
                                   )

In [141]:
for idx, batch in enumerate(test_train_loader):
    #print(batch)
    batch = batch
    
batch, (idx+1)*batch_size

([Graph(num_nodes=1600, num_edges=7536,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(401,), dtype=torch.float32)}
        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])],
 2016)