In [1]:
import torch
import tqdm
import numpy as np
from tqdm.notebook import tqdm
from torch_geometric.data import Data
from torch_geometric.nn import Linear, GATv2Conv
from torch_geometric.nn import pool
import torch.nn.functional as F

from arigin.expressions import generate
from arigin.graph.generation import (
    graph_from_expression, 
    generate_multiple_graphs
)
from arigin.features import node_features, edge_features
from arigin.preprocessing import GraphEntityToDataSet

In [2]:
max_numbers = 4
min_numbers = 2

In [None]:
graph_entities, y = generate_multiple_graphs(
    n_graphs=1000, 
    min_numbers=min_numbers, 
    max_numbers=max_numbers
)

dataset_create = GraphEntityToDataSet().fit(graph_entities, y)
data = dataset_create.transform(graph_entities, y)

100%|██████████| 1000/1000 [00:00<00:00, 1237.78it/s]


In [30]:
class GAT(torch.nn.Module):
    def __init__(
            self, 
            in_channels: int,
            emb_channels: int,
            hidden_channels: int, 
            out_channels: int, 
            activation = F.relu,
            dropout=0.2):

        super().__init__()

        self.linear_in = Linear(in_channels, emb_channels)
        self.linear_in_2 = Linear(emb_channels, emb_channels)
        self.linear_in_3 = Linear(emb_channels, hidden_channels)

        self.conv = GATv2Conv(
                hidden_channels,
                hidden_channels, 
                heads=1,
                fill_value=0,
                add_self_loops=True,
                dropout=dropout
        )

        self.pool = pool.SAGPooling(
            hidden_channels,
            GNN=GATv2Conv,
            ratio=0.5,
            heads=1,
            fill_value=0,
            add_self_loops=True
        )

        self.head = Linear(hidden_channels, hidden_channels)
        self.head_2 = Linear(hidden_channels, out_channels)

        self.activation = activation

    def forward(self, x, edge_index, edge_attr, batch):

        x = self.linear_in(x)
        x = self.activation(x)
        x = self.linear_in_2(x)
        x = self.activation(x)
        x = self.linear_in_3(x)
        x = self.activation(x)

        for i in range(3):
            x, edge_index, _, batch, _, _ = self.pool(x, edge_index, batch=batch)
        
            x = self.conv(x, edge_index, edge_attr)
            x = self.activation(x)

        if x.shape[0] != len(np.unique(batch)):
            print(x.shape, len(np.unique(batch)))

        x = self.head(x)
        x = F.relu(x)
        x = self.head_2(x)

        return x

In [31]:
model = GAT(
    in_channels=X.shape[1],
    emb_channels=1024,
    hidden_channels=32,
    out_channels=1,
    dropout=0.,
    activation=F.relu
)

sum(p.numel() for p in model.parameters() if p.requires_grad)

1095974

In [34]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)
optimizer.zero_grad()

model.train()
for epoch in tqdm(range(100), total=100):
    total_loss = 0
    out = model(
        x=data.x, 
        edge_index=data.edge_index, 
        edge_attr=None, 
        batch=data.batch
    )
    loss = F.l1_loss(out, data.y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    total_loss += loss.item()
    if epoch % 10 == 0:
        print(
            "Epoch {:05d} | Loss {:.6f} |".format(
                epoch, total_loss
            )
        )

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 00000 | Loss 1.526859 |
Epoch 00010 | Loss 1.526860 |
Epoch 00020 | Loss 1.526861 |
Epoch 00030 | Loss 1.526860 |
Epoch 00040 | Loss 1.526859 |
Epoch 00050 | Loss 1.526859 |
Epoch 00060 | Loss 1.526861 |
Epoch 00070 | Loss 1.526864 |
Epoch 00080 | Loss 1.526861 |
Epoch 00090 | Loss 1.526861 |


In [None]:
model.eval()

In [None]:
graph_entities = {"nodes": [], "relationships": []}
node_batches = []
relationship_batches = []
result_batches = []
batch_nos = []
n_in_batch = 0
for i in tqdm(range(10), total=10):
    if not n_in_batch:
        node_batch = []
        relationship_batch = []
        result_batch = []
        batch_no = []
    expr = generate(min_numbers, max_numbers)
    single_graph_entities = graph_from_expression(expr)
    try:
        y = eval(expr)
    except ZeroDivisionError:
        continue
    n_nodes = len(single_graph_entities["nodes"])
    graph_entities["nodes"] += single_graph_entities["nodes"]
    graph_entities["relationships"] += single_graph_entities["relationships"]

    node_batch += single_graph_entities["nodes"]
    relationship_batch += single_graph_entities["relationships"]
    result_batch.append(y)
    batch_no += [i] * n_nodes
    n_in_batch += 1
    if n_in_batch == 10:
        n_in_batch = 0
        node_batches.append(node_batch)
        relationship_batches.append(relationship_batch)
        result_batches.append(result_batch)
        batch_nos.append(batch_no)

dataset = []
for i, (nodes, relationships, y, batch_no) in enumerate(zip(node_batches, relationship_batches, result_batches, batch_nos)):

    id_index_mapping = {
        node.id: index
        for index, node in enumerate(nodes)
    }

    X = node_features.transform(nodes)
    y = np.array(y).reshape(-1, 1)

    # Get edge indices
    edge_index = [
        [
            id_index_mapping[relationship.source.id],
            id_index_mapping[relationship.target.id]
        ]
        for relationship in relationships
    ]
    edge_index = np.array(edge_index)
    edge_index = np.vstack((edge_index, edge_index[:, [1, 0]]))

    # Get edge features
    E = edge_features.fit_transform(relationships)
    Ez = np.zeros_like(E)
    E = np.vstack(
        (
            np.hstack((E, Ez)),
            np.hstack((Ez, E))
        )
    )
    edge_index = torch.tensor(edge_index.T, dtype=torch.long)
    X = torch.tensor(X, dtype=torch.float)
    E = torch.tensor(E, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.float)
    batch_no = torch.tensor(batch_no, dtype=torch.long)

    dataset.append(Data(x=X, edge_index=edge_index, y=y, batch=batch_no))

In [88]:
data = dataset[0]
ypred = model(x=data.x, edge_index=data.edge_index, edge_attr=None, batch=data.batch)

In [None]:
ypred, y