In [4]:
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import math
import torch

def get_bdm(bonds, max_size):  # bond distance matrix
    temp = torch.eye(max_size)
    for i, bond1 in enumerate(bonds):
        for j, bond2 in enumerate(bonds):
            if i >= j:
                continue
            if torch.unique(torch.cat([bond1, bond2])).size(0) < 4:  # at least on overlap
                temp[i][j], temp[j][i] = 1, 1 # connect
    return temp.unsqueeze(0).long()


def pack_bond_feats(bonds_feats, pooled_bonds):
    #print('pooled bonds are:\n',pooled_bonds)
    masks = [torch.ones(len(feats), dtype=torch.uint8) for feats in bonds_feats]
    padded_feats = pad_sequence(bonds_feats, batch_first=True, padding_value= 0)
    bdms = [get_bdm(bonds, padded_feats.size(1)) for bonds in pooled_bonds]
    #print('bdms looks like this:\n', bdms)
    masks = pad_sequence(masks, batch_first=True, padding_value= 0)
    return padded_feats, masks, torch.cat(bdms, dim = 0)


In [5]:
pooled_bonds = [torch.tensor([[1,3],[3,1],[2,4]]), torch.tensor([[1,3],[4,2]])]
pooled_bonds

[tensor([[1, 3],
         [3, 1],
         [2, 4]]),
 tensor([[1, 3],
         [4, 2]])]

In [6]:
#torch.rand(2,3,8)

In [140]:
bonds_feats = [torch.tensor([[0.7613, 0.0477, 0.9314, 0.3316, 0.4118, 0.8059, 0.3685, 0.1201],
                             [0.6702, 0.5265, 0.1083, 0.9956, 0.5293, 0.2533, 0.9514, 0.9963],
                             [0.6594, 0.6021, 0.8603, 0.9002, 0.3984, 0.3284, 0.3642, 0.1200]]),
               torch.tensor([[0.1940, 0.9269, 0.1595, 0.5202, 0.1896, 0.2814, 0.2581, 0.0624],
                            [0.6663, 0.5718, 0.8896, 0.6519, 0.4934, 0.9678, 0.3778, 0.5658]])]

In [141]:

bond_feats, mask, bcms = pack_bond_feats(bonds_feats, pooled_bonds)

In [142]:
bond_feats

tensor([[[0.7613, 0.0477, 0.9314, 0.3316, 0.4118, 0.8059, 0.3685, 0.1201],
         [0.6702, 0.5265, 0.1083, 0.9956, 0.5293, 0.2533, 0.9514, 0.9963],
         [0.6594, 0.6021, 0.8603, 0.9002, 0.3984, 0.3284, 0.3642, 0.1200]],

        [[0.1940, 0.9269, 0.1595, 0.5202, 0.1896, 0.2814, 0.2581, 0.0624],
         [0.6663, 0.5718, 0.8896, 0.6519, 0.4934, 0.9678, 0.3778, 0.5658],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])

In [143]:
mask

tensor([[1, 1, 1],
        [1, 1, 0]], dtype=torch.uint8)

In [144]:
bcms

tensor([[[1, 1, 0],
         [1, 1, 0],
         [0, 0, 1]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 0, 1]]])

Now I am ready for Attention!!!

In [145]:


class Global_Reactivity_Attention(nn.Module):
    def __init__(self, d_model, heads = 8, n_layers = 3, positional_number = 5, dropout = 0.1):
        super(Global_Reactivity_Attention, self).__init__()
        self.n_layers = n_layers
        att_stack = []
        pff_stack = []
        for _ in range(n_layers):
            att_stack.append(MultiHeadAttention(heads, d_model, positional_number, dropout))
            pff_stack.append(FeedForward(d_model, dropout=dropout))
        self.att_stack = nn.ModuleList(att_stack)
        self.pff_stack = nn.ModuleList(pff_stack)

    def forward(self, x, rpm, mask = None):
        att_scores = {}
        for n in range(self.n_layers):
            m, att_score = self.att_stack[n](x, rpm, mask)
            x = x + self.pff_stack[n](x+m)
            att_scores[n] = att_score
        return x, att_scores
                                 

In [146]:

class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 
    
class FeedForward(nn.Module):
    def __init__(self, d_model, dropout = 0.1):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model*2),
            GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model*2, d_model))
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.layer_norm(x)
        return self.net(x)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, positional_number = 5, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()
        self.p_k = positional_number
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        if self.p_k != 0:
            self.relative_k = nn.Parameter(torch.randn(self.p_k, self.d_k))
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Sequential(
                            nn.Linear(d_model, d_model), 
                            nn.ReLU(), 
                            nn.Dropout(dropout),
                            nn.Linear(d_model, d_model))
        self.gating = nn.Linear(d_model, d_model)
        self.to_out = nn.Linear(d_model, d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.reset_parameters()
        
    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        nn.init.constant_(self.gating.weight, 0.)
        nn.init.constant_(self.gating.bias, 1.)
        
    def one_hot_embedding(self, labels):
        y = torch.eye(self.p_k)
        return y[labels]
                
    def forward(self, x, gpm, mask=None):
        bs, atom_size = x.size(0), x.size(1) #this line is done
        x = self.layer_norm(x) 
        k = self.k_linear(x) #this line is done
        q = self.q_linear(x) #this line is done
        v = self.v_linear(x) #this line is done
        k1 = k.view(bs, -1, self.h, self.d_k).transpose(1,2) #this line is done
        q1 = q.view(bs, -1, self.h, self.d_k).transpose(1,2) #this line is done
        v1 = v.view(bs, -1, self.h, self.d_k).transpose(1,2) #this line is done
        attn1 = torch.matmul(q1, k1.permute(0, 1, 3, 2)) #this line is done
        
        if self.p_k == 0:
            attn = attn1/math.sqrt(self.d_k)
        else:
            gpms = self.one_hot_embedding(gpm.unsqueeze(1).repeat(1, self.h, 1, 1)).to(x.device) #this line is done
            attn2 = torch.matmul(q1, self.relative_k.transpose(0, 1)) #this line is done
            attn2 = torch.matmul(gpms, attn2.unsqueeze(-1)).squeeze(-1) #this line is done
            attn = (attn1 + attn2) /math.sqrt(self.d_k) #this line is done
        
        if mask is not None:
            mask = mask.bool() #this line is done
            mask = mask.unsqueeze(1).repeat(1,mask.size(-1),1) #this line is done
            mask = mask.unsqueeze(1).repeat(1,attn.size(1),1,1) #this line is done
            attn[~mask] = float(-9e9) #this line is done
        attn = torch.softmax(attn, dim=-1) #this line is done
        attn = self.dropout1(attn) #this line is done
        v1 = v.view(bs, -1, self.h, self.d_k).permute(0, 2, 1, 3) #this line is done
        output = torch.matmul(attn, v1) #this line is done

        output = output.transpose(1,2).contiguous().view(bs, -1, self.d_model).squeeze(-1)  #this line is done
        output = self.to_out(output * self.gating(x).sigmoid()) # gate self attention
        return self.dropout2(output), attn
#         return output, attn

In [150]:
gra = Global_Reactivity_Attention(d_model=8, heads=2)

In [151]:
#check the instance variables
#gra.__dict__

In [152]:
bond_feat_new, att_scores = gra(x=bond_feats, rpm=bcms, mask=mask)

torch.Size([2, 2, 3, 3]) torch.Size([2, 2, 3, 4])
torch.Size([2, 2, 3, 3]) torch.Size([2, 2, 3, 4])
torch.Size([2, 2, 3, 3]) torch.Size([2, 2, 3, 4])


In [17]:
bond_feat_new.shape

torch.Size([2, 3, 8])

In [18]:
att_scores[0].shape

torch.Size([2, 8, 3, 3])

In [19]:
#understanding line by line

In [21]:
d_model = 8
q_linear = nn.Linear(d_model, d_model, bias=False)
q_linear

Linear(in_features=8, out_features=8, bias=False)

In [23]:
x_q = q_linear(bond_feats)
x_q.shape

torch.Size([2, 3, 8])

In [24]:
#same with k_linear
d_model = 8
k_linear = nn.Linear(d_model, d_model, bias=False)
k_linear

Linear(in_features=8, out_features=8, bias=False)

In [25]:
x_k = k_linear(bond_feats)
x_k.shape

torch.Size([2, 3, 8])

In [27]:
#v is a bit different; but the input and output dimension are the same
v_linear = nn.Sequential(nn.Linear(d_model, d_model),
                         nn.ReLU(),
                         nn.Dropout(0.1),
                         nn.Linear(d_model, d_model))

In [28]:
x_v = v_linear(bond_feats)
x_v.shape

torch.Size([2, 3, 8])

In [29]:
#now the fun part of matrix manipulation and shape changes

In [35]:
bs = 2
h = 2
d_k = 4
k1 = x_k.view(bs, -1, h, d_k)
k1.shape

torch.Size([2, 3, 2, 4])

In [34]:
bond_feats.shape

torch.Size([2, 3, 8])

In [37]:
k1 = k1.transpose(1,2)
k1.shape #(B, H, N, dk); B= batch size; H= num heads; N = num bonds/atoms/seq len etc; dk = new feat dimension

torch.Size([2, 2, 3, 4])

In [38]:
q1 = x_q.view(bs, -1, h, d_k).transpose(1,2)
v1 = x_v.view(bs, -1, h, d_k).transpose(1,2)
q1.shape, v1.shape

(torch.Size([2, 2, 3, 4]), torch.Size([2, 2, 3, 4]))

In [39]:
q1.shape

torch.Size([2, 2, 3, 4])

In [40]:
k1.shape

torch.Size([2, 2, 3, 4])

In [41]:
k1 = k1.permute(0,1,3,2)
k1.shape

torch.Size([2, 2, 4, 3])

In [42]:
attn1 = torch.matmul(q1, k1)
attn1.shape

torch.Size([2, 2, 3, 3])

In [43]:
#now understand the relative positional encoding
#gpms = self.one_hot_embedding(gpm.unsqueeze(1).repeat(1, self.h, 1, 1))

In [44]:
bcms

tensor([[[1, 1, 0],
         [1, 1, 0],
         [0, 0, 1]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 0, 1]]])

In [45]:
bcms.shape

torch.Size([2, 3, 3])

In [47]:
bcms = bcms.unsqueeze(1) #expected output shape: [2,1,3,3]
bcms.shape

torch.Size([2, 1, 3, 3])

In [49]:
bcms = bcms.repeat(1, h, 1, 1) #expected output shape: [2,2,3,3]
bcms.shape #now you see why this is done? It is the same shape as that of attention scores i.e. attn1.shape(2,2,3,3) QK^T

torch.Size([2, 2, 3, 3])

In [70]:
bcms

tensor([[[[1, 1, 0],
          [1, 1, 0],
          [0, 0, 1]],

         [[1, 1, 0],
          [1, 1, 0],
          [0, 0, 1]]],


        [[[1, 0, 0],
          [0, 1, 0],
          [0, 0, 1]],

         [[1, 0, 0],
          [0, 1, 0],
          [0, 0, 1]]]])

In [71]:
#Now what happens when I apply the one_hot_embedding function? Let's see.
def one_hot_embedding(p_k, labels):
    y = torch.eye(p_k)
    return y[labels]

In [72]:
y = torch.eye(5)
y

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

In [73]:
gpms = one_hot_embedding(5, bcms)
gpms.shape #omg, I didn't see this coming; what is it doing?, okay now got it, 

torch.Size([2, 2, 3, 3, 5])

In [74]:
gpms

tensor([[[[[0., 1., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.]],

          [[0., 1., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.]],

          [[1., 0., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.]]],


         [[[0., 1., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.]],

          [[0., 1., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.]],

          [[1., 0., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.]]]],



        [[[[0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [1., 0., 0., 0., 0.]],

          [[1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.]],

          [[1., 0., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.]]],


         [[[0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [1., 0

In [76]:
#now next step;
p_k = 5
d_k = 4
relative_k = nn.Parameter(torch.randn(p_k, d_k)) #it is just a random matrix of shape p_k, d_k; just that you have the requires_grad=True
relative_k.shape

torch.Size([5, 4])

In [77]:
relative_k

Parameter containing:
tensor([[ 0.2876,  1.6704,  0.0232,  2.8696],
        [ 0.2369,  0.8154, -1.8616, -0.2084],
        [ 1.0022,  1.4363,  0.6675, -0.9453],
        [-0.9335,  0.3653, -0.5655, -0.5385],
        [ 2.0957, -0.9150, -0.3297, -0.7697]], requires_grad=True)

In [79]:
q1.shape

torch.Size([2, 2, 3, 4])

In [80]:
attn2 = torch.matmul(q1, relative_k.T)
attn2.shape

torch.Size([2, 2, 3, 5])

In [82]:
attn2 = attn2.unsqueeze(-1)
attn2.shape

torch.Size([2, 2, 3, 5, 1])

In [83]:
gpms.shape

torch.Size([2, 2, 3, 3, 5])

In [84]:
attn2 = torch.matmul(gpms, attn2)
attn2.shape

torch.Size([2, 2, 3, 3, 1])

In [85]:
attn2 = attn2.squeeze(-1)
attn2.shape

torch.Size([2, 2, 3, 3])

In [86]:
attn1.shape

torch.Size([2, 2, 3, 3])

In [88]:
attn = (attn1 + attn2) /math.sqrt(d_k)
attn.shape #expected to be the same as that of attn1/attn2


torch.Size([2, 2, 3, 3])

In [89]:
attn

tensor([[[[-0.2112, -0.2452,  0.1531],
          [ 0.0787, -0.0205,  0.1865],
          [ 0.5586,  0.5176, -0.2876]],

         [[ 0.0474,  0.0353,  0.2991],
          [ 0.3091,  0.2454, -0.2282],
          [-0.2213, -0.2191, -0.0989]]],


        [[[-0.0184,  0.4468,  0.4570],
          [ 0.3119, -0.2268,  0.3394],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.1037, -0.2947, -0.3047],
          [-0.1731,  0.1510, -0.2842],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<DivBackward0>)

In [90]:
#now the masking part

In [91]:
# mask = mask.bool()
# mask = mask.unsqueeze(1).repeat(1,mask.size(-1),1)
# mask = mask.unsqueeze(1).repeat(1,attn.size(1),1,1)

In [104]:
mask = torch.tensor([[1, 1, 1],[1, 1, 0]])

In [105]:
mask.bool()

tensor([[ True,  True,  True],
        [ True,  True, False]])

In [106]:
mask = mask.unsqueeze(1)
mask.shape

torch.Size([2, 1, 3])

In [107]:
mask = mask.repeat(1, mask.size(-1), 1)
mask.shape

torch.Size([2, 3, 3])

In [108]:
mask

tensor([[[1, 1, 1],
         [1, 1, 1],
         [1, 1, 1]],

        [[1, 1, 0],
         [1, 1, 0],
         [1, 1, 0]]])

In [109]:
mask = mask.unsqueeze(1)
mask

tensor([[[[1, 1, 1],
          [1, 1, 1],
          [1, 1, 1]]],


        [[[1, 1, 0],
          [1, 1, 0],
          [1, 1, 0]]]])

In [110]:
mask.shape

torch.Size([2, 1, 3, 3])

In [111]:
mask = mask.repeat(1, attn.size(1), 1, 1)
mask.shape #now mask has the same shape as that of attn

torch.Size([2, 2, 3, 3])

In [112]:
attn[~mask] = float(-9e9)

In [114]:
attn.shape

torch.Size([2, 2, 3, 3])

In [116]:
attn = torch.softmax(attn, dim=-1)
attn.shape

torch.Size([2, 2, 3, 3])

In [119]:
attn = nn.Dropout(0.1)(attn)
attn.shape

torch.Size([2, 2, 3, 3])

In [120]:
attn

tensor([[[[0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704]],

         [[0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.0000, 0.0000, 0.3704]]],


        [[[0.3704, 0.0000, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704]],

         [[0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704]]]], grad_fn=<MulBackward0>)

In [132]:
#now the values modification
#v.view(bs, -1, self.h, self.d_k).permute(0, 2, 1, 3); This I have done previously
v1 = x_v.view(bs, -1, h, d_k)
v1.shape

torch.Size([2, 3, 2, 4])

In [133]:
v1 = v1.permute(0, 2, 1, 3)
v1.shape

torch.Size([2, 2, 3, 4])

In [154]:
v1.shape

torch.Size([2, 2, 3, 4])

In [134]:
attn.shape

torch.Size([2, 2, 3, 3])

In [155]:
#now the final equation application (attention score * values) i.e. the weighted sum
output = torch.matmul(attn, v1)
output.shape

torch.Size([2, 2, 3, 4])

In [156]:
#now concatenate all the heads
#output = output.transpose(1,2).contiguous().view(bs, -1, self.d_model).squeeze(-1)

In [157]:
output = output.transpose(1,2)
output.shape

torch.Size([2, 3, 2, 4])

In [160]:
output = output.contiguous() #ask why contiguos is necessary here
output.shape

torch.Size([2, 3, 2, 4])

In [163]:
#now the combining
output = output.view(bs, -1, d_model) #d_model is 8; remember?
output.shape

torch.Size([2, 3, 8])

In [166]:
output = output.squeeze(-1) #why this squeeze, I don't find the reason
output.shape

torch.Size([2, 3, 8])

In [167]:
#now gated self attention; extra
#output = self.to_out(output * self.gating(x).sigmoid()) # gate self attention

In [170]:
bond_feats.shape #this is my x, just that I have not passed through layer_norm

torch.Size([2, 3, 8])

In [171]:
gating = nn.Linear(d_model, d_model)
gating

Linear(in_features=8, out_features=8, bias=True)

In [172]:
x_gate = gating(bond_feats)
x_gate.shape #same expected

torch.Size([2, 3, 8])

In [175]:
x_sigmoid = x_gate.sigmoid()
x_sigmoid.shape

torch.Size([2, 3, 8])

In [176]:
to_out = nn.Linear(d_model, d_model)
to_out

Linear(in_features=8, out_features=8, bias=True)

In [177]:
#first multiplication
out = output * x_sigmoid
out.shape

torch.Size([2, 3, 8])

In [178]:
output_final = to_out(out)
output_final.shape

torch.Size([2, 3, 8])

In [179]:
#final pass through another dropout layer

In [181]:
output_final = nn.Dropout(0.1)(output_final)
output_final.shape

torch.Size([2, 3, 8])

In [None]:
#YOU ARE DONE!!!