In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
import pickle
import math

### Change the parameters below. Just run the rest. 

In [24]:
# Number of Traces
N_SAMPLE = 4

# Number of Fanout
N_FANOUT = 4

In [25]:
# log with base 4
def log(x):
    return np.log(x) / np.log(N_FANOUT)

In [26]:
depth = math.ceil(log(N_SAMPLE))
print("Required DEPTH: ", depth)

Required DEPTH:  1


### Create graph

In [27]:
edgeList = [[]]
pointer = 0
# go down every loop
for i in range(depth):
    # add a layer of nodes
    for j in range(N_FANOUT**i):
        edgeNum = len(edgeList)
        for r in range(N_FANOUT):
            edgeList.append([])
            edgeList[pointer].append(edgeNum + r)
        pointer += 1

### Create a data with the required depth

In [32]:
# create a list of unique 10 traces
traces = set()
while len(traces) < N_SAMPLE:
    # end the trace when the node has no child
    trace = [0]
    node = 0
    while len(edgeList[node]) > 0:
        node = random.choice(edgeList[node])
        trace.append(node)
    
    # add the trace to the set if it is of length 9
    if len(trace) == depth+1:
        traces.add(tuple(trace))

a = list(traces)
# convert the traces to pytorch tensor
traces = torch.tensor(a, dtype=torch.float32)

### Convert to Adjacency Matrix

In [33]:
adjacencyMatrix = np.zeros((len(edgeList), len(edgeList)))
for i in range(len(edgeList)):
    for j in edgeList[i]:
        adjacencyMatrix[i][j] = 1
        
# to tensor
adjacencyMatrix = torch.tensor(adjacencyMatrix, dtype=torch.float32)

### One hot encoding

In [34]:
nodeLabels = np.zeros((len(edgeList), len(edgeList)))
for i in range(len(edgeList)):
    nodeLabels[i][i] = 1

# to tensor
nodeLabels = torch.tensor(nodeLabels, dtype=torch.float32)

In [35]:
adjacencyMatrix

tensor([[0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

### Output the data

In [11]:
# pickle the data
pickle.dump(adjacencyMatrix, open("adjacencyMatrix.pkl", "wb"))
pickle.dump(nodeLabels, open("nodeLabels.pkl", "wb"))
pickle.dump(traces, open("traces.pkl", "wb"))