You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The reason of this performance issue is related to jax-ml/jax#2242. The SparseGCN uses jax.ops.index_add, but a large Python "for" loop leads to a serious performance issue when involving jax.ops.index_add.
According to issue comments, I have to rewrite training loop using lax.scan or lax.fori_loop in order to resolve this issue. I think if the training loop is rewritten using lax.scan or lax.fori_loop, it will improve the performance about not only SparseGCN but also PadGCN. Therefore, it is really important to resolve this issue.
However, lax.scan or lax.fori_loop were effected by functional programing style and it is difficult to treat them. So, it is difficult to rewrite training loop and I'm struggling this issue. I explain what is blocking my work.
1. lax.scan or lax.fori_loop don't accept a side effect
DeepChem's DiskDataset provides the iterbatches. We can use this method to write training loop like below.
But, lax.scan or lax.fori_loop don't accept a side effect (like iterator/generator). So, I try to implement like below, but it didin't work. I made the issue related to this topic, please confirm jax-ml/jax#3567
train_iterator=train_dataset.iterbatches(batch_size=batch_size)
defrun_epoch(init_params):
defbody_fun(idx, params):
# this iterator doesn't work... batch value is always same in a loopbatch=next(train_iterator)
params, predict=forward(batch, params)
returnparamsreturnlax.fori_loop(0, train_num_batches, body_fun, init_params)
forepochinrange(num_epochs):
params=run_epoch(params)
2. All values in the body_fun of lax.scan or lax.fori_loop don't accept changing the shape
All values in lax.scan or lax.fori_loop, like return value, arguments and so on, don't accept changing the shape. (See the documentation) This is a hard limitation of lax.scan or lax.fori_loop. (To be honest, there is also some additional limitation.... like jax-ml/jax#2962 )
One of the pain points is that it is difficult to treat accumulation operations (like adding a value to the list each loop). I explained some example!
This point may be a problem if the number of metrics which we want to collect is increasing. Sometimes, we need a creative implementation.(See : jax-ml/jax#1708)
Another pain point is that the sparse pattern mini-batch is incompatible with this limitation.
In the case of the sparse pattern modeling, mini-batch data is changing a shape each batch like below.
This is the example of PyTorch Geometric. (x is a node feature)
This is a serious problem about implementing the sparse pattern model. Now, I'm thinking how to resolve this shape issue. The one solution is padding mini-batch graph data like below.
However, we should care about padding values because the values have possibilities to affect the node aggregation algorithm of the sparse pattern.
3. It is difficult to debug the body_fun in lax.scan or lax.fori_loop
It is difficult to debug the body_fun like adding print function in lax.scan or lax.fori_loop. This point is also discussed in this issue jax-ml/jax#999, but the issue is still open...
train_iterator=train_dataset.iterbatches(batch_size=batch_size)
defrun_epoch(init_params):
defbody_fun(idx, params):
# this iterator doesn't work... batch value is always same in a loopbatch=next(train_iterator)
params, predict=forward(batch, params)
# any values were printed....print(predict)
returnparamsreturnlax.fori_loop(0, train_num_batches, body_fun, init_params)
forepochinrange(num_epochs):
params=run_epoch(params)
The text was updated successfully, but these errors were encountered:
SparseGCN
has a serious performance issue.Training time/epoch of the Tox21 example is almost 30 times than
PadGCN
.Result on my local PC (CPU)
Log about
SparseGCN
Log about
PadGCN
The reason of this performance issue is related to jax-ml/jax#2242.
The
SparseGCN
usesjax.ops.index_add
, but a large Python "for" loop leads to a serious performance issue when involvingjax.ops.index_add
.According to issue comments, I have to rewrite training loop using
lax.scan
orlax.fori_loop
in order to resolve this issue. I think if the training loop is rewritten usinglax.scan
orlax.fori_loop
, it will improve the performance about not onlySparseGCN
but alsoPadGCN
. Therefore, it is really important to resolve this issue.However,
lax.scan
orlax.fori_loop
were effected by functional programing style and it is difficult to treat them. So, it is difficult to rewrite training loop and I'm struggling this issue. I explain what is blocking my work.1. lax.scan or lax.fori_loop don't accept a side effect
DeepChem's DiskDataset provides the
iterbatches
. We can use this method to write training loop like below.But,
lax.scan
orlax.fori_loop
don't accept a side effect (like iterator/generator). So, I try to implement like below, but it didin't work. I made the issue related to this topic, please confirm jax-ml/jax#35672. All values in the body_fun of lax.scan or lax.fori_loop don't accept changing the shape
All values in
lax.scan
orlax.fori_loop
, like return value, arguments and so on, don't accept changing the shape. (See the documentation) This is a hard limitation oflax.scan
orlax.fori_loop
. (To be honest, there is also some additional limitation.... like jax-ml/jax#2962 )One of the pain points is that it is difficult to treat accumulation operations (like adding a value to the list each loop). I explained some example!
This point may be a problem if the number of metrics which we want to collect is increasing. Sometimes, we need a creative implementation.(See : jax-ml/jax#1708)
Another pain point is that the sparse pattern mini-batch is incompatible with this limitation.
In the case of the sparse pattern modeling, mini-batch data is changing a shape each batch like below.
This is the example of PyTorch Geometric. (
x
is a node feature)The sparse pattern modeling constructs one big graph each batch by unifying all graphs, so each mini-batch data has a different shape.
This is a serious problem about implementing the sparse pattern model. Now, I'm thinking how to resolve this shape issue. The one solution is padding mini-batch graph data like below.
However, we should care about padding values because the values have possibilities to affect the node aggregation algorithm of the sparse pattern.
3. It is difficult to debug the body_fun in lax.scan or lax.fori_loop
It is difficult to debug the body_fun like adding print function in
lax.scan
orlax.fori_loop
. This point is also discussed in this issue jax-ml/jax#999, but the issue is still open...The text was updated successfully, but these errors were encountered: