In [1]:
import time

import torch
import torch.nn.functional as F
import numpy as np
from dgl import DGLGraph


from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from models import GCN
from datasets import Cora, CiteseerM10, Dblp
from text_transformers import TFIDF


Using backend: pytorch


In [2]:
def get_masks(n, main_ids, main_labels, test_ratio, val_ratio, seed=1):
    train_mask = np.zeros(n)
    val_mask = np.zeros(n)
    test_mask = np.zeros(n)

    x_dev, x_test, y_dev, y_test = train_test_split(main_ids,
                                                    main_labels,
                                                    stratify=main_labels,
                                                    test_size=test_ratio,
                                                    random_state=seed)

    x_train, x_val, y_train, y_val = train_test_split(x_dev,
                                                      y_dev,
                                                      stratify=y_dev,
                                                      test_size=val_ratio,
                                                      random_state=seed)

    train_mask[x_train] = 1
    val_mask[x_val] = 1
    test_mask[x_test] = 1

    return train_mask, val_mask, test_mask


def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, predicted = torch.max(logits, dim=1)
        f1 = f1_score(labels, predicted, average='micro')
        return f1


def train_gcn(dataset,
              test_ratio=0.5,
              val_ratio=0.2,
              seed=1,
              n_hidden=16,
              n_epochs=200,
              lr=1e-2,
              weight_decay=5e-4,
              dropout=0.5,
              verbose=True):
    data = dataset.get_data()
    features = torch.FloatTensor(data['features'])
    labels = torch.LongTensor(data['labels'])
    n = len(data['ids'])
    train_mask, val_mask, test_mask = get_masks(n,
                                                data['main_ids'],
                                                data['main_labels'],
                                                test_ratio=test_ratio,
                                                val_ratio=val_ratio,
                                                seed=seed)

    train_mask = torch.BoolTensor(train_mask)
    val_mask = torch.BoolTensor(val_mask)
    test_mask = torch.BoolTensor(test_mask)

    g = DGLGraph(data['graph'])
    n_edges = g.number_of_edges()

    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0

    g.ndata['norm'] = norm.unsqueeze(1)

    in_feats = features.shape[1]
    # + 1 for unknown class
    n_classes = data['n_classes'] + 1
    model = GCN(g,
                in_feats=in_feats,
                n_hidden=n_hidden,
                n_classes=n_classes,
                activation=F.relu,
                dropout=dropout)

    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)

    # initialize graph
    dur = []
    for epoch in range(n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        f1 = evaluate(model, features, labels, val_mask)
        if verbose:
            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | F1 {:.4f} | "
                  "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                                                f1, n_edges / np.mean(dur) / 1000))

    f1 = evaluate(model, features, labels, test_mask)

    if verbose:
        print()
        print("Test F1 {:.2}".format(f1))

    return f1

In [3]:
dataset = Cora()
transformer = TFIDF()
dataset.transform_features(transformer)
train_gcn(dataset, verbose=True)

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Epoch 00000 | Time(s) nan | Loss 2.0774 | F1 0.3727 | ETputs(KTEPS) nan
Epoch 00001 | Time(s) nan | Loss 2.0267 | F1 0.3506 | ETputs(KTEPS) nan
Epoch 00002 | Time(s) nan | Loss 1.9662 | F1 0.3469 | ETputs(KTEPS) nan
Epoch 00003 | Time(s) 0.0401 | Loss 1.9052 | F1 0.3395 | ETputs(KTEPS) 262.93
Epoch 00004 | Time(s) 0.0369 | Loss 1.8542 | F1 0.3395 | ETputs(KTEPS) 285.84
Epoch 00005 | Time(s) 0.0364 | Loss 1.7922 | F1 0.3395 | ETputs(KTEPS) 289.92
Epoch 00006 | Time(s) 0.0348 | Loss 1.7406 | F1 0.3469 | ETputs(KTEPS) 303.29
Epoch 00007 | Time(s) 0.0336 | Loss 1.6714 | F1 0.3469 | ETputs(KTEPS) 314.02
Epoch 00008 | Time(s) 0.0328 | Loss 1.6306 | F1 0.3506 | ETputs(KTEPS) 321.85
Epoch 00009 | Time(s) 0.0329 | Loss 1.5674 | F1 0.3579 | ETputs(KTEPS) 321.00
Epoch 00010 | Time(s) 0.0339 | Loss 1.5402 | F1 0.3616 | ETputs(KTEPS) 311.32
Epoch 00011 | Time(s) 0.0340 | Loss 1.5083 | F1 0.4022 | ETputs(KTEPS) 310.64
Epoch 00012 | Time(s) 0.0344 | Loss 1.4579 | F1 0.4539 | ETputs(KTEPS) 306.60
Epoc

Epoch 00108 | Time(s) 0.0300 | Loss 0.3468 | F1 0.8450 | ETputs(KTEPS) 352.27
Epoch 00109 | Time(s) 0.0299 | Loss 0.3532 | F1 0.8376 | ETputs(KTEPS) 352.60
Epoch 00110 | Time(s) 0.0300 | Loss 0.3306 | F1 0.8376 | ETputs(KTEPS) 352.33
Epoch 00111 | Time(s) 0.0300 | Loss 0.3184 | F1 0.8376 | ETputs(KTEPS) 351.70
Epoch 00112 | Time(s) 0.0300 | Loss 0.3076 | F1 0.8376 | ETputs(KTEPS) 351.30
Epoch 00113 | Time(s) 0.0301 | Loss 0.3142 | F1 0.8376 | ETputs(KTEPS) 350.13
Epoch 00114 | Time(s) 0.0301 | Loss 0.3289 | F1 0.8376 | ETputs(KTEPS) 350.16
Epoch 00115 | Time(s) 0.0301 | Loss 0.3236 | F1 0.8413 | ETputs(KTEPS) 350.19
Epoch 00116 | Time(s) 0.0303 | Loss 0.3301 | F1 0.8376 | ETputs(KTEPS) 348.83
Epoch 00117 | Time(s) 0.0303 | Loss 0.3250 | F1 0.8339 | ETputs(KTEPS) 348.45
Epoch 00118 | Time(s) 0.0303 | Loss 0.3270 | F1 0.8376 | ETputs(KTEPS) 348.68
Epoch 00119 | Time(s) 0.0303 | Loss 0.3038 | F1 0.8413 | ETputs(KTEPS) 348.72
Epoch 00120 | Time(s) 0.0303 | Loss 0.3096 | F1 0.8413 | ETputs(

0.8692762186115215