In [1]:
import torch
import tqdm
import numpy as np
from tqdm.notebook import tqdm
import torch.nn.functional as F
from sklearn.preprocessing import FunctionTransformer
from arigin.expressions import generate
from arigin.graph.generation import (
    graph_from_expression, 
    generate_multiple_graphs
)
from arigin.graph.models import MathModel
from arigin.preprocessing import GraphEntityToDataSet

In [2]:
max_numbers = 4
min_numbers = 2

In [None]:
graph_entities, y = generate_multiple_graphs(
    n_graphs=5000, 
    min_numbers=min_numbers, 
    max_numbers=max_numbers
)

target_transformer = FunctionTransformer(
    func=lambda x: np.tanh(x), 
    inverse_func=lambda x: np.arctanh(x), 
    validate=False)

dataset_create = GraphEntityToDataSet(target_transformer=target_transformer).fit(graph_entities, y)
data = dataset_create.transform(graph_entities, y)

In [None]:
model = MathModel(
    in_channels=data.x.shape[1],
    emb_channels=24,
    hidden_channels=6,
    heads=8,
    edge_dim=data.edge_attr.shape[1],
    out_channels=data.y.shape[1],
    dropout=0.,
    activation=F.gelu
)

sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
optimizer.zero_grad()

model.train()
for epoch in tqdm(range(2000), total=2000):
    total_loss = 0
    out = model(
        x=data.x, 
        edge_index=data.edge_index, 
        edge_attr=data.edge_attr,
        batch=data.batch
    )
    loss = F.l1_loss(out, data.y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    total_loss += loss.item()
    if epoch % 100 == 0:
        print(
            "Epoch {:05d} | Loss {:.6f} |".format(
                epoch, total_loss
            )
        )

In [None]:
expression = "0.5 + 0.5 * 0.2 "
y_true = eval(expression)
graph = graph_from_expression(expression)
graph["batch"] = torch.tensor([0 for _ in range(len(graph["nodes"]))])

datatest = dataset_create.transform(graph, y_true)

In [44]:
model.eval()
y_pred = model(x=datatest.x, edge_index=datatest.edge_index, edge_attr=datatest.edge_attr, batch=datatest.batch)
y_pred = y_pred.detach().numpy()

In [None]:
y_pred[0][0] - datatest.y