Skip to content

Commit

Permalink
use csr only for readonly graph.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jan 14, 2019
1 parent 4042b3d commit 8e24bb0
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,37 +551,37 @@ def adjacency_matrix(self, transpose, ctx):
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLGraphGetAdj(self._handle, transpose, fmt)
if fmt == "csr":
if self.is_readonly():
fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLGraphGetAdj(self._handle, transpose, fmt)
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
shuffle = utils.toindex(rst(2))
dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx)
return F.sparse_matrix(dat, ('csr', indices, indptr),
(self.number_of_nodes(), self.number_of_nodes()))[0], shuffle
elif fmt == "coo":
#src, dst, _ = self.edges(False)
#src = src.tousertensor(ctx) # the index of the ctx will be cached
#dst = dst.tousertensor(ctx) # the index of the ctx will be cached
#src = F.unsqueeze(src, dim=0)
#dst = F.unsqueeze(dst, dim=0)
#if transpose:
# idx = F.cat([src, dst], dim=0)
#else:
# idx = F.cat([dst, src], dim=0)
#print(idx.shape)
else:
src, dst, _ = self.edges(False)
src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached
src = F.unsqueeze(src, dim=0)
dst = F.unsqueeze(dst, dim=0)
if transpose:
idx = F.cat([src, dst], dim=0)
else:
idx = F.cat([dst, src], dim=0)

## FIXME(minjie): data type
idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
#idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
m = self.number_of_edges()
idx = F.reshape(idx, (2, m))
#idx = F.reshape(idx, (2, m))
dat = F.ones((m,), dtype=F.float32, ctx=ctx)
n = self.number_of_nodes()
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, n))
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return adj, shuffle_idx
else:
raise Exception("unknown format")
#else:
# raise Exception("unknown format")

@utils.cached_member(cache='_cache', prefix='inc')
def incidence_matrix(self, typestr, ctx):
Expand Down

0 comments on commit 8e24bb0

Please sign in to comment.