In [36]:
import torch
from torch import Tensor
from typing import Tuple, Optional
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import reset
from torch_scatter import scatter
import torch.nn as nn
import torch.nn.functional as F
class EGCL(MessagePassing):
    def __init__(self,
        in_dims: Tuple[int, Optional[int]],
        out_dims: Tuple[int, Optional[int]],
        hid_dims: int,
        num_radial: int,
        cutoff: float,
        eps: float = 1e-6,
        edges_in_df: int = 0,
        has_v_in: bool = True,
        basis: str = "bessel",
        vector_aggr: str = "mean"):
        super(EGCL, self).__init__(node_dim=0, aggr=None, flow="source_to_target")
        self.vector_aggr = vector_aggr
        self.in_dims = in_dims
        self.si, self.vi = in_dims
        self.out_dims = out_dims
        self.so, self.vo = out_dims
        self.has_v_in = has_v_in
        act_fn = nn.ReLU()
        print(2*self.si + num_radial + edges_in_df)
        self.phi_m = nn.Sequential(
            nn.Linear(2*self.si+num_radial+edges_in_df, hid_dims),
            act_fn,
            nn.Linear(hid_dims, hid_dims)
            )
        self.phi_x = nn.Sequential(
            nn.Linear(hid_dims, hid_dims),
            act_fn,
            nn.Linear(hid_dims,1,bias=False)
        )
        self.phi_h = nn.Sequential(
            nn.Linear(self.si+hid_dims, hid_dims),
            act_fn,
            nn.Linear(hid_dims, self.so)
        )
    def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor, dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
        print("It's my turn")
        print(inputs[0].shape, inputs[1].shape, index.shape)
        ms = scatter(inputs[0], index=index, dim=0, dim_size=dim_size, reduce="sum")
        mv = scatter(inputs[1], index=index, dim=0, dim_size=dim_size, reduce="mean")
        print("Done")
        return ms, mv
    def message(
            self,
            s_i: Tensor, 
            s_j: Tensor, 
            v_i: Tensor, 
            v_j: Tensor,
            d: Tensor,
            r: Tensor,
        ) -> Tensor:
        print(s_i.shape, s_j.shape, d.shape)
        dist = torch.pow(d, 2).unsqueeze(-1)
        a_ij = torch.cat([s_i, s_j, dist], dim=-1)
        ms_j = self.phi_m(a_ij)
        # 得到 X_j - X_i
        # print("V_j:", v_j)
        rel_pos = v_i - v_j
        print("Calulating...")
        # print(((v_j - v_i - (d.unsqueeze(-1) * r)) < 1e-2))
        assert ((v_j - v_i - (d.unsqueeze(-1) * r)) < 1e-3).all()
        print(ms_j.shape, rel_pos.shape)
        print(self.phi_x(ms_j).shape)
        mv_j = rel_pos * self.phi_x(ms_j)
        # print("Hello!")
        return ms_j, mv_j
    def forward(
        self,   
        x: Tuple[Tensor, Tensor],
        edge_index: Tensor,
        edge_attr: Tuple[Tensor, Tensor]):
        s, v = x
        # r 是相对的方向，单位向量
        d, r = edge_attr
        ms, mv = self.propagate(
            edge_index=edge_index,
            dim_size = s.size(0),
            s=s,
            v=v,
            d=d,
            r=r
        )
        
        s = self.phi_h(torch.cat([s, ms], dim=-1))
        v = v + mv
        print(s, v)

model = EGCL((16,3), (128,128), 64, 1, 6)
nodes = torch.randn((100,16))
pos = torch.randn((100,3))
edge_index = torch.randint(0, 99, size=(2,50))
rel_pos = pos[edge_index[0]] - pos[edge_index[1]]
dist = torch.pow(rel_pos, 2).sum(-1).sqrt()
# print("Pos_j:", pos[edge_index[0]])
r = F.normalize(rel_pos, dim=-1, eps=1e-6)
# print((rel_pos-(dist.unsqueeze(-1)*r) < 1e-3).all())
edge_attr = (dist, r)
x = (nodes, pos)
model(x, edge_index, edge_attr)

33
torch.Size([50, 16]) torch.Size([50, 16]) torch.Size([50])
Calulating...
torch.Size([50, 64]) torch.Size([50, 3])
torch.Size([50, 1])
It's my turn
torch.Size([50, 64]) torch.Size([50, 3]) torch.Size([50])
Done
tensor([[-0.2349, -0.1054,  0.0346,  ..., -0.0755,  0.0426, -0.0114],
        [-0.0837, -0.3273,  0.0299,  ..., -0.1072,  0.1049, -0.2252],
        [-0.2218, -0.2731,  0.1572,  ..., -0.1133,  0.1486, -0.1403],
        ...,
        [-0.1854, -0.1367,  0.1585,  ...,  0.0404, -0.0164, -0.1617],
        [-0.0670, -0.0375,  0.2003,  ..., -0.1988,  0.1194, -0.0788],
        [-0.1002, -0.3127,  0.1781,  ..., -0.0506, -0.1033, -0.1486]],
       grad_fn=<AddmmBackward0>) tensor([[ 0.3178,  1.0815,  0.8918],
        [ 0.6076,  1.1006, -1.5017],
        [ 0.9798,  0.7401,  0.8962],
        [-0.9316,  0.2238,  0.6606],
        [-2.3515, -0.4694, -1.3275],
        [-0.4054,  1.7196, -0.9626],
        [-0.4186, -0.2232,  0.8788],
        [-0.0943, -0.1446,  0.6532],
        [ 0.0376, -0.200

In [1]:
class test(object):
    def __init__(self) -> None:
        self._weight = "123"
    def print(self):
        print(self.weight)
        
a = test()
a.print()

<bound method test.print of <__main__.test object at 0x7f3c08232790>>

In [10]:
a = range(300)
b = a[:]
c = a[:-1]
d = a[1:]
print(a,b,c,d)        
print(a)
print(range(10)[2])

range(0, 300) range(0, 300) range(0, 299) range(1, 300)
range(0, 300)
2


In [1]:
import torch.nn as nn
import torch
linear = nn.Linear(21,100)
x = torch.ones(100,21)

print(linear(x))

tensor([[-0.0579,  0.1234,  0.6805,  ..., -0.1826,  0.8181,  0.3417],
        [-0.0579,  0.1234,  0.6805,  ..., -0.1826,  0.8181,  0.3417],
        [-0.0579,  0.1234,  0.6805,  ..., -0.1826,  0.8181,  0.3417],
        ...,
        [-0.0579,  0.1234,  0.6805,  ..., -0.1826,  0.8181,  0.3417],
        [-0.0579,  0.1234,  0.6805,  ..., -0.1826,  0.8181,  0.3417],
        [-0.0579,  0.1234,  0.6805,  ..., -0.1826,  0.8181,  0.3417]],
       grad_fn=<AddmmBackward0>)
