In [None]:
import torch
import numpy as np
from tqdm.notebook import tqdm
from typing import List
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
from torch_geometric.nn import Linear, GATv2Conv
import torch.nn.functional as F

from arigin.expressions import generate
from arigin.graph.elements import Result
from arigin.graph.generation import graph_from_expression
from arigin.features import node_features_emb, node_features_values, edge_features

In [None]:
max_numbers = 4
min_numbers = 2

In [None]:
graph_entities = {"nodes": [], "relationships": []}
node_batches = []
relationship_batches = []
for _ in tqdm(range(50000), total=50000):

    expr = generate(min_numbers, max_numbers)
    single_graph_entities = graph_from_expression(expr)
    graph_entities["nodes"] += single_graph_entities["nodes"]
    graph_entities["relationships"] += single_graph_entities["relationships"]

    node_batches += single_graph_entities["nodes"]
    relationship_batches += single_graph_entities["relationships"]

In [None]:
# Get node features
node_features_emb.fit(graph_entities["nodes"])
node_features_values.fit(graph_entities["nodes"])
is_result = np.array([node.__class__ == Result for node in graph_entities["nodes"]])
#y = np.array([node.value if node.value is not None else 0 for node in graph_entities["nodes"]], dtype=float)
scaler = StandardScaler(with_mean=False)#.fit(y[:, np.newaxis]) 

id_index_mapping = {
    node.id: index
    for index, node in enumerate(node_batches)
}

Xemb = node_features_emb.transform(node_batches)
Xval = node_features_values.transform(node_batches)
Xval = scaler.fit_transform(Xval)
y = Xval[:, [-1]]
Xval[is_result, :] = 0

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in relationship_batches
]
edge_index = torch.tensor(edge_index).T
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

# Get edge features
E = edge_features.fit_transform(relationship_batches)
Ez = np.zeros_like(E)
E = np.vstack(
    (
        np.hstack((E, Ez)),
        np.hstack((Ez, E))
    )
)
edge_index = torch.tensor(edge_index, dtype=torch.long)
Xemb = torch.tensor(Xemb, dtype=torch.int)
Xval = torch.tensor(Xval, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)

#y = np.array([node.value if node.value is not None else 0 for node in node_batches], dtype=float)
#y = y[:, np.newaxis]
#y = scaler.transform(y)
y = torch.tensor(y, dtype=torch.float)

dataset = Data(x=[Xemb, Xval], edge_index=edge_index, edge_attr=E, y=y)

In [None]:
import matplotlib.pyplot as plt

plt.hist(y, bins=200, log=True)

In [None]:
class NodeEmbedding(torch.nn.Module):

    def __init__(
            self, 
            num_embeddings, 
            embedding_channels: int = 2,
            out_channels: int = 6
            ):

        super().__init__()
        self.embedding = torch.nn.Embedding(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_channels
        )
        self.output = Linear(
            out_channels,
            out_channels
        )

        self.intermed1 = Linear(embedding_channels + 2, out_channels)
        self.intermed2 = Linear(out_channels, out_channels)
        self.out_channels = out_channels

    def forward(self, X: List[torch.Tensor]):

        x = torch.concat(
            [
                self.embedding(X[0]).squeeze(dim=1),
                X[1]
            ],
            dim=-1
        )
        x = F.relu(self.intermed1(x))
        x = F.relu(self.intermed2(x))
        return F.relu(self.output(x))

    
class GAT(torch.nn.Module):
    def __init__(
            self,
            hidden_channels: int, 
            out_channels: int, 
            heads: int,
            embedding: NodeEmbedding,
            num_gat_layers: int = 8,
            gat_activation = F.elu,
            dropout_inter_layer=0.1,
            dropout_gat=0.2):

        super().__init__()

        self.embedding = embedding

        self.gatconv = torch.nn.ModuleList()

        self.gatconv.append(
            GATv2Conv(
                embedding.out_channels,
                hidden_channels, 
                heads, 
                edge_dim=E.shape[1], 
                residual=True,
                add_self_loops=True, 
                fill_value=0, 
                dropout=dropout_gat
            )
        )
        for _ in range(num_gat_layers - 1):
            self.gatconv.append(
                GATv2Conv(
                    hidden_channels * heads,
                    hidden_channels, 
                    heads, 
                    edge_dim=E.shape[1], 
                    residual=True,
                    add_self_loops=True, 
                    fill_value=0, 
                    dropout=dropout_gat
                )
            )

        self.head = Linear(hidden_channels * heads, out_channels)

        self.dropout_inter_layer = dropout_inter_layer
        self.gat_activation = gat_activation


    def forward(self, x, edge_index, edge_attr):

        value = x[-1]

        x_ = self.embedding(x)
        for layer in self.gatconv:
            x_ = layer(x_, edge_index, edge_attr)
            x_ = self.gat_activation(x_)
            x_ = F.dropout(x_, p=self.dropout_inter_layer, training=self.training)

        x_ = self.head(x_)
        x_ = value[:, [-1]] + x_

        return x_
        

In [None]:
model = GAT(
    hidden_channels=12,
    num_gat_layers=4, 
    heads=1,
    out_channels=y.shape[1],
    dropout_inter_layer=0.,
    dropout_gat=0.,
    gat_activation=F.elu,
    embedding=NodeEmbedding(6, 2, out_channels=12)
)

sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0)
optimizer.zero_grad()
model.train()
for epoch in tqdm(range(1000), total=1000):
    out = model(x=dataset.x, edge_index=dataset.edge_index, edge_attr=dataset.edge_attr)
    loss = F.l1_loss(out, dataset.y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    loss += loss.item()
    if epoch % 1 == 0:
        print("Epoch {:05d} | Loss {:.6f} |".format(epoch, loss))

In [None]:
expr = "(0.2 * 0.3)"
ent = graph_from_expression(expr)
model.eval()

id_index_mapping = {
    node.id: index
    for index, node in enumerate(ent["nodes"])
}

Xemb = node_features_emb.transform(ent["nodes"])
Xval = node_features_values.transform(ent["nodes"])

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in ent["relationships"]
]
edge_index = torch.tensor(edge_index).T
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

# Get edge features
E = edge_features.fit_transform(ent["relationships"])
Ez = np.zeros_like(E)
E = np.vstack(
    (
        np.hstack((E, Ez)),
        np.hstack((Ez, E))
    )
)
edge_index = torch.tensor(edge_index, dtype=torch.long)
Xemb = torch.tensor(Xemb, dtype=torch.int)
Xval = torch.tensor(Xval, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)

y = np.array([node.value if node.value is not None else 0 for node in ent["nodes"]], dtype=float)
y = y[:, np.newaxis]
y = torch.tensor(y, dtype=torch.float)

pred = model(x=[Xemb, Xval], edge_index=edge_index, edge_attr=E)
loss = F.l1_loss(pred, y)
pred = pred.detach().numpy()
pred = scaler.inverse_transform(np.hstack([pred] * 2))[:, -1]
y = y.detach().numpy()
y = scaler.inverse_transform(np.hstack([y] * 2))[:, -1]

for n, p, t in zip(ent["nodes"], pred, y):
    print(n)

print(np.round(pred.flatten(), 3))
print(np.round(y.flatten(), 3))
print(loss.item())