### 测试稀疏矩阵运算 $\Leftarrow$ `SparseTensor, spspmm`

In [2]:
import torch
import numpy as np
from torch_sparse import SparseTensor, spspmm

  from .autonotebook import tqdm as notebook_tqdm


In [78]:
edge_index = torch.LongTensor([[0, 0, 0, 1, 3, 3], [0, 1, 2, 0, 1, 2]])
nodes = torch.max(edge_index).item() + 1
sp1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(nodes, nodes))
print(sp1)
vl1 = torch.ones((edge_index.shape[1]))
np.random.seed(1)
dense = torch.from_numpy(np.random.randint(2, size=(4, 4), dtype=np.int64))
sp2 = SparseTensor.from_dense(dense)
vl2 = torch.ones((torch.count_nonzero(dense)))
eix, val = spspmm(edge_index, vl1, edge_index, vl1, 4, 4, 4)


SparseTensor(row=tensor([0, 0, 0, 1, 3, 3]),
             col=tensor([0, 1, 2, 0, 1, 2]),
             size=(4, 4), nnz=6, density=37.50%)


In [81]:
print(ei, val)
ei = eix[0]
ei_ = torch.cat([ei[0:1], ei[:-1]])

cutpoints = torch.nonzero(ei - ei_).squeeze().tolist()
cutpoints = [0] + cutpoints + [ei.shape[0]]
cutpoints

tensor([0, 0, 0, 1, 1, 1, 3]) tensor([2., 1., 1., 1., 1., 1., 1.])


[0, 3, 6, 7]

In [65]:
adj_raw = [(ei[:, start:end], val[start:end]) for start, end in zip(cutpoints[:-1], cutpoints[1:])]
adj_raw

[(tensor([[0, 0, 0],
          [0, 1, 2]]),
  tensor([2., 1., 1.])),
 (tensor([[1, 1, 1],
          [0, 1, 2]]),
  tensor([1., 1., 1.])),
 (tensor([[3],
          [0]]),
  tensor([1.]))]

In [75]:
mean_value = [torch.mean(value) / 1. for data, value in adj_raw]
adj_selected_raw = [data[:, value<mean] for (data, value), mean in zip(adj_raw, mean_value)]
adj_selected = torch.cat(adj_selected_raw, dim=-1)

In [77]:
SparseTensor.from_edge_index(adj_selected, sparse_sizes=(nodes, nodes))

SparseTensor(row=tensor([0, 0]),
             col=tensor([1, 2]),
             size=(4, 4), nnz=2, density=12.50%)

In [12]:
a = torch.tensor([1, 3, 14, 2, 2, 15, 444, 31, 2])
b = torch.tensor([[1, 3, 14, 2, 2, 15, 444, 31, 2],[1, 3, 14, 2, 2, 15, 444, 31, 2]])
_, index = torch.topk(a, 5, largest=False)
print(index)
index = torch.sort(index).values
b[:,index]

tensor([0, 4, 8, 3, 1])


tensor([[1, 3, 2, 2, 2],
        [1, 3, 2, 2, 2]])

### 测试plot_influence $\Leftarrow$ `torch.autograd`

In [1]:
import torch
from torch_sparse import SparseTensor
from models.sognn import NodeLevelSOGNN
from models.sognn_layer import SOGNNConv
from models.gcn import NodeLevelGCN

from models.utils import get_jacobian
from torch_geometric.data import Data

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
edge_index = torch.LongTensor([
    [0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 5, 6, 7, 7, 8, 8, 9],
    [3, 3, 3, 0, 1, 2, 5, 5, 3, 4, 6, 7, 5, 5, 8, 7, 9, 8]
])
value = torch.ones((edge_index.shape[1]))
x = torch.rand((10, 20), requires_grad=True)
data = Data(x, edge_index, value)
# y = x.clone()
adj_m = SparseTensor.from_edge_index(edge_index,sparse_sizes=(10, 10)).to_dense(dtype=torch.long)
adj_m_d = torch.matmul(adj_m, adj_m)
adj_m_d = torch.matmul(adj_m_d, adj_m)
adj_m_d_s = SparseTensor.from_dense(adj_m_d).cuda()
SOGNNConv.edge_index_distant = adj_m_d_s
sognn = NodeLevelSOGNN(20, 3, 10, 0, 0.5, 0.001, 0.01)
gcn = NodeLevelGCN(20, 3, 10, 0, 0.5, 0.001, 0.01)
for r in range(8):
    print(get_jacobian(sognn, data, 0, r))
    print(get_jacobian(gcn, data, 0, r))


  rank_zero_deprecation(
Global seed set to 1
Global seed set to 1


   influence  r
0        1.0  0
   influence  r
0        1.0  0
   influence  r
0        1.0  1
   influence  r
0        1.0  1
   influence  r
0   0.055019  2
1   0.055019  2
2   0.889963  2
   influence  r
0   0.379873  2
1   0.379873  2
2   0.240253  2
   influence  r
0   0.241916  3
1   0.343413  3
2   0.414671  3
   influence  r
0        0.0  3
1        0.0  3
2        0.0  3
   influence  r
0        1.0  4
   influence  r
0        0.0  4
   influence  r
0        0.0  5
   influence  r
0        0.0  5
Empty DataFrame
Columns: [influence, r]
Index: []
Empty DataFrame
Columns: [influence, r]
Index: []
Empty DataFrame
Columns: [influence, r]
Index: []
Empty DataFrame
Columns: [influence, r]
Index: []


In [2]:
a = [1, 3, 4]
print(a.remove(3))

None
