In [1]:
import os
import sys
lib_base = os.path.dirname(os.getcwd())
if lib_base not in sys.path:
    sys.path = [lib_base] + sys.path

In [2]:
import torch
from config.se3_transformer_config import SE3TransformerConfig
from se3_transformer.tfn_transformer import TFNTransformer
from common.helpers.neighbor_utils import get_neighbor_info

# Set Up

## Input

In [11]:
b = 2 # batch size
n = 100 # number of coordinates per batch
d_in = (32, 4) # input dimension for scalar and point features
d_hidden = (64, 8) # hidden dimension of scaar and point features
d_out = (32, 8) # output dimension of scalar and point features
d_edge = 32 # edge hidden dimension
N = 12 # number of neigbors to consider per-point

def get_model_input(batch_size):
    coords = torch.randn(batch_size,n,3)*10
    scalar_feats = torch.randn(batch_size,n,d_in[0])
    coord_feats = torch.randn(batch_size,n,d_in[1],3)
    edge_feats = torch.randn(batch_size,n,n,d_edge)
    neighbor_info = get_neighbor_info(coords=coords, top_k = N, max_radius = 10)
    return dict(
        feats = {"0":scalar_feats, "1": coord_feats},
        edges = edge_feats,
        neighbor_info = neighbor_info
    )
    

## SE(3) - Equivariant Transformer

In [12]:
se3_config = SE3TransformerConfig(
    fiber_in = d_in,
    fiber_hidden = d_hidden,
    fiber_out = d_out,
    heads = (4,4),
    dim_heads = (12, 4),
    edge_dim = d_edge,
    depth = 2,
)

transformer = TFNTransformer(se3_config)

## Run the model

In [13]:
out = transformer(**get_model_input(batch_size=4))

# Test Batching

In [14]:
def get_batch(batched_feats, batch_idx):
    scalar_feats = batched_feats["feats"]['0'][batch_idx].unsqueeze(0)
    coord_feats = batched_feats["feats"]['1'][batch_idx].unsqueeze(0)
    edge_feats = batched_feats["edges"][batch_idx].unsqueeze(0)
    neighbor_info = batched_feats["neighbor_info"][batch_idx]
    return dict(
        feats = {"0":scalar_feats, "1": coord_feats},
        edges = edge_feats,
        neighbor_info = neighbor_info
    )

In [15]:
batch_size = 3
batched_input = get_model_input(batch_size)
batched_out = transformer(**batched_input)
expected_out = [transformer(**get_batch(batched_input,i)) for i in range(batch_size)]
# compare
for i in range(batch_size):
    actual_i = {k:v[i].unsqueeze(0) for k,v in batched_out.items()}
    expected_i = expected_out[i]
    for key in expected_i:
        actual, expected = actual_i[key],expected_i[key]
        assert actual.shape == expected.shape
        print(f"batch : {i}, feature_ty : {key}, norm :",torch.norm(actual-expected))

batch : 0, feature_ty : 0, norm : tensor(7.8868e-06, grad_fn=<CopyBackwards>)
batch : 0, feature_ty : 1, norm : tensor(1.5134e-06, grad_fn=<CopyBackwards>)
batch : 1, feature_ty : 0, norm : tensor(7.5952e-06, grad_fn=<CopyBackwards>)
batch : 1, feature_ty : 1, norm : tensor(1.5490e-06, grad_fn=<CopyBackwards>)
batch : 2, feature_ty : 0, norm : tensor(7.7446e-06, grad_fn=<CopyBackwards>)
batch : 2, feature_ty : 1, norm : tensor(1.7034e-06, grad_fn=<CopyBackwards>)


## View Config

In [None]:
for k,v in vars(se3_config).items():
    print(k,":",v)
    

In [None]:
for k,v in se3_config.attn_config()._asdict().items():
    print(k,":",v)