In [1]:
#!pip install dgl dglgo -f https://data.dgl.ai/wheels/repo.html
import dgl
import torch as th
import torch
import numpy as np
from dgl import save_graphs, load_graphs
import torch.nn as nn
import dgl.nn as dglnn
import torch
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
import dgl.function as fn
#from torchmetrics.classification import BinaryAUROC
import torch.nn as nn
from dgl.dataloading.negative_sampler import _BaseNegativeSampler
from dgl import backend as b


In [2]:
glist, label_dict = load_graphs("./graphs/hetero_graphs_primary_word2vec_w_zeros.bin")
train_hetero_graph = glist[0]
val_hetero_graph = glist[1]
test_hetero_graph = glist[2]

In [3]:
class PerSourceUniformCustom(_BaseNegativeSampler):

    def __init__(self, k):
        self.k = k

    def _generate(self, g, eids, canonical_etype):
        unique_authors = torch.unique(g.edges(etype = "authored")[1])
        #print(len(unique_authors))
        _, _, vtype = canonical_etype
        shape = b.shape(eids)
        dtype = b.dtype(eids)
        ctx = b.context(eids)
        shape = (shape[0] * self.k,)
        src, _ = g.find_edges(eids, etype=canonical_etype)
        src = b.repeat(src, self.k, 0)
        dst_indexes = th.randint(0, len(unique_authors), shape, dtype=dtype, device=ctx)
        dst = unique_authors[dst_indexes]
        return src, dst

In [4]:
def construct_negative_graph(graph, k, etype):
    utype, edge_type, vtype = etype
    src, dst = graph.edges(etype=etype)
    eids = graph.edge_ids(src, dst, etype=edge_type)
    #eids = torch.unique(train_hetero_graph.edges(etype="authored")[0])
    neg_sampler = PerSourceUniformCustom(k)
    #neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(k)
    neg_src, neg_dst = neg_sampler(graph, {edge_type: eids})[etype]
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

In [80]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, num_classes_papers, num_classes_authors):

        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            "writes": dglnn.GraphConv(2731, 512),
        }, aggregate="stack")

        self.conv2 = dglnn.HeteroGraphConv({
            "cites": dglnn.GraphConv(2731, 512),
        })
        self.dropout = nn.Dropout(p=0.2)
        self.linear = nn.Linear(512, 2731)

    def forward(self, graph, inputs):
        classes_paper = inputs["paper"]
        classes_author = inputs["author"]

        #h_writes = self.conv1(graph["writes"], {"paper": classes_paper, "author": classes_author})["paper"]
        h_cites = self.conv2(graph["cites"], {"paper": classes_paper, "author": classes_author})["paper"]
        #h_cites = self.dropout(h_cites)
        h_cites = F.relu(h_cites)

        #h_cites = torch.cat([classes_paper, h_cites], dim=1)
        """
        h_cites = self.conv2(graph["cites"], {"paper": classes_paper, "author": classes_author})["paper"].flatten(1)
        h_cites = self.dropout(h_cites)
        h_cites = F.relu(h_cites)
        h_cites = self.linear(h_cites)
        """
        h_cites = self.linear(h_cites)
        return {"paper": h_cites, "author": classes_author}
    
class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']
        
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names, num_classes_paper, num_classes_author):
        super().__init__()
        # Encoder
        self.sage = RGCN(in_features, hidden_features, out_features, num_classes_paper, num_classes_author)
        # Decoder
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return h

    def scores(self, g, neg_g, x, etype):
      h = self(g, neg_g, x, etype)
      return self.pred(g, h, etype), self.pred(neg_g, h, etype)

In [81]:
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).squeeze(1).numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    ).numpy()
    return roc_auc_score(labels, scores)

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

def compute_loss_logits(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).to("cuda")
    
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    ).to(device)
    return F.binary_cross_entropy_with_logits(scores.squeeze(1), labels)

def accuracy(logits, graph):
  with torch.no_grad():
    all_papers = torch.unique(graph.edges(etype="authored")[0])
    src, dst = graph.edges(etype="authored")
    unique_authors = torch.unique(dst)
    tst = 0
    author_logits = logits["author"][unique_authors]

    for idx, index_paper in enumerate(all_papers):
      #if idx % 25000 == 0:
      #  print(f"{idx}/{len(all_papers)}")
      current_logits = logits["paper"][index_paper]
      dot_product_all = torch.sum(current_logits * author_logits, dim=-1)

      #not_in_authors = torch.where(torch.logical_not(torch.isin(torch.arange(len(dot_product_all.to(device))).to(device), unique_authors.to(device))))
      #dot_product_all[not_in_authors] = -10

      max = torch.argmax(dot_product_all)
      max = unique_authors[max].item()
      filter_acc = src == index_paper
      if max in dst[filter_acc]:
        tst += 1
    return tst/len(all_papers)

In [82]:
# Load data into GPU memory
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [83]:
# Sum Pooling Genders
train_hetero_graph.nodes["gender"].data["h"] = train_hetero_graph.nodes('gender').type(torch.float32)
train_hetero_graph.nodes["author"].data["h"] = th.zeros(len(train_hetero_graph.nodes("author")))
train_hetero_graph["gendered"].update_all(fn.copy_u('h', 'm'), fn.max('m', 'h'))
author_genders = train_hetero_graph.nodes["author"].data["h"].view(-1,1)

In [84]:
# Sum Pooling Countries
country_feats = F.one_hot(train_hetero_graph.nodes('country')).type(torch.float32)
train_hetero_graph.nodes["country"].data["h"] = country_feats
train_hetero_graph.nodes["affiliation"].data["h"] = th.zeros((len(train_hetero_graph.nodes("affiliation")), country_feats.shape[0]))
train_hetero_graph["contains"].update_all(fn.copy_u('h', 'm'), fn.max('m', 'h'))
affiliation_countries = train_hetero_graph.nodes["affiliation"].data["h"]

In [85]:
# Sum Pooling Affiliations
affiliation_feats = F.one_hot(train_hetero_graph.nodes('affiliation')).type(torch.float32)
train_hetero_graph.nodes["affiliation"].data["h"] = torch.concat([affiliation_feats, affiliation_countries], dim=1)
train_hetero_graph.nodes["author"].data["h"] = th.zeros((len(train_hetero_graph.nodes("author")), country_feats.shape[0]))
train_hetero_graph["affiliated"].update_all(fn.copy_u('h', 'm'), fn.max('m', 'h'))
author_affiliation_country = train_hetero_graph.nodes["author"].data["h"]

In [93]:
author_feats = train_hetero_graph.nodes['author'].data['feature'].type(torch.float32)
author_feats = torch.concat([author_feats, author_genders, author_affiliation_country], dim=1)

In [94]:
#author_feats = train_hetero_graph.nodes['author'].data['feature'].type(torch.float32)

In [95]:
#author_feats = train_hetero_graph.nodes['author'].data['feature']
train_hetero_graph.nodes["author"].data["h"] = author_feats
train_hetero_graph.nodes["paper"].data["h"] = th.zeros((len(train_hetero_graph.nodes['paper'].data['feature']), author_feats.shape[0]))
train_hetero_graph["writes"].update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
paper_feats = train_hetero_graph.nodes["paper"].data["h"].type(torch.float32)

In [96]:
#paper_feats = train_hetero_graph.nodes['paper'].data['feature']

In [97]:
"""
user_feats_train = train_hetero_graph.nodes['author'].data['feature']
item_feats_train = train_hetero_graph.nodes['paper'].data['feature']
node_features_train = {'author': user_feats_train.to(device), 'paper': item_feats_train.to(device)}
"""
#author_feats = train_hetero_graph.nodes['author'].data['feature'].type(torch.float32)
#author_feats = torch.concat([author_feats, author_genders, author_affiliation_country], dim=1)
#paper_feats = train_hetero_graph.nodes['paper'].data['feature']

node_features = {'paper': paper_feats.to(device), "author": author_feats.type(torch.float32).to(device)}

In [98]:
node_features["paper"].shape

torch.Size([36099, 2731])

In [99]:
k = 8
loss_training_epoch = []
loss_validation_epoch = []

auc_training_epoch = []
auc_validation_epoch = []

acc_training = []
acc_validation = []

pred = HeteroDotProductPredictor()

In [100]:
model = Model(256, 1024, 512, train_hetero_graph.etypes, len(train_hetero_graph.nodes("paper")), len(train_hetero_graph.nodes("author"))).to(device)

opt = torch.optim.Adam(model.parameters())

for epoch in range(5000000):
    negative_graph = construct_negative_graph(train_hetero_graph, k, ('paper', 'authored', 'author'))
    pos_score, neg_score = model.scores(train_hetero_graph.to(device), negative_graph.to(device), node_features, ('paper', 'authored', 'author'))
    loss = compute_loss_logits(pos_score.to(device), neg_score.to(device))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 50 == 0:
        with torch.no_grad():
          logits_train = model(train_hetero_graph.to(device), negative_graph.to(device), node_features, ('paper', 'authored', 'author'))
          acc_train = accuracy(logits_train, train_hetero_graph.to(device))
          auc_train = compute_auc(pos_score.cpu(), neg_score.cpu())
          loss_train = loss.item()

          loss_training_epoch.append(loss_train)
          auc_training_epoch.append(auc_train)

          logits_val = model(val_hetero_graph.to(device), "x", node_features, ('paper', 'authored', 'author'))
          acc_val = accuracy(logits_val, val_hetero_graph.to(device))

          negative_graph = construct_negative_graph(val_hetero_graph, k, ('paper', 'authored', 'author'))
          pos_score_eval, neg_score_eval = model.scores(val_hetero_graph.to(device), negative_graph.to(device), node_features, ('paper', 'authored', 'author'))
          loss_val = compute_loss_logits(pos_score_eval.to(device), neg_score.to(device)).item()
          auc_val = compute_auc(pos_score_eval.cpu(), neg_score_eval.cpu())

          loss_validation_epoch.append(loss_val)
          auc_validation_epoch.append(auc_val)

          acc_training.append(acc_train)
          acc_validation.append(acc_val)

          print(f"EPOCH: {epoch}; Loss: {loss_train}, AUC: {auc_train}, Acc Train: {acc_train}; Loss {loss_val}, AUC: {auc_val} Acc Evaluation: {acc_val}")

    """
    if epoch % 50 == 0:

      with torch.no_grad():
        total_list  = loss_training_epoch + loss_validation_epoch + auc_training_epoch + auc_validation_epoch + acc_training + acc_validation
        with open(f'./drive/MyDrive/Bachelor_thesis/models/author_metrics_{epoch}_final.txt', 'w') as f:
          for item in total_list:
              f.write(str(item) + '\n')
        torch.save(model, f"./drive/MyDrive/Bachelor_thesis/models/author_{epoch}_final.pt")
    """

EPOCH: 0; Loss: 0.7121787667274475, AUC: 0.5082325789097338, Acc Train: 0.0012583892617449664; Loss 0.721087634563446, AUC: 0.5047201388888889 Acc Evaluation: 0.0033333333333333335
EPOCH: 50; Loss: 0.34297725558280945, AUC: 0.6056159975550087, Acc Train: 0.01552013422818792; Loss 0.16357916593551636, AUC: 0.5434965277777777 Acc Evaluation: 0.01
EPOCH: 100; Loss: 0.28469377756118774, AUC: 0.8340620126204901, Acc Train: 0.06501677852348993; Loss 0.13837355375289917, AUC: 0.7309215277777777 Acc Evaluation: 0.02666666666666667
EPOCH: 150; Loss: 0.21881690621376038, AUC: 0.9288082594463651, Acc Train: 0.1950503355704698; Loss 0.11247862875461578, AUC: 0.8153958333333334 Acc Evaluation: 0.14666666666666667
EPOCH: 200; Loss: 0.16750603914260864, AUC: 0.9665566382496847, Acc Train: 0.30956375838926176; Loss 0.09374386072158813, AUC: 0.8464368055555554 Acc Evaluation: 0.20666666666666667
EPOCH: 250; Loss: 0.13086068630218506, AUC: 0.9826726467135489, Acc Train: 0.39093959731543626; Loss 0.08320

KeyboardInterrupt: 

In [None]:
 torch.no_grad():
  torch.save(model, f"./drive/MyDrive/Bachelor_thesis/models/author.pt")