In [None]:
import dgl
import torch 
from ogb.utils import smiles2graph
from ogb.lsc import DglPCQM4MDataset
import networkx as nx

import torch.nn as nn
import torch.nn.functional as F

In [None]:
dataset = DglPCQM4MDataset(root='/home/ksadowski/datasets', smiles2graph=smiles2graph)

In [None]:
dataset

In [None]:
molecule = dgl.to_homogeneous(dataset[2736][0], ndata=dataset[2736][0].ndata, edata=dataset[2736][0].edata, store_type=False)
molecule_lg = dgl.line_graph(molecule, backtracking=False)

# molecule = dgl.add_self_loop(molecule)

print(f'Num nodes: {molecule.num_nodes()}')
print(f'Num edges: {molecule.num_edges()}')
print(f'Num source nodes: {molecule.num_src_nodes()}')

nx.draw_kamada_kawai(molecule.to_networkx(), with_labels=True)

In [None]:
for i in range(molecule.num_edges()):
    source_edge = molecule_lg.edges()[0][i].item()
    destination_edge = molecule_lg.edges()[1][i].item()

    print(molecule.find_edges(source_edge))
    print(molecule.find_edges(destination_edge))
    print()

for i in range(molecule.num_edges()):
    print(molecule.find_edges(i))


In [None]:
mask = torch.tensor([[molecule.find_edges(i)[1]] for i in range(molecule.num_edges())])

for i in mask:
    print(i.item())

In [None]:
print(molecule_lg.edges())
print(molecule_lg.adjacency_matrix().to_dense())

nx.draw_circular(molecule_lg.to_networkx(), with_labels=True)

In [None]:
molecule.nodes()

In [None]:
molecule.ndata

In [None]:
molecule.edata['feat']

In [None]:
# calculate edge attention
W = torch.ones([3])
edge_attention = molecule_lg.adjacency_matrix() @ molecule.edata['feat'].float() @ W
edge_attention = nn.Softmax(dim=0)(edge_attention)

edge_attention

In [None]:
# calculate node attention
W = torch.ones([9])
node_attention = molecule.adjacency_matrix() @ molecule.ndata['feat'].float() @ W
node_attention = nn.Softmax(dim=0)(node_attention)

W.shape

In [None]:
adjacency = molecule.adjacency_matrix().to_dense()

adjacency

In [None]:
# attention adjacency matrix
for edge in range(molecule.number_of_edges()):
    i = molecule.edges()[0][edge].item()
    j = molecule.edges()[1][edge].item()
    attention = edge_attention[edge].item()

    print(i, j, attention)
    
    adjacency[i][j] = attention

adjacency

In [None]:
adjacency_lg = molecule_lg.adjacency_matrix().to_dense()

print(adjacency_lg)

nx.draw_kamada_kawai(molecule_lg.to_networkx(), with_labels=True)

In [None]:
# attention lg adjacency matrix

for i in range(molecule.number_of_edges()):
    for j in range(molecule.number_of_edges()):
        if adjacency_lg[i][j] == 1:
            assert molecule.edges()[1][i].item() == molecule.edges()[0][j].item()
            
            connecting_node = molecule.edges()[1][i].item()
            attention = node_attention[connecting_node].item()

            adjacency_lg[i][j] = attention

# adjacency_lg @ molecule.edata['feat'].float()
adjacency_lg



In [None]:
# try to use it (edge 0 here -> edge 0 in molecule destination node?)

molecule_lg.edges()

In [None]:
connecting_node = molecule.edges()[1][molecule_lg.edges()[0][16].item()].item()

connecting_node

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_feats):
        super().__init__()
        self.key_linear = nn.Linear(in_feats, 1)
        self.value_linear = nn.Linear(in_feats, in_feats)

    def forward(self, inputs):
        key = self.key_linear(inputs)
        value = self.value_linear(inputs)

        x = value @ inputs.t() @ key
        x = F.softmax(x, dim=0)

        return x

class AttentionGraphConv(nn.Module):
    def __init__(self, in_feats):
        self.linear = nn.Linear(in_feats, in_feats)

    def forward(self, inputs: torch.Tensor(), adjacency: torch.Tensor(), attention: torch.Tensor()):
        x = adjacency * attention @ inputs
        x = self.linear(x)

        return x


class Head(nn.Module):
    def __init__(self, node_in_feats, edge_in_feats):
        super().__init__()
        self.node_attention = SelfAttention(node_in_feats)
        self.node_query_linear = nn.Linear(node_in_feats, node_in_feats)
        self.node_query_conv = AttentionGraphConv(node_in_feats)
        self.edge_attention = SelfAttention(edge_in_feats)
        self.edge_query_linear = nn.Linear(edge_in_feats, edge_in_feats)
        self.edge_query_conv = AttentionGraphConv(edge_in_feats)

    def forward(
        self, 
        g: dgl.DGLGraph, 
        g_adjacency: torch.Tensor, 
        lg: dgl.DGLGraph, 
        lg_adjacency: torch.Tensor, 
        node_feats: torch.Tensor, 
        edge_feats: torch.Tensor,
    ):
        node_attention = self.node_attention(node_feats)
        node_query = self.node_query_mlp(node_feats)

        edge_attention = self.edge_attention(edge_feats)
        edge_query = self.edge_query_mlp(edge_feats)

        node_attention_to_edge = torch.tensor([[node_attention[g.find_edges(edge)[1]]] for edge in range(g.num_edges())])
        
        edge_query = self.edge_query_conv(edge_query, lg_adjacency, node_attention_to_edge)

        

        return node_attention, edge_attention

node_feats = molecule.ndata['feat'].float()
edge_feats = molecule.edata['feat'].float()

attentions = Head(9, 3)(node_feats, edge_feats)

print(attentions)


In [None]:
molecule_lg.adjacency_matrix().shape[0]
mask = torch.tensor([[molecule.find_edges(i)[1]] for i in range(molecule.num_edges())])

node_attention = attentions[0]

edge_attention_from_node = torch.tensor([[node_attention[molecule.find_edges(edge)[1]]] for edge in range(molecule.num_edges())])

molecule_lg.adjacency_matrix().to_dense() * edge_attention_from_node @ molecule.edata['feat'].float()

