In [1]:
import sys
sys.path.insert(0, '../src/')

import os
import abc
import warnings
warnings.filterwarnings('ignore')

import time

#import tensorflow as tf
import torch
device = 'cpu'
dtype = torch.float32

import scipy.sparse as sp
import numpy as np
from matplotlib import pyplot as plt
from scipy.sparse import save_npz, load_npz, csr_matrix
from sklearn.metrics import roc_auc_score, average_precision_score
import time
import pandas as pd
%matplotlib inline

# from netgan.netgan import *
# from netgan import utils

from net.utils import *
from net import utils_netgan as utils
import net.net as net

In [2]:
_A_obs, _X_obs, _z_obs = utils.load_npz('../data/cora_ml.npz')
#_A_obs = load_npz('../data/gemsec.npz')
_A_obs = _A_obs + _A_obs.T
_A_obs[_A_obs > 1] = 1
lcc = utils.largest_connected_components(_A_obs)
_A_obs = _A_obs[lcc,:][:,lcc]
_N = _A_obs.shape[0]

val_share = 0.1
test_share = 0.05
seed = 481516234

Selecting 1 largest connected components


In [3]:
train_ones, val_ones, val_zeros, test_ones, test_zeros = utils.train_val_test_split_adjacency(_A_obs, val_share, test_share, seed, undirected=True, connected=True, asserts=False)

train_graph = sp.coo_matrix((np.ones(len(train_ones)),(train_ones[:,0], train_ones[:,1]))).tocsr()
assert (train_graph.toarray() == train_graph.toarray().T).all()

In [127]:
class Callback(abc.ABC):
    def __init__(self, invoke_every):
        self.training_stopped = False
        self.invoke_every = invoke_every
        print()
        
    def __call__(self, loss, model):
        if model.step % self.invoke_every == 0:
            self.invoke(loss, model)
        
    def stop_training(self):
        self.training_stopped = True
    
    @abc.abstractmethod
    def invoke(self, loss, model):
        pass


class OverlapTracker(Callback):
    """
    This callback serves in three ways:
    - It samples a graph from the model and saves it on hard drive.
    - It tracks the EdgeOverlap and stops if the limit is met.
    """
    def __init__(self, logdir=None, invoke_every=100, EO_limit=1.):
        super().__init__(invoke_every)
        self.logdir = logdir
        if self.logdir is None:
            self.logs = []
        self.EO_limit = EO_limit

    def invoke(self, loss, model):
        sampled_graph = model.sample_graph()
        # TODO: tune edge_overlap func
        overlap = utils.edge_overlap(model.A.numpy(), sampled_graph) / model.num_edges
        
        
        step_str = f'{model.step:{model.step_str_len}d}'
        print(f'Step: {step_str}/{model.steps}, Loss: {loss:.5f}, Edge-Overlap: {overlap:.3f}')
        if overlap >= self.EO_limit:
            self.stop_training()
            
        if self.logdir:
            filename = f'graph_{model.step:0{model.step_str_len}d}'
            save_npz(file=os.path.join(self.logdir, filename),
                     matrix=sampled_graph)
        else:
            self.logs.append(sampled_graph)


class WeightWatcher(Callback):
    """
    Saves the model's weights on hard drive.
    """
    def __init__(self, logdir, invoke_every=100):
        super().__init__(invoke_every)
        self.logdir = logdir
        
    def invoke(self, loss, model):
        filename =  f'weights_{model.step:0{model.step_str_len}d}'
        np.savez(file=os.path.join(self.logdir, filename),
                 W_down=model.W_down.detach().numpy(),
                 W_up=model.W_up.detach().numpy())
        pass

In [128]:
class Net(object):
    def __init__(self, A, H, callbacks=[]):
        self.num_edges = A.sum()
        self.A = torch.tensor(A)
        self.step = 1
        self.callbacks = callbacks
        self._optimizer = None
        N = A.shape[0]
        gamma = np.sqrt(2/(N+H))
        self.W_down = (gamma * torch.randn(N, H, device=device, dtype=dtype)).clone().detach().requires_grad_()
        self.W_up = (gamma * torch.randn(H, N, device=device, dtype=dtype)).clone().detach().requires_grad_()
        self.total_time = 0
              
    def __call__(self):
        return torch.nn.functional.softmax(self.get_W(), dim=-1).detach().numpy()
    
    def get_W(self):
        W = torch.mm(self.W_down, self.W_up)
        W -= W.max(dim=-1, keepdims=True)[0]
        #if self.force_W_symmetric:
        #    W = torch.max(W, W.T)
        return W
    
    def loss(self, W):
        """
        Computes the weighted cross-entropy loss in logits with weight matrix M * P.
        Parameters
        ----------
        W: torch.tensor of shape (N, N)
                Logits of learnable (low rank) transition matrix.

        Returns
        -------
        loss: torch.tensor (float)
                Loss at logits.
        """
        d = torch.log(torch.exp(W).sum(dim=-1, keepdims=True))
        loss = torch.sum(self.A * (d * torch.ones_like(self.A) - W)) / self.num_edges
        return loss 
    
    def _closure(self):
        W = self.get_W()
        loss = self.loss(W=W)
        self._optimizer.zero_grad()
        loss.backward()
        return loss
        
    def _train_step(self):
        time_start = time.time()
        loss = self._optimizer.step(self._closure)
        time_end = time.time()
        return loss.item(), (time_end - time_start)
    
    def train(self, steps, optimizer_fn, optimizer_args, EO_criterion=None):
        self._optimizer = optimizer_fn([self.W_down, self.W_up], **optimizer_args)
        self.steps = steps
        self.step_str_len = int(np.log10(steps))+1
        stop = False
        for self.step in range(self.step, steps+self.step):
            loss, time = self._train_step()
            self.total_time += time
            for callback in self.callbacks:
                callback(loss=loss, model=self)
                stop = stop or callback.training_stopped    
            if stop: break
                
    def sample_graph(self):
        transition_matrix = self()
        scores_matrix = scores_matrix_from_transition_matrix(transition_matrix=transition_matrix,
                                                             symmetric=True)
        sampled_graph = utils.graph_from_scores(scores_matrix, self.num_edges)
        return sampled_graph

In [129]:
netmodel = Net(A=train_graph.toarray(),
               H=12,
               callbacks=[OverlapTracker(logdir='../logs/sampled_graphs',
                                         invoke_every=5,
                                         EO_limit=.5),
                          WeightWatcher(logdir='../logs/weights',
                                        invoke_every=5)])





In [130]:
start = time.time()
netmodel.train(steps=400,
               optimizer_fn=torch.optim.Adam,
               optimizer_args={'lr': 0.1,
                               'weight_decay': 1e-7})
total = time.time() - start
print(total)

Step:   5/400, Loss: 7.39126, Edge-Overlap: 0.006
Step:  10/400, Loss: 5.35504, Edge-Overlap: 0.092
Step:  15/400, Loss: 4.02432, Edge-Overlap: 0.355
Step:  20/400, Loss: 3.35688, Edge-Overlap: 0.456
Step:  25/400, Loss: 2.97855, Edge-Overlap: 0.571
14.361342430114746


In [9]:
(total - netmodel.total_time) / (6.3-netmodel.total_time)

3.3384577906085626

In [10]:
netmodel.total_time

2.943687915802002

In [13]:
sp.save_npz?

In [9]:
sp.csc_matrix?

In [124]:
help(np.savez)

Help on function savez in module numpy:

savez(file, *args, **kwds)
    Save several arrays into a single file in uncompressed ``.npz`` format.
    
    If arguments are passed in with no keywords, the corresponding variable
    names, in the ``.npz`` file, are 'arr_0', 'arr_1', etc. If keyword
    arguments are given, the corresponding variable names, in the ``.npz``
    file will match the keyword names.
    
    Parameters
    ----------
    file : str or file
        Either the file name (string) or an open file (file-like object)
        where the data will be saved. If file is a string or a Path, the
        ``.npz`` extension will be appended to the file name if it is not
        already there.
    args : Arguments, optional
        Arrays to save to the file. Since it is not possible for Python to
        know the names of the arrays outside `savez`, the arrays will be saved
        with names "arr_0", "arr_1", and so on. These arguments can be any
        expression.
    kwds 