In [1]:
import os
from tqdm.auto import tqdm
import re
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from parameters import parse_args
from graph_utils import get_graph_data
from data_generation_utils import get_kfold_lp_data
from transformers import AutoTokenizer

from nltk.tokenize import word_tokenize

import sys; sys.argv=['']; del sys

In [2]:
args = parse_args()
data_dir = args.data_dir
args.graphs_file = os.path.join(data_dir, args.graphs_file)


graph_data = get_graph_data(args.graphs_file)
label_map, super_type_map = graph_data['entities_encoder'], graph_data['super_types_encoder']
inverse_label_map = {v: k for k, v in label_map.items()}
inverse_super_type_map = {v: k for k, v in super_type_map.items()}


Masking graphs:   0%|          | 0/6219 [00:00<?, ?it/s]

Adding node strings to graphs:   0%|          | 0/6219 [00:00<?, ?it/s]

Masking graphs:   0%|          | 0/328 [00:00<?, ?it/s]

Adding node strings to graphs:   0%|          | 0/328 [00:00<?, ?it/s]

Getting node triples:   0%|          | 0/6219 [00:00<?, ?it/s]

Getting node triples:   0%|          | 0/328 [00:00<?, ?it/s]

Sample Train triples [('StateMachine', 'states SMState', ''), ('SMState', 'transitions SMTransition', '')]
Sample Test triples [('Workbench', 'things Thing, thoughts Thoughts, systemView System, functionProperties FunctionProperty', ''), ('RelatedTo', 'fromThing Thing', '')]
Total entities: 51290
Total super types: 7084
Sample Train triples [('StateMachine', 'states SMState', ''), ('SMState', 'transitions SMTransition', '')]
Sample Test triples [('Workbench', 'things Thing, thoughts Thoughts, systemView System, functionProperties FunctionProperty', ''), ('RelatedTo', 'fromThing Thing', '')]
Total train triples: 105784
Total test triples: 6839


In [3]:
label_map, super_type_map = graph_data['entities_encoder'], graph_data['super_types_encoder']
for i, data in enumerate(get_kfold_lp_data(graph_data)):
    break

Train graphs:  5597 Test graphs:  622 Unseen graphs:  328


In [4]:
train, test, unseen = data['train'], data['test'], data['unseen']

In [5]:
from models import UMLGPT

pth = 'models/super_PT_gpt2_s_pre_tok=bert-base-cased/super_PT_gpt2_s_pre_tok=bert-base-cased_best_model.pt'
model = UMLGPT.from_pretrained(pth)

In [6]:
from data_generation_utils import SPECIAL_TOKENS
from trainers import get_tokenizer

args.trainer = 'PT'
args.special_tokens = SPECIAL_TOKENS

tokenizer = get_tokenizer('bert-base-cased', args)

2023-12-18 22:16:30.915287: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Creating pretrained LM tokenizer...
Vocab size:  29008
Done!


In [7]:
train_graph = train[0]

In [8]:
from data_generation_utils import promptize_triple

promptize_node = lambda g, n: promptize_triple((n, g.nodes[n]['references'] if 'references' in g.nodes[n] else '', g.nodes[n]['super_types'] if 'super_types' in g.nodes[n] else ''))

In [9]:
[(e, promptize_node(train_graph, e[0]), promptize_node(train_graph, e[1])) for e in train_graph.edges() if train_graph.edges[e]['masked']]

[(('SMState', 'StateMachine'),
  '<s> <superType>  </superType> <entity> SMState </entity> <relations> transitions SMTransition </relations> </s>',
  '<s> <superType>  </superType> <entity> StateMachine </entity> <relations> states SMState </relations> </s>'),
 (('SMInstance', 'SMState'),
  '<s> <superType>  </superType> <entity> SMInstance </entity> <relations> stateMachine StateMachine target EObject transitionInstances SMTransitionInstance </relations> </s>',
  '<s> <superType>  </superType> <entity> SMState </entity> <relations> transitions SMTransition </relations> </s>')]

In [10]:
print([(e, tokenizer.encode(promptize_node(train_graph, e[0])), tokenizer.encode(promptize_node(train_graph, e[1]))) for e in train_graph.edges() if train_graph.edges[e]['masked']])

[(('SMState', 'StateMachine'), [101, 28998, 29002, 29003, 29004, 25345, 10237, 29005, 29006, 26829, 19293, 1942, 4047, 5053, 2116, 29007, 28999, 102], [101, 28998, 29002, 29003, 29004, 1426, 2107, 19226, 1673, 29005, 29006, 2231, 25345, 10237, 29007, 28999, 102]), (('SMInstance', 'SMState'), [101, 28998, 29002, 29003, 29004, 19293, 2240, 22399, 3923, 29005, 29006, 1352, 2107, 19226, 1673, 1426, 2107, 19226, 1673, 4010, 142, 2346, 24380, 6468, 2240, 22399, 3923, 1116, 19293, 1942, 4047, 5053, 2116, 2240, 22399, 3923, 29007, 28999, 102], [101, 28998, 29002, 29003, 29004, 25345, 10237, 29005, 29006, 26829, 19293, 1942, 4047, 5053, 2116, 29007, 28999, 102])]


In [27]:
from dgl.data import DGLDataset
import dgl
from tqdm.auto import tqdm

from data_generation_utils import get_encoding_size
from models import get_embedding


def get_pos_neg_graphs(nxg, tr=0.2):
    g = dgl.from_networkx(nxg, edge_attrs=['masked'])
    u, v = g.edges()
    test_mask = torch.where(g.edata['masked'])[0]
    train_mask = torch.where(~g.edata['masked'])[0]
    test_size = int(g.number_of_edges() * tr)
    test_pos_u, test_pos_v = u[test_mask], v[test_mask]
    train_pos_u, train_pos_v = u[train_mask], v[train_mask]

    # Find all negative edges and split them for training and testing
    adj = g.adjacency_matrix()
    adj_neg = 1 - adj.to_dense() - np.eye(g.number_of_nodes())
    neg_u, neg_v = np.where(adj_neg != 0)

    neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
    test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
    train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]

    train_g = dgl.remove_edges(g, test_mask)

    train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
    train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())

    test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
    test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())
    
    graphs = {
        'train_pos_g': train_pos_g,
        'train_neg_g': train_neg_g,
        'test_pos_g': test_pos_g,
        'test_neg_g': test_neg_g,
        'train_g': train_g
    }
    return graphs


class LinkPredictionDataset(DGLDataset):
    def __init__(self, graphs, tokenizer, model, test_size=0.2, raw_dir='datasets/LP', save_dir='datasets/LP'):
        self.raw_graphs = graphs
        self.tokenizer = tokenizer
        self.model = model
        self.test_size = test_size
        
        super().__init__(name='link_prediction', raw_dir=raw_dir, save_dir=save_dir)
        """
        Load dataset of graphs if exists, otherwise create it.
        """
        
        
    def __getitem__(self, idx):
        return self.graphs[idx]
    
    def __len__(self):
        return len(self.graphs)
    
    def process(self):
        self.graphs = self._prepare()

    def _prepare(self):
        prepared_graphs = [self._prepare_graph(g) for g in tqdm(self.raw_graphs, desc='Preparing graphs')]
        return prepared_graphs
    
    def _prepare_graph(self, g):

        node_strs = [promptize_node(g, n) for n in g.nodes()]
        max_token_length = get_encoding_size(node_strs, tokenizer)
        node_encodings = self.tokenizer(node_strs, padding=True, truncation=True, max_length=max_token_length, return_tensors='pt')
        node_embeddings = get_embedding(model, node_encodings)
        pos_neg_graphs = get_pos_neg_graphs(g, self.test_size)        
        
        dgl_graph = pos_neg_graphs['train_g']
        dgl_graph.ndata['h'] = node_embeddings

        return pos_neg_graphs

    
    def save(self):
        """Save list of DGLGraphs using DGL save_graphs."""
        print("Saving graphs to cache...")
        keys = ['train_pos_g', 'train_neg_g', 'test_pos_g', 'test_neg_g', 'train_g']
        graphs = {k: [g[k] for g in self.graphs] for k in keys}
        for k, v in graphs.items():
            dgl.save_graphs(os.path.join(self.save_dir, f'{self.name}_{k}.dgl'), v)
    
    
    def load(self):
        """Load list of DGLGraphs using DGL load_graphs."""
        print("Loading graphs from cache...")
        
        keys = ['train_pos_g', 'train_neg_g', 'test_pos_g', 'test_neg_g', 'train_g']
        k_graphs = {k: [] for k in keys}
        for k in keys:
            k_graphs[k] = dgl.load_graphs(os.path.join(self.save_dir, f'{self.name}_{k}.dgl'))[0]
        
        self.graphs = list()
        for i in range(len(k_graphs['train_g'])):
            self.graphs.append({k: v[i] for k, v in k_graphs.items()})
        
        print(f'Loaded {len(self.graphs)} graphs.')

        
    def has_cache(self):
        return os.path.exists(os.path.join(self.save_dir, f'{self.name}_train_g.dgl'))

In [28]:
dataset = LinkPredictionDataset(train[:4], tokenizer, model)

Preparing graphs:   0%|          | 0/4 [00:00<?, ?it/s]

Encoding size:  38
Encoding size:  19
Encoding size:  86
Encoding size:  44
Saving graphs to cache...


In [29]:
from dgl.dataloading import GraphDataLoader

def collate_graphs(graphs):
    collated_graph = {k: list() for k in graphs[0].keys()}
    for g in graphs:
        for k, v in g.items():
            collated_graph[k].append(v)
    
    for k, v in collated_graph.items():
        collated_graph[k] = dgl.batch(v)
    return collated_graph

loader = GraphDataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_graphs)

In [30]:
for batch in loader:
    print(batch)
    break

{'train_pos_g': Graph(num_nodes=22, num_edges=36,
      ndata_schemes={}
      edata_schemes={}), 'train_neg_g': Graph(num_nodes=22, num_edges=36,
      ndata_schemes={}
      edata_schemes={}), 'test_pos_g': Graph(num_nodes=22, num_edges=8,
      ndata_schemes={}
      edata_schemes={}), 'test_neg_g': Graph(num_nodes=22, num_edges=8,
      ndata_schemes={}
      edata_schemes={}), 'train_g': Graph(num_nodes=22, num_edges=36,
      ndata_schemes={'h': Scheme(shape=(128,), dtype=torch.float32)}
      edata_schemes={'masked': Scheme(shape=(), dtype=torch.bool)})}


In [38]:
import itertools
from sklearn.metrics import roc_auc_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
    return torch.nn.BCEWithLogitsLoss()(scores.float(), labels.float())

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).detach().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).detach().numpy()
    return roc_auc_score(labels, scores)



class GNNLinkPredictionTrainer:
    def __init__(self, model, predictor, args) -> None:
        self.model = model
        self.predictor = predictor
        self.model.to(device)
        self.optimizer = torch.optim.Adam(itertools.chain(model.parameters(), predictor.parameters()), lr=args.lr)

        
        self.edge2index = lambda g: torch.stack(list(g.edges())).contiguous()
        self.args = args
        print("GNN Trainer initialized.")

    def train(self, dataloader):
        self.model.train()
        self.predictor.train()

        epoch_loss, epoch_acc = 0, 0
        for batch in dataloader:
            self.optimizer.zero_grad()
            self.model.zero_grad()
            self.predictor.zero_grad()
            
            h = self.get_logits(batch['train_g'])

            pos_score = self.predictor(batch['train_pos_g'], h)
            neg_score = self.predictor(batch['train_neg_g'], h)
            loss = compute_loss(pos_score, neg_score)

            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += compute_auc(pos_score, neg_score)

        epoch_loss /= len(dataloader)
        epoch_acc /= len(dataloader)
        print(f"Epoch Train Loss: {epoch_loss} and Train Accuracy: {epoch_acc}")
        return epoch_loss, epoch_acc
    

    def test(self, dataloader):
        self.model.eval()
        self.predictor.eval()
        with torch.no_grad():
            epoch_loss, epoch_acc = 0, 0
            for batch in dataloader:            
                h = self.get_logits(batch['train_g'])

                pos_score = self.predictor(batch['test_pos_g'], h)
                neg_score = self.predictor(batch['test_neg_g'], h)
                loss = compute_loss(pos_score, neg_score)

                epoch_loss += loss.item()
                epoch_acc += compute_auc(pos_score, neg_score)

            epoch_loss /= len(dataloader)
            epoch_acc /= len(dataloader)
            print(f"Epoch Test Loss: {epoch_loss} and Test Accuracy: {epoch_acc}")
            return epoch_loss, epoch_acc


    def get_logits(self, g):
        edge_index = self.edge2index(g).to(device)
        x = g.ndata['h'].float()
        h = self.model(x, edge_index)
        return h


    def get_prediction(self, h, g):
        edge_index = self.edge2index(g).to(device)
        out = self.predictor(h, edge_index)
        return out


    def run_epochs(self, dataloader, num_epochs):
        max_val_acc, max_train_acc = 0, 0
        outputs = list()
        for epoch in tqdm(range(num_epochs), desc="Epochs"):
        # for epoch in range(num_epochs):
            train_loss, train_acc = self.train(dataloader)
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch} Train Loss: {train_loss}")
            
            test_loss, test_acc = self.test(dataloader)

            if test_acc > max_val_acc:
                max_val_acc = test_acc
                max_train_acc = train_acc
                outputs.append({
                    'epoch': epoch,
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'test_acc': test_acc
                })

            
        
        print(f"Max Test Accuracy: {max_val_acc}")
        print(f"Max Train Accuracy: {max_train_acc}")
        max_output = max(outputs, key=lambda x: x['test_acc'])
        return max_output


In [32]:
from models import GNNModel, MLPPredictor

gnn_model = GNNModel(
    model_name='SAGEConv', 
    input_dim=128, 
    hidden_dim=256, 
    out_dim=256,
    num_layers=2, 
    residual=True,
)

predictor = MLPPredictor(
    h_feats=256,
    num_layers=2,
)

In [33]:
lp_trainer = GNNLinkPredictionTrainer(gnn_model, predictor, args)

GNN Trainer initialized.


In [40]:
lp_trainer.run_epochs(loader, 5)

Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch Train Loss: 0.5127556025981903 and Train Accuracy: 0.894345137293505
Epoch 0 Train Loss: 0.5127556025981903
Epoch Test Loss: 0.5765130072832108 and Test Accuracy: 0.774725
Epoch Train Loss: 0.37719155848026276 and Train Accuracy: 0.9280618372050121
Epoch Test Loss: 0.6184224039316177 and Test Accuracy: 0.7858947681331747
Epoch Train Loss: 0.38555996119976044 and Train Accuracy: 0.9286165557199211
Epoch Test Loss: 0.6855111718177795 and Test Accuracy: 0.75985
Epoch Train Loss: 0.2889115735888481 and Train Accuracy: 0.9468441212120338
Epoch Test Loss: 0.7037681937217712 and Test Accuracy: 0.8326025564803805
Epoch Train Loss: 0.3423616588115692 and Train Accuracy: 0.9424386711045365
Epoch Test Loss: 0.7048476040363312 and Test Accuracy: 0.8343861474435196
Max Test Accuracy: 0.8343861474435196
Max Train Accuracy: 0.9424386711045365


{'epoch': 4,
 'train_loss': 0.3423616588115692,
 'test_loss': 0.7048476040363312,
 'test_acc': 0.8343861474435196}