# Embedding Analysis Notebook

This notebook provides a template for analyzing embeddings using the Embedding Projector. It includes functions for loading and caching embeddings and labels, as well as a dashboard app for visualizing embeddings.

## How to Use This Notebook

1. Run the cells in the "Setup" section to import necessary libraries and functions.
2. Modify the variables in the "Configuration" section to specify the dataset and embedding files.
3. Run the cells in the "Load Data" section to load the embeddings and labels.
4. Run the cells in the "Visualize Embeddings" section to launch the dashboard app.

Note: The dataset caching functionality is handled by the functions imported from `embedding_utils.py`. You can modify variables to change how embeddings work and specify different datasets.

In [None]:
import os
import sys
import logging
import pandas as pd
import openai
import subprocess
from pathlib import Path
from kedro.framework.project import configure_project
from kedro.framework.session import KedroSession

utils_path = os.path.abspath('/home/wadmin/embed_norm/apps/embed_norm/embed_norm_test')
if utils_path not in sys.path:
    sys.path.append(utils_path)

root_path = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).decode().strip()
os.chdir(Path(root_path) / 'pipelines' / 'matrix')

%load_ext autoreload
%autoreload 2

from embedding_utils import (
    process_model,
    embedding_models_info,
    parse_list_string,
    load_datasets,
    load_embeddings_and_labels,
    missing_data_rows_dict,
    generate_candidate_pairs,
    refine_candidate_mappings_with_llm,
    find_additional_mappings_with_curategpt
)

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(message)s')

## Set Up Variables

Specify your dataset names, seeds, and cache directory.

In [None]:
cache_dir = os.path.join(root_path, 'apps', 'embed_norm', 'cached_datasets')
os.makedirs(cache_dir, exist_ok=True)

for subdir in ['categories', 'embeddings', 'datasets']:
    os.makedirs(os.path.join(cache_dir, subdir), exist_ok=True)

pos_seed = 54321
neg_seed = 67890

dataset_name = 'rtx_kg2.int'
nodes_dataset_name = 'integration.int.rtx.nodes'
edges_dataset_name = 'integration.int.rtx.edges'

categories = ['All Categories']

model_name = 'OpenAI'
model_info = embedding_models_info[model_name]

openai.api_key = os.getenv('OPENAI_API_KEY')

total_sample_size = 1000
positive_ratio = 0.2

positive_n = int(total_sample_size * positive_ratio)
negative_n = total_sample_size - positive_n
cache_suffix = f"_pos_{positive_n}_neg_{negative_n}"

## Load Datasets and Embeddings

Use the provided functions to load datasets and embeddings, handling caching automatically.

In [None]:
configure_project('matrix')

with KedroSession.create() as session:
    context = session.load_context()
    catalog = context.catalog
    nodes_df = catalog.load(nodes_dataset_name)

categories, positive_datasets, datasets = load_datasets(
    nodes_df=nodes_df,
    cache_dir=os.path.join(cache_dir, 'datasets'),
    dataset_name=dataset_name,
    seed1=pos_seed,
    seed2=neg_seed,
    total_sample_size=total_sample_size,
    positive_ratio=positive_ratio
)

## Define Text Representation and Label Generation Functions

Customize how embeddings are generated by defining custom text representation.

In [None]:
# def node_to_string(row, text_fields=None):
#     if text_fields is None:
#         text_fields = ['all_names:string[]', 'all_categories:string[]']
#     global missing_data_rows_dict
#     fields = [row.get(field, '') for field in text_fields]
#     missing_fields = [field for field, value in zip(text_fields, fields)
#                       if pd.isnull(value) or not str(value).strip()]
#     for missing_field in missing_fields:
#         if missing_field not in missing_data_rows_dict:
#             missing_data_rows_dict[missing_field] = []
#         missing_data_rows_dict[missing_field].append(row)
#     text_values = []
#     for field_value in fields:
#         parsed_list = parse_list_string(field_value)
#         text_values.extend(parsed_list)
#     text_representation = ' '.join(text_values).strip()
#     if not text_representation:
#         logging.warning(f"Empty text representation for row with index {row.name}")
#     return text_representation


def node_to_string(row, text_fields):
    fields = [row.get(field, '') for field in text_fields]
    text_values = []
    for field_value in fields:
        if pd.notnull(field_value):
            parsed_list = parse_list_string(field_value)
            text_values.extend(parsed_list)
    return ' '.join(text_values).strip()

def label_func(row):
    return f"{row['id']}, {row['name']}, custom label"

text_fields = ['name', 'description', 'category', 'labels', 'all_categories', 'equivalent_identifiers']

print(nodes_df.columns)

## Process Models

Generate and cache embeddings for your datasets using the `process_model` functions.

### Generate or Load Embeddings for Datasets

In [None]:
model_name, embeddings_dict = process_model(
    model_name=model_name,
    model_info=model_info,
    datasets=datasets,
    cache_dir=os.path.join(cache_dir, 'embeddings'),
    seed=neg_seed,
    text_fields=text_fields,
    text_representation_func=node_to_string,
    # label_generation_func=label_func,
    dataset_name=dataset_name,
    use_ontogpt=False,
    cache_suffix=cache_suffix
)

model_name, embeddings_dict_pos = process_model(
    model_name=model_name,
    model_info=model_info,
    datasets=positive_datasets,
    cache_dir=os.path.join(cache_dir, 'embeddings'),
    seed=pos_seed,
    text_fields=text_fields,
    text_representation_func=node_to_string,
    # label_generation_func=label_func,
    dataset_name=dataset_name,
    use_ontogpt=False,
    cache_suffix=cache_suffix
)

### Visualize Embeddings

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

category = 'All Categories'
embeddings = embeddings_dict[category]
# Assuming labels_dict is loaded or generated accordingly
# labels = labels_dict[category]

reduced_embeddings = PCA(n_components=2).fit_transform(embeddings)

plt.figure(figsize=(10, 10))
plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], alpha=0.5)
plt.title(f'Embeddings Visualization for {category}')
plt.xlabel('Component 1')
plt.ylabel('Component 2')
plt.show()