In [None]:
import torch
import dgl
from sklearn.metrics import f1_score
from model_hetero import HAN
from utils import load_data, EarlyStopping

In [None]:
def score(logits, labels):
    _, indices = torch.max(logits, dim=1)
    prediction = indices.long().cpu().numpy()
    labels = labels.cpu().numpy()

    accuracy = (prediction == labels).sum() / len(prediction)
    micro_f1 = f1_score(labels, prediction, average='micro')
    macro_f1 = f1_score(labels, prediction, average='macro')

    return accuracy, micro_f1, macro_f1

In [None]:
def evaluate(model, g, features, labels, mask, loss_func):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
    loss = loss_func(logits[mask], labels[mask])
    accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])

    return loss, accuracy, micro_f1, macro_f1

In [None]:
def main(args):
    # If args['hetero'] is True, g would be a heterogeneous graph.
    # Otherwise, it will be a list of homogeneous graphs.
    g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
    val_mask, test_mask = load_data(args['dataset'])

    print('args dataset',args['dataset'])
    if hasattr(torch, 'BoolTensor'):
        train_mask = train_mask.bool()
        val_mask = val_mask.bool()
        test_mask = test_mask.bool()
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([
    {('author', 'ai', 'institution'): 3,
    ('institution', 'ia', 'author'): 3,
    ('author', 'ap', 'paper'): 3,
    ('paper', 'pa', 'author'): 3,
    ('paper', 'pP', 'paper'): 3,
    ('paper', 'Pp', 'paper'): 3,
    ('paper', 'pf', 'field_of_study'): 3,
    ('field_of_study', 'fp', 'paper'): 3}] * 3)
    collator = dgl.dataloading.NodeCollator(g, train_idx, sampler)
    dataloader = torch.utils.data.DataLoader(
    collator.dataset, collate_fn=collator.collate,
    batch_size=1024, shuffle=True, drop_last=False, num_workers=4)   

    """features = features.to(args['device'])
    labels = labels.to(args['device'])
    train_mask = train_mask.to(args['device'])
    val_mask = val_mask.to(args['device'])
    test_mask = test_mask.to(args['device'])"""

    #meta_paths=[['pa', 'ap'], ['pf', 'fp']]
    model = HAN(meta_paths=[['pa', 'ap'], ['pf', 'fp'], ['ai', 'ia'], ['pP', 'Pp']],
                    in_size=features.shape[1],
                    hidden_size=args['hidden_units'],
                    out_size=num_classes,
                    num_heads=args['num_heads'],
                    dropout=args['dropout']).to(args['device'])
    g = g.to(args['device'])
   
    stopper = EarlyStopping(patience=args['patience'])
    loss_fcn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
                                 weight_decay=args['weight_decay'])

    for epoch in range(args['num_epochs']):
        i=0
        for input_nodes, output_nodes, blocks in dataloader:
            #blocks = [b.to(torch.device('cpu')) for b in blocks]
            #print(blocks[-1])
            input_features = blocks[0].srcdata['features']
            output_labels = blocks[-1].dstdata['labels']
            #print(blocks)
            model.train()
            output_predictions = model(blocks, input_features)
            loss = loss_fcn(output_labels, output_predictions)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_acc, train_micro_f1, train_macro_f1 = score(output_labels, output_predictions)
            #val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, blocks, input_features, output_labels, val_mask, loss_fcn)
            #early_stop = stopper.step(val_loss.data.item(), val_acc, model)

            print('Epoch {:d} | Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '
              .format(
            epoch + 1, i+1, loss.item(), train_micro_f1, train_macro_f1))

        #if early_stop:
            #break

    stopper.load_checkpoint(model)
    test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)
    print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
        test_loss.item(), test_micro_f1, test_macro_f1))


In [1]:
if __name__ == '__main__':
    import argparse

    from utils import setup

    parser = argparse.ArgumentParser('HAN')
    parser.add_argument('-s', '--seed', type=int, default=1,
                        help='Random seed')
    parser.add_argument('-ld', '--log-dir', type=str, default='results',
                        help='Dir for saving training results')
    parser.add_argument('--acmraw', action='store_true',
                        help='Use metapath coalescing with DGL\'s own dataset')
    parser.add_argument('--mag', action='store_true',
                        help='Use metapath coalescing with DGL\'s own dataset')
    args = parser.parse_args().__dict__

    args = setup(args)

    main(args)

ModuleNotFoundError: No module named 'utils'