# Hyper-graph N-beddings
![Drag Racing](images/HyperGraph_Nbedding.jpg)

In [124]:
import torch
from torch.nn.modules.module import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.linear import Linear
import torch.nn as nn

%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [36]:
src = torch.rand((3,8))
edges = torch.tensor([[0,1,0],
                      [0,0,0],
                      [2,2,0]])

In [37]:
class Edge_Attention(Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.Wq = Linear(embed_dim,embed_dim)
        self.Wk = Linear(embed_dim,embed_dim)
        self.Wv = Linear(embed_dim,embed_dim)
        
        self.edge_embedding = nn.Embedding(3, embed_dim)

    def forward(self, queries, keys, values, edge_matrix):
        # edge_matrix [L,L]
        seq_len = queries.shape[0]
        Q = self.Wq(queries) # [L,E]
        K = self.Wk(keys) # [L,E]
        V = self.Wv(values) # [L,E]
        
        attention_matrix = torch.zeros((seq_len,seq_len)) # [L,L]
        for i in range(seq_len):
            for j in range(seq_len):
                edge_ij_emb = self.edge_embedding(edge_matrix[i,j])
                attention_matrix[i,j] = torch.dot(Q[i]+edge_ij_emb, K[j]+edge_ij_emb)
                
        attention_matrix = torch.softmax(attention_matrix, dim=-1)
        output = torch.matmul(attention_matrix, V)
        print(output)

In [38]:
ea = Edge_Attention(8)
ea(src,src,src,edges)

tensor([[-0.0125,  0.1365,  0.0868, -0.0329,  0.0552, -0.2522,  0.0808,  0.0194],
        [-0.1271,  0.0905,  0.0298, -0.0840,  0.1488, -0.2664,  0.2652, -0.0683],
        [-0.0094,  0.1227,  0.0752, -0.0388,  0.0496, -0.2534,  0.0653,  0.0209]],
       grad_fn=<MmBackward>)


In [176]:
class Batch_Edge_Attention(Module):
    def __init__(self, embed_dim):
        '''
        Edge attation computation, provides structured information to the attention computation. 
        '''
        super().__init__()
        self.Wq = Linear(embed_dim,embed_dim)
        self.Wk = Linear(embed_dim,embed_dim)
        self.Wv = Linear(embed_dim,embed_dim)
        
        self.edge_embedding = nn.Embedding(3, embed_dim)

    def forward(self, queries, keys, values, edge_matrix):
        '''
        L sequence length, N Batch size, E embeding dim
        queries: [L,N,E]
        keys: [L,N,E]
        values: [L,N,E]
        edge_matrix: [N,L,L]
        
        returns: [L,N,E]
        >>> src = torch.rand((2,3,8))
        >>> edges = torch.tensor([[[0,1,0],
                                   [0,0,0],
                                   [2,2,0]],
                                  [[0,0,0],
                                   [1,0,0],
                                   [2,2,0]]])
        >>> ea = Batch_Edge_Attention(8)
        >>> ea(src,src,src,edges).shape
        '''
        queries = queries.permute(1,0,2)
        keys = keys.permute(1,0,2)
        values = values.permute(1,0,2)
        batch_size = edge_matrix.shape[0]
        seq_len = queries.shape[1]
        Q = self.Wq(queries) # [N,L,E]
        K = self.Wk(keys) # [N,L,E]
        V = self.Wv(values) # [N,L,E]
        
        attention_matrix = torch.zeros((batch_size,seq_len,seq_len)) # [N,L,L]
        for i in range(seq_len):
            for j in range(seq_len):
                edge_ij_emb = self.edge_embedding(edge_matrix[:,i,j])
                A = Q[:,i]+edge_ij_emb
                B = K[:,j]+edge_ij_emb
                attention_matrix[:,i,j] = torch.bmm(A.unsqueeze(dim=1), B.unsqueeze(dim=2)).squeeze()
                
        attention_matrix = torch.softmax(attention_matrix, dim=-1)
        output = torch.bmm(attention_matrix, V)
        output = output.permute(1,0,2)
        return output

In [177]:
seq_len = 3
embed_dim = 128
device = "cpu"
src = torch.rand((seq_len,2,embed_dim), device=device)
edges = torch.tensor([[[0,1,0],
                       [0,0,0],
                       [2,2,0]],
                      
                      [[0,0,0],
                       [1,0,0],
                       [2,2,0]]], device=device)
# edges = torch.randint(0,3,(2,embed_dim,embed_dim), device=device)
ea = Batch_Edge_Attention(embed_dim).to(device)
ea(src,src,src,edges).shape

torch.Size([3, 2, 128])

In [152]:
%time ea(src,src,src,edges).shape

CPU times: user 389 ms, sys: 4.92 ms, total: 394 ms
Wall time: 25.9 ms


torch.Size([16, 2, 128])

In [269]:
self_attn = MultiheadAttention(embed_dim, 1)

In [313]:
%time self_attn(src, src, src)[0].shape

CPU times: user 740 ms, sys: 4.06 ms, total: 744 ms
Wall time: 49.2 ms


torch.Size([4096, 2, 128])

# Improving speed with scatter and sparce matrices

In [347]:
class Fast_Edge_Attention(Module):
    def __init__(self, embed_dim):
        '''
        Edge attation computation, provides structured information to the attention computation. 
        '''
        super().__init__()
        self.Wq = Linear(embed_dim,embed_dim, bias=False)
        self.Wk = Linear(embed_dim,embed_dim, bias=False)
        self.Wv = Linear(embed_dim,embed_dim, bias=False)
        
        self.edge_embedding = nn.Embedding(3, embed_dim)

    def forward(self, queries, keys, values, edge_matrix):
        '''
        L sequence length, N Batch size, E embeding dim
        queries: [L,N,E]
        keys: [L,N,E]
        values: [L,N,E]
        edge_matrix: [N,L,L]
        
        returns: [L,N,E]
        >>> src = torch.rand((2,3,8))
        >>> edges = torch.tensor([[[0,1,0],
                                   [0,0,0],
                                   [2,2,0]],
                                  [[0,0,0],
                                   [1,0,0],
                                   [2,2,0]]])
        >>> ea = Batch_Edge_Attention(8)
        >>> ea(src,src,src,edges).shape
        '''
        
        queries = queries.permute(1,0,2)
        keys = keys.permute(1,0,2)
        values = values.permute(1,0,2)
        batch_size = edge_matrix.shape[0]
        seq_len = queries.shape[1]
        Q = self.Wq(queries) # [N,L,E]
        K = self.Wk(keys) # [N,L,E]
        V = self.Wv(values) # [N,L,E]
        
        attention_matrix = torch.bmm(Q,K.permute(0,2,1))
        
        sparse_edges = edges.to_sparse()
        sparse_edges_indices = sparse_edges.indices()
        query_indices = sparse_edges_indices[[True, True, False]]
        key_indices = sparse_edges_indices[[True, False, True]]
        
        query_edge_vectors = Q[query_indices[0],query_indices[1]]
        key_edge_vectors = Q[key_indices[0],key_indices[1]]
        
        indexed_edge_embeddings = self.edge_embedding(sparse_edges.values())
        query_edge_vectors += indexed_edge_embeddings
        key_edge_vectors += indexed_edge_embeddings
        
        edge_attention_values = torch.bmm(query_edge_vectors.unsqueeze(1),key_edge_vectors.unsqueeze(2)).squeeze()
        attention_matrix[sparse_edges_indices[0],sparse_edges_indices[1],sparse_edges_indices[2]] = edge_attention_values
        print(attention_matrix)
                
        attention_matrix = torch.softmax(attention_matrix, dim=-1)
        output = torch.bmm(attention_matrix, V)
        output = output.permute(1,0,2)
        return output

In [349]:
seq_len = 3
embed_dim = 128
device = "cpu"
src = torch.zeros((seq_len,2,embed_dim), device=device)
edges = torch.tensor([[[0,1,0],
                       [0,0,0],
                       [2,2,0]],
                      
                      [[0,0,0],
                       [1,0,0],
                       [2,2,0]]], device=device)
# edges = torch.randint(0,3,(2,embed_dim,embed_dim), device=device)
fea = Fast_Edge_Attention(embed_dim).to(device)
fea.edge_embedding.weight.data = torch.zeros(3,128)
fea.edge_embedding.weight.data[[1],[0]] = 1
fea.edge_embedding.weight.data[[2],[0]] = 2
fea(src,src,src,edges).shape

tensor([[[0., 1., 0.],
         [0., 0., 0.],
         [4., 4., 0.]],

        [[0., 0., 0.],
         [1., 0., 0.],
         [4., 4., 0.]]], grad_fn=<IndexPutBackward>)


torch.Size([3, 2, 128])

In [321]:
edges.to_sparse().indices().shape

torch.Size([3, 21836])

In [317]:
%time fea(src,src,src,edges).shape

CPU times: user 660 ms, sys: 213 ms, total: 872 ms
Wall time: 64.6 ms


torch.Size([4096, 2, 128])

# Parsing the Pen Treebank Dataset

In [352]:
import treebank
train_data = treebank.penn['train']

In [359]:
train_data[:100]

' aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memote'

In [361]:
from torchnlp.datasets import penn_treebank_dataset

In [374]:
dataset = penn_treebank_dataset(train=True)

In [379]:
dataset[:100]

['aer',
 'banknote',
 'berlitz',
 'calloway',
 'centrust',
 'cluett',
 'fromstein',
 'gitano',
 'guterman',
 'hydro-quebec',
 'ipo',
 'kia',
 'memotec',
 'mlx',
 'nahb',
 'punts',
 'rake',
 'regatta',
 'rubens',
 'sim',
 'snack-food',
 'ssangyong',
 'swapo',
 'wachter',
 '</s>',
 'pierre',
 '<unk>',
 'N',
 'years',
 'old',
 'will',
 'join',
 'the',
 'board',
 'as',
 'a',
 'nonexecutive',
 'director',
 'nov.',
 'N',
 '</s>',
 'mr.',
 '<unk>',
 'is',
 'chairman',
 'of',
 '<unk>',
 'n.v.',
 'the',
 'dutch',
 'publishing',
 'group',
 '</s>',
 'rudolph',
 '<unk>',
 'N',
 'years',
 'old',
 'and',
 'former',
 'chairman',
 'of',
 'consolidated',
 'gold',
 'fields',
 'plc',
 'was',
 'named',
 'a',
 'nonexecutive',
 'director',
 'of',
 'this',
 'british',
 'industrial',
 'conglomerate',
 '</s>',
 'a',
 'form',
 'of',
 'asbestos',
 'once',
 'used',
 'to',
 'make',
 'kent',
 'cigarette',
 'filters',
 'has',
 'caused',
 'a',
 'high',
 'percentage',
 'of',
 'cancer',
 'deaths',
 'among',
 'a',
 'gro