In [1]:
# Dependencies

import torch
import networkx as nx
from sklearn.base import BaseEstimator, ClassifierMixin
from datasets import load_dataset
from sklearn.model_selection import cross_val_score
from graph import process_dataset
import sys

sys.path.append("../")

import thdc
from hdc import pm


torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
# torch.cuda.empty_cache()
# torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [2]:
# DATASET = load_dataset("graphs-datasets/MUTAG")["train"]
DATASET = load_dataset("graphs-datasets/PROTEINS")["train"]

In [3]:
FOLDS, DIMENSIONS = 10, 2000
(graphs, labels) = process_dataset(DATASET)
VECTORS = torch.randint(0, 2, (101, DIMENSIONS), dtype=torch.float64).cuda()
VECTORS[VECTORS == 0] = -1

MAT = torch.from_numpy(pm(DIMENSIONS)).cuda()
# MAT.type(torch.int8)

In [4]:
torch.matmul(VECTORS[0], MAT)

tensor([-1., -1., -1.,  ...,  1., -1., -1.], device='cuda:0',
       dtype=torch.float64)

In [5]:
def encode(graph, vectors, mat):
    G = None
    for vs in nx.bfs_layers(graph, [1]):
        if len(vs) == 0:
            continue

        # indices = torch.tensor(vs)
        if G is None:
            G = torch.sum(
                torch.index_select(vectors, 0, torch.tensor(vs)),
                dim=0,
            )
        else:
            G = torch.sum(
                torch.cat(
                    [
                        torch.matmul(G, mat)[None, :],
                        torch.index_select(vectors, 0, torch.tensor(vs)),
                    ],
                    0,
                ),
                dim=0,
            )
    return G

In [6]:
class GraphClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self):
        self.memory = thdc.ItemMemory()

    def fit(self, X, y):
        for i in range(len(X)):
            self.memory.add_vector(str(y[i]), encode(X[i], VECTORS, MAT, 0))

        return self

    def predict(self, X):
        p = []
        for query in X:
            (label, _, _) = self.memory.cleanup(encode(query, VECTORS, MAT, 0))
            p.append(int(label))

        return p

In [7]:
def main():
    clf = GraphClassifier()
    scores = cross_val_score(
        clf,
        graphs,
        labels,
        cv=5,
        n_jobs=1,
        verbose=4,
        error_score="raise",
    )
    print("Acc =>", scores.mean())


main()