# Demonstration: Maximising Persistence

We demonstrate how we can find a spectral wavelet, parametrised by a basis of chebyshev polynomials, such that the total persistence of a graph's filtration is maximised. 

In [1]:
import numpy as np
import networkx as nx

import torch
import torch.optim as optim
import torch.nn as nn

import numpy as np
from scipy.linalg import eigh

import pickle

from models import models
from models import utils
import numpy.polynomial.chebyshev as cheby

## Choosing a Model

We consider a wavelet spanned by a degree 6 chebyshev polynomial, with coefficients normalised to 1. We truncate the barcode to the 25 most persistent intervals as a computational necessity

In [2]:
cheby_degree = 6
max_intervals = 25

## Loading the data


In [3]:
dataset_name = 'MUTAG'

graph_list = pickle.load(open('data_example/' + dataset_name + '/networkx_graphs.pkl', 'rb'))

print('The ' + dataset_name + ' dataset has ', len(graph_list), ' graphs.')

The MUTAG dataset has  188  graphs.


## Preprocess the data

The ChebyshevWavelets pytorch module takes in a list of dictionaries, each dictionary representing the necessary data to compute spectral wavelets.

The dictionary contains the following fields:
- 'chebyshev': an intermediary matrix (num vertices) x (chebyshev degree) for computations; the vertex values of the filtration are given by the product of this matrix with the vector of chebyshev coefficients
- 'simplex_tree': the Gudhi representation of a simplicial complex

In [4]:
data = []
for i in range(len(graph_list)):
    
    G = graph_list[i]
    datum = dict()
    L = nx.normalized_laplacian_matrix(G)
    w, v = eigh(L.todense()) #computes eigenvalues w and eigenvectors
    vandermonde = cheby.chebvander(w.flatten()-1, cheby_degree)
    datum['chebyshev'] = torch.from_numpy(np.matmul(v**2, vandermonde[:, 1:])).float()

    hks = np.matmul(v**2,  np.exp(-0.1*w)).flatten() #random initial filtration for the simplex_tree
    st = utils.simplex_tree_constructor([list(e) for e in G.edges()])
    datum['simplex_tree'] = utils.filtration_update(st, hks)
    data.append(datum)
print('Finished initial processing')
del graph_list


Finished initial processing


## Experiment Design

### Tenfolding

We find a set of chebyshev coefficients that maximise the average $L^2$-persistence of the graph barcodes across the dataset

We perform a ten-fold cross validation. In a ten-fold, we randomly partition the dataset into 10 portions. We perform the maximisation across 9 portions and then validate the learnt parameters on the remaining portion. We cycle through the ten portions so that each portion is the validation set once.

Normally in machine learning practices, we conduct 10 ten-folds and average across all 100 validation measures, but in the interest of time we only perform one ten-fold in this demo.

In [10]:
data_len = len(data)
test_size = data_len // 10
train_size = data_len - test_size

### Optimisation using PyTorch tools
We specify the batch size and the number of epochs. We use stochastic gradient descent.

In [6]:
### training parameters #####
batch_size = 20
train_batches = np.ceil((data_len-test_size)/batch_size).astype(int)
max_epoch = 25

print('num points = ', data_len, ' number of batches = ', train_batches, ' batch size = ', batch_size, ' test size ', test_size)

num points =  188  number of batches =  9  batch size =  20  test size  18


In [15]:
####### torch random seeds #######
shuffidx = list(range(data_len)) # data indexer

torch.manual_seed(99)
rng_state= torch.get_rng_state() #seed init to ensure same initial conditions for each training

p_tracker = []
tt_loss = []
tn_loss = []

for fold in range(10):
    print ('> fold ', fold)

    param_tracker = []
    test_loss = []
    train_loss = []

    test_bottom = fold * test_size
    test_top = (1+fold) * test_size
    test_indices = shuffidx[test_bottom : test_top]
    train_indices = shuffidx[0:test_bottom] + shuffidx[test_top :]

    torch.set_rng_state(rng_state) #fix init state
    barcodes = models.ChebyshevWavelets(cheby_degree = cheby_degree, max_intervals = max_intervals)
    param_tracker.append(list(barcodes.cheby_params.detach().flatten().numpy()))

    optimizer = optim.SGD(barcodes.parameters(), lr=1e-4, weight_decay = 0.0)

    for epoch in range(max_epoch):
        
        barcodes.train()
        np.random.shuffle(train_indices)
        for b in range(train_batches):

            train_indices_batch = train_indices[b*batch_size : (b+1)*batch_size ]
            optimizer.zero_grad()
            births, deaths = barcodes([data[i] for i in train_indices])
            loss = -torch.sum((deaths - births)**2)/train_size
            loss.backward()
            
            optimizer.step()

        barcodes.eval()
        param_tracker.append(list(barcodes.cheby_params.detach().flatten().numpy()))


        barcodes.eval()
        b,d = barcodes([data[i] for i in train_indices])
        tnl = torch.sum((d- b)**2)/train_size
        b,d  = barcodes([data[i] for i in test_indices])
        ttl = torch.sum((d- b)**2)/test_size
        test_loss.append(ttl.detach().numpy())
        train_loss.append(tnl.detach().numpy())
        
        if epoch % 5 == 0:
            print(epoch, param_tracker[-1])
            print('train: ', train_loss[-1])
            print('test: ',test_loss[-1])

    p_tracker.append(param_tracker)
    tt_loss.append(test_loss)
    tn_loss.append(train_loss)

> fold  0
0 [0.19873825, -0.38596588, -0.24803096, -0.21418718, 0.24152116, -0.2106402]
train:  0.94126076
test:  0.84419525
5 [0.1942889, -0.40188175, -0.24247803, -0.19857855, 0.23643611, -0.21257612]
train:  1.069699
test:  0.9619286
10 [0.18935049, -0.41823336, -0.23631474, -0.1813835, 0.23071557, -0.21411334]
train:  1.2166607
test:  1.0961475
15 [0.18388596, -0.43480247, -0.22949494, -0.16251977, 0.22434296, -0.21524544]
train:  1.3831999
test:  1.2487042
20 [0.1778703, -0.45137846, -0.22198732, -0.142042, 0.21725197, -0.21585584]
train:  1.5694593
test:  1.4189979
> fold  1
0 [0.19876146, -0.38589865, -0.24805996, -0.21424843, 0.24155203, -0.21060948]
train:  0.9147942
test:  1.0885763
5 [0.1944401, -0.40147173, -0.24266678, -0.1989808, 0.2366406, -0.21239157]
train:  1.0373375
test:  1.2296659
10 [0.18965077, -0.4174744, -0.23668957, -0.18219972, 0.23111643, -0.21378535]
train:  1.1772486
test:  1.3898085
15 [0.18435724, -0.43371433, -0.23008308, -0.16378641, 0.22496746, -0.214

In [16]:
np.array(tt_loss

[[array(0.84419525, dtype=float32),
  array(0.8664122, dtype=float32),
  array(0.88925713, dtype=float32),
  array(0.91281945, dtype=float32),
  array(0.9370448, dtype=float32),
  array(0.9619286, dtype=float32),
  array(0.9874242, dtype=float32),
  array(1.0134817, dtype=float32),
  array(1.0402309, dtype=float32),
  array(1.0677447, dtype=float32),
  array(1.0961475, dtype=float32),
  array(1.1253159, dtype=float32),
  array(1.1551636, dtype=float32),
  array(1.1856643, dtype=float32),
  array(1.2168387, dtype=float32),
  array(1.2487042, dtype=float32),
  array(1.2812785, dtype=float32),
  array(1.314585, dtype=float32),
  array(1.3486603, dtype=float32),
  array(1.383465, dtype=float32),
  array(1.4189979, dtype=float32),
  array(1.455358, dtype=float32),
  array(1.4924461, dtype=float32),
  array(1.5301964, dtype=float32),
  array(1.5685762, dtype=float32)],
 [array(1.0885763, dtype=float32),
  array(1.1152787, dtype=float32),
  array(1.142703, dtype=float32),
  array(1.1709275, d