In [None]:
import os.path as osp
import random

import pandas as pd
import torch
import torch.nn.functional as F
from torch.nn import Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from rae import PROJECT_ROOT

In [None]:
from torch_geometric.datasets import CitationFull
from torch_geometric.transforms import RandomNodeSplit
from pytorch_lightning import seed_everything

dataset_name = "Cora"
seed_everything(0)
transform = T.Compose([T.NormalizeFeatures(), RandomNodeSplit(num_val=0.1, num_test=0)])
dataset = Planetoid(PROJECT_ROOT / "data" / "pyg" / dataset_name, dataset_name, transform=transform)
data = dataset[0]
_, edge_weight = gcn_norm(
    data.edge_index, num_nodes=data.x.size(0), add_self_loops=False
)  # Pre-process GCN normalization.
data.edge_weight = edge_weight
data

In [None]:
from rae.modules.attention import RelativeAttention

In [None]:
from rae.modules.enumerations import Output
from rae.modules.attention import AttentionOutput
from torch import nn


class Net(torch.nn.Module):
    def __init__(
        self,
        relative: bool,
        hidden_proj: nn.Module,
        hidden_fn,
        relative_proj: RelativeAttention,
        class_proj: nn.Module,
        convs: nn.ModuleList,
        conv_fn,
        conv_out: int,
        dropout: float,
    ):
        super().__init__()

        self.hidden_proj: nn.Module = hidden_proj
        self.class_proj: nn.Module = class_proj

        self.hidden_fn = hidden_fn

        self.relative_proj = relative_proj

        self.convs = convs
        self.conv_fn = conv_fn
        self.conv_fc = nn.Linear(in_features=conv_out, out_features=conv_out)

        self.layer_norm = nn.LayerNorm(conv_out)

        self.dropout = dropout
        self.relative: bool = relative

    def forward(self, x, edge_index, edge_weight, anchor_idxs: torch.Tensor):
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.hidden_proj(x)

        x = x_0 = self.hidden_fn(x)

        for conv in self.convs:
            x = F.dropout(x, self.dropout, training=self.training)
            params = {"edge_index": edge_index}
            if type(self.convs[0]).__name__ == "GCN2Conv":
                params["x_0"] = x_0
                params["edge_weight"] = edge_weight
            x = conv(x, **params)
            x = self.conv_fn(x)

        x = self.conv_fc(x)
        anchors: torch.Tensor = x[anchor_idxs, :]

        rel_out = self.relative_proj(x=x, anchors=anchors)
        assert not self.relative
        if self.relative:
            x = rel_out[AttentionOutput.OUTPUT]

        x = F.normalize(x, p=2, dim=-1)
        x = self.class_proj(x)
        return {Output.LOGITS: x, Output.SIMILARITIES: rel_out[AttentionOutput.SIMILARITIES]}

In [None]:
num_anchors: int = 300
data.anchors = torch.as_tensor(random.sample(data.train_mask.nonzero().squeeze().cpu().tolist(), num_anchors))

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch_geometric.nn import GATConv, GCN2Conv, GCNConv, GINConv

In [None]:
def encoder_factory(encoder_type, num_layers: int, in_channels: int, out_channels: int, **params):
    assert num_layers > 0
    if encoder_type == "GCN2Conv":
        convs = []
        for layer in range(num_layers):
            convs.append(GCN2Conv(layer=layer + 1, channels=out_channels, **params))
        return nn.ModuleList(convs)

    elif encoder_type == "GCNConv":
        convs = []
        # current_out_channels = in_channels
        #
        # for layer in range(num_layers):
        #     convs.append(
        #         GCNConv(
        #             in_channels=current_out_channels,
        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),
        #             **params,
        #         )
        #     )
        convs = [
            GCNConv(
                in_channels=in_channels,
                out_channels=out_channels,
                **params,
            )
        ]
        in_channels = out_channels
        for layer in range(num_layers - 1):
            convs.append(
                GCNConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    **params,
                )
            )
        return nn.ModuleList(convs)

    elif encoder_type == "GATConv":
        convs = []

        # for layer in range(num_layers):
        #     convs.append(
        #         GATConv(
        #             in_channels=current_out_channels,
        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),
        #             **params,
        #         )
        #     )

        convs = [
            GATConv(
                in_channels=in_channels,
                out_channels=out_channels,
                **params,
            )
        ]
        in_channels = out_channels
        for layer in range(num_layers - 1):
            convs.append(
                GATConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    **params,
                )
            )

        return nn.ModuleList(convs)

    elif encoder_type == "GINConv":
        convs = []
        # current_out_channels = in_channels
        #
        # for layer in range(num_layers):
        #     convs.append(
        #         GINConv(
        #             nn=nn.Linear(
        #                 in_features=current_out_channels,
        #                 out_features=(current_out_channels := max(out_channels, current_out_channels // 2)),
        #             )
        #         )
        #     )
        current_in_channels = in_channels
        for layer in range(num_layers):
            convs.append(
                GINConv(
                    nn=nn.Linear(
                        in_features=current_in_channels,
                        out_features=out_channels,
                    ),
                    **params,
                )
            )
            current_in_channels = out_channels
        return nn.ModuleList(convs)

    else:
        raise NotImplementedError

In [None]:
import itertools
import functools
from pytorch_lightning.utilities.seed import log as seed_log
from pytorch_lightning import seed_everything


import random
import torch.nn.functional as F

# General SWEEP
sweep = {
    "seed": list(range(5)),
    # "seed_index": [0],
    "num_epochs": [10, 30, 50],
    "in_channels": [num_anchors],
    # "out_channels": [10, 32, 64],
    "out_channels": [num_anchors],
    "num_layers": [64, 32],
    "dropout": [0.1, 0.5],
    # "hidden_fn": [torch.relu, torch.tanh, torch.sigmoid],
    # "conv_fn": [torch.relu, torch.tanh, torch.sigmoid],
    "hidden_fn": [torch.nn.ReLU(), torch.nn.Tanh()],
    "conv_fn": [torch.nn.ReLU(), torch.nn.Tanh()],
    "optimizer": [torch.optim.Adam, torch.optim.SGD],
    "lr": [0.01, 0.02],
    "encoder": [
        (
            "GCN2Conv",
            functools.partial(
                encoder_factory,
                encoder_type="GCN2Conv",
                **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),
            ),
        ),
        # ("GCNConv", functools.partial(encoder_factory, encoder_type="GCNConv")),
        # ("GATConv", functools.partial(encoder_factory, encoder_type="GATConv")),
        ("GINConv", functools.partial(encoder_factory, encoder_type="GINConv")),
    ],
}

# Best model config
# sweep = {
#     "seed": [1],
#     #"seed_index": [0],
#     "num_epochs": [500],
#     "in_channels": [num_anchors],
#     # "out_channels": [10, 32, 64],
#     "out_channels": [num_anchors],
#     "num_layers": [32],
#     "dropout": [0.5],
#     # "hidden_fn": [torch.relu, torch.tanh, torch.sigmoid],
#     # "conv_fn": [torch.relu, torch.tanh, torch.sigmoid],
#     "hidden_fn": [torch.nn.ReLU()],
#     "conv_fn": [torch.nn.ReLU()],
#     "optimizer": [torch.optim.Adam],
#     "lr": [0.02],
#     "encoder": [
#         (
#             "GCN2Conv",
#             functools.partial(
#                 encoder_factory,
#                 encoder_type="GCN2Conv",
#                 **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),
#             ),
#         ),
#         # ("GCNConv", functools.partial(encoder_factory, encoder_type="GCNConv")),
#         # ("GATConv", functools.partial(encoder_factory, encoder_type="GATConv")),
# #         ("GINConv", functools.partial(encoder_factory, encoder_type="GINConv")),
#     ],
# }


# relative_proj = RelativeAttention(
#     n_anchors=num_anchors,
#     n_classes=dataset.num_classes,
#     similarity_mode="inner",
#     values_mode="similarities",
#     normalization_mode="l2",
# )
# keys, values = zip(*sweep.items())
# experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
from sklearn.model_selection import ParameterGrid

experiments = ParameterGrid(sweep)
f"Total available experiments={len(experiments)}"

In [None]:
def train_step(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)
    logits = out[Output.LOGITS]
    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)

In [None]:
dataset[0].train_mask.shape

In [None]:
@torch.no_grad()
def test(model, data):
    model.eval()
    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)
    pred = out[Output.LOGITS].argmax(dim=-1)

    accs = []
    for _, mask in data("train_mask", "val_mask"):
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return out, accs

In [None]:
old_best = torch.load(PROJECT_ROOT / "experiments" / "sec:data-manifold" / f"{'Cora'}_best_run.pt")
reference_latents = [old_best["best_epoch"]["rel_x"]]
reference_latents[0].shape

In [None]:
from typing import *
import numpy as np


def get_distance(latents1: torch.Tensor, latents_ref: Sequence[torch.Tensor]):
    assert not isinstance(latents_ref, (np.ndarray, torch.Tensor))
    dists = [
        F.cosine_similarity(
            latents1,
            latent_ref,
        )
        .mean()
        .item()
        for latent_ref in latents_ref
    ]
    return np.mean(dists)

In [None]:
relative: bool = False

In [None]:
from sklearn.model_selection import ParameterSampler, ParameterGrid
import logging
from tqdm import tqdm
from pprint import pprint
from rae.utils.utils import to_device

experiments = []
stats = {x: [] for x in ("experiment", "epoch", "loss", "train_acc", "val_acc", "reference_distance")}

# for i, experiment in enumerate(pbar := tqdm(ParameterSampler(sweep, n_iter=100, random_state=42), desc="Experiment")):
for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc="Experiment")):
    encoder_name, encoder_build = experiment["encoder"]
    if encoder_name == "GCN2Conv":
        experiment["out_channels"] = num_anchors
        experiment["in_channels"] = num_anchors

    seed: int = experiment["seed"]
    temp_log_level = seed_log.getEffectiveLevel()
    seed_log.setLevel(logging.ERROR)
    seed_everything(seed)
    seed_log.setLevel(temp_log_level)

    hidden_proj = nn.Linear(dataset.num_features, experiment["in_channels"])

    convs = encoder_build(
        num_layers=experiment["num_layers"],
        in_channels=experiment["in_channels"],
        out_channels=experiment["out_channels"],
    )
    class_proj = nn.Linear(experiment["out_channels"], dataset.num_classes)

    model = Net(
        relative=relative,
        hidden_proj=hidden_proj,
        hidden_fn=experiment["hidden_fn"],
        relative_proj=relative_proj,
        class_proj=class_proj,
        convs=convs,
        conv_fn=experiment["conv_fn"],
        conv_out=experiment["out_channels"],
        dropout=experiment["dropout"],
    ).to(DEVICE)
    data = data.to(DEVICE)
    optimizer = experiment["optimizer"](
        model.parameters(),
        lr=experiment["lr"],
    )

    best_val_acc = 0
    best_epoch = None
    epochs = []
    for epoch in range(experiment["num_epochs"]):
        loss = train_step(model=model, optimizer=optimizer, data=data)
        model_out, (train_acc, val_acc) = test(model=model, data=data)
        # epochs.append(epoch_out)

        stats["experiment"].append(i)
        stats["epoch"].append(epoch)
        stats["loss"].append(loss)
        stats["train_acc"].append(train_acc)
        stats["val_acc"].append(val_acc)
        stats["reference_distance"].append(
            get_distance(
                latents1=model_out[Output.SIMILARITIES].to("cpu", non_blocking=True), latents_ref=reference_latents
            )
            if reference_latents is not None
            else None
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = {
                "rel_x": model_out[Output.SIMILARITIES].to("cpu", non_blocking=True),
                "epoch": epoch,
                "loss": loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
            }
        # print(
        #     f"Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, "
        #     f"Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, "
        #     f"Final Test: {test_acc:.4f}"
        # )
    experiment["best_epoch"] = best_epoch
    # experiment["epochs"] = epochs
    # best_epoch = epochs[best_epoch]
    pbar.set_description(
        f"Epoch: {best_epoch['epoch']:04d}, Loss: {best_epoch['loss']:.4f} Train: {best_epoch['train_acc']:.4f}, "
        f"Val: {best_epoch['val_acc']:.4f} Dist: {stats['reference_distance'][-1]}"
    )

    experiments.append(experiment)
    model.cpu()

stats = pd.DataFrame(stats)

In [None]:
torch.save(
    experiments, PROJECT_ROOT / "experiments" / "sec:data-manifold" / f"{dataset_name}_data_manifold_experiments.pt"
)
stats.to_csv(PROJECT_ROOT / "experiments" / "sec:data-manifold" / f"{dataset_name}_data_manifold_stats.tsv", sep="\t")

In [None]:
# torch.save(experiments[0], PROJECT_ROOT / "experiments" / "sec:data-manifold" / f"{dataset_name}_best_run.pt")