In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [None]:
import logging

import matplotlib.pyplot as plt
import torch
from pykeen.datasets import Nations
from pykeen.pipeline import pipeline
from statistics import mean, median
import pandas as pd
import numpy as np
import seaborn as sns
import networkx as nx
from pyvis.network import Network
from ipywidgets import interact
import ipywidgets as widgets
from itertools import chain
from pathlib import Path
from rich import print

from utils import prepare_for_visualization

## Settings

In [None]:
import warnings

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")

pd.set_option("display.float_format", lambda x: "%.5f" % x)
logging.getLogger("pykeen").setLevel(logging.CRITICAL)

## Loading data

In [None]:
dataset = Nations()

g = nx.DiGraph()
g.add_edges_from([(h, t, {"title": r}) for h, r, t in dataset.training.triples])

## Visualization

In [None]:
nt = Network("500px", "500px", directed=True, notebook=True)
nt.inherit_edge_colors(False)
nt.from_nx(g)
nt.toggle_physics(False)
nt.show("basic.html")

In [None]:
_list_nodes = list(g.nodes)


@interact
def visualize(
    nodes=widgets.SelectMultiple(options=list(g.nodes), value=[_list_nodes[0]]),
    k=[0, 1, 2, 3],
    toggle_physics=False,
):
    filtered = set(
        chain(
            *[
                list(nx.single_source_shortest_path_length(g, n, cutoff=k))
                for n in nodes
            ]
        )
    )
    # print(filtered)
    subgraph = nx.subgraph_view(g, filter_node=lambda x: x in filtered)
    nt = Network(
        "500px", "500px", directed=True, notebook=True, cdn_resources="in_line"
    )
    nt.inherit_edge_colors(False)
    nt.from_nx(subgraph)
    nt.toggle_physics(toggle_physics)
    display(nt.show("basic.html"))

## EDA

In [None]:
data = []
for subset_name in ["training", "validation", "testing"]:
    subset_metrics = {"subset": subset_name}
    subset = dataset.__getattribute__(subset_name)
    triples = subset.triples
    subset_metrics["num_triples"] = len(triples)
    subset_metrics["num_entities"] = len(np.unique(triples[:, [0, 2]]))
    subset_metrics["num_relations"] = len(np.unique(triples[:, 1]))
    data.append(subset_metrics)

pd.DataFrame(data)

In [None]:
metrics = {}
metrics["n_connected_components"] = nx.number_connected_components(g.to_undirected())
metrics["mean_size_of_connected_components"] = mean(
    len(c) for c in nx.connected_components(g.to_undirected())
)
metrics["median_size_of_connected_components"] = median(
    len(c) for c in nx.connected_components(g.to_undirected())
)
metrics["density"] = nx.density(g)
metrics["number_of_selfloops"] = nx.number_of_selfloops(g)
metrics["average_clustering"] = nx.average_clustering(g)

pd.DataFrame({"training": metrics})

In [None]:
x, y = zip(*g.degree())

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].set_title("Degree of nodes")
axes[0].barh(y=x, width=y)
axes[1].set_title("Degree Histogram")
sns.histplot([d for n, d in g.degree()], ax=axes[1])
plt.tight_layout()

## Training 

In [None]:
result = pipeline(
    dataset="Nations",
    model="TransE",
    model_kwargs={"embedding_dim": 64},
    loss="nssa",
    loss_kwargs={"adversarial_temperature": 0.34, "margin": 9},
    optimizer="Adam",
    optimizer_kwargs={"lr": 0.004},
    negative_sampler_kwargs={"num_negs_per_pos": 33},
    training_kwargs=dict(
        num_epochs=50,
        batch_size=512,
        use_tqdm_batch=False,
    ),
    random_seed=123,
)

In [None]:
save_location = Path("results/nations")
save_location.mkdir(exist_ok=True, parents=True)
result.save_to_directory(save_location)
print(f"Saved: {os.listdir(save_location)}")

## Metrics

In [None]:
result.plot_losses()
plt.show()

In [None]:
metrics = result.metric_results.to_df()

In [None]:
metrics[(metrics.Side == "both") & (metrics.Type == "realistic")]

## Embeddings visualization

In [None]:
model = result.model
model

In [None]:
embeddings = result.model.entity_representations[0](torch.arange(dataset.num_entities))
labels = [dataset.training.entity_id_to_label[i] for i in range(dataset.num_entities)]

In [None]:
prepare_for_visualization(embeddings.detach().numpy(), labels, Path("logs/nations"))

In [None]:
!tensorboard --logdir=logs/nations