In [1]:
import numpy as np
import scipy.sparse as sp
import torch

In [2]:
def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot
# 读取cora.content
path="./data/cora/"
dataset="cora"
idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset), dtype=np.dtype(str))

In [3]:
print(idx_features_labels.shape)
print(type(idx_features_labels))

(2708, 1435)
<class 'numpy.ndarray'>


In [4]:
features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
labels = encode_onehot(idx_features_labels[:, -1])

In [5]:
print(type(features))
print(labels.shape)
# 稀疏存储示例
a = np.array([[1,2,0], [2,1,0]])
print(a)

<class 'scipy.sparse.csr.csr_matrix'>
(2708, 7)
[[1 2 0]
 [2 1 0]]


In [6]:
a = sp.csr_matrix(a, dtype=np.float32)
print(a)
print(a.shape) # a.shape表示的是矩阵原始的形状，不是稀疏存储的形状

  (0, 0)	1.0
  (0, 1)	2.0
  (1, 0)	2.0
  (1, 1)	1.0
(2, 3)


In [7]:
idx = np.array(idx_features_labels[:, 0], dtype=np.int32)  # 取第一列，即论文编号
idx_map = {j: i for i, j in enumerate(idx)}  # 为论文重新编号，按照content的顺序编号
# print(idx_map)

In [8]:
edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32) # 读取引文信息
print(edges_unordered.shape)

(5429, 2)


In [9]:
edges_unordered.flatten().shape
# print(list(map(idx_map.get, edges_unordered.flatten())))
# 序号映射，重新编号
edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=np.int32).reshape(edges_unordered.shape)
print(edges.shape)

(5429, 2)


In [10]:
# 构建邻接矩阵，存在引用关系的位置为1
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)
# build symmetric adjacency matrix 构建一个对称的邻接矩阵
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

In [21]:
# 测试
a = sp.coo_matrix((np.ones(5), ([1,2,3,1,3], [2,3,1,4,2])), shape=(5, 5), dtype=np.float32)
b = a.T > a
print(b)
print(a.T.multiply(b))
print(a.multiply(b))
a = a + a.T.multiply(a.T > a) - a.multiply(a.T > a)
print(a)

  (1, 3)	True
  (2, 1)	True
  (4, 1)	True
  (1, 3)	1.0
  (2, 1)	1.0
  (4, 1)	1.0

  (1, 2)	1.0
  (1, 3)	1.0
  (1, 4)	1.0
  (2, 1)	1.0
  (2, 3)	1.0
  (3, 1)	1.0
  (3, 2)	1.0
  (4, 1)	1.0
