In [1]:
from protein_utils.features.feature_config import InputFeatureConfig
from protein_utils.features.input_embedding import InputEmbedding
from protein_utils.features.input_features import get_input_features
from protein_utils.features.constants import SMALL_SEP_BINS
from protein_utils.constants import AA_ALPHABET


In [2]:
config = InputFeatureConfig(
    # Residue Type (SCALAR)
    embed_res_ty = False,
    res_ty_embed_dim = 32,
    one_hot_res_ty = True,

    # Residue Relative Position (SCALAR)
    embed_res_rel_pos = False,
    res_rel_pos_embed_dim = 6,
    res_rel_pos_bins = 10,
    one_hot_res_rel_pos = True,

    # BB Dihedral (SCALAR)
    embed_bb_dihedral = True,
    bb_dihedral_embed_dim = 6,
    fourier_encode_bb_dihedral = True,
    n_bb_dihedral_fourier_feats = 2,
    one_hot_bb_dihedral = False,
    bb_dihedral_bins = 36,

    # Centrality (SCALAR)
    embed_centrality = False,
    centrality_embed_bins = 6,
    centrality_embed_dim = 6,
    one_hot_centrality = True,

    # Relative Separation (PAIR)
    embed_rel_sep = False,
    rel_sep_embed_dim = 32,
    one_hot_rel_sep = False,
    rel_sep_embed_bins = len(SMALL_SEP_BINS),

    # Relative Distance (PAIR)
    embed_rel_dist = False,
    rel_dist_embed_dim = 16,
    one_hot_rel_dist = True,
    rel_dist_atom_tys = ["CA", "CA", "N", "CA"],
    rel_dist_embed_bins = 32,

    # trRosetta Orientation (PAIR)
    embed_tr_rosetta_ori = True,
    tr_rosetta_ori_embed_dim = 6,
    tr_rosetta_ori_embed_bins = 36,
    fourier_encode_tr_rosetta_ori = True,
    tr_rosetta_fourier_feats = 2,
    one_hot_tr_rosetta_ori = False,

    # Joint Embedding for Pair and Sep (PAIR)
    joint_embed_res_pair_rel_sep = True,
    joint_embed_res_pair_rel_sep_embed_dim = 48,
)

In [3]:
import torch
import numpy as np
b,n,a = 1, 30, 4
coords = torch.randn(n,a,3)
atom_ty_map = {a:i for i,a in enumerate(["N","CA","CB","C"])}
seq = "".join(np.random.choice(np.array([x for x in AA_ALPHABET]),n))

In [4]:
input_feats = get_input_features(
    seq = seq,
    coords = coords,
    atom_ty_to_coord_idx = atom_ty_map,
    config = config
)

In [5]:
for feat_name,feat in input_feats.items():
    print(f" ----------- {feat_name} -----------")
    print(f"raw: {feat.raw_shape}")
    print(f"encoded: {feat.encoded_shape}")
          

 ----------- res_ty -----------
raw: None
encoded: torch.Size([30, 1])
 ----------- rel_pos -----------
raw: torch.Size([30, 1])
encoded: torch.Size([30, 1])
 ----------- bb_dihedral -----------
raw: torch.Size([30, 3])
encoded: torch.Size([30, 3])
 ----------- centrality -----------
raw: torch.Size([30, 1])
encoded: torch.Size([30, 1])
 ----------- rel_sep -----------
raw: torch.Size([30, 30, 1])
encoded: torch.Size([30, 30, 1])
 ----------- tr_ori -----------
raw: torch.Size([30, 30, 3])
encoded: torch.Size([30, 30, 3])
 ----------- rel_dist -----------
raw: torch.Size([30, 30, 2])
encoded: torch.Size([30, 30, 2])


In [6]:
input_emb = InputEmbedding(config)

In [7]:
scalar_in, pair_in = input_emb(input_feats,(n,))

In [8]:
print(scalar_in.shape, pair_in.shape)

torch.Size([30, 67]) torch.Size([30, 30, 142])


In [9]:
print(input_emb.dims)

(67, 142)
