# Embedding Analysis Notebook

This notebook demonstrates how to use the updated functions from `embedding_utils.py` to load and cache datasets, embeddings, and labels, process models, and visualize embeddings using the dashboard app.

## How to Use This Notebook

1. **Import Modules and Functions**: Import necessary functions from `embedding_utils.py`, including the new functions for loading and caching embeddings and labels.
2. **Set Variables**: Specify variables such as dataset names, seeds, and categories.
3. **Load Datasets and Embeddings**: Use the provided functions to load datasets and embeddings, handling caching automatically.
4. **Define Custom Functions**: Define custom text representation and label generation functions as needed.
5. **Process Models**: Generate embeddings or load them from cache using the provided functions.
6. **Run the Dashboard App**: Launch the dashboard app to visualize your embeddings.

Note: The dataset caching functionality is handled by the functions imported from `embedding_utils.py`. You can modify variables to change how embeddings work and specify different datasets.

In [None]:
# Import necessary functions and modules
import os
import sys
import logging

from embed_norm_test.embedding_utils import (
    process_model,
    process_model_combinations,
    embedding_models_info,
    parse_list_string,
    load_datasets,
    load_categories,
    load_embeddings_and_labels,
    missing_data_rows_dict
)

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(message)s')

# Additional imports
import pandas as pd
from kedro.framework.project import configure_project
import openai  # If you're using OpenAI embeddings

## Set Up Variables

Specify your dataset names, seeds, and cache directory.

In [None]:
# Set up variables
cache_dir = 'cached_datasets'
os.makedirs(cache_dir, exist_ok=True)

# Seeds for sampling
seed1 = 54321  # Seed for positive samples
seed2 = 67890  # Seed for negative samples

# Dataset configuration
dataset_name = 'rtx_kg2'  # Replace with your dataset name
nodes_dataset_name = 'ingestion.raw.rtx_kg2.nodes@pandas'  # Replace with your nodes dataset name
edges_dataset_name = 'ingestion.raw.rtx_kg2.edges@pandas'  # Replace with your edges dataset name

# Categories to process
categories = ['All Categories']  # Specify categories or use 'All Categories'

# Model configuration
model_name = 'OpenAI'
model_info = embedding_models_info[model_name]

# OpenAI API Key (if using OpenAI embeddings)
openai.api_key = os.getenv("OPENAI_API_KEY")  # Ensure your API key is set

## Load Datasets and Embeddings

Use the provided functions to load datasets and embeddings, handling caching automatically.

In [None]:
# Load datasets using caching functions

# Configure Kedro project
configure_project('matrix')  # Replace 'matrix' with your Kedro project name if different

# Load categories
categories = load_categories(
    cache_dir=cache_dir,
    dataset_name=dataset_name,
    nodes_dataset_name=nodes_dataset_name
)

# Load datasets
positive_datasets, datasets = load_datasets(
    cache_dir=cache_dir,
    dataset_name=dataset_name,
    nodes_dataset_name=nodes_dataset_name,
    edges_dataset_name=edges_dataset_name,
    categories=categories,
    seed1=seed1,
    seed2=seed2
)

# Load embeddings and labels (if cached)
embeddings_dict, labels_dict = load_embeddings_and_labels(
    cache_dir=cache_dir,
    dataset_name=dataset_name,
    model_name=model_name,
    categories=categories,
    seed=seed2,
    combinations=False
)

# Check if embeddings are loaded; if not, they will be generated in the next steps
if not embeddings_dict:
    print("Embeddings not found in cache; they will be generated.")
else:
    print("Embeddings loaded from cache.")

## Define Custom Functions

Customize how embeddings are generated by defining custom text representation and label generation functions.

In [None]:
# # Define custom text representation function
# def node_to_string(row):
#     fields = [row.get('name', ''), row.get('description', '')]
#     text_values = []
#     for field_value in fields:
#         if pd.notnull(field_value):
#             parsed_list = parse_list_string(field_value)
#             text_values.extend(parsed_list)
#     return ' '.join(text_values).strip()

def node_to_string(row, text_fields=None):
    if text_fields is None:
        text_fields = ['all_names:string[]', 'all_categories:string[]']
    global missing_data_rows_dict
    fields = [row.get(field, '') for field in text_fields]
    missing_fields = [field for field, value in zip(text_fields, fields)
                      if pd.isnull(value) or not str(value).strip()]
    for missing_field in missing_fields:
        if missing_field not in missing_data_rows_dict:
            missing_data_rows_dict[missing_field] = []
        missing_data_rows_dict[missing_field].append(row)
    text_values = []
    for field_value in fields:
        parsed_list = parse_list_string(field_value)
        text_values.extend(parsed_list)
    text_representation = ' '.join(text_values).strip()
    if not text_representation:
        logging.warning(f"Empty text representation for row with index {row.name}")
    print(f"Text representation for row with index {row.name}: {text_representation}")
    return text_representation

In [None]:
# Define custom label generation function
def label_func(row):
    # return '<br>'.join(f"{k}: {str(v)[:200]}" for k, v in row.items())
    reutrn (row['id'], row['name'], 'wow this is a label')

## Process Models

Generate embeddings for your datasets using the `process_model` functions. If embeddings are already cached, they will be loaded; otherwise, they will be generated and cached.

### Generate or Load Embeddings for Datasets

In [None]:
# Process embeddings for the datasets
model_name, embeddings_dict = process_model(
    model_name=model_name,
    model_info=model_info,
    datasets=datasets,
    cache_dir=cache_dir,
    seed=seed2,
    text_representation_func=node_to_string,
    label_generation_func=label_func,
    dataset_name=dataset_name
)
# Process embeddings for positive datasets
model_name, embeddings_dict_pos = process_model(
    model_name=model_name,
    model_info=model_info,
    datasets=positive_datasets,
    cache_dir=cache_dir,
    seed=seed1,
    text_representation_func=node_to_string,
    label_generation_func=label_func,
    dataset_name=dataset_name
)


### Generate or Load Embeddings with Combinations (Optional)

In [None]:
# # Define custom text representations for combinations
# def node_to_strings(row):
#     names_field = 'name'
#     categories_field = 'category'
#     names_field_value = row.get(names_field, '')
#     categories_field_value = row.get(categories_field, '')
    
#     names_list = parse_list_string(names_field_value)
#     categories_list = parse_list_string(categories_field_value)
    
#     from itertools import product
#     combinations = list(product(names_list, categories_list))
    
#     text_representations = [' '.join(combination).strip() for combination in combinations]
#     return text_representations

# # Generate embeddings with combinations
# model_name, embeddings_dict = process_model_combinations(
#     model_name=model_name,
#     model_info=model_info,
#     datasets=datasets,
#     cache_dir=cache_dir,
#     seed=seed2,
#     text_representation_func=node_to_strings,
#     label_generation_func=label_func,
#     dataset_name=dataset_name
# )

# Process embeddings for positive datasets with combinations
# model_name, embeddings_dict_pos = process_model_combinations(
#     model_name=model_name,
#     model_info=model_info,
#     datasets=positive_datasets,
#     cache_dir=cache_dir,
#     seed=seed1,
#     text_representation_func=node_to_strings,
#     label_generation_func=label_func,
#     dataset_name=dataset_name
# )

### Process Positive Datasets (Optional)

You can also process positive datasets using the same functions.

## Run the Dashboard App

After generating the embeddings, you can run the dashboard app to visualize them.

### Instructions to Run the Dashboard App

1. Ensure that the embeddings have been generated and saved in the cache directory.
2. Navigate to the directory containing `app.py`.
3. Run the app using the command:
   ```
   python app.py
   ```
4. Open the provided URL (usually `http://0.0.0.0:3000`) in your web browser.

The dashboard should now display the embeddings and allow you to interact with them.

## Adding New Models

To add new models, update the `embedding_models_info` dictionary in `embedding_utils.py` with the details of the new model.

### Example: Adding a New Hugging Face Model

```python
# In embedding_utils.py
embedding_models_info = {
    'OpenAI': {
        'type': 'openai',
    },
    'YourModelName': {
        'type': 'hf',
        'tokenizer_name': 'your-model-tokenizer-name',
        'model_name': 'your-model-name'
    },
}
```

After updating, you can regenerate embeddings for the new model using the same process.