In [None]:
import torch
import numpy as np
from tqdm.notebook import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt

from arigin.graph.elements import Result, Node
from arigin.expressions import generate
from arigin.graph.generation import graph_from_expression
from arigin.features import node_features, edge_features

In [None]:
def func_to_approx(x, y):
    return np.sin(4 * np.pi * (x * y))

In [None]:
x = np.random.random(size=(2, 1000))
y = func_to_approx(*x)[:, np.newaxis]

x = torch.tensor(x.T, dtype=torch.float)
y = torch.tensor(y, dtype=torch.float)

In [None]:
class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = torch.nn.Linear(2, 4)
        self.lin2 = torch.nn.Linear(4, 4)
        self.lin3 = torch.nn.Linear(4, 4)
        self.lin4 = torch.nn.Linear(4, 1)

    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        x = F.relu(x)
        x = self.lin4(x)
        return x
    

model = MLP()
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
device = torch.device("mps")

model = model.to(device)
x = x.to(device)
y = y.to(device)

In [None]:
%%time

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)

model.train()
for epoch in tqdm(range(10000), total=10000):
    optimizer.zero_grad()
    out = model(x=x)
    loss = F.l1_loss(out, y)
    if epoch % 500 == 0:
        print(epoch, loss)
    loss.backward()
    optimizer.step()

In [None]:
xtest = np.linspace(0, 1, 1000)[np.newaxis, :]
ytest = func_to_approx(xtest, xtest)

ypred = model(x=torch.tensor(np.vstack((xtest, xtest)), dtype=torch.float, device=device).T).cpu().detach().numpy()
plt.plot(xtest.flatten(), ytest.flatten(), "-", lw=0.5, label="test")
plt.plot(xtest.flatten(), ypred.flatten(), "-", lw=0.5, label="pred")
plt.legend()

In [None]:
graph_entities = {"nodes": [], "relationships": []}
data = []
max_n_nodes = 0
for _ in tqdm(range(10000), total=10000):
    expr = generate(2, 5)
    single_graph_entities = graph_from_expression(expr)
    data.append(
        {
            "nodes": single_graph_entities["nodes"],
            "relationships": single_graph_entities["relationships"]
        } 
    )
    max_n_nodes = max(max_n_nodes, len(single_graph_entities["nodes"]))

In [None]:
n0 = Node(expression="")

nodes = []
for single in data:
    result = single["nodes"][-1]
    single["nodes"].remove(result)
    while len(single["nodes"]) < max_n_nodes - 1:
        single["nodes"].append(n0)
    single["nodes"].append(result)
    nodes += single["nodes"]

# Get node features
X = node_features.fit_transform(nodes)
X = X.reshape(len(data), -1)
X = torch.tensor(X, dtype=torch.float)

In [None]:
y = torch.tensor(
    [[node.value if node.value is not None else 0 for node in single["nodes"]] for single in data],
    dtype=torch.float
).reshape(len(data), -1)
y = torch.tanh(y)
train_mask = ~torch.isnan(y)
n_train = sum(train_mask)
n_train

In [None]:
class MLP(torch.nn.Sequential):
    def __init__(self):
        super().__init__(
           torch.nn.Linear(X.shape[1], 512),
           torch.nn.ReLU(),
           torch.nn.Linear(512, 128),
           torch.nn.ReLU(),
           torch.nn.Linear(128, 64),
           torch.nn.ReLU(),
           torch.nn.Linear(64, 32),
           torch.nn.ReLU(),
           torch.nn.Linear(32, 32),
           torch.nn.ReLU(),
           torch.nn.Linear(32, 16),
           torch.nn.ReLU(),
           torch.nn.Linear(16, 16),
           torch.nn.ReLU(),
           torch.nn.Linear(16, 8),
           torch.nn.ReLU(),
           torch.nn.Linear(8, y.shape[1]),
        )

model = MLP()
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
device = torch.device("mps")

model = model.to(device)
X = X.to(device)
y = y.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)

model.train()
for epoch in tqdm(range(20000), total=20000):
    optimizer.zero_grad()
    out = model(X)
    loss = F.l1_loss(out, y)
    if epoch % 1000 == 0:
        print(
            "Epoch {:05d} | Loss {:.6f} |".format(
                epoch, loss
            )
        )
    loss.backward()
    optimizer.step()

In [None]:
expr = "0.1 * 0.35"
ent = graph_from_expression(expr)

nodes = ent["nodes"]
result = nodes[-1]
nodes.remove(result)
while len(nodes) < max_n_nodes - 1:
    nodes.append(n0)
nodes.append(result)

# Get node features
Xtest = node_features.transform(nodes)
Xtest = Xtest.reshape(1, -1)
Xtest = torch.tensor(Xtest, dtype=torch.float)

[node.value for node in nodes]

In [None]:
ypred = torch.atanh(model.to("cpu")(Xtest)).detach().numpy()
print(np.round(ypred, 4))