In [7]:
import scipy.io
import numpy as np

In [188]:
def construct_adjacency_matrix(edge_list, graph_size):
    max_index = np.max(edge_list)
    n = max_index + 1
    if graph_size > n:
        n = graph_size

    adjacency_matrix = np.zeros((n, n), dtype=np.int32)
    for edge in edge_list:
        weight = 1
        if len(edge) > 2:
            weight = edge[2]
        adjacency_matrix[edge[0], edge[1]] = weight
        adjacency_matrix[edge[1], edge[0]] = weight

    return adjacency_matrix

def read_matrix(path):
    matrix = scipy.io.loadmat(path)
    for key in matrix.keys():
        if not key.startswith('__'):
            return matrix[key].astype(int)
    raise TypeError('Matrix not found!')

def construct_matrix(dataset, file_names=[], graph_size=0, construct_adjacency=True):
    true_labels = np.loadtxt(f'data/{dataset}/ground_truth.txt')
    if dataset == 'imdb':
        imdb = scipy.io.loadmat('data/imdb/imdb.mat')
        return np.stack((imdb['MAM'], imdb['MDM']), axis=-1), imdb['feature'], true_labels
    loaded_matrices = []
    for file in file_names:
        layer = read_matrix(f'data/{dataset}/{file}.mat')
        if construct_adjacency:
            loaded_matrices.append(construct_adjacency_matrix(layer, graph_size))
        else:
            loaded_matrices.append(layer)

    if len(loaded_matrices) > 1:
        return np.stack(loaded_matrices, axis=-1), true_labels
    return np.array(loaded_matrices[0]), true_labels

In [189]:
acm, true_labels = construct_matrix('acm', ['PAP', 'PLP'], construct_adjacency=False)
acm_attributes, _ = construct_matrix('acm', ['feature'], construct_adjacency=False)
print(acm.shape)
print(acm_attributes.shape)
print(true_labels.shape)

(3025, 3025, 2)
(3025, 1870)
(3025,)


In [190]:
dblp, true_labels = construct_matrix('dblp', ['APNet', 'citation', 'co_citation', 'coauthorNet'], 8401)
print(dblp.shape)
print(true_labels.shape)

(8401, 8401, 4)
(8401,)


In [191]:
flickr, true_labels = construct_matrix('flickr', ['layer0', 'layer1'])
print(flickr.shape)
print(true_labels.shape)

(10364, 10364, 2)
(10364,)


In [192]:
imdb, imdb_attributes, true_labels = construct_matrix('imdb', [])
print(imdb.shape)
print(imdb_attributes.shape)
print(true_labels.shape)

(3550, 3550, 2)
(3550, 2000)
(3550,)
