In [1]:
import torch
import numpy as np
from tqdm.notebook import tqdm
from typing import List
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
from torch_geometric.loader import DataListLoader
from torch_geometric.nn import Linear, GATv2Conv
import torch.nn.functional as F

from arigin.expressions import generate
from arigin.graph.generation import graph_from_expression
from arigin.features import node_features_emb, node_features_values, edge_features

In [2]:
max_numbers = 2
min_numbers = 2

In [3]:
graph_entities = {"nodes": [], "relationships": []}
node_batches = []
relationship_batches = []
n_in_batch = 0
for _ in tqdm(range(80000), total=80000):
    if not n_in_batch:
        node_batch = []
        relationship_batch = []
    expr = generate(min_numbers, max_numbers)
    single_graph_entities = graph_from_expression(expr)
    graph_entities["nodes"] += single_graph_entities["nodes"]
    graph_entities["relationships"] += single_graph_entities["relationships"]

    node_batch += single_graph_entities["nodes"]
    relationship_batch += single_graph_entities["relationships"]
    n_in_batch += 1
    if n_in_batch == 64:
        n_in_batch = 0
        node_batches.append(node_batch)
        relationship_batches.append(relationship_batch)

  0%|          | 0/80000 [00:00<?, ?it/s]

In [4]:
# Get node features
node_features_emb.fit(graph_entities["nodes"])
node_features_values.fit(graph_entities["nodes"])
y = np.array([node.value if node.value is not None else 0 for node in graph_entities["nodes"]], dtype=float)
scaler = StandardScaler(with_mean=False).fit(y[:, np.newaxis]) 

dataset = []
for i, (nodes, relationships) in enumerate(zip(node_batches, relationship_batches)):

    id_index_mapping = {
        node.id: index
        for index, node in enumerate(nodes)
    }

    Xemb = node_features_emb.transform(nodes)
    Xval = node_features_values.transform(nodes)
    
    # Get edge indices
    edge_index = [
        (
            id_index_mapping[relationship.source.id],
            id_index_mapping[relationship.target.id]
        )
        for relationship in relationships
    ]
    edge_index = torch.tensor(edge_index).T
    edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

    # Get edge features
    E = edge_features.fit_transform(relationships)
    Ez = np.zeros_like(E)
    E = np.vstack(
        (
            np.hstack((E, Ez)),
            np.hstack((Ez, E))
        )
    )
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    Xemb = torch.tensor(Xemb, dtype=torch.int)
    Xval = torch.tensor(Xval, dtype=torch.float)
    E = torch.tensor(E, dtype=torch.float)

    y = np.array([node.value if node.value is not None else 0 for node in nodes], dtype=float)
    y = y[:, np.newaxis]
    y = scaler.transform(y)
    y = torch.tensor(y, dtype=torch.float)

    dataset.append(Data(x=[Xemb, Xval], edge_index=edge_index, edge_attr=E, y=y))

In [5]:
class NodeEmbedding(torch.nn.Module):

    def __init__(
            self, 
            num_embeddings, 
            embedding_channels: int = 2,
            out_channels: int = 16
            ):

        super().__init__()
        self.embedding = torch.nn.Embedding(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_channels
        )
        self.output = Linear(
            out_channels,
            out_channels
        )

        self.intermed1 = Linear(embedding_channels + 1, out_channels)
        self.intermed2 = Linear(out_channels, out_channels)
        self.out_channels = out_channels

    def forward(self, X: List[torch.Tensor]):

        x = torch.concat(
            [
                self.embedding(X[0]).squeeze(dim=1),
                X[1]
            ],
            dim=-1
        )
        x = F.relu(self.intermed1(x))
        x = F.relu(self.intermed2(x))
        return F.relu(self.output(x))

    
class GAT(torch.nn.Module):
    def __init__(
            self,
            hidden_channels: int, 
            out_channels: int, 
            heads: int,
            embedding: NodeEmbedding,
            num_gat_layers: int = 8,
            gat_activation = F.elu,
            dropout_inter_layer=0.1,
            dropout_gat=0.2):

        super().__init__()

        self.embedding = embedding

        self.gatconv = torch.nn.ModuleList()

        self.gatconv.append(
            GATv2Conv(
                embedding.out_channels,
                hidden_channels, 
                heads, 
                edge_dim=E.shape[1], 
                residual=True,
                add_self_loops=True, 
                fill_value=0, 
                dropout=dropout_gat
            )
        )
        for _ in range(num_gat_layers - 1):
            self.gatconv.append(
                GATv2Conv(
                    hidden_channels * heads,
                    hidden_channels, 
                    heads, 
                    edge_dim=E.shape[1], 
                    residual=True,
                    add_self_loops=True, 
                    fill_value=0, 
                    dropout=dropout_gat
                )
            )

        self.head = Linear(hidden_channels * heads, out_channels)

        self.dropout_inter_layer = dropout_inter_layer
        self.gat_activation = gat_activation


    def forward(self, x, edge_index, edge_attr):

        value = x[-1]

        x_ = self.embedding(x)
        for layer in self.gatconv:
            x_ = layer(x_, edge_index, edge_attr)
            x_ = self.gat_activation(x_)
            x_ = F.dropout(x_, p=self.dropout_inter_layer, training=self.training)

        x_ = self.head(x_)
        x_ = value[:, [-1]] + x_

        return x_
        

In [6]:
model = GAT(
    hidden_channels=8,
    num_gat_layers=8, 
    heads=1,
    out_channels=y.shape[1],
    dropout_inter_layer=0.,
    dropout_gat=0.,
    gat_activation=F.elu,
    embedding=NodeEmbedding(6, 2, out_channels=8)
)

sum(p.numel() for p in model.parameters() if p.requires_grad)

2373

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
optimizer.zero_grad()

loader = DataListLoader(dataset=dataset, shuffle=True)

model.train()
for epoch in tqdm(range(2000), total=2000):
    total_loss = 0
    for batch_no, batch in enumerate(loader):
        data = batch[0]
        out = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr)
        loss = F.l1_loss(out, data.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    if epoch % 1 == 0:
        print(
            "Epoch {:05d} | Loss {:.6f} |".format(
                epoch, total_loss / (batch_no + 1)
            )
        )

  0%|          | 0/2000 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x4 and 3x8)

In [None]:
expr = "(0.2 * 0.3)"
ent = graph_from_expression(expr)
model.eval()

id_index_mapping = {
    node.id: index
    for index, node in enumerate(ent["nodes"])
}

Xemb = node_features_emb.transform(ent["nodes"])
Xval = node_features_values.transform(ent["nodes"])

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in ent["relationships"]
]
edge_index = torch.tensor(edge_index).T
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

# Get edge features
E = edge_features.fit_transform(ent["relationships"])
Ez = np.zeros_like(E)
E = np.vstack(
    (
        np.hstack((E, Ez)),
        np.hstack((Ez, E))
    )
)
edge_index = torch.tensor(edge_index, dtype=torch.long)
Xemb = torch.tensor(Xemb, dtype=torch.int)
Xval = torch.tensor(Xval, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)

y = np.array([node.value if node.value is not None else 0 for node in ent["nodes"]], dtype=float)
y = y[:, np.newaxis]
y = torch.tensor(y, dtype=torch.float)

pred = model(x=[Xemb, Xval], edge_index=edge_index, edge_attr=E)
loss = F.l1_loss(pred, y)

pred = scaler.inverse_transform(pred.detach().numpy())

for n, p, t in zip(ent["nodes"], pred, y.detach().numpy()):
    print(n)

print(np.round(pred.flatten(), 3))
print(np.round(y.detach().numpy().flatten(), 3))
print(loss.item())