Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: Tensors used as indices must be long, byte or bool tensors #3

Closed
wenhao-gao opened this issue Jan 19, 2022 · 3 comments
Closed

Comments

@wenhao-gao
Copy link

Dear authors, thanks for sharing the code for this wonderful work!

I am currently trying to run the naive gflownet training code in molecular docking setting by running
python gflownet.py
under the mols directory. I have unzipped the datasets and have all requirements installed. And I have successfully run the model in the toy grid environment.

However, I got this error when I run in the mols environment:

Exception while sampling:
tensors used as indices must be long, byte or bool tensors

And when I further look up, it seems like the problem occurs around the line 70 in model_block.py. I tried to print out the stem_block_batch_idx but it doesn't seems like could be transfered to long type directly, which is required by an index:

tensor([[-8.4156e-02, -4.2767e-02, -7.2483e-02, -3.3011e-02, -1.1865e-02,
2.0981e-03, 1.3293e-02, -7.3515e-03, -4.1853e-02, 2.1048e-02,
3.8597e-02, -1.5558e-02, 2.1581e-02, 4.9257e-03, 9.5167e-02,
4.0965e-02, 2.0146e-02, -5.5610e-02, -3.5318e-02, -3.1394e-02,
7.2078e-02, 1.8894e-02, -3.0249e-02, 2.9740e-02, 5.6950e-02,
-3.8425e-02, 2.8620e-02, 9.2052e-02, -8.5357e-03, 1.6788e-02,
7.7801e-02, -4.2119e-02, 1.3606e-02, 7.5316e-02, 4.7131e-02,
-4.3429e-03, 1.4157e-04, 2.0939e-02, -2.3499e-02, -6.5888e-02,
-2.8960e-02, 3.1548e-02, -9.2680e-03, 5.4192e-02, -9.6579e-03,
2.0602e-02, 1.8935e-02, 4.1228e-03, -6.3467e-02, 3.6747e-02,
1.4168e-02, -6.1473e-03, -1.9472e-02, -3.3970e-02, -5.7308e-03,
-4.6021e-02, -3.8956e-02, 4.7375e-02, -8.4562e-02, -1.0087e-02,
2.0478e-02, -6.8286e-02, 5.4663e-02, -5.1468e-02, 1.2617e-02,
2.4625e-02, 5.2167e-02, 5.7779e-02, -5.7788e-02, -1.3323e-02,
1.3913e-02, -7.4439e-02, -4.0981e-02, 5.0797e-02, -5.6230e-02,
-5.0963e-02, -5.5488e-02, -2.7339e-02, 1.0469e-02, 3.4695e-02,
-3.2623e-02, 7.6694e-03, -5.8748e-03, 7.0495e-02, -2.2805e-02,
-5.4334e-03, -2.1636e-02, 1.9597e-02, 6.2370e-02, -2.4995e-02,
1.6165e-02, -4.6878e-03, 2.9743e-02, 1.2653e-02, -5.4271e-02,
1.1247e-02, -3.8340e-03, -4.7489e-02, 1.5719e-02, 3.2552e-02,
6.0665e-02, -1.2330e-02, 2.6115e-02, -2.7376e-02, 3.4152e-02,
-1.0086e-02, -2.4257e-02, 3.2202e-02, -3.2659e-02, 8.6094e-02,
-3.1996e-02, 7.8751e-02, 4.5367e-02, -3.8693e-02, -3.6531e-02,
6.7311e-03, 3.2884e-02, -3.2774e-02, -3.8855e-02, 2.8814e-02,
4.3942e-02, -1.3374e-02, 3.0905e-02, -7.0064e-02, -5.7230e-03,
4.5093e-02, 3.8167e-02, -3.0602e-02, -4.0387e-02, -1.5985e-02,
-9.5962e-02, -1.1354e-02, 2.0879e-02, 1.4092e-02, -3.8405e-02,
1.4337e-02, -6.0682e-02, -9.0190e-03, -5.0898e-02, -4.7344e-02,
4.1045e-02, -6.7031e-02, 8.8112e-02, 3.2149e-02, 3.7748e-02,
-4.0757e-02, 1.4378e-02, -1.0749e-01, 6.1679e-02, -6.7268e-03,
-2.7889e-02, -5.9315e-02, -5.5883e-02, -2.6489e-02, 7.3640e-02,
1.8273e-02, -5.2330e-02, -7.7003e-05, 6.8413e-04, -1.4364e-01,
-1.9389e-02, 4.5649e-02, -4.0468e-02, -4.2819e-02, 4.5874e-02,
-1.6481e-02, 1.2627e-02, -8.4941e-02, -3.7458e-02, 2.1359e-02,
-9.2863e-02, -3.4932e-03, 7.1990e-02, 6.2144e-02, 8.1462e-02,
-2.0569e-02, 5.9194e-02, 1.6996e-03, 8.0618e-03, 6.1753e-02,
4.1602e-02, 1.0910e-02, 2.0523e-02, -9.9781e-04, 1.9131e-02,
-1.0267e-02, -9.4474e-02, -3.5725e-02, 9.9953e-03, -4.3195e-02,
-7.9051e-02, -3.1881e-02, 9.2158e-03, -9.6167e-04, -2.7508e-02,
7.1478e-02, -5.4107e-02, 8.0026e-02, -1.8887e-02, 4.6941e-02,
6.5166e-02, 1.2000e-02, 3.9906e-02, -2.8206e-02, 3.7483e-02,
3.5408e-02, -2.5863e-02, 2.3528e-02, 7.1814e-03, 8.0863e-02,
-1.3736e-02, -8.5978e-02, -4.1238e-02, -1.2545e-02, 5.5479e-02,
7.3487e-03, 8.9125e-02, -3.4814e-02, -4.5358e-02, 4.9893e-02,
3.5286e-02, 3.2084e-02, 5.0868e-02, 2.3549e-02, -9.2907e-02,
-6.9315e-03, -1.3088e-02, 8.7066e-02, 1.1554e-02, 1.3771e-02,
-1.7489e-02, -5.2921e-02, 9.2110e-03, 1.6766e-02, 4.8030e-02,
1.4481e-02, 2.9254e-03, 3.5795e-02, 1.0397e-01, -2.0675e-03,
-2.9916e-02, -5.3299e-02, -2.1396e-02, -5.3189e-02, 3.2805e-02,
-2.6538e-03, -2.6352e-02, -1.2823e-02, 6.1972e-02, 5.4822e-02,
4.5579e-02, -3.6638e-02, 8.1013e-03, -5.6014e-02, 1.5187e-02,
-6.5561e-02]], device='cuda:0', dtype=torch.float64,
grad_fn=)

I wonder if I am running the code in the correct way. Is this index correct and if so, do you know what's happening?

@dongqian0206
Copy link

@wenhao-gao
Hi Wenhao, did you address the previous problem?

Now, the problem becomes object has no attribute '__slices__', which happens around lines 69-71 in model_block.py.

stem_block_batch_idx = (
        torch.tensor(graph_data.__slices__['x'], device=out.device)[graph_data.stems_batch]
        + graph_data.stems[:, 0]
)

@wenhao-gao
Copy link
Author

Hi Dong, my problem was the version of PyG. The author provided the following packages slightly different but works:
torch 1.8.0
torch-cluster 1.5.9
torch-geometric 1.6.3
torch-scatter 2.0.6
torch-sparse 0.6.9
torch-spline-conv 1.2.1

Hope it helps!

@dongqian0206
Copy link

dongqian0206 commented Jul 15, 2022

Hi Wenhao, thanks for your message. I totally ignored the version issue. For the latest torch-geometric, I fixed it using the following code:

stem_block_batch_idx = (
        torch.tensor(graph_data._slice_dict['x'], device=out.device)[graph_data.stems_batch]
        + graph_data.stems[:, 0]
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants