In [None]:
import os.path as osp
import random

import pandas as pd
import torch
import torch.nn.functional as F
from torch.nn import Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from rae import PROJECT_ROOT

In [None]:
from torch_geometric.datasets import CitationFull
from torch_geometric.transforms import RandomNodeSplit
from pytorch_lightning import seed_everything

dataset_name = "Cora"
seed_everything(0)
transform = T.Compose([T.NormalizeFeatures(), RandomNodeSplit(num_val=0.1, num_test=0)])
dataset = Planetoid(PROJECT_ROOT / "data" / "pyg" / dataset_name, dataset_name, transform=transform)
data = dataset[0]
_, edge_weight = gcn_norm(
    data.edge_index, num_nodes=data.x.size(0), add_self_loops=False
)  # Pre-process GCN normalization.
data.edge_weight = edge_weight
data

In [None]:
from rae.modules.attention import RelativeAttention

In [None]:
from rae.modules.enumerations import Output
from rae.modules.attention import AttentionOutput
from torch import nn
from typing import *
import numpy as np
from sklearn.model_selection import ParameterSampler, ParameterGrid
import logging
from tqdm import tqdm
from pprint import pprint
from rae.utils.utils import to_device
import functools
import itertools
import functools
from pytorch_lightning.utilities.seed import log as seed_log
from pytorch_lightning import seed_everything


import random
import torch.nn.functional as F


class Net(torch.nn.Module):
    def __init__(
        self,
        hidden_proj: nn.Module,
        hidden_fn,
        relative_proj: RelativeAttention,
        class_proj: nn.Module,
        convs: nn.ModuleList,
        conv_fn,
        conv_out: int,
        dropout: float,
        relative: bool = True,
    ):
        super().__init__()

        self.hidden_proj: nn.Module = hidden_proj
        self.class_proj: nn.Module = class_proj

        self.hidden_fn = hidden_fn

        self.relative_proj = relative_proj

        self.convs = convs
        self.conv_fn = conv_fn
        self.conv_fc = nn.Linear(in_features=conv_out, out_features=conv_out)

        self.layer_norm = nn.LayerNorm(conv_out)

        self.dropout = dropout
        self.relative: bool = relative

    def forward(self, x, edge_index, edge_weight, anchor_idxs: torch.Tensor, only_absolute: bool = False):
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.hidden_proj(x)

        x = x_0 = self.hidden_fn(x)

        for conv in self.convs:
            x = F.dropout(x, self.dropout, training=self.training)
            params = {"edge_index": edge_index}
            if type(self.convs[0]).__name__ == "GCN2Conv":
                params["x_0"] = x_0
                params["edge_weight"] = edge_weight
            x = conv(x, **params)
            x = self.conv_fn(x)

        x = self.conv_fc(x)

        if only_absolute:
            return x
        anchors: torch.Tensor = x[anchor_idxs, :]
        rel_out = self.relative_proj(x=x, anchors=anchors)

        x = self.class_proj(x)
        return {Output.LOGITS: x, Output.SIMILARITIES: rel_out[AttentionOutput.SIMILARITIES]}

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch_geometric.nn import GATConv, GCN2Conv, GCNConv, GINConv

In [None]:
def encoder_factory(encoder_type, num_layers: int, in_channels: int, out_channels: int, **params):
    assert num_layers > 0
    if encoder_type == "GCN2Conv":
        convs = []
        for layer in range(num_layers):
            convs.append(GCN2Conv(layer=layer + 1, channels=out_channels, **params))
        return nn.ModuleList(convs)

    elif encoder_type == "GCNConv":
        convs = []
        # current_out_channels = in_channels
        #
        # for layer in range(num_layers):
        #     convs.append(
        #         GCNConv(
        #             in_channels=current_out_channels,
        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),
        #             **params,
        #         )
        #     )
        convs = [
            GCNConv(
                in_channels=in_channels,
                out_channels=out_channels,
                **params,
            )
        ]
        in_channels = out_channels
        for layer in range(num_layers - 1):
            convs.append(
                GCNConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    **params,
                )
            )
        return nn.ModuleList(convs)

    elif encoder_type == "GATConv":
        convs = []

        # for layer in range(num_layers):
        #     convs.append(
        #         GATConv(
        #             in_channels=current_out_channels,
        #             out_channels=(current_out_channels := max(out_channels, current_out_channels // 2)),
        #             **params,
        #         )
        #     )

        convs = [
            GATConv(
                in_channels=in_channels,
                out_channels=out_channels,
                **params,
            )
        ]
        in_channels = out_channels
        for layer in range(num_layers - 1):
            convs.append(
                GATConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    **params,
                )
            )

        return nn.ModuleList(convs)

    elif encoder_type == "GINConv":
        convs = []
        # current_out_channels = in_channels
        #
        # for layer in range(num_layers):
        #     convs.append(
        #         GINConv(
        #             nn=nn.Linear(
        #                 in_features=current_out_channels,
        #                 out_features=(current_out_channels := max(out_channels, current_out_channels // 2)),
        #             )
        #         )
        #     )
        current_in_channels = in_channels
        for layer in range(num_layers):
            convs.append(
                GINConv(
                    nn=nn.Linear(
                        in_features=current_in_channels,
                        out_features=out_channels,
                    ),
                    **params,
                )
            )
            current_in_channels = out_channels
        return nn.ModuleList(convs)

    else:
        raise NotImplementedError

In [None]:
relative_proj = RelativeAttention(
    n_anchors=None,
    n_classes=dataset.num_classes,
    similarity_mode="inner",
    values_mode="similarities",
    normalization_mode="l2",
)

In [None]:
# TRAIN BEST MODEL


sweep = {
    "seed": [1],
    # "seed_index": [0],
    "num_epochs": [50],
    "in_channels": [128],
    # "out_channels": [10, 32, 64],
    "out_channels": [128],
    "num_layers": [32],
    "dropout": [0.5],
    # "hidden_fn": [torch.relu, torch.tanh, torch.sigmoid],
    # "conv_fn": [torch.relu, torch.tanh, torch.sigmoid],
    "hidden_fn": [torch.nn.ReLU()],
    "conv_fn": [torch.nn.ReLU()],
    "optimizer": [torch.optim.Adam],
    "lr": [0.02],
    "encoder": [
        (
            "GCN2Conv",
            functools.partial(
                encoder_factory,
                encoder_type="GCN2Conv",
                **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),
            ),
        ),
        # ("GCNConv", functools.partial(encoder_factory, encoder_type="GCNConv")),
        # ("GATConv", functools.partial(encoder_factory, encoder_type="GATConv")),
        #         ("GINConv", functools.partial(encoder_factory, encoder_type="GINConv")),
    ],
    "num_anchors": [300],
}


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


def train_step(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)
    logits = out[Output.LOGITS]
    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(model, data):
    model.eval()
    out = model(data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=data.anchors)
    pred = out[Output.LOGITS].argmax(dim=-1)

    accs = []
    for _, mask in data("train_mask", "val_mask"):
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return out, accs


assert len(ParameterGrid(sweep)) == 1
for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc="Experiment")):
    seed: int = experiment["seed"]
    temp_log_level = seed_log.getEffectiveLevel()
    seed_log.setLevel(logging.ERROR)
    seed_everything(seed)
    seed_log.setLevel(temp_log_level)

    num_anchors: int = experiment["num_anchors"]
    data.anchors = torch.as_tensor(random.sample(data.train_mask.nonzero().squeeze().cpu().tolist(), num_anchors))

    encoder_name, encoder_build = experiment["encoder"]
    if encoder_name == "GCN2Conv":
        experiment["out_channels"] = num_anchors
        experiment["in_channels"] = num_anchors

    hidden_proj = nn.Linear(dataset.num_features, experiment["in_channels"])

    convs = encoder_build(
        num_layers=experiment["num_layers"],
        in_channels=experiment["in_channels"],
        out_channels=experiment["out_channels"],
    )
    class_proj = nn.Sequential(
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=experiment["out_channels"]),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=experiment["out_channels"], out_features=64),
        nn.Tanh(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=64),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=64, out_features=dataset.num_classes),
    )

    model = Net(
        hidden_proj=hidden_proj,
        hidden_fn=experiment["hidden_fn"],
        relative_proj=relative_proj,
        class_proj=class_proj,
        convs=convs,
        conv_fn=experiment["conv_fn"],
        conv_out=experiment["out_channels"],
        dropout=experiment["dropout"],
    ).to(DEVICE)
    data = data.to(DEVICE)
    optimizer = experiment["optimizer"](
        model.parameters(),
        lr=experiment["lr"],
    )

    best_val_acc = 0
    best_epoch = None
    epochs = []
    for epoch in range(experiment["num_epochs"]):
        loss = train_step(model=model, optimizer=optimizer, data=data)
        model_out, (train_acc, val_acc) = test(model=model, data=data)
        # epochs.append(epoch_out)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = {
                "epoch": epoch,
                "loss": loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
            }

        pbar.set_description(
            f"Epoch: {best_epoch['epoch']:04d}, Loss: {best_epoch['loss']:.4f} Train: {best_epoch['train_acc']:.4f}, "
            f"Val: {best_epoch['val_acc']:.4f}"
        )

    model.cpu()

best_model = model

In [None]:
best_model.cuda()
best_absolute = best_model(
    data.x, edge_index=data.edge_index, edge_weight=data.edge_weight, anchor_idxs=None, only_absolute=True
).detach()
best_absolute

In [None]:
# General SWEEP
sweep = {
    "seed": list(range(3)),
    # "seed_index": [0],
    "num_epochs": [50],
    "in_channels": [128],
    # "out_channels": [10, 32, 64],
    "out_channels": [128],
    "num_layers": [32],
    "dropout": [0.5],
    # "hidden_fn": [torch.relu, torch.tanh, torch.sigmoid],
    # "conv_fn": [torch.relu, torch.tanh, torch.sigmoid],
    "hidden_fn": [None],
    "conv_fn": [None],
    "optimizer": [
        torch.optim.Adam,
    ],
    "lr": [0.02],
    "encoder": [None],
    "num_anchors": list(range(1, 50, 1)),
}


# keys, values = zip(*sweep.items())
# experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
from sklearn.model_selection import ParameterGrid

experiments = ParameterGrid(sweep)
f"Total available experiments={len(experiments)}"

In [None]:
from typing import *
import numpy as np
from sklearn.model_selection import ParameterSampler, ParameterGrid
import logging
from tqdm import tqdm
from pprint import pprint
from rae.utils.utils import to_device

stats = {x: [] for x in ("experiment", "epoch", "loss", "train_acc", "val_acc", "num_anchors")}


class ModelHead(nn.Module):
    def __init__(self, num_anchors, num_classes):
        super().__init__()
        self.model = nn.Sequential(
            Lambda(lambda x: x.permute(1, 0)),
            nn.InstanceNorm1d(num_features=num_anchors),
            Lambda(lambda x: x.permute(1, 0)),
            #             nn.Linear(in_features=num_anchors, out_features=300),
            #             nn.Tanh(),
            #             Lambda(lambda x: x.permute(1, 0)),
            #             nn.InstanceNorm1d(num_features=300),
            #             Lambda(lambda x: x.permute(1, 0)),
            nn.Linear(in_features=num_anchors, out_features=num_classes),
        )

    def forward(self, x, anchors):
        rel_out = relative_proj(x=x, anchors=anchors)[AttentionOutput.OUTPUT]
        return {Output.LOGITS: self.model(rel_out)}


def train_step(model, optimizer, absolute_latents, anchors_idxs):
    model.train()
    optimizer.zero_grad()
    out = model(absolute_latents, absolute_latents[anchors_idxs])
    logits = out[Output.LOGITS]
    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(model, absolute_latents, anchors_idxs):
    model.eval()
    out = model(absolute_latents, absolute_latents[anchors_idxs])
    pred = out[Output.LOGITS].argmax(dim=-1)
    accs = []
    for _, mask in data("train_mask", "val_mask"):
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return out, accs


# for i, experiment in enumerate(pbar := tqdm(ParameterSampler(sweep, n_iter=50, random_state=42), desc="Experiment")):
for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc="Experiment")):
    seed: int = experiment["seed"]
    temp_log_level = seed_log.getEffectiveLevel()
    seed_log.setLevel(logging.ERROR)
    seed_everything(seed)
    seed_log.setLevel(temp_log_level)

    num_anchors: int = experiment["num_anchors"]
    anchors_idxs = torch.as_tensor(random.sample(data.train_mask.nonzero().squeeze().cpu().tolist(), num_anchors))

    model = ModelHead(num_anchors=num_anchors, num_classes=dataset.num_classes)
    model.to(DEVICE)

    data = data.to(DEVICE)
    optimizer = experiment["optimizer"](
        model.parameters(),
        lr=experiment["lr"],
    )

    best_val_acc = 0
    best_epoch = None
    epochs = []
    for epoch in range(experiment["num_epochs"]):
        loss = train_step(model=model, optimizer=optimizer, absolute_latents=best_absolute, anchors_idxs=anchors_idxs)
        model_out, (train_acc, val_acc) = test(model=model, absolute_latents=best_absolute, anchors_idxs=anchors_idxs)
        # epochs.append(epoch_out)

        stats["experiment"].append(i)
        stats["epoch"].append(epoch)
        stats["loss"].append(loss)
        stats["train_acc"].append(train_acc)
        stats["val_acc"].append(val_acc)
        stats["num_anchors"].append(num_anchors)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = {
                "epoch": epoch,
                "loss": loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
            }
        # print(
        #     f"Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, "
        #     f"Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, "
        #     f"Final Test: {test_acc:.4f}"
        # )

    pbar.set_description(
        f"Epoch: {best_epoch['epoch']:04d}, Loss: {best_epoch['loss']:.4f} Train: {best_epoch['train_acc']:.4f}, "
        f"Val: {best_epoch['val_acc']:.4f}"
    )

    model.cpu()

stats = pd.DataFrame(stats)
stats.to_csv(
    PROJECT_ROOT
    / "experiments"
    / "sec:anchor-analysis"
    / f"{dataset_name}_data_manifold_stats_anchors_analysis_frozen_encoder.tsv",
    sep="\t",
)

In [None]:
import plotly.express as px

best_step = stats.groupby("experiment").agg([np.max]).droplevel(level=1, axis=1)
px.scatter(best_step, x="num_anchors", y="val_acc")

In [None]:
stats.to_csv(
    PROJECT_ROOT
    / "experiments"
    / "sec:anchor-analysis"
    / f"{dataset_name}_data_manifold_stats_anchors_analysis_frozen_encoder.tsv",
    sep="\t",
)