In [45]:
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import math
import torch

In [46]:
def pack_atom_feats(bg, atom_feats):
    bg.ndata['h'] = atom_feats
    gs = dgl.unbatch(bg)
    edit_feats = [g.ndata['h'] for g in gs]
    masks = [torch.ones(g.num_nodes(), dtype=torch.uint8) for g in gs]
    padded_feats = pad_sequence(edit_feats, batch_first=True, padding_value= 0)
    masks = pad_sequence(masks, batch_first=True, padding_value= 0)

    return padded_feats, masks

In [47]:
#take a batched graph of two small molecules ['COC', 'CC']


In [48]:
feats = [torch.tensor([[-0.1575, -0.8111,  0.1657,  0.9680, -0.5555, 0.1657,  0.9680, -0.5555],
                       [-0.7875, -0.4606,  1.0383,  0.2772,  0.8936, 0.1657,  0.9680, -0.5555],
                       [-1.0585, -1.3510, -0.9072, -2.0390,  0.6899, 0.1657,  0.9680, -0.5555]]),
         torch.tensor([[ 0.8600,  0.8323,  1.8069,  0.1786,  2.1482, 0.1657,  0.9680, -0.5555],
                       [ 1.4895,  2.1187,  1.5305,  0.7236, -0.5476, 0.1657,  0.9680, -0.5555]])]

In [115]:
feats[0] #three atoms, three featurs set

tensor([[-0.1575, -0.8111,  0.1657,  0.9680, -0.5555,  0.1657,  0.9680, -0.5555],
        [-0.7875, -0.4606,  1.0383,  0.2772,  0.8936,  0.1657,  0.9680, -0.5555],
        [-1.0585, -1.3510, -0.9072, -2.0390,  0.6899,  0.1657,  0.9680, -0.5555]])

In [116]:
feats[1]

tensor([[ 0.8600,  0.8323,  1.8069,  0.1786,  2.1482,  0.1657,  0.9680, -0.5555],
        [ 1.4895,  2.1187,  1.5305,  0.7236, -0.5476,  0.1657,  0.9680, -0.5555]])

In [117]:
masks = [torch.ones(i.shape[0]) for i in feats]
masks

[tensor([1., 1., 1.]), tensor([1., 1.])]

In [118]:
padded_feats = pad_sequence(feats, batch_first=True, padding_value= 0)
padded_feats.shape

torch.Size([2, 3, 8])

In [119]:
padded_feats

tensor([[[-0.1575, -0.8111,  0.1657,  0.9680, -0.5555,  0.1657,  0.9680,
          -0.5555],
         [-0.7875, -0.4606,  1.0383,  0.2772,  0.8936,  0.1657,  0.9680,
          -0.5555],
         [-1.0585, -1.3510, -0.9072, -2.0390,  0.6899,  0.1657,  0.9680,
          -0.5555]],

        [[ 0.8600,  0.8323,  1.8069,  0.1786,  2.1482,  0.1657,  0.9680,
          -0.5555],
         [ 1.4895,  2.1187,  1.5305,  0.7236, -0.5476,  0.1657,  0.9680,
          -0.5555],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000]]])

In [120]:
masks = pad_sequence(masks, batch_first=True, padding_value= 0)
masks

tensor([[1., 1., 1.],
        [1., 1., 0.]])

In [121]:
def get_adm(mol, max_distance = 6):
    mol_size = mol.GetNumAtoms()
    distance_matrix = np.ones((mol_size, mol_size)) * max_distance + 1
    dm = Chem.GetDistanceMatrix(mol)
    dm[dm > 100] = -1 # remote (different molecule)
    dm[dm > max_distance] = max_distance # remote (same molecule)
    dm[dm == -1] = max_distance + 1
    distance_matrix[:dm.shape[0],:dm.shape[1]] = dm
    return distance_matrix

In [122]:
from rdkit import Chem
import numpy as np

adms = [get_adm(mol) for mol in [Chem.MolFromSmiles(k) for k in ['COC', 'CC']]]
adms

[array([[0., 1., 2.],
        [1., 0., 1.],
        [2., 1., 0.]]),
 array([[0., 1.],
        [1., 0.]])]

In [123]:
def pad_atom_distance_matrix(adm_list):
    max_size = max([adm.shape[0] for adm in adm_list])
    adm_list = [torch.tensor(np.pad(adm, (0, max_size - adm.shape[0]), 'maximum')).unsqueeze(0).long() for adm in adm_list]
    return torch.cat(adm_list, dim = 0)

In [124]:
adms = pad_atom_distance_matrix(adms)
adms

tensor([[[0, 1, 2],
         [1, 0, 1],
         [2, 1, 0]],

        [[0, 1, 1],
         [1, 0, 1],
         [1, 1, 1]]])

Now I am ready for Attention!!!

In [125]:


class Global_Reactivity_Attention(nn.Module):
    def __init__(self, d_model, heads = 8, n_layers = 3, positional_number = 5, dropout = 0.1):
        super(Global_Reactivity_Attention, self).__init__()
        self.n_layers = n_layers
        att_stack = []
        pff_stack = []
        for _ in range(n_layers):
            att_stack.append(MultiHeadAttention(heads, d_model, positional_number, dropout))
            pff_stack.append(FeedForward(d_model, dropout=dropout))
        self.att_stack = nn.ModuleList(att_stack)
        self.pff_stack = nn.ModuleList(pff_stack)

    def forward(self, x, rpm, mask = None):
        att_scores = {}
        for n in range(self.n_layers):
            m, att_score = self.att_stack[n](x, rpm, mask)
            x = x + self.pff_stack[n](x+m)
            att_scores[n] = att_score
        return x, att_scores
                                 

In [126]:

class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 
    
class FeedForward(nn.Module):
    def __init__(self, d_model, dropout = 0.1):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model*2),
            GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model*2, d_model))
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.layer_norm(x)
        return self.net(x)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, positional_number = 5, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()
        self.p_k = positional_number
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        if self.p_k != 0:
            self.relative_k = nn.Parameter(torch.randn(self.p_k, self.d_k))
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Sequential(
                            nn.Linear(d_model, d_model), 
                            nn.ReLU(), 
                            nn.Dropout(dropout),
                            nn.Linear(d_model, d_model))
        self.gating = nn.Linear(d_model, d_model)
        self.to_out = nn.Linear(d_model, d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.reset_parameters()
        
    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        nn.init.constant_(self.gating.weight, 0.)
        nn.init.constant_(self.gating.bias, 1.)
        
    def one_hot_embedding(self, labels):
        y = torch.eye(self.p_k)
        return y[labels]
                
    def forward(self, x, gpm, mask=None):
        bs, atom_size = x.size(0), x.size(1) #this line is done
        x = self.layer_norm(x) 
        k = self.k_linear(x) #this line is done
        q = self.q_linear(x) #this line is done
        v = self.v_linear(x) #this line is done
        k1 = k.view(bs, -1, self.h, self.d_k).transpose(1,2) #this line is done
        q1 = q.view(bs, -1, self.h, self.d_k).transpose(1,2) #this line is done
        v1 = v.view(bs, -1, self.h, self.d_k).transpose(1,2) #this line is done
        attn1 = torch.matmul(q1, k1.permute(0, 1, 3, 2)) #this line is done
        
        if self.p_k == 0:
            attn = attn1/math.sqrt(self.d_k)
        else:
            gpms = self.one_hot_embedding(gpm.unsqueeze(1).repeat(1, self.h, 1, 1)).to(x.device) #this line is done
            attn2 = torch.matmul(q1, self.relative_k.transpose(0, 1)) #this line is done
            attn2 = torch.matmul(gpms, attn2.unsqueeze(-1)).squeeze(-1) #this line is done
            attn = (attn1 + attn2) /math.sqrt(self.d_k) #this line is done
        
        if mask is not None:
            mask = mask.bool() #this line is done
            mask = mask.unsqueeze(1).repeat(1,mask.size(-1),1) #this line is done
            mask = mask.unsqueeze(1).repeat(1,attn.size(1),1,1) #this line is done
            attn[~mask] = float(-9e9) #this line is done
        attn = torch.softmax(attn, dim=-1) #this line is done
        attn = self.dropout1(attn) #this line is done
        v1 = v.view(bs, -1, self.h, self.d_k).permute(0, 2, 1, 3) #this line is done
        output = torch.matmul(attn, v1) #this line is done

        output = output.transpose(1,2).contiguous().view(bs, -1, self.d_model).squeeze(-1)  #this line is done
        output = self.to_out(output * self.gating(x).sigmoid()) # gate self attention
        return self.dropout2(output), attn
#         return output, attn

In [127]:
gra = Global_Reactivity_Attention(d_model=8, heads=2)

In [128]:
#check the instance variables
#gra.__dict__

In [129]:
padded_feats.shape

torch.Size([2, 3, 8])

In [134]:
atom_feat_new, att_scores = gra(x=padded_feats, rpm=adms, mask=masks)
atom_feat_new.shape, #att_scores

(torch.Size([2, 3, 8]),)

In [135]:
att_scores[0].shape

torch.Size([2, 2, 3, 3])

In [136]:
#understanding line by line

In [137]:
d_model = 8
q_linear = nn.Linear(d_model, d_model, bias=False)
q_linear

Linear(in_features=8, out_features=8, bias=False)

In [139]:
atom_feats = padded_feats

In [140]:
x_q = q_linear(atom_feats)
x_q.shape

torch.Size([2, 3, 8])

In [141]:
#same with k_linear
d_model = 8
k_linear = nn.Linear(d_model, d_model, bias=False)
k_linear

Linear(in_features=8, out_features=8, bias=False)

In [143]:
x_k = k_linear(atom_feats)
x_k.shape

torch.Size([2, 3, 8])

In [144]:
#v is a bit different; but the input and output dimension are the same
v_linear = nn.Sequential(nn.Linear(d_model, d_model),
                         nn.ReLU(),
                         nn.Dropout(0.1),
                         nn.Linear(d_model, d_model))

In [146]:
x_v = v_linear(atom_feats)
x_v.shape

torch.Size([2, 3, 8])

In [147]:
#now the fun part of matrix manipulation and shape changes

In [168]:
bs = 2
h = 2
d_k = 4
k1 = x_k.view(bs, -1, h, d_k)
k1.shape

torch.Size([2, 3, 2, 4])

In [169]:
atom_feats.shape

torch.Size([2, 3, 8])

In [170]:
k1 = k1.transpose(1,2)
k1.shape #(B, H, N, dk); B= batch size; H= num heads; N = num bonds/atoms/seq len etc; dk = new feat dimension

torch.Size([2, 2, 3, 4])

In [171]:
q1 = x_q.view(bs, -1, h, d_k).transpose(1,2)
v1 = x_v.view(bs, -1, h, d_k).transpose(1,2)
q1.shape, v1.shape

(torch.Size([2, 2, 3, 4]), torch.Size([2, 2, 3, 4]))

In [172]:
q1.shape

torch.Size([2, 2, 3, 4])

In [173]:
k1.shape

torch.Size([2, 2, 3, 4])

In [174]:
k1 = k1.permute(0,1,3,2)
k1.shape

torch.Size([2, 2, 4, 3])

In [175]:
q1.shape, k1.shape

(torch.Size([2, 2, 3, 4]), torch.Size([2, 2, 4, 3]))

In [179]:
attn1 = torch.matmul(q1, k1)
attn1.shape

torch.Size([2, 2, 3, 3])

In [180]:
#now understand the relative positional encoding
#gpms = self.one_hot_embedding(gpm.unsqueeze(1).repeat(1, self.h, 1, 1))

In [183]:
bcms = adms
bcms

tensor([[[0, 1, 2],
         [1, 0, 1],
         [2, 1, 0]],

        [[0, 1, 1],
         [1, 0, 1],
         [1, 1, 1]]])

In [184]:
bcms

tensor([[[0, 1, 2],
         [1, 0, 1],
         [2, 1, 0]],

        [[0, 1, 1],
         [1, 0, 1],
         [1, 1, 1]]])

In [185]:
bcms.shape

torch.Size([2, 3, 3])

In [186]:
bcms = bcms.unsqueeze(1) #expected output shape: [2,1,3,3]
bcms.shape

torch.Size([2, 1, 3, 3])

In [187]:
bcms = bcms.repeat(1, h, 1, 1) #expected output shape: [2,2,3,3]
bcms.shape #now you see why this is done? It is the same shape as that of attention scores i.e. attn1.shape(2,2,3,3) QK^T

torch.Size([2, 2, 3, 3])

In [188]:
bcms

tensor([[[[0, 1, 2],
          [1, 0, 1],
          [2, 1, 0]],

         [[0, 1, 2],
          [1, 0, 1],
          [2, 1, 0]]],


        [[[0, 1, 1],
          [1, 0, 1],
          [1, 1, 1]],

         [[0, 1, 1],
          [1, 0, 1],
          [1, 1, 1]]]])

In [189]:
#Now what happens when I apply the one_hot_embedding function? Let's see.
def one_hot_embedding(p_k, labels):
    y = torch.eye(p_k)
    return y[labels]

In [190]:
y = torch.eye(5)
y

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

In [191]:
gpms = one_hot_embedding(5, bcms)
gpms.shape #omg, I didn't see this coming; what is it doing?, okay now got it, 

torch.Size([2, 2, 3, 3, 5])

In [192]:
gpms

tensor([[[[[1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [0., 0., 1., 0., 0.]],

          [[0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.]],

          [[0., 0., 1., 0., 0.],
           [0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.]]],


         [[[1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [0., 0., 1., 0., 0.]],

          [[0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.]],

          [[0., 0., 1., 0., 0.],
           [0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.]]]],



        [[[[1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [0., 1., 0., 0., 0.]],

          [[0., 1., 0., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.]],

          [[0., 1., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [0., 1., 0., 0., 0.]]],


         [[[1., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0.],
           [0., 1

In [193]:
#now next step;
p_k = 5
d_k = 4
relative_k = nn.Parameter(torch.randn(p_k, d_k)) #it is just a random matrix of shape p_k, d_k; just that you have the requires_grad=True
relative_k.shape

torch.Size([5, 4])

In [194]:
relative_k

Parameter containing:
tensor([[-0.3844, -0.3392, -0.2740,  0.5339],
        [-0.5518,  1.3238, -1.3641,  0.8611],
        [ 1.6464,  0.0230,  0.4377,  2.0879],
        [-1.2769,  1.2342,  0.4180, -0.7119],
        [ 1.5270, -0.3405, -0.4975,  1.3918]], requires_grad=True)

In [195]:
q1.shape

torch.Size([2, 2, 3, 4])

In [196]:
attn2 = torch.matmul(q1, relative_k.T)
attn2.shape

torch.Size([2, 2, 3, 5])

In [197]:
attn2 = attn2.unsqueeze(-1)
attn2.shape

torch.Size([2, 2, 3, 5, 1])

In [198]:
gpms.shape

torch.Size([2, 2, 3, 3, 5])

In [199]:
attn2 = torch.matmul(gpms, attn2)
attn2.shape

torch.Size([2, 2, 3, 3, 1])

In [200]:
attn2 = attn2.squeeze(-1)
attn2.shape

torch.Size([2, 2, 3, 3])

In [201]:
attn1.shape

torch.Size([2, 2, 3, 3])

In [202]:
attn = (attn1 + attn2) /math.sqrt(d_k)
attn.shape #expected to be the same as that of attn1/attn2


torch.Size([2, 2, 3, 3])

In [203]:
attn

tensor([[[[-0.1167,  0.1943, -0.3456],
          [ 0.1609, -0.0497,  0.0408],
          [-0.1334,  0.0393,  0.0072]],

         [[ 0.1182,  0.2237, -0.7956],
          [ 0.1897,  0.3259, -0.0773],
          [-0.3982,  0.5244, -0.0920]]],


        [[[-0.0737, -0.2128, -0.2208],
          [-0.6171, -0.6275, -0.7236],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.7220,  0.0400, -0.1517],
          [ 0.5035,  0.0426,  0.2673],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<DivBackward0>)

In [204]:
#now the masking part

In [205]:
# mask = mask.bool()
# mask = mask.unsqueeze(1).repeat(1,mask.size(-1),1)
# mask = mask.unsqueeze(1).repeat(1,attn.size(1),1,1)

In [206]:
mask = torch.tensor([[1, 1, 1],[1, 1, 0]])

In [207]:
mask.bool()

tensor([[ True,  True,  True],
        [ True,  True, False]])

In [208]:
mask = mask.unsqueeze(1)
mask.shape

torch.Size([2, 1, 3])

In [209]:
mask = mask.repeat(1, mask.size(-1), 1)
mask.shape

torch.Size([2, 3, 3])

In [210]:
mask

tensor([[[1, 1, 1],
         [1, 1, 1],
         [1, 1, 1]],

        [[1, 1, 0],
         [1, 1, 0],
         [1, 1, 0]]])

In [211]:
mask = mask.unsqueeze(1)
mask

tensor([[[[1, 1, 1],
          [1, 1, 1],
          [1, 1, 1]]],


        [[[1, 1, 0],
          [1, 1, 0],
          [1, 1, 0]]]])

In [212]:
mask.shape

torch.Size([2, 1, 3, 3])

In [213]:
mask = mask.repeat(1, attn.size(1), 1, 1)
mask.shape #now mask has the same shape as that of attn

torch.Size([2, 2, 3, 3])

In [214]:
attn[~mask] = float(-9e9)

In [215]:
attn.shape

torch.Size([2, 2, 3, 3])

In [216]:
attn = torch.softmax(attn, dim=-1)
attn.shape

torch.Size([2, 2, 3, 3])

In [217]:
attn = nn.Dropout(0.1)(attn)
attn.shape

torch.Size([2, 2, 3, 3])

In [218]:
attn

tensor([[[[0.3704, 0.0000, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.0000]],

         [[0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.3704, 0.0000, 0.3704]]],


        [[[0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704]],

         [[0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704],
          [0.3704, 0.3704, 0.3704]]]], grad_fn=<MulBackward0>)

In [219]:
#now the values modification
#v.view(bs, -1, self.h, self.d_k).permute(0, 2, 1, 3); This I have done previously
v1 = x_v.view(bs, -1, h, d_k)
v1.shape

torch.Size([2, 3, 2, 4])

In [220]:
v1 = v1.permute(0, 2, 1, 3)
v1.shape

torch.Size([2, 2, 3, 4])

In [221]:
v1.shape

torch.Size([2, 2, 3, 4])

In [222]:
attn.shape

torch.Size([2, 2, 3, 3])

In [223]:
#now the final equation application (attention score * values) i.e. the weighted sum
output = torch.matmul(attn, v1)
output.shape

torch.Size([2, 2, 3, 4])

In [224]:
#now concatenate all the heads
#output = output.transpose(1,2).contiguous().view(bs, -1, self.d_model).squeeze(-1)

In [225]:
output = output.transpose(1,2)
output.shape

torch.Size([2, 3, 2, 4])

In [226]:
output = output.contiguous() #ask why contiguos is necessary here
output.shape

torch.Size([2, 3, 2, 4])

In [227]:
#now the combining
output = output.view(bs, -1, d_model) #d_model is 8; remember?
output.shape

torch.Size([2, 3, 8])

In [228]:
output = output.squeeze(-1) #why this squeeze, I don't find the reason
output.shape

torch.Size([2, 3, 8])

In [229]:
#now gated self attention; extra
#output = self.to_out(output * self.gating(x).sigmoid()) # gate self attention

In [232]:
atom_feats.shape #this is my x, just that I have not passed through layer_norm

torch.Size([2, 3, 8])

In [233]:
gating = nn.Linear(d_model, d_model)
gating

Linear(in_features=8, out_features=8, bias=True)

In [235]:
x_gate = gating(atom_feats)
x_gate.shape #same expected

torch.Size([2, 3, 8])

In [236]:
x_sigmoid = x_gate.sigmoid()
x_sigmoid.shape

torch.Size([2, 3, 8])

In [237]:
to_out = nn.Linear(d_model, d_model)
to_out

Linear(in_features=8, out_features=8, bias=True)

In [238]:
#first multiplication
out = output * x_sigmoid
out.shape

torch.Size([2, 3, 8])

In [239]:
output_final = to_out(out)
output_final.shape

torch.Size([2, 3, 8])

In [240]:
#final pass through another dropout layer

In [241]:
output_final = nn.Dropout(0.1)(output_final)
output_final.shape

torch.Size([2, 3, 8])

In [242]:
#YOU ARE DONE!!!