In [1]:
import os
import sys
lib_base = os.path.dirname(os.getcwd())
if lib_base not in sys.path:
    sys.path = [lib_base] + sys.path

In [2]:
import torch
from config.se3_transformer_config import SE3TransformerConfig
from se3_transformer.tfn_transformer import TFNTransformer
from common.helpers.neighbor_utils import get_neighbor_info

# Set Up

## Input

In [3]:
b = 2 # batch size
n = 100 # number of coordinates per batch
d_in = (32, 4) # input dimension for scalar and point features
d_hidden = (64, 8) # hidden dimension of scaar and point features
d_out = (32, 8) # output dimension of scalar and point features
d_edge = 32 # edge hidden dimension
N = 12 # number of neigbors to consider per-point

In [4]:
coords = torch.randn(b,n,3)*10
scalar_feats = torch.randn(b,n,d_in[0])
coord_feats = torch.randn(b,n,d_in[1],3)
edge_feats = torch.randn(b,n,n,d_edge)
neighbor_info = get_neighbor_info(coords=coords, top_k = N, max_radius = 10)

## SE(3) - Equivariant Transformer

In [5]:
se3_config = SE3TransformerConfig(
    fiber_in = d_in,
    fiber_hidden = d_hidden,
    fiber_out = d_out,
    heads = (4,4),
    dim_heads = (12, 4),
    edge_dim = d_edge,
    depth = 2,
)

transformer = TFNTransformer(se3_config)

## Run the model

In [6]:
out = transformer(
    feats = {"0":scalar_feats, "1": coord_feats},
    edges = edge_feats,
    neighbor_info = neighbor_info,
)

## View Config

In [7]:
for k,v in vars(se3_config).items():
    print(k,":",v)
    

fiber_in : [(0, 32), (1, 4)]
fiber_hidden : [(0, 64), (1, 8)]
fiber_out : [(0, 32), (1, 8)]
global_feats_dim : None
max_degrees : 2
edge_dim : 32
depth : 2
conv_in_layers : 1
conv_out_layers : 1
project_out : True
norm_out : True
normalize_radial_dists : True
append_norm : True
pair_bias : True
dropout : 0.0
differentiable_coords : False
append_rel_dist : False
append_edge_attn : True
use_re_zero : True
radial_dropout : 0.0
radial_compress : False
radial_mult : 2
checkpoint_tfn : False
heads : [(0, 4), (1, 4)]
dim_head : [(0, 12), (1, 4)]
attend_self : True
use_null_kv : True
linear_proj_keys : False
fourier_encode_rel_dist : False
fourier_rel_dist_feats : 4
share_keys_and_values : False
hidden_mult : 2
share_attn_weights : True
use_dist_sim : False
learn_head_weights : True
use_coord_attn : True
use_dist_conv : False
pairwise_dist_conv : False
num_dist_conv_filters : 16
attn_ty : tfn
nonlin : RecursiveScriptModule(original_name=FusedGELU)


In [8]:
for k,v in se3_config.attn_config()._asdict().items():
    print(k,":",v)

heads : [(0, 4), (1, 4)]
dim_heads : [(0, 12), (1, 4)]
edge_dim : 32
global_feats_dim : None
attend_self : True
use_null_kv : True
share_attn_weights : True
use_dist_sim : False
learn_head_weights : True
use_coord_attn : True
append_edge_attn : True
use_dist_conv : False
pairwise_dist_conv : False
num_dist_conv_filters : 16
pair_bias : True
append_norm : True
append_hidden_dist : True
