In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
from torch_geometric.data import Data
from torch_geometric.loader import DataListLoader
from torch_geometric.nn import Linear, GATv2Conv
import torch.nn.functional as F

from arigin.expressions import generate
from arigin.graph.elements import Result
from arigin.graph.generation import graph_from_expression
from arigin.features import node_features, edge_features

In [None]:
max_numbers = 2
min_numbers = 2

In [None]:
graph_entities = {"nodes": [], "relationships": []}
node_batches = []
relationship_batches = []
n_in_batch = 0
for _ in tqdm(range(2000), total=2000):
    if not n_in_batch:
        node_batch = []
        relationship_batch = []
    expr = generate(min_numbers, max_numbers)
    single_graph_entities = graph_from_expression(expr)
    graph_entities["nodes"] += single_graph_entities["nodes"]
    graph_entities["relationships"] += single_graph_entities["relationships"]

    node_batch += single_graph_entities["nodes"]
    relationship_batch += single_graph_entities["relationships"]
    n_in_batch += 1
    if n_in_batch == 64:
        n_in_batch = 0
        node_batches.append(node_batch)
        relationship_batches.append(relationship_batch)
    

In [None]:
# Get node features
node_features.fit(graph_entities["nodes"])
y = np.array([node.value if node.value is not None else 0 for node in graph_entities["nodes"]], dtype=float)
scaler = StandardScaler(with_mean=False).fit(y[:, np.newaxis]) 

dataset = []
for i, (nodes, relationships) in enumerate(zip(node_batches, relationship_batches)):

    id_index_mapping = {
        node.id: index
        for index, node in enumerate(nodes)
    }

    X = node_features.transform(nodes)
    is_result = np.array([node.__class__ == Result for node in nodes])
    X[is_result, -1] = 0
    X = PolynomialFeatures(2, interaction_only=True).fit_transform(X)
    # Drop any non-unique features
    unique = (pd.DataFrame(X).nunique() >= 2)
    unique = unique[unique].index
    X = X[:, unique]

    # Get edge indices
    edge_index = [
        (
            id_index_mapping[relationship.source.id],
            id_index_mapping[relationship.target.id]
        )
        for relationship in relationships
    ]
    edge_index = torch.tensor(edge_index).T
    edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

    # Get edge features
    E = edge_features.fit_transform(relationships)
    Ez = np.zeros_like(E)
    E = np.vstack(
        (
            np.hstack((E, Ez)),
            np.hstack((Ez, E))
        )
    )
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    X = torch.tensor(X, dtype=torch.float)
    E = torch.tensor(E, dtype=torch.float)

    y = np.array([node.value if node.value is not None else 0 for node in nodes], dtype=float)
    y = y[:, np.newaxis]
    y = scaler.transform(y)
    y = torch.tensor(y, dtype=torch.float)

    dataset.append(Data(x=X, edge_index=edge_index, edge_attr=E, y=y))

In [None]:
class GAT(torch.nn.Module):
    def __init__(
            self, 
            in_channels: int, 
            emb_channels: int, 
            hidden_channels: int, 
            out_channels: int, 
            heads: int,
            num_emb_layers: int = 3,
            num_gat_layers: int = 8,
            emb_activation = F.relu,
            gat_activation = F.elu,
            dropout_inter_layer=0.1,
            dropout_gat=0.2):

        super().__init__()

        num_emb_layers = max(1, num_emb_layers)

        self.emb = [Linear(in_channels, emb_channels)]
        self.emb += [
            Linear(emb_channels, emb_channels)
            for _ in range(num_emb_layers - 1)
        ]

        self.gatconv = [
            GATv2Conv(
                emb_channels,
                hidden_channels, 
                heads, 
                edge_dim=E.shape[1], 
                residual=True,
                add_self_loops=True, 
                fill_value=0, 
                dropout=dropout_gat
            )
        ]
        self.gatconv += [
            GATv2Conv(
                hidden_channels * heads,
                hidden_channels, 
                heads, 
                edge_dim=E.shape[1], 
                residual=True,
                add_self_loops=True, 
                fill_value=0, 
                dropout=dropout_gat
            )
            for _ in range(num_gat_layers - 1)
        ]

        self.head = Linear(hidden_channels * heads, out_channels)

        self.dropout_inter_layer = dropout_inter_layer
        self.emb_activation = emb_activation
        self.gat_activation = gat_activation

    
    def embedding(self, x):

        for layer in self.emb:
            x = layer(x)
            x = self.emb_activation(x)
            x = F.dropout(x, p=self.dropout_inter_layer, training=self.training)

        return x


    def forward(self, x, edge_index, edge_attr):

        x_ = self.embedding(x)
        for layer in self.gatconv:
            x_ = layer(x_, edge_index, edge_attr)
            x_ = self.gat_activation(x_)
            x_ = F.dropout(x_, p=self.dropout_inter_layer, training=self.training)

        x_ = self.head(x_)
        x_ = x[:, [-1]] + x_

        return x_

In [None]:
model = GAT(
    in_channels=X.shape[1],
    hidden_channels=32,
    emb_channels=32,
    num_emb_layers=2,
    num_gat_layers=3, 
    heads=1,
    out_channels=y.shape[1],
    dropout_inter_layer=0.,
    dropout_gat=0.
)

sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=5e-4)
optimizer.zero_grad()

loader = DataListLoader(dataset=dataset, shuffle=True)

model.train()
for epoch in tqdm(range(1000), total=1000):
    total_loss = 0
    for batch_no, batch in enumerate(loader):
        data = batch[0]
        out = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr)
        loss = F.l1_loss(out, data.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    if epoch % 1 == 0:
        print(
            "Epoch {:05d} | Loss {:.6f} |".format(
                epoch, total_loss / (batch_no + 1)
            )
        )

In [None]:
expr = "0.2 * 0.95"
ent = graph_from_expression(expr)
model.eval()

id_index_mapping = {
    node.id: index
    for index, node in enumerate(ent["nodes"])
}

# Get node features
X = node_features.transform(ent["nodes"])

y = np.array([node.value if node.value is not None else 0 for node in ent["nodes"]], dtype=float)
y = y[:, np.newaxis]
y = torch.tensor(y, dtype=torch.float)

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in ent["relationships"]
]
edge_index = torch.tensor(edge_index).T
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))

# Get edge features
E = edge_features.fit_transform(ent["relationships"])
Ez = np.zeros_like(E)
E = np.vstack(
    (
        np.hstack((E, Ez)),
        np.hstack((Ez, E))
    )
)

edge_index = torch.tensor(edge_index, dtype=torch.long)
X = torch.tensor(X, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)

pred = model(x=X, edge_index=edge_index, edge_attr=E)
loss = F.l1_loss(pred, y)

pred = scaler.inverse_transform(pred.detach().numpy())

for n, p, t in zip(ent["nodes"], pred, y.detach().numpy()):
    print(n)

print(np.round(pred.flatten(), 3))
print(np.round(y.detach().numpy().flatten(), 3))
print(loss.item())


In [None]:
X1 = model.embedding(X)
X1, A = model.conv1(X1, edge_index, E, return_attention_weights=True)

In [None]:
A