<a href="https://colab.research.google.com/github/hadwin-357/ProteinMPNN_breakdown/blob/main/model_utils_function_6_positional_embedding_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#test model utils function
# positional embedding
'''
inputs:  num_embedding - embedding dimension, max_relative_feature - cut off for inter residue interaction
offset: tensor of relative distance between residuex index

Function explain:
1. only [-max_relative_fature, max_relative_feature] will be considered effective distance
2. offset out of the effective range, is noted as 2*max_relative_feature +1
3. shift offset in effective range to positive [0, 2*max_relative_feature]
4. generate onehot encoding with dimensiont 2*max_relative_feature +1+1  # [0,2*max_relative_feature +1 ]
5. do MLP to generate trainable postional embeddeing with num_embedding dimension

Return:
E: trainable positional embedding   [ , num_embedding]
'''

In [3]:
class PositionalEncodings(nn.Module):
    def __init__(self, num_embeddings, max_relative_feature=32):
        super(PositionalEncodings, self).__init__()
        self.num_embeddings = num_embeddings
        self.max_relative_feature = max_relative_feature
        self.linear = nn.Linear(2*max_relative_feature+1+1, num_embeddings)

    def forward(self, offset, mask):
        d = torch.clip(offset + self.max_relative_feature, 0, 2*self.max_relative_feature)*mask + (1-mask)*(2*self.max_relative_feature+1)
        d_onehot = torch.nn.functional.one_hot(d, 2*self.max_relative_feature+1+1)
        E = self.linear(d_onehot.float())
        return E

In [4]:
import torch
import torch.nn as nn

# Example usage
num_embeddings = 10  # Number of output embeddings
max_relative_feature = 3  # Maximum relative feature

# Create an instance of PositionalEncodings
pos_encodings = PositionalEncodings(num_embeddings, max_relative_feature)

# Generate some sample data
offset = torch.tensor([[0, 1, 2, 3], [-1, 0, 1, 4]])  # Relative offset of positions
mask = torch.tensor([[1, 1, 0, 1], [1, 1, 1, 1]])     # Binary mask indicating which positions to consider

# Forward pass through the PositionalEncodings module
output_embeddings = pos_encodings(offset, mask)

print("Output embeddings:")
print(output_embeddings)


Output embeddings:
tensor([[[-0.1943, -0.0680, -0.0904, -0.0353, -0.2950, -0.4314,  0.0384,
          -0.0398, -0.2900,  0.1954],
         [-0.1738, -0.5044, -0.4461,  0.1775, -0.4461, -0.2965,  0.5042,
           0.1289, -0.2309,  0.0477],
         [-0.3550, -0.0270, -0.3596,  0.1300, -0.6334,  0.1781,  0.0190,
          -0.1367, -0.4597, -0.2480],
         [-0.4692, -0.2868,  0.1567, -0.1909, -0.1305,  0.1703,  0.0346,
          -0.2905, -0.4425, -0.2349]],

        [[-0.1740, -0.0297, -0.0584,  0.0007, -0.0610, -0.0790, -0.0027,
          -0.0373, -0.5962, -0.0238],
         [-0.1943, -0.0680, -0.0904, -0.0353, -0.2950, -0.4314,  0.0384,
          -0.0398, -0.2900,  0.1954],
         [-0.1738, -0.5044, -0.4461,  0.1775, -0.4461, -0.2965,  0.5042,
           0.1289, -0.2309,  0.0477],
         [-0.4692, -0.2868,  0.1567, -0.1909, -0.1305,  0.1703,  0.0346,
          -0.2905, -0.4425, -0.2349]]], grad_fn=<ViewBackward0>)


In [5]:
# shift the offset non-negative then clip between (0, 2*max_realtive feature), if not masked then label as  2*max_relative_feature+1
d = torch.clip(offset + max_relative_feature, 0, 2*max_relative_feature)*mask + (1-mask)*(2*max_relative_feature+1)
d

tensor([[3, 4, 7, 6],
        [2, 3, 4, 6]])

In [6]:
#then turn d into one-hot encoding
d_onehot = torch.nn.functional.one_hot(d, 2*max_relative_feature+1+1)
d_onehot.shape

torch.Size([2, 4, 8])

In [7]:
#followed by MLP to get the trainable postional embedding
[2,4, 8] ---> [2, 4, num_embeddings]
output_embeddings.shape


torch.Size([2, 4, 10])