Skip to content

Commit

Permalink
[FEAT][TemplateEmbedder]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed May 15, 2024
1 parent 5cd150f commit ade2e1e
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions alphafold3/template_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
import torch.nn as nn
from alphafold3.pairformer import PairFormer


class TemplateEmbedder(nn.Module):
def __init__(
self,
dim: int = None,
depth: int = 2,
seq_len: int = None,
heads: int = 64,
dim_head: int = 64,
attn_dropout: float = 0.0,
ff_dropout: float = 0.0,
global_column_attn: bool = False,
c: int = 64,
Ntemplates: int = 1,
*args,
**kwargs
):
super(TemplateEmbedder, self).__init__()
# Define layers used in the embedding
self.layer_norm_z = nn.LayerNorm(c)
self.layer_norm_v = nn.LayerNorm(c)
self.linear_no_bias_z = nn.Linear(c, c, bias=False)
self.linear_no_bias_a = nn.Linear(c, c, bias=False)
self.pairformer = PairFormer(
dim=dim,
seq_len=seq_len,
heads=heads,
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
depth=depth,
*args,
**kwargs
)
self.relu = nn.ReLU()
self.final_linear = nn.Linear(c, c, bias=False)

def forward(self, f, zij, Ntemplates):
# Step 1-3: Compute various masks and concatenate
template_backbone_frame_mask = f # Placeholder operation
template_pseudo_beta_mask = f # Placeholder operation
template_distogram = f # Placeholder operation
template_unit_vector = f # Placeholder operation

atij = torch.cat(
[
template_distogram,
template_backbone_frame_mask,
template_unit_vector,
template_pseudo_beta_mask,
],
dim=-1,
)

# Step 4-5: Apply masking based on asym_id and concatenate restypes
asym_id_mask = (
f == f
) # Placeholder for actual asym_id comparison logic
atij = atij * asym_id_mask
restype = f # Placeholder for restype feature
atij = torch.cat([atij, restype, restype], dim=-1)

# Initialize uij
uij = torch.zeros_like(atij)

# Step 7-11: Iterate over templates
for _ in range(Ntemplates):
vij = self.linear_no_bias_z(
self.layer_norm_z(zij)
) + self.linear_no_bias_a(atij)
for layer in self.pairformer_stack:
vij = layer(
vij
) # Assuming some residual connection or similar logic in actual Pairformer
uij += self.layer_norm_v(vij)

# Step 12-13: Normalize and apply final transformation
uij /= Ntemplates
uij = self.final_linear(self.relu(uij))

return uij

0 comments on commit ade2e1e

Please sign in to comment.