A simple self attention layer with masking and support for multi headed.

In [None]:
# Torch imports
import torch
from torch import nn
import torch.autograd
import torch.nn.functional as F

import numpy as np
from typing import List, Optional, Dict, Tuple

# Local imports
from utils import *


from models import slice_triples

In [None]:
bs = 4
n = 5
ent_emb_dims = 3
rel_emb_dims = 4
out_features = 7
alpha_leaky = 0.2

In [None]:
def self_attention_template():
    # Setting things up
    bs = 4
    n = 5
    ent_emb_dims = 3
    rel_emb_dims = 4
    out_features = 7
    alpha_leaky = 0.2

    matrix = torch.randn(bs,n,2*ent_emb_dims + rel_emb_dims) # concat s,p,o.
    print(f"shape of matrix is bs*n*emb_dim i.e {matrix.shape}")
    
    # passing it through layer1
    w1 = nn.Linear(2 * ent_emb_dim + rel_emb_dim, out_features)
    nn.init.xavier_normal_(w1.weight.data, gain=1.414)

    c = w1(matrix)
    print(f"shape of c is {c.shape}")
    
    # passing it through layer2
    w2 = nn.Linear(out_features,1)
    nn.init.xavier_normal_(w2.weight.data, gain=1.414)

    b = w2(c)
    leaky_relu = nn.LeakyReLU(alpha_leaky)
    b = leaky_relu(b).squeeze()
    print(f"shape of b is {b.shape}")
    
    # There will be no masking here. So simply a softmax and then multiply and sum across n.
    alphas = torch.softmax(b,dim=1)
    h = torch.sum((alphas.unsqueeze(-1)*c),dim=1)
    
    print(f"shape of final vector by {h.shape}")

In [None]:
self_attention_template()

In [None]:
def self_attention_template_multi_head(num_head, final_layer=False):
    # Setting things up
    bs = 4
    n = 5
    ent_emb_dims = 3
    rel_emb_dims = 4
    out_features = 7
    alpha_leaky = 0.2

    matrix = torch.randn(bs,n,2*ent_emb_dims + rel_emb_dims) # concat s,p,o.
    print(f"shape of matrix is bs*n*emb_dim i.e {matrix.shape}")
    
    # passing it through layer1
    w1 = nn.Linear(2 * ent_emb_dims + rel_emb_dims, out_features)
    nn.init.xavier_normal_(w1.weight.data, gain=1.414)

    c = w1(matrix)
    print(f"shape of c is {c.shape}")
    
    # passing it through layer2
    w2 = nn.Linear(out_features,num_head)
    nn.init.xavier_normal_(w2.weight.data, gain=1.414)

    b = w2(c)
    leaky_relu = nn.LeakyReLU(alpha_leaky)
    b = leaky_relu(b).squeeze()
    
    print(f"shape of b is {b.shape}")
    
    # There will be no masking here. So simply a softmax and then multiply and sum across n.
    alphas = torch.softmax(b,dim=1)
    print(f"shape of alphas is {alphas.shape}")
    
    h = torch.bmm(c.transpose(1,2),alphas)
    print(f"shape of h is {h.shape}")
    if not final_layer:
        h = h.view(bs,-1)
        h = F.elu(h)
    else:
        h = torch.mean(h, dim=-1)
        
    print(f"shape of final vector by {h.shape}")

In [None]:
self_attention_template_multi_head(num_head=8, final_layer=True)

In [None]:
class GraphAttentionLayerMultihead(nn.Module):
    
    def __init__(self, config: dict, final_layer: bool = False):
        
        super().__init__()
        
        # Parse params
        ent_emb_dim, rel_emb_dim = config['EMBEDDING_DIM'], config['EMBEDDING_DIM']
        out_features = config['GATARGS']['OUT']
        num_head = config['GATARGS']['HEAD']
        alpha_leaky = config['GATARGS']['ALPHA']
        
        self.w1 = nn.Linear(2 * ent_emb_dim + rel_emb_dim, out_features)
        self.w2 = nn.Linear(out_features, num_head)
        self.relu = nn.LeakyReLU(alpha_leaky)

        self.final = final_layer
        
        # Why copy un-necessary stuff
        self.heads = num_head
        
        # Not initializing here. Should be called by main module
    
    def initialize(self):
        nn.init.xavier_normal_(self.w1.weight.data, gain=1.414)
        nn.init.xavier_normal_(self.w2.weight.data, gain=1.414)
        
    def forward(self, data: torch.Tensor, mask: torch.Tensor=None):
        """ 
            data: size (batchsize, num_neighbors, 2*ent_emb+rel_emb) or (bs, n, emb)
            mask: size (batchsize, num_neighbors)
            
            PS: num_neighbors is padded either with max neighbors or with a limit 
        """
        
                                                      #data: bs, n, emb
        c = self.w1(data)                                #c: bs, n, out_features
        b = self.relu(self.w2(c)).squeeze()              #b: bs, n, num_heads
        m = mask.unsqueeze(-1).repeat(1, 1, self.heads)  #m: bs, n, num_heads
        alphas = masked_softmax(b, m, dim=1)             #α: bs, n, num_heads
        
        print(alphas)
        print(mask)
        
        # BMM simultaneously weighs the triples and sums across neighbors
        h = torch.bmm(c.transpose(1,2),alphas)          #h: bs, out_features, num_heads
        
        if self.final:
            h = torch.mean(h, dim=-1)                   #h: bs, out_features
        else:
            h = F.elu(h.view(bs, -1))                   #h: bs, out_features*num_heads
        
        return h

In [None]:
if False:
    bs = 4
    n = 5
    ent_emb_dims = 3
    rel_emb_dims = 4
    out_features = 7
    alpha_leaky = 0.2

    attn = GraphAttentionLayerMultihead(ent_emb_dims, rel_emb_dims, 
                                        out_features, alpha_leaky, num_head=8, final_layer=False)
    print(attn)

    data = torch.randn(bs, n, 2*ent_emb_dims+rel_emb_dims)
    data[0][2:] = 0
    data[1][4:] = 0
    data[-1][1:] = 0

    mask = compute_mask(data)
    mask_condensed = torch.mean(mask, dim=-1)

    print(data.shape)
    op = attn(data, mask_condensed)

    op, op.shape

In [None]:
class KBGat(BaseModule):
    
    model_name = 'KBGAT'
    
    def __init__(self, config: dict, pretrained_embeddings=None) -> None:
        
        self.margin_ranking_loss_size_average: bool = True
        self.entity_embedding_max_norm: Optional[int] = None
        self.entity_embedding_norm_type: int = 2
        self.model_name = 'KBGAT'
        super().__init__(config)
        self.statement_len = config['STATEMENT_LEN']

        # Embeddings
        self.l_p_norm_entities = config['NORM_FOR_NORMALIZATION_OF_ENTITIES']
        self.scoring_fct_norm = config['SCORING_FUNCTION_NORM']
        self.relation_embeddings = nn.Embedding(config['NUM_RELATIONS'], config['EMBEDDING_DIM'], padding_idx=0)

        self.config = config

        if self.config['PROJECT_QUALIFIERS']:
            self.proj_mat = nn.Linear(2*self.embedding_dim, self.embedding_dim, bias=False)
            
        self.gat1 = GraphAttentionLayerMultihead(self.config, final_layer=False)
        self.gat2 = GraphAttentionLayerMultihead(self.config, final_layer=True)
        
        # Put in weights
        self._initialize(pretrained_embeddings)
        
    def _initialize(self, pretrained_embeddings):
        if pretrained_embeddings is None:
            embeddings_init_bound = 6 / np.sqrt(self.config['EMBEDDING_DIM'])
            nn.init.uniform_(
                self.entity_embeddings.weight.data,
                a=-embeddings_init_bound,
                b=+embeddings_init_bound,
            )
            nn.init.uniform_(
                self.relation_embeddings.weight.data,
                a=-embeddings_init_bound,
                b=+embeddings_init_bound,
            )

            norms = torch.norm(self.relation_embeddings.weight,
                               p=self.config['NORM_FOR_NORMALIZATION_OF_RELATIONS'], dim=1).data
            self.relation_embeddings.weight.data = self.relation_embeddings.weight.data.div(
                norms.view(self.num_relations, 1).expand_as(self.relation_embeddings.weight))

            self.relation_embeddings.weight.data[0] = torch.zeros(1, self.embedding_dim)
            self.entity_embeddings.weight.data[0] = torch.zeros(1, self.embedding_dim)  # zeroing the padding index

        else:
            raise NotImplementedError("Haven't wired in the mechanism to load weights yet fam")

        # Also init the GUTS with bacteria and tapeworms
        self.gat1.initialize(), self.gat2.initialize()
            
    def predict(self, triples_hops) -> torch.Tensor:
        pass
    
    def normalize(self) -> None:
        # Normalize embeddings of entities
        norms = torch.norm(self.entity_embeddings.weight, p=self.l_p_norm_entities, dim=1).data
        
        self.entity_embeddings.weight.data = self.entity_embeddings.weight.data.div(
            norms.view(self.num_entities, 1).expand_as(self.entity_embeddings.weight))
        
        # zeroing the padding index            
        self.entity_embeddings.weight.data[0] = torch.zeros(1, self.embedding_dim)  
            
    def forward(self, pos: List, neg: List) -> torch.Tensor:
        """
            triples of size: (bs, 3)
               hop1 of size: (bs, n, 2) (s and r)
               hop2 of size: (bs, n, 3) (s and r1 and r2)

            (here n -> num_neighbors)
            (here hop2 has for bc it is <s r1 r2 o> )
            
            (pos has pos_triples, pos_hop1, pos_hop2. neg has same.)
        """
        pos_triples, pos_hop1, pos_hop2 = pos
        neg_triples, neg_hop1, neg_hop2 = neg

        self.normalize()

        positive_scores = self._score_triples(pos_triples, pos_hop1, pos_hop2)
        negative_scores = self._score_triples(neg_triples, neg_hop1, neg_hop2)

        loss = self._compute_loss(positive_scores=positive_scores, negative_scores=negative_scores)
        return (positive_scores, negative_scores), loss


    def _score_triples(self, 
                       triples: torch.Tensor, 
                       hop1: torch.Tensor, 
                       hop2: torch.Tensor) -> torch.Tensor:
        """ 
            triples of size: (bs, 3) 
            hop1 of size: (bs, n, 2)
            hop2 of size: (bs, n, 3) 
            
            1. Embed all things so triples (bs, 3, emb), hop1 (bs, n, 3, emb), hop2 (bs, n, 4, emb)
            2. Concat hop1, hop2 to be (bs, n, 3*emb) and (bs, n, 4*emb) each
            3. Pass the baton to some other function.
        """
        triples, hop1, hop2 = self.embed(triples, hop1, hop2)
        
        # TODO: Check this view thing
        hop1 = hop1.view(hop1.shape[0], hop1[1].shape[1], -1)
        hop2 = hop2.view(hop2.shape[0], hop2[1].shape[1], -1)
        
        # DO SOMETHING
        ....


    def embed(self, tr, h1, h2):
        """ The obj is to pass things through entity and rel matrices as needed """
        # Triple 
        s, p, o = slice_triples(tr, 3)                                  #*   : (bs, 1)
        
        
        # Hop1
        h1_s, h1_p = h1[:,:,0], h1[:,:,1]                               #h1_*: (bs, n, 1)
        h1_o = triple[:,-1].repeat(1,h1.shape[1],1)                     #h1_o: (bs, n, emb)
        h1_s = self.entity_embeddings(h1_s)                             #h1_s: (bs, n, emb)
        h1_p = self.relation_embeddings(h1_p)                           #h1_p: (bs, n, emb)
        
        h1 = torch.cat((h1_s, h1_p, h1_o), dim=-1)                      #h1  : (bs, n, 3*emb)
        
        # Compute Mask
        mask = compute_mask(h1)                                         #m   : (bs, n, 3*emb)
        
        gat1_op = self.gat1(h1, mask)                                   #op  : (bs, num_head*out_dim)
        
        
         
        


        


    def _get_relation_embeddings(self, relations):
        return self.relation_embeddings(relations).view(-1, self.embedding_dim)

        


In [None]:
h1 = torch.randint(0, 10, (2, 4, 3))
h1_s, h1_p, h1_o = h1[:,:,0], h1[:,:,1], h1[:,:,2]

h1_s.shape

In [None]:
emb =  nn.Embedding(30, 5)
h1_s, h1_p, h1_o = emb(h1_s), emb(h1_p), emb(h1_o)

# h1_s.shape, h1_p.shape, h1_o.shape
torch.cat((h1_s, h1_p, h1_s), dim=-1).shape
# torch.cat((_a, _b, _c), dim=-1).shape

In [None]:
help(compute_mask)

In [None]:
#         tr_s, tr_p, tr_o = slice_triples(triples, slices = 3)      #each: (bs, 1)
#         tr_s = self.entity_embeddings(tr_s)
#         tr_p = self.relation_embeddings(tr_p)
#         tr_o = self.entity_embeddings(tr_o)
        
#         tr = torch.cat((tr_s, tr_p, tr_o), dim=-1)