In [1]:
import os
import sys
lib_base = os.path.dirname(os.getcwd())
if lib_base not in sys.path:
    sys.path = [lib_base] + sys.path

In [2]:
import torch
from config.evoformer_config import EvoformerConfig
from evoformer.evoformer import Evoformer
from common.helpers.neighbor_utils import get_neighbor_info

# Set Up

## Input

In [12]:
b = 2 # batch size
n = 100 # number of coordinates per batch
node_in = 32 # input dimension for node features
node_hidden = 64 # hidden dimension of node features
node_out = 32 # output dimension of node features
edge_in = 32 # edge hidden dimension
edge_hidden = 64 # edge hidden dimension
edge_out = 64 # edge hidden dimension
N = min(n,20) # number of neigbors to consider per-node

In [13]:
coords = torch.randn(b,n,3)*10
node_feats = torch.randn(b,n,node_in)
edge_feats = torch.randn(b,n,n,edge_in)
neighbor_info = get_neighbor_info(coords=coords, top_k = N, max_radius = 10, exclude_self=False)

## EvoFormer

In [16]:
evoformer_config = EvoformerConfig(
    node_dim_in = node_in,
    edge_dim_in = edge_in,
    edge_dim_out = edge_out,
    node_dim_out = node_out,
    depth = 4,
    node_dim_head = 12,
    node_attn_heads = 6,
    edge_dim_head = 12,
    edge_attn_heads = 4,  
)

evoformer = Evoformer(evoformer_config)

## Run the model

In [17]:
out = evoformer(
    node_feats = node_feats,
    edge_feats = edge_feats,
    nbr_info = neighbor_info,
)

In [18]:
print(out[0].shape,out[1].shape)

torch.Size([2, 100, 32]) torch.Size([2, 100, 100, 64])
