In [None]:
import torch
import numpy as np
import tqdm
import pandas as pd
from tqdm.notebook import tqdm
from typing import List
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from torch_geometric.data import Data
from torch_geometric.nn import Linear, GATv2Conv, GatedGraphConv, GCNConv
import torch.nn.functional as F

from arigin.expressions import generate
from arigin.graph.elements import Result
from arigin.graph.generation import graph_from_expression
from arigin.graph.models import GCN
from arigin.features import node_features, node_features, edge_features

In [None]:
min_numbers = 2
max_numbers = 4

In [None]:
graph_entities = {"nodes": [], "relationships": []}
node_batches = []
relationship_batches = []
for _ in tqdm(range(5000), total=5000):

    expr = generate(min_numbers, max_numbers)
    single_graph_entities = graph_from_expression(expr)
    graph_entities["nodes"] += single_graph_entities["nodes"]
    graph_entities["relationships"] += single_graph_entities["relationships"]

    node_batches += single_graph_entities["nodes"]
    relationship_batches += single_graph_entities["relationships"]

In [None]:
# Get node features
node_features.fit(graph_entities["nodes"])
scaler = StandardScaler(with_mean=False)

nodes = graph_entities["nodes"]
relationships = graph_entities["relationships"]

id_index_mapping = {
    node.id: index
    for index, node in enumerate(nodes)
}

X = node_features.transform(nodes)

y = X[:, [-1]]
is_result = np.array([node.__class__ == Result for node in nodes])
X[is_result, -1] = 0
X = PolynomialFeatures(2, interaction_only=True).fit_transform(X)
# Drop any non-unique features
unique = (pd.DataFrame(X).nunique() >= 2)
unique = unique[unique].index
X = X[:, unique]

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in relationships
]
edge_index = torch.tensor(edge_index).T
E = np.ones(edge_index.shape[1])
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))
E = np.hstack((E,  -E))

# Get edge features
edge_index = torch.tensor(edge_index, dtype=torch.long)
X = torch.tensor(X, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)
y = torch.tensor(y, dtype=torch.float)

dataset = Data(x=X, edge_index=edge_index, edge_attr=E, y=y)

In [None]:
model = GCN(
    in_channels=X.shape[1],
    hidden_channels=128,
    emb_channels=128,
    out_channels=y.shape[1],
    dropout_inter_layer=0.,
    gat_activation=F.relu
)

sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.00002, weight_decay=1e-3)
optimizer.zero_grad()
model.train()
for epoch in tqdm(range(1000), total=1000):
    out = model(x=dataset.x, edge_index=dataset.edge_index, edge_weight=E)
    loss = F.l1_loss(out, dataset.y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    loss += loss.item()
    if epoch % 50 == 0:
        print("Epoch {:05d} | Loss {:.6f} |".format(epoch, loss))

In [None]:
import matplotlib.pyplot as plt
l1_error = np.abs((out - y).detach().numpy())
plt.hist(l1_error, bins=np.linspace(0, 2, 100), log=True)
plt.grid()

In [None]:
expr = "0.95 * 0.3"
ent = graph_from_expression(expr)
model.eval()

nodes = ent["nodes"]
relationships = ent["relationships"]

id_index_mapping = {
    node.id: index
    for index, node in enumerate(nodes)
}

X = node_features.transform(nodes)
y = X[:, [-1]]
is_result = np.array([node.__class__ == Result for node in nodes])
X[is_result, -1:] = 0
X = PolynomialFeatures(2, interaction_only=True).fit_transform(X)
X = X[:, unique]

# Get edge indices
edge_index = [
    (
        id_index_mapping[relationship.source.id],
        id_index_mapping[relationship.target.id]
    )
    for relationship in relationships
]
edge_index = torch.tensor(edge_index).T
E = np.ones(edge_index.shape[1])
edge_index = np.hstack((edge_index, edge_index[[1, 0], :]))
E = np.hstack((E, -E))

# Get edge features
edge_index = torch.tensor(edge_index, dtype=torch.long)
X = torch.tensor(X, dtype=torch.float)
E = torch.tensor(E, dtype=torch.float)
y = torch.tensor(y, dtype=torch.float)

pred = model(x=X, edge_index=edge_index, edge_weight=E)
loss = F.l1_loss(pred, y)

pred = pred.detach().numpy()

print(np.round(pred, 4))
print(np.round(y.detach().numpy(), 4))
print(loss.item())

In [None]:
from sklearn.metrics import pairwise_distances

X_ = model.embedding(X)
Xa = model.gatconv_1(X_, edge_index)
V = model.embedding(X).detach().numpy()
pairwise_distances(V, V)