In [None]:
import torch
import numpy as np
import tqdm
import pandas as pd
from tqdm.notebook import tqdm
from typing import List
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
from torch_geometric.nn import Linear, GATv2Conv, GatedGraphConv
import torch.nn.functional as F

from arigin.expressions import generate
from arigin.graph.elements import Result
from arigin.graph.generation import graph_from_expression
from arigin.features import node_features, node_features, edge_features

In [None]:
min_numbers = 2
max_numbers = 2

In [None]:
graph_entities = {"nodes": [], "relationships": []}
node_batches = []
relationship_batches = []
for _ in tqdm(range(50000), total=50000):

    expr = generate(min_numbers, max_numbers)
    single_graph_entities = graph_from_expression(expr)
    graph_entities["nodes"] += single_graph_entities["nodes"]
    graph_entities["relationships"] += single_graph_entities["relationships"]

    node_batches += single_graph_entities["nodes"]
    relationship_batches += single_graph_entities["relationships"]

In [None]:
# Get node features
node_features.fit(graph_entities["nodes"])
scaler = StandardScaler(with_mean=False)

nodes = graph_entities["nodes"]
relationships = graph_entities["relationships"]

id_index_mapping = {
    node.id: index
    for index, node in enumerate(nodes)
}

X = node_features.transform(nodes)
X[:, -1:] = scaler.fit_transform(X[:, -1:])
is_result = np.array([node.__class__ == Result for node in nodes])
y = X[:, [-1]]
X[is_result, -1:] = 0

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in relationships
]
edge_index = torch.tensor(edge_index).T
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

# Get edge features
E = edge_features.fit_transform(relationships)
E = np.vstack((E, -E))
edge_index = torch.tensor(edge_index, dtype=torch.long)
X = torch.tensor(X, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)
y = torch.tensor(y, dtype=torch.float)

dataset = Data(x=X, edge_index=edge_index, edge_attr=E, y=y)

In [None]:
class GAT(torch.nn.Module):
    def __init__(
            self,
            in_channels: int,
            hidden_channels: int,
            emb_channels: int,
            heads: int,
            out_channels: int,
            gat_activation = F.elu,
            dropout_inter_layer=0.1,
            dropout_gat=0.2
            ):

        super().__init__()

        self.embedding_1 = Linear(in_channels, emb_channels)
        self.embedding_2 = Linear(emb_channels, emb_channels)
        self.embedding_3 = Linear(emb_channels, hidden_channels)

        self.gatconv_1 = GATv2Conv(
                hidden_channels,
                hidden_channels, 
                heads, 
                edge_dim=E.shape[1], 
                residual=True,
                add_self_loops=True,
                fill_value=0, 
                dropout=dropout_gat
            )
        
        self.gatconv_2 = GATv2Conv(
                hidden_channels * heads,
                hidden_channels, 
                heads, 
                edge_dim=E.shape[1], 
                residual=True,
                add_self_loops=True,
                fill_value=0, 
                dropout=dropout_gat,
                concat=False
            )

        self.head = Linear(hidden_channels, out_channels)

        self.dropout_inter_layer = dropout_inter_layer
        self.gat_activation = gat_activation


    def embedding(self, x):

        x = self.embedding_1(x)
        x = F.elu(x)
        x = self.embedding_2(x)
        x = F.elu(x)
        x = self.embedding_3(x)
        x = F.elu(x)

        return x

    def forward(self, x, edge_index, edge_attr):

        x_ = self.embedding(x)

        x_ = self.gatconv_1(x_, edge_index, edge_attr)
        x_ = self.gat_activation(x_)
        x_ = F.dropout(x_, p=self.dropout_inter_layer, training=self.training)
        x_ = self.gatconv_2(x_, edge_index, edge_attr)
        x_ = self.gat_activation(x_)
        x_ = F.dropout(x_, p=self.dropout_inter_layer, training=self.training)

        x_ = self.head(x_)
        x_ = x[:, [-1]] + x_

        return x_

In [None]:
model = GAT(
    in_channels=X.shape[1],
    hidden_channels=8,
    emb_channels=8,
    heads=1,
    out_channels=y.shape[1],
    dropout_inter_layer=0.,
    dropout_gat=0.,
    gat_activation=F.elu
)

sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=5e-3)
optimizer.zero_grad()
model.train()
for epoch in tqdm(range(1000), total=1000):
    out = model(x=dataset.x, edge_index=dataset.edge_index, edge_attr=dataset.edge_attr)
    loss = F.l1_loss(out, dataset.y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    loss += loss.item()
    if epoch % 1 == 0:
        print("Epoch {:05d} | Loss {:.6f} |".format(epoch, loss))

In [None]:
expr = "0.2 + 0.95"
ent = graph_from_expression(expr)
model.eval()

nodes = ent["nodes"]
relationships = ent["relationships"]

id_index_mapping = {
    node.id: index
    for index, node in enumerate(nodes)
}

X = node_features.transform(nodes)
X[:, -1:] = scaler.transform(X[:, -1:])
is_result = np.array([node.__class__ == Result for node in nodes])
y = X[:, [-1]]
X[is_result, -1:] = 0

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in relationships
]
edge_index = torch.tensor(edge_index).T
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

# Get edge features
E = edge_features.fit_transform(relationships)
E = np.vstack((E, -E))
edge_index = torch.tensor(edge_index, dtype=torch.long)
X = torch.tensor(X, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)
y = torch.tensor(y, dtype=torch.float)

pred = model(x=X, edge_index=edge_index, edge_attr=E)
loss = F.l1_loss(pred, y)

pred = pred.detach().numpy()

print(np.round(pred, 2))
print(np.round(y.detach().numpy(), 2))
print(loss.item())

In [None]:
X_ = model.embedding(X)

In [None]:
Xa, A = model.gatconv_1(X_, edge_index, E, return_attention_weights=True)
index = A[0].detach().numpy()
att = A[1].detach().numpy()
df = pd.DataFrame(
    index.T, columns=["from", "to"]
)
df["attention"] = att
df.sort_values(["to", "from"]).set_index(["to", "from"])

In [None]:
Xa