In [1]:
!pip install torch networkx numpy tqdm sklearn matplotlib wandb

In [None]:
%%bash
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!wandb login

In [4]:
import networkx as nx
import numpy as np
import torch
from torch import nn
from torch.nn.parameter import Parameter
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
import math
import gzip
import copy
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import random

def decode_temporal_graphs(encoded_graphs):
    """
    Decodes a list of graphs from a array.
    """
    graph_seperator = 255
    edge_separator=254
    graphs = []
    current_graph = nx.DiGraph()
    edge_start = None
    edge_end = None
    edge_timestep = None
    for number in encoded_graphs:
        if number == graph_seperator:
            if edge_start is not None or edge_end is not None or edge_timestep is not None:
                raise ValueError("Invalid encoded graph")
            graphs.append(current_graph)
            current_graph = nx.DiGraph()
        else:
            if edge_start is None:
                edge_start = number
                continue
            if edge_end is None:
                edge_end = number
                continue
            if edge_timestep is None:
                edge_timestep = number
                continue
            if number == edge_separator:
                if edge_start is None or edge_end is None or edge_timestep is None:
                    raise ValueError("Invalid encoded graph")
                current_graph.add_edge(edge_start, edge_end, timestep=edge_timestep)
                edge_start = None
                edge_end = None
                edge_timestep = None
                continue
            raise ValueError("Invalid encoded graph")
    return graphs

def load_target_graphs(path):
    _graphs= np.load(path, allow_pickle=True)
    graphs = []
    timesteps = []
    not_connected = 0
    for g in _graphs:
      current_graph = nx.DiGraph()
      current_graph.add_nodes_from(list(range(len(g))))
      for i, row in enumerate(g):
        for j, edge_enc in enumerate(row):
          timestep = np.where(edge_enc == 1)[0]
          if len(timestep) > 0:
            current_graph.add_edge(i, j, timestep=int(timestep[0]))
      if len(current_graph.edges()) > 0:
        graphs.append(current_graph)
      else:
        #not_connected += 1
        # NOTE: quick-fix to handle disconnected graph
        current_graph.add_edge(0, 1, timestep=1)
        graphs.append(current_graph)

    print(f"{not_connected} or {(not_connected/ len(_graphs)) * 100}% of graphs have been filtered as they have no edges" )
    return graphs

def load_target_graphs_dict(path):
    graph_dict = np.load(path, allow_pickle=True)
    keys = list(graph_dict.keys())[:1000]
    _graphs = [graph_dict[key] for key in keys]
    graphs = []
    timesteps = []
    not_connected = 0
    for g in _graphs:
      current_graph = nx.DiGraph()
      current_graph.add_nodes_from(list(range(len(g))))
      for i, row in enumerate(g):
        for j, edge_enc in enumerate(row):
          timestep = np.where(edge_enc == 1)[0]
          if len(timestep) > 0:
            current_graph.add_edge(i, j, timestep=int(timestep[0]))
      if len(current_graph.edges()) > 0:
        graphs.append(current_graph)
      else:
        not_connected += 1
    print(f"{(not_connected/ len(keys)) * 100}% of graphs filtered as they have no edges" )
    return graphs


def load_source_graphs(path):
    encoded_graphs = np.load(path)
    graphs = decode_temporal_graphs(encoded_graphs)
#    for g in graphs:
#      timesteps = nx.get_edge_attributes(g, 'timestep')
#      edges = copy.deepcopy(g.edges())
#      for e in edges:
#        g.add_edge(e[1], e[0], timestep=timesteps[e])
    return graphs

def create_mlp(dims, dropout = None):
    layers = [] if dropout is None else [nn.Dropout(dropout)]
    for i, dim in enumerate(dims[:-1]):
      layers.append(nn.Linear(dim, dims[i+1]))
      if i < len(dims[:-1]) - 1:
        layers.append(nn.ReLU(inplace=True))
    return nn.Sequential(*layers)

def perform_conv(convs, act, emb, edge_index, edge_feats = None):
    for i, conv in enumerate(convs):
      emb = conv(emb, edge_index, edge_feats)
      if i < len(convs) - 1:
        emb = act(emb)
    return emb

class MetricsLogger():
  def __init__(self):
    self.reset()
    self._best_acc = 0
    self._best_acc_epoch = 0

  def log_loss(self, loss):
    self._loss += loss
    self._step += 1
  
  def log_prediction(self, preds, labels, num_classes=None):
    # no classification
    if num_classes is None:
      self._preds += preds.tolist()
      self._labels += labels.tolist() if len(labels) > 0 else []
    # binary classification w/ multi-labels
    elif num_classes > 1:
      self._preds += np.argmax(preds, axis=-1).tolist()
      self._labels += np.argmax(labels, axis=-1).tolist()
    # binary classification
    else:
      self._labels += labels.tolist()
      self._preds += (preds > 0).tolist()
  
  def reset(self):
    self._loss = 0
    self._step = 0
    self._labels = []
    self._embeddings = []
    self._preds = []
    self._labels = []

  def get_embeddings(self):
    return np.array(self._preds)
    
  def get_loss(self):
    loss = math.floor((self._loss / self._step) * 10000) / 10000
    return loss

  def log_best_acc(self, acc, epoch):
    if acc > self._best_acc:
      self._best_acc = acc
      self._best_acc_epoch = epoch


In [None]:
import torch_geometric.data as geom_data
from torch_geometric.nn import TransformerConv, global_mean_pool, GATConv, GatedGraphConv, GINEConv, SplineConv
from torch_geometric.utils import to_dense_adj, from_networkx
from sklearn.metrics import accuracy_score
from torch import optim
import torch
import torch.nn as nn
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
import os
from torch_geometric.data import InMemoryDataset, download_url, DataLoader
from torch_geometric.datasets import TUDataset
import networkx as nx
import torch.nn.functional as F
import wandb

seed = 42
random.seed(seed + 1)
np.random.seed(seed + 2)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

PATH = "/content/drive/MyDrive/graph_matching/"

class TUGraphDataset():
  def __init__(self, name, batch_size):
    print(f"Load {name} dataset")
    dataset = TUDataset(root=f"/tmp/{name.upper()}", name=name.upper(), use_edge_attr=True)
    dataset = dataset.shuffle()
    split = int(len(dataset) * 0.8)
    train_dataset = dataset[:split]
    test_dataset = dataset[split:]

    self.adj_dim = max([data.x.shape[0] for data in dataset])
    pos = sum([data.y[0] for data in dataset])
    self.node_feat_dim = dataset.num_node_features
    self.edge_feat_dim = dataset.num_edge_features
    self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    self.unshuffled_train_loader = DataLoader(train_dataset, batch_size=batch_size)
    self.test_loader = DataLoader(test_dataset, batch_size=batch_size)
    self.num_classes = dataset.num_classes

    print(f"Ratio is {pos/len(dataset)} and max nodes are {self.adj_dim}")

class AnomalyDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        collate, self.split = torch.load(self.processed_paths[0])
        self.data, self.slices = collate

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['processed_anomaly_graphs.pt']

    def download(self):
        pass

    def process(self):
        source_path = PATH + 'apg_unique_subgraphs_temporal_node2-17.npy'
        target_path = PATH + '100k_frontend_graphs_temporal_unique.npy'
        source_graphs = load_source_graphs(source_path)
        target_graphs = load_target_graphs(target_path)
        print(f"Found {len(source_graphs)} source graphs and {len(target_graphs)} target graphs")
        graphs = source_graphs + target_graphs
        data = []
        split = len(source_graphs)
        for graph in tqdm(graphs, 'Transform graphs'):
          # TODO: try to set node label as attributes
          nx.set_node_attributes(graph, [1., 1., 1., 1., 1., 1., 1., 1.], "labels")
          edge_labels = nx.get_edge_attributes(graph, 'timestep')
          new_edge_labels = {}
          for k, v in edge_labels.items():
            new_edge_labels[k] = float(v)
          nx.set_edge_attributes(graph, new_edge_labels, 'timestep')
          d = from_networkx(graph, group_node_attrs=['labels'], group_edge_attrs=['timestep'])
          data.append(d)

        torch.save((self.collate(data), split), self.processed_paths[0])

class TGAnomalyDataset():
    """Graph anomaly dataset."""
    def __init__(self, batch_size):  
        print("Load anomaly dataset")      
        dataset = AnomalyDataset(root=f"/tmp/anomaly")
        split = dataset.split
        test_dataset = dataset[:split]
        train_dataset = dataset[split:]

        self.adj_dim = max([data.x.shape[0] for data in dataset])
        self.node_feat_dim = dataset.num_node_features
        self.edge_feat_dim = dataset.num_edge_features
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.unshuffled_train_loader = DataLoader(train_dataset, batch_size=batch_size)
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size)
        self.num_classes = None
        self.source_graphs = [to_networkx(data) for data in test_dataset]
        self.target_graphs = [to_networkx(data) for data in train_dataset]

class TGDecoder(nn.Module):
    def __init__(self, adj_dim, node_feat_dim, edge_feat_dim, hidden_dims, emb_dim, dp):
        super(TGDecoder, self).__init__()
        self.adj_dim = adj_dim
        self.emb_dim = emb_dim
        self.edge_feat_dim = edge_feat_dim if edge_feat_dim > 0 else None

        self.edge_mlp = create_mlp([adj_dim, adj_dim * int(self.edge_feat_dim / 2), adj_dim *  self.edge_feat_dim] if 
                                     self.edge_feat_dim is not None and self.edge_feat_dim > 2 else [adj_dim, adj_dim, adj_dim])
        self.graph_mlp = create_mlp([emb_dim, *hidden_dims, emb_dim * adj_dim], dp)

    def forward(self, emb):
        batch_size = emb.shape[0]
        emb = self.graph_mlp(emb)
        emb = torch.reshape(emb, (batch_size, self.adj_dim, self.emb_dim))
        adj = torch.bmm(emb, torch.transpose(emb, 1, 2))
        if self.edge_feat_dim is not None:
          adj = self.edge_mlp(adj)
          adj = torch.reshape(adj, (batch_size, self.adj_dim, self.adj_dim, self.edge_feat_dim)) 
        return adj

class TGEncoder(nn.Module):
    def __init__(self, node_feat_dim, edge_feat_dim, conv_dims, emb_dim, attention_heads, linear_dp, conv_dp, gnn_layer):
        super(TGEncoder, self).__init__()
        conv_dims = [node_feat_dim, *conv_dims]
        self.edge_feat_dim = edge_feat_dim if edge_feat_dim > 0 else None
        self.gnn_layer = gnn_layer
        if gnn_layer == "TransformerConv":
          self.convs = nn.ParameterList([TransformerConv(dim, conv_dims[i+1], 
                                            edge_dim=self.edge_feat_dim, heads=attention_heads, 
                                            concat=attention_heads == 1, beta=False, dropout=conv_dp) for i, dim in enumerate(conv_dims[:-1])])
        elif gnn_layer == "GATConv":
          self.convs = nn.ParameterList([GATConv(dim, conv_dims[i+1], 
                                            edge_dim=self.edge_feat_dim, heads=attention_heads, 
                                            concat=attention_heads == 1, dropout=conv_dp) for i, dim in enumerate(conv_dims[:-1])])
        elif gnn_layer == "GINEConv":
          self.convs = nn.ParameterList([GINEConv(nn=nn.Sequential(*[nn.Linear(dim, conv_dims[i+1])]), 
                                            edge_dim=edge_feat_dim if edge_feat_dim > 0 else 1) for i, dim in enumerate(conv_dims[:-1])])
        else:
          raise RuntimeError('Given gnn layer not defined')

        self.act = nn.ReLU()
        self.emb_mlp = create_mlp([conv_dims[-1], emb_dim], linear_dp)

    def forward(self, node_feats, edge_index, batch_idx, edge_feats):
        if self.gnn_layer == "GINEConv" and edge_feats is None:
          edge_feats = torch.tensor(np.ones((edge_index.shape[1], 1))).float()
        emb = perform_conv(self.convs, self.act, node_feats, edge_index, edge_feats)
        emb = global_mean_pool(emb, batch_idx)
        emb = self.emb_mlp(emb)
        return emb

class TGAutoEncoder(nn.Module):
    def __init__(self, adj_dim, node_feat_dim, edge_feat_dim, emb_dim, enc_hidden_dims, dec_hidden_dims, attention_heads, 
                 linear_dp, conv_dp, gnn_layer):
        super(TGAutoEncoder, self).__init__()
        self.encoder = TGEncoder(node_feat_dim, edge_feat_dim, enc_hidden_dims, emb_dim, attention_heads, linear_dp, conv_dp, gnn_layer)
        self.decoder = TGDecoder(adj_dim, node_feat_dim, edge_feat_dim, dec_hidden_dims, emb_dim, linear_dp)

    def forward(self, node_feats, edge_index, batch_idx, edge_feats):
        enc = self.encoder(node_feats, edge_index, batch_idx, edge_feats)
        dec_adj = self.decoder(enc)
        return enc, dec_adj

    def classify(self, train_preds, train_labels, eval_preds, eval_labels):
      train_acc = accuracy_score(train_labels, train_preds)
      test_acc = accuracy_score(eval_labels, eval_preds)
      return math.floor(train_acc * 10000) / 100, math.floor(test_acc * 10000) / 100

    def classify_emb(self, train_embeddings, train_labels, eval_embeddings, eval_labels):
      clf = LogisticRegression(max_iter=2000)
      clf.fit(train_embeddings, train_labels)
      return math.floor(clf.score(train_embeddings, train_labels) * 10000) / 100, math.floor(clf.score(eval_embeddings, eval_labels) * 10000) / 100

def get_dataset(name, batch_size):
  if name == "anomaly":
    return TGAnomalyDataset(batch_size=batch_size)
  if name == "proteins":
    return TUGraphDataset(name="proteins", batch_size=batch_size)
  if name == "enzymes":
    return TUGraphDataset(name="enzymes", batch_size=batch_size)
  if name == "mutag":
    return TUGraphDataset(name="mutag", batch_size=batch_size)
  if name == "bzr":
    return TUGraphDataset(name="bzr_md", batch_size=batch_size)

batch_size = 64
dataset_name = "anomaly"
dataset = get_dataset(dataset_name, batch_size)
criterion = nn.MSELoss()
CLASSIFICATION = "log" # "linear", "log", None
emb_dim = 16
lr = 1e-4
num_epochs = 5
enc_hidden_dims = [128, 128, 128]
dec_hidden_dims = [32, 64, 128]
attention_heads = 1
linear_dp = 0.6
conv_dp = 0.
gnn_layer = "TransformerConv" # "TransformerConv", "GATConv", "GINEConv"
LOG = False

if dataset_name == "anomaly":
  emb_dim = 17
if dataset.num_classes is None:
  CLASSIFICATION = None
if CLASSIFICATION == "linear":
  emb_dim = 1 if dataset.num_classes==2 else dataset.num_classes
  criterion = nn.BCEWithLogitsLoss() if dataset.num_classes==2 else nn.CrossEntropyLoss()

model = TGAutoEncoder(adj_dim=dataset.adj_dim, node_feat_dim=dataset.node_feat_dim, edge_feat_dim=dataset.edge_feat_dim, emb_dim=emb_dim, 
                      enc_hidden_dims=enc_hidden_dims, dec_hidden_dims=dec_hidden_dims, attention_heads=attention_heads, linear_dp=linear_dp, 
                      conv_dp=conv_dp, gnn_layer=gnn_layer)
optimizer = optim.AdamW(model.parameters(), lr=lr)

if LOG:
  wandb.init(project="graph-matching-autoencoder", entity="schti", tags=[dataset_name])
  wandb.config.update({
    "dataset": dataset_name,
    "classification_type": CLASSIFICATION,
    "learning_rate": lr,
    "epochs": num_epochs,
    "batch_size": batch_size,
    "encoder_hidden_dims": enc_hidden_dims,
    "decoder_hidden_dims": dec_hidden_dims,
    "embedding_dim": emb_dim,
    "attention_heads": attention_heads,
    "linear_dropout": linear_dp,
    "conv_dropout": conv_dp,
    "gnn_layer": gnn_layer,
  })
  wandb.run.name = f"run_{dataset_name}_{num_epochs}"
  wandb.define_metric("train_step")
  wandb.define_metric("train_loss", step_metric="train_step")
  wandb.define_metric("eval_step")
  wandb.define_metric("eval_loss", step_metric="eval_step")

print(f"Start training for {num_epochs} epochs")
train_logger = MetricsLogger()
eval_logger = MetricsLogger()
train_step = 0
eval_step = 0
for epoch in range(num_epochs):
  model.train()
  train_logger.reset()
  for data in tqdm(dataset.train_loader, f"Epoch {epoch +1}"):
    node_feats, edge_index, batch_idx, edge_feats, labels = data.x, data.edge_index, data.batch, data.edge_attr, data.y
    optimizer.zero_grad()
    preds, adj = model(node_feats, edge_index, batch_idx, edge_feats)
    o_adj = to_dense_adj(edge_index=edge_index, batch=batch_idx, edge_attr=edge_feats, max_num_nodes=dataset.adj_dim)
    if CLASSIFICATION == "linear":
      if emb_dim == 1:
        preds = preds.squeeze(dim=-1)
      else:
        labels = F.one_hot(labels).float()
      preds = preds.float()
      labels = labels.float()
      preds = preds.squeeze(dim=-1)
      loss = criterion(preds, labels)
    else:
      loss = criterion(adj, o_adj)
    if LOG:
      wandb.log({'train_loss': loss.item(), "train_step": train_step})
    train_logger.log_loss(loss.item())
    train_logger.log_prediction(preds.detach().numpy(), [] if labels is None else labels.detach().numpy(), num_classes=emb_dim if CLASSIFICATION == "linear" else None)
    loss.backward()
    optimizer.step()
    train_step+=1

  model.eval()
  eval_logger.reset()
  for data in dataset.test_loader:
    node_feats, edge_index, batch_idx, edge_feats, labels = data.x, data.edge_index, data.batch, data.edge_attr, data.y
    preds, adj = model(node_feats, edge_index, batch_idx, edge_feats)
    o_adj = to_dense_adj(edge_index=edge_index, batch=batch_idx, edge_attr=edge_feats, max_num_nodes=dataset.adj_dim)
    
    if CLASSIFICATION == "linear":
      if emb_dim == 1:
        labels = labels.float()
      else:
        labels = F.one_hot(labels).float()
      preds = preds.float()
      preds = preds.squeeze(dim=-1)
      loss = criterion(preds, labels)
    else:
      loss = criterion(adj, o_adj)
    if LOG:
      wandb.log({'eval_loss': loss.item(), "eval_step": eval_step})
    eval_logger.log_loss(loss.item())
    eval_logger.log_prediction(preds.detach().numpy(), [] if labels is None else labels.detach().numpy(), num_classes=emb_dim if CLASSIFICATION == "linear" else None)
    eval_step+=1

  train_loss = train_logger.get_loss()
  eval_loss = eval_logger.get_loss()

  results = {}
  if CLASSIFICATION is not None and CLASSIFICATION == "linear":
    train_acc, eval_acc = model.classify(train_logger._preds, train_logger._labels, eval_logger._preds, eval_logger._labels)
    results = {'train_acc': train_acc, 'eval_acc': eval_acc}
    eval_logger.log_best_acc(eval_acc, epoch)
  elif CLASSIFICATION is not None and CLASSIFICATION == "log":
    train_acc, eval_acc = model.classify_emb(train_logger._preds, train_logger._labels, eval_logger._preds, eval_logger._labels)
    results =  {'train_acc': train_acc, 'eval_acc': eval_acc}
    eval_logger.log_best_acc(eval_acc, epoch)
  
  result = "".join([f"{key}: {value}" for key, value in results.items()])
  print(f"Epoch {epoch + 1}/{num_epochs} - train_loss: {train_loss} - eval_loss: {eval_loss}" + (f" && {result}" if len(result) > 0 else ''))

if CLASSIFICATION is not None:
  print(f"Best acc was {eval_logger._best_acc} after {eval_logger._best_acc_epoch +1} epochs!")

def save_embeddings(embeddings, file):
  with open(f'{PATH}{file}.npy', 'wb') as f:
    np.save(f, embeddings)

def load_embeddings(file):
  with open(f'{PATH}{file}.npy', 'rb') as f:
    return np.load(f)

name = f"embeddings_{emb_dim}_dims_{num_epochs}_epochs"
target_embeddings = np.array([])
source_embeddings = np.array([])
model.eval()
for data in tqdm(dataset.unshuffled_train_loader, f"Save target embeddings"):
    node_feats, edge_index, batch_idx, edge_feats, labels = data.x, data.edge_index, data.batch, data.edge_attr, data.y
    emb, _ = model(node_feats, edge_index, batch_idx, edge_feats)
    target_embeddings = np.concatenate((target_embeddings, emb.detach().numpy())) if len(target_embeddings) > 0 else emb.detach().numpy()
for data in tqdm(dataset.test_loader, f"Save source embeddings"):
    node_feats, edge_index, batch_idx, edge_feats, labels = data.x, data.edge_index, data.batch, data.edge_attr, data.y
    emb, _ = model(node_feats, edge_index, batch_idx, edge_feats)
    source_embeddings = np.concatenate((source_embeddings, emb.detach().numpy())) if len(source_embeddings) > 0 else emb.detach().numpy()
save_embeddings(target_embeddings, dataset_name + "_target_" + name)
save_embeddings(source_embeddings, dataset_name + "_source_" + name)
print("Embeddings saved!")


"""
rankings = []
for source in eval_logger.get_embeddings():
  distances = []
  for target in train_logger.get_embeddings():
    dist = np.linalg.norm(target-source)
    distances.append(dist)
  rankings.append(np.argsort(distances))

s_idx = 3
t_idxs = range(5)
s = [(dataset.source_graphs[s_idx], "source")]
t = [(dataset.target_graphs[rankings[s_idx][t_idx]], f"target_{t_idx + 1}" ) for t_idx in t_idxs]
#pos = nx.spring_layout(dataset.source_graphs[s_idx])
for i, (g, name) in enumerate([*s, *t]):
  pos = nx.spring_layout(g)
  fig = plt.figure(figsize=(10,10))
  ax = plt.subplot()
  ax.set_title(name)
  nx.draw(g, pos)
  nx.draw_networkx_labels(g, pos)
  edge_labels = nx.get_edge_attributes(g, 'timestep')
  nx.draw_networkx_edge_labels(g, pos, edge_labels)
  plt.tight_layout()
  plt.savefig(f"/content/drive/MyDrive/graph_matching/{name}.png", format="PNG")
"""