In [1]:
import torch
from torch_geometric.nn import SchNet, DimeNet
from torch_geometric.data import Data
import math

In [2]:
schnet = SchNet(
    hidden_channels=128,
    num_filters=128,
    num_interactions=6,
    num_gaussians=50,
    cutoff=10.0,
    max_num_neighbors=32,
    readout="add",
)

In [3]:
dimenet = DimeNet(
    hidden_channels=128,
    out_channels=1,
    num_blocks=6,
    num_bilinear=8,
    num_spherical=7,
    num_radial=6,
    cutoff=5.0,
    envelope_exponent=5,
    num_before_skip=1,
    num_after_skip=2,
    num_output_layers=3,
)

In [4]:
num_atoms = 5
z = torch.tensor([6, 1, 1, 1, 1], dtype=torch.long)  # 炭素原子1つと水素原子4つ
pos = torch.randn(num_atoms, 3)  # ランダムな3D座標

# データオブジェクトの作成
data = Data(z=z, pos=pos)

# 予測
schnet.eval()
dimenet.eval()
with torch.no_grad():
    y_schnet = schnet(data.z, data.pos)
    y_dimenet = dimenet(data.z, data.pos)
print(f"Prediction: {y_schnet.item()}")
print(f"Prediction: {y_dimenet.item()}")

Prediction: 2.3206124305725098
Prediction: 0.0


Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
