In [19]:
import numpy as np
import pandas as pd
import networkx as nx

import sys

sys.path.append("../")
import utils

from torch.utils.data import Dataset
from torch_geometric.utils.convert import from_networkx
from os import path
import torch
import os
import numpy as np

import torch
import torch.nn.functional as F

from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, GINConv, global_add_pool
from tqdm.notebook import tqdm

In [20]:
class GraphDataset(Dataset):
    def __init__(
        self,
        dirpath,
        dataset,
        quantile: bool = True,
        n_quantiles: int = 25,
        cache: bool = True,
    ):
        X_ts, labels = GraphDataset.readucr(
            path.join(dirpath, dataset, f"{dataset}.txt")
        )
        self.X_ts = pd.DataFrame(X_ts.T)
        self.labels = torch.tensor(labels, dtype=int)

        subdirname = "quantile_" + str(n_quantiles) if quantile else "visibility"
        self.path = path.join(dirpath, dataset, subdirname)
        if path.exists(self.path):
            return
        os.mkdir(self.path)
        for idx, col in tqdm(
            enumerate(self.X_ts.columns), total=len(self.X_ts.columns)
        ):
            if quantile:
                torch.save(
                    from_networkx(
                        utils.df_to_quantile_graph(
                            self.X_ts, y_col=col, n_quantiles=n_quantiles
                        )
                    ),
                    path.join(self.path, f"{idx}.pt"),
                )
            else:
                torch.save(
                    from_networkx(utils.df_to_visibility_graph(self.X_ts, y_col=col)),
                    path.join(self.path, f"{idx}.pt"),
                )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        data = torch.load(path.join(self.path, f"{idx}.pt"))
        data.y = self.labels[idx]
        data.x = data.x.unsqueeze(1)
        return data

    def readucr(filename):
        data = np.loadtxt(filename)
        y = data[:, 0] == 1
        x = data[:, 1:]
        return x, y

In [21]:
### Zerżnięte z: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/compile/gin.py
### znalezione przez: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/compile.html?highlight=compile#basic-usage

train_dataset = GraphDataset("../data/", "FordA_TRAIN", False, n_quantiles=5)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = GraphDataset("../data/", "FordA_TEST", False, n_quantiles=5)
test_loader = DataLoader(test_dataset, batch_size=32)


class GIN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(5):
            mlp = MLP([in_channels, 32, 32])
            self.convs.append(GINConv(mlp, train_eps=False))
            in_channels = 32

        self.mlp = MLP([32, 32, out_channels], norm=None, dropout=0.5)

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        x = global_add_pool(x, batch)
        return self.mlp(x)


model = GIN(1, 2)

# Compile the model into an optimized version:
# Note that `compile(model, dynamic=True)` does not work yet in PyTorch 2.0, so
# we use `transforms.Pad` and static compilation as a current workaround.
# See: https://github.com/pytorch/pytorch/issues/94640

# to coś mi nie działa >:(
# model = torch.compile(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()

    total_correct = 0
    for data in loader:
        pred = model(data.x, data.edge_index, data.batch).argmax(dim=-1)
        total_correct += int((pred == data.y).sum())
    return total_correct / len(loader.dataset)


for epoch in range(1, 101):
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(
        f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, "
        f"Test: {test_acc:.4f}"
    )

[False  True False ... False  True False]
[False False False ...  True  True  True]
