In [1]:
import sys
sys.path.insert(0, '../src/')

import os
from pathlib import Path
import pickle
import warnings
warnings.filterwarnings('ignore')
import abc
import time

import numpy as np
import scipy.sparse as sp

import torch
device = 'cpu'
dtype = torch.float32

from scipy.sparse import load_npz, save_npz

import utils

In [2]:
_A_obs = load_npz('../data/CORA_ML.npz')

In [3]:
val_share = 0.1
test_share = 0.05
seed = 481516234

train_ones, val_ones, val_zeros, test_ones, test_zeros = utils.train_val_test_split_adjacency(_A_obs, val_share, test_share, seed, undirected=True, connected=True, asserts=False)

train_graph = sp.csr_matrix((np.ones(len(train_ones)),(train_ones[:,0], train_ones[:,1])))
assert (train_graph.toarray() == train_graph.toarray().T).all()

In [19]:
class Callback(abc.ABC):
    def __init__(self, invoke_every):
        self.training_stopped = False
        self.invoke_every = invoke_every
        
    def __call__(self, loss, model):
        if model.step % self.invoke_every == 0:
            self.invoke(loss, model)
        
    def stop_training(self):
        self.training_stopped = True
    
    @abc.abstractmethod
    def invoke(self, loss, model):
        pass


class OverlapTracker(Callback):
    """
    This callback serves in three ways:
    - It samples a graph from the model and saves it on hard drive.
    - It tracks the EdgeOverlap and stops if the limit is met.
    - It tracks the validation AUC-ROC score and the average precision.
    - It tracks the total time.
    """
    def __init__(self, logdir=None, invoke_every=100, EO_limit=1., val_edges=(None, None)):
        super().__init__(invoke_every)
        self.logdir = logdir
        if self.logdir is None:
            self.logs = []
        self.EO_limit = EO_limit
        self.overlap_dict = {}
        self.roc_auc_dict = {}
        self.avg_prec_dict = {}
        self.time_dict = {}
        (self.val_ones, self.val_zeros) = val_edges

    def invoke(self, loss, model):
        start = time.time()
        sampled_graph = model.sample_graph()
        overlap = utils.edge_overlap(model.A_sparse, sampled_graph) / model.num_edges
        self.overlap_dict[model.step] = overlap
        overlap_time = time.time() - start
        model.total_time += overlap_time
        self.time_dict[model.step] = model.total_time
        
        if self.val_ones is not None and self.val_zeros is not None:
            roc_auc, avg_prec = utils.link_prediction_performance(model._scores_matrix,
                                                                  self.val_ones,
                                                                  self.val_zeros)
            self.roc_auc_dict[model.step] = roc_auc
            self.avg_prec_dict[model.step] = avg_prec
        
        step_str = f'{model.step:{model.step_str_len}d}'
        print(f'Step: {step_str}/{model.steps}, Loss: {loss:.5f}, Edge-Overlap: {overlap:.3f}')
        if overlap >= self.EO_limit:
            self.stop_training()
            
        
            
        if self.logdir:
            filename = f'graph_{model.step:0{model.step_str_len}d}'
            save_npz(file=os.path.join(self.logdir, filename),
                     matrix=sampled_graph)
            
            if self.training_stopped or model.step==model.steps:
                utils.save_dict(self.overlap_dict, os.path.join(self.logdir, 'overlap.pickle'))
                utils.save_dict(self.time_dict, os.path.join(self.logdir,'timing.pickle'))

                if self.val_ones is not None and self.val_zeros is not None:
                    utils.save_dict(self.roc_auc_dict, os.path.join(self.logdir,'ROC-AUC.pickle'))
                    utils.save_dict(self.avg_prec_dict, os.path.join(self.logdir,'avg_prec.pickle'))
        else:
            self.logs.append(sampled_graph)


class WeightWatcher(Callback):
    """
    Saves the model's weights on hard drive.
    """
    def __init__(self, logdir, invoke_every=100):
        super().__init__(invoke_every)
        self.logdir = logdir
        
    def invoke(self, loss, model):
        filename =  f'weights_{model.step:0{model.step_str_len}d}'
        np.savez(file=os.path.join(self.logdir, filename),
                 W_down=model.W_down.detach().numpy(),
                 W_up=model.W_up.detach().numpy())
        pass

In [20]:
class Net(object):
    def __init__(self, A, H, callbacks=[]):
        self.num_edges = A.sum()/2
        self.A_sparse = A
        self.A = torch.tensor(A.toarray())
        self.step = 1
        self.callbacks = callbacks
        self._optimizer = None
        N = A.shape[0]
        gamma = np.sqrt(2/(N+H))
        self.W_down = (gamma * torch.randn(N, H, device=device, dtype=dtype)).clone().detach().requires_grad_()
        self.W_up = (gamma * torch.randn(H, N, device=device, dtype=dtype)).clone().detach().requires_grad_()
        self.total_time = 0
              
    def __call__(self):
        return torch.nn.functional.softmax(self.get_W(), dim=-1).detach().numpy()
    
    def get_W(self):
        W = torch.mm(self.W_down, self.W_up)
        W -= W.max(dim=-1, keepdims=True)[0]
        #if self.force_W_symmetric:
        #    W = torch.max(W, W.T)
        return W
    
    def loss(self, W):
        """
        Computes the weighted cross-entropy loss in logits with weight matrix M * P.
        Parameters
        ----------
        W: torch.tensor of shape (N, N)
                Logits of learnable (low rank) transition matrix.

        Returns
        -------
        loss: torch.tensor (float)
                Loss at logits.
        """
        d = torch.log(torch.exp(W).sum(dim=-1, keepdims=True))
        loss = .5 * torch.sum(self.A * (d * torch.ones_like(self.A) - W)) / self.num_edges
        return loss 
    
    def _closure(self):
        W = self.get_W()
        loss = self.loss(W=W)
        self._optimizer.zero_grad()
        loss.backward()
        return loss
        
    def _train_step(self):
        time_start = time.time()
        loss = self._optimizer.step(self._closure)
        time_end = time.time()
        return loss.item(), (time_end - time_start)
    
    def train(self, steps, optimizer_fn, optimizer_args, EO_criterion=None):
        self._optimizer = optimizer_fn([self.W_down, self.W_up], **optimizer_args)
        self.steps = steps
        self.step_str_len = len(str(steps))
        stop = False
        for self.step in range(self.step, steps+self.step):
            loss, time = self._train_step()
            self.total_time += time
            for callback in self.callbacks:
                callback(loss=loss, model=self)
                stop = stop or callback.training_stopped    
            if stop: break
                
    def sample_graph(self):
        transition_matrix = self()
        self._scores_matrix = utils.scores_matrix_from_transition_matrix(transition_matrix=transition_matrix,
                                                                         symmetric=True)
        sampled_graph = utils.graph_from_scores(self._scores_matrix, self.num_edges)
        return sampled_graph

In [21]:
netmodel = Net(A=train_graph,
               H=12,
               callbacks=[OverlapTracker(logdir='../logs/sampled_graphs',
                                         invoke_every=5,
                                         EO_limit=.5,
                                         val_edges=(val_ones, val_zeros)),
                          WeightWatcher(logdir='../logs/weights',
                                        invoke_every=5)])

In [7]:
"""
start = time.time()
netmodel.train(steps=400,
               optimizer_fn=torch.optim.Adam,
               optimizer_args={'lr': 0.1,
                               'weight_decay': 1e-7})
total = time.time() - start
print(total)
"""

"\nstart = time.time()\nnetmodel.train(steps=400,\n               optimizer_fn=torch.optim.Adam,\n               optimizer_args={'lr': 0.1,\n                               'weight_decay': 1e-7})\ntotal = time.time() - start\nprint(total)\n"

In [22]:
def start_experiments(num_experiments,
                      experiment_root,
                      train_graph,
                      H,
                      optimizer,
                      optimizer_args,
                      invoke_every,
                      steps,
                      val_edges=(None, None)):
    """Start multiple experiments."""
    # create root folder
    Path(experiment_root).mkdir(parents=True, exist_ok=True)
    netmodels = []
    for experiment in range(num_experiments):
        # create experiment folder
        path = os.path.join(experiment_root, f'Experiment_{experiment:0{len(str(num_experiments))}d}')
        
        path_graphs = os.path.join(path, 'sampled_graphs')
        Path(path_graphs).mkdir(parents=True, exist_ok=True)
        
        path_weights = os.path.join(path, 'weights')
        Path(path_weights).mkdir(parents=True, exist_ok=True)
        
        # initialize model
        netmodel = Net(A=train_graph,
                       H=H,
                       callbacks=[OverlapTracker(logdir=path_graphs,
                                                 invoke_every=invoke_every,
                                                 EO_limit=1.,
                                                 val_edges=val_edges),
                                  WeightWatcher(logdir=path_weights,
                                                invoke_every=invoke_every)])
        
        # train model
        print(f'\nExperiment_{experiment:0{len(str(num_experiments))}d}')
        netmodel.train(steps=steps,
               optimizer_fn=optimizer,
               optimizer_args=optimizer_args)
        netmodels.append(netmodel)
    return netmodels

In [23]:
models = start_experiments(num_experiments=20,
                           experiment_root='../logs/CORA-ML/Ours',
                           train_graph=train_graph,
                           H=12,
                           optimizer=torch.optim.Adam,
                           optimizer_args={'lr': 0.1,
                                           'weight_decay': 1e-7},
                           invoke_every=5,
                           steps = 100,
                           val_edges=(val_ones, val_zeros))


Experiment_00
Step:   5/100, Loss: 7.40340, Edge-Overlap: 0.006
Step:  10/100, Loss: 5.43552, Edge-Overlap: 0.085
Step:  15/100, Loss: 4.12632, Edge-Overlap: 0.336
Step:  20/100, Loss: 3.44316, Edge-Overlap: 0.445
Step:  25/100, Loss: 3.04343, Edge-Overlap: 0.541
Step:  30/100, Loss: 2.78372, Edge-Overlap: 0.622
Step:  35/100, Loss: 2.61644, Edge-Overlap: 0.664
Step:  40/100, Loss: 2.50347, Edge-Overlap: 0.714
Step:  45/100, Loss: 2.42417, Edge-Overlap: 0.737
Step:  50/100, Loss: 2.36605, Edge-Overlap: 0.764
Step:  55/100, Loss: 2.32168, Edge-Overlap: 0.778
Step:  60/100, Loss: 2.28721, Edge-Overlap: 0.800
Step:  65/100, Loss: 2.25938, Edge-Overlap: 0.814
Step:  70/100, Loss: 2.23643, Edge-Overlap: 0.829
Step:  75/100, Loss: 2.21721, Edge-Overlap: 0.838
Step:  80/100, Loss: 2.20099, Edge-Overlap: 0.854
Step:  85/100, Loss: 2.18709, Edge-Overlap: 0.862
Step:  90/100, Loss: 2.17499, Edge-Overlap: 0.865
Step:  95/100, Loss: 2.16449, Edge-Overlap: 0.872
Step: 100/100, Loss: 2.15506, Edge-

Step:  15/100, Loss: 4.07667, Edge-Overlap: 0.354
Step:  20/100, Loss: 3.39720, Edge-Overlap: 0.455
Step:  25/100, Loss: 3.01082, Edge-Overlap: 0.560
Step:  30/100, Loss: 2.76417, Edge-Overlap: 0.615
Step:  35/100, Loss: 2.60746, Edge-Overlap: 0.673
Step:  40/100, Loss: 2.49986, Edge-Overlap: 0.711
Step:  45/100, Loss: 2.42303, Edge-Overlap: 0.744
Step:  50/100, Loss: 2.36609, Edge-Overlap: 0.765
Step:  55/100, Loss: 2.32270, Edge-Overlap: 0.785
Step:  60/100, Loss: 2.28877, Edge-Overlap: 0.811
Step:  65/100, Loss: 2.26150, Edge-Overlap: 0.817
Step:  70/100, Loss: 2.23910, Edge-Overlap: 0.825
Step:  75/100, Loss: 2.22045, Edge-Overlap: 0.845
Step:  80/100, Loss: 2.20464, Edge-Overlap: 0.850
Step:  85/100, Loss: 2.19109, Edge-Overlap: 0.860
Step:  90/100, Loss: 2.17927, Edge-Overlap: 0.874
Step:  95/100, Loss: 2.16901, Edge-Overlap: 0.872
Step: 100/100, Loss: 2.16014, Edge-Overlap: 0.882

Experiment_09
Step:   5/100, Loss: 7.39870, Edge-Overlap: 0.004
Step:  10/100, Loss: 5.38821, Edge-

Step:  25/100, Loss: 2.99241, Edge-Overlap: 0.551
Step:  30/100, Loss: 2.74924, Edge-Overlap: 0.632
Step:  35/100, Loss: 2.59109, Edge-Overlap: 0.679
Step:  40/100, Loss: 2.48726, Edge-Overlap: 0.706
Step:  45/100, Loss: 2.41212, Edge-Overlap: 0.748
Step:  50/100, Loss: 2.35630, Edge-Overlap: 0.766
Step:  55/100, Loss: 2.31374, Edge-Overlap: 0.786
Step:  60/100, Loss: 2.28011, Edge-Overlap: 0.798
Step:  65/100, Loss: 2.25268, Edge-Overlap: 0.822
Step:  70/100, Loss: 2.22976, Edge-Overlap: 0.829
Step:  75/100, Loss: 2.21032, Edge-Overlap: 0.840
Step:  80/100, Loss: 2.19386, Edge-Overlap: 0.854
Step:  85/100, Loss: 2.17986, Edge-Overlap: 0.861
Step:  90/100, Loss: 2.16794, Edge-Overlap: 0.870
Step:  95/100, Loss: 2.15788, Edge-Overlap: 0.883
Step: 100/100, Loss: 2.14912, Edge-Overlap: 0.889

Experiment_17
Step:   5/100, Loss: 7.41197, Edge-Overlap: 0.005
Step:  10/100, Loss: 5.44516, Edge-Overlap: 0.090
Step:  15/100, Loss: 4.11053, Edge-Overlap: 0.349
Step:  20/100, Loss: 3.43992, Edge-

In [9]:
(total - netmodel.total_time) / (6.3-netmodel.total_time)

3.3384577906085626

In [10]:
netmodel.total_time

2.943687915802002

In [13]:
sp.save_npz?

In [9]:
sp.csc_matrix?