# Embedding graphs in hyperbolic spaces

## Imports

In [8]:
from gensim.models.poincare import PoincareModel
from bokeh.models import HoverTool, LabelSet
from bokeh.plotting import figure, show, ColumnDataSource
from bokeh.io import output_notebook
import pandas as pd
import matplotlib.pyplot as plt

## Utils

`GeometryUtils.poincare_disk_to_halfspace` is the function that maps from [Poincare disk model](https://en.wikipedia.org/wiki/Poincar%C3%A9_disk_model) to [half-plane model](https://en.wikipedia.org/wiki/Poincar%C3%A9_half-plane_model). This is an isometry, I wanted to experiment if visualization in half-space works better

In [24]:
class RoamUtils:
    
    @staticmethod
    def get_roam_relations(org_roam_records_df):
        return  [
            (record.source_name, record.destination_name)
            for record in org_roam_records_df[["source_name", "destination_name"]].itertuples()
        ]

class GometryUtils:

    @staticmethod
    def poincare_disk_to_halfspace(poincare_vectors):
        """
        this mapping
        """
        z_poincare = poincare_vectors[:,0] + 1j * poincare_vectors[:,1]
        w_half_space = -1j * (z_poincare - 1j) / (z_poincare + 1j)
        return np.column_stack([np.real(w_half_space), np.imag(w_half_space)])


def plot_model_annotated_scatterplot(kv_model, shown_text_indices=None):
    text_annotated_scatterplot(kv_model.vectors, kv_model.index_to_key, shown_text_indices)


def text_annotated_scatterplot(data, text, shown_text_indices=None, **kwargs):
    """
    displays data in 2d as a scatterplot
    the i-th point shows text[i] on hover
    """
    # Invoke output command to enable showing the plot in notebook
    output_notebook()

    # Create a data source to enable refreshing of fill
    source = ColumnDataSource(data=dict(
        x=data[:,0],
        y=data[:,1],
        desc=text,
    ))

    # Generate the scatter plot
    p = figure(width=1280, height=1280, tools="pan,wheel_zoom,xbox_select,reset", toolbar_location=None)
    p.scatter('x', 'y', size=5, source=source)

    # Add a hover tool to display text
    hover = HoverTool()
    hover.tooltips = [
        ("index", "$index"),
        ("(x,y)", "($x, $y)"),
        ("desc", "@desc"),
    ]
    p.add_tools(hover)

    # If shown_text_indices is provided, add labels to the plot
    if shown_text_indices is not None:
        labels = LabelSet(x='x', y='y', text='desc', level='glyph',
                          x_offset=5, y_offset=5, 
                          source=ColumnDataSource(data=dict(
                              x=data[shown_text_indices,0],
                              y=data[shown_text_indices,1],
                              desc=[text[i] for i in shown_text_indices],
                          )))
        p.add_layout(labels)

    show(p)

In [17]:
org_roam_records_df = pd.read_parquet("../data/org_roam_records_2023_09_10.parquet")

The graph is exported from org-roam files, see `Org_roam_graph.livemd`

In [18]:
org_roam_records_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 621 entries, 0 to 620
Data columns (total 8 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   destination_date       621 non-null    object
 1   destination_name       621 non-null    object
 2   destination_path       621 non-null    object
 3   destination_root_path  621 non-null    object
 4   source_date            621 non-null    object
 5   source_name            621 non-null    object
 6   source_path            621 non-null    object
 7   source_root_path       621 non-null    object
dtypes: object(8)
memory usage: 38.9+ KB


## Poincare embeddings from gensim

These are the easiest to setup so we will start with them

In [19]:
roam_relations = get_roam_relations(org_roam_records_df)
len(roam_relations)

621

In [20]:
model = PoincareModel(roam_relations, negative=4, size=2, alpha=0.01)
model.train(epochs=50)

In [21]:
model.kv.key_to_index # vocabulary

{'substructural_logic.org': 0,
 'logic.org': 1,
 'curry_howard.org': 2,
 'rust_types.org': 3,
 'c_algebras.org': 4,
 'renormalization.org': 5,
 'math.org': 6,
 'ml_general.org': 7,
 'general_nlp.org': 8,
 'ml_articles.org': 9,
 'metrics_mrr.org': 10,
 'retrieval_nlp.org': 11,
 'mgr_lcsh.org': 12,
 'mgr.org': 13,
 'readlog_logic.org': 14,
 'wittgenstein.org': 15,
 'logic_intensionality.org': 16,
 'modal_logic.org': 17,
 'epistemic_logic.org': 18,
 'elixir_metaprogramming.org': 19,
 'elixir.org': 20,
 'org_ai_finance.org': 21,
 'options_data_api.org': 22,
 'multiple_instance_learning.org': 23,
 'llms_interesting.org': 24,
 'llms_rwkv.org': 25,
 'llms_llama.org': 26,
 'llms_news.org': 27,
 'redpajama.org': 28,
 'ggml.org': 29,
 'llms.org': 30,
 'nlp_piqa.org': 31,
 'llms_evaluation.org': 32,
 'apteki_intents.org': 33,
 'apteki_intents_training_data.org': 34,
 'pl_en_translation.org': 35,
 'polish_nlp.org': 36,
 'transformers_onnx.org': 37,
 'zenml.org': 38,
 'mlops.org': 39,
 'zettelkaste

In [22]:
text_annotated_scatterplot

<function __main__.text_annotated_scatterplot(data, text, shown_text_indices=None, **kwargs)>

In [23]:
len(model.kv.index_to_key), len(model.kv.vectors)

(482, 482)

In [29]:
plot_model_annotated_scatterplot(model.kv, list(range(0, 400, 20)))