In [None]:
import sys
import os
import logging
import subprocess
import pandas as pd
import numpy as np
import openai
from pathlib import Path

utils_path = os.path.abspath('/home/wadmin/embed_norm/apps/embed_norm/src')
if utils_path not in sys.path:
    sys.path.append(utils_path)

# root_path = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).decode().strip()
# os.chdir(Path(root_path) / 'pipelines' / 'matrix')

# configure_logging()
# utils_path = Path.cwd()
# if utils_path not in sys.path:
#     sys.path.append(utils_path)

%load_ext autoreload
%autoreload 2

# Import functions/classes from main.py
from test import (
    configure_logging,
    setup_environment,
    CacheManager,
    Config,
    DataLoader,
    Normalizer,
    EmbeddingGenerator,
    LLMEnhancer,
    load_data,
)


In [None]:
configure_logging()
setup_environment(utils_path=utils_path)

In [None]:
root_path = Path.cwd().parents[1]

# Set up variables
cache_dir = Path(root_path) / "apps" / "embed_norm" / "cached_datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Cache directory set at '{cache_dir}'.")

for subdir in ["embeddings", "datasets"]:
    subdir_path = cache_dir / subdir
    subdir_path.mkdir(parents=True, exist_ok=True)
    logging.info(f"Subdirectory '{subdir}' created at '{subdir_path}'.")

pos_seed = 54321
neg_seed = 67890
dataset_name = "rtx_kg2.int"
nodes_dataset_name = "integration.int.rtx.nodes"
edges_dataset_name = "integration.int.rtx.edges"
categories = ["All Categories"]
model_names = ["OpenAI", "PubMedBERT", "SapBERT", "BlueBERT", "BioBERT"]

openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
    logging.error("OPENAI_API_KEY environment variable is not set.")
    sys.exit(1)
openai.api_key = openai_api_key
logging.info("OpenAI API key is set.")

total_sample_size = 1000
positive_ratio = 0.2
positive_n = int(total_sample_size * positive_ratio)
negative_n = total_sample_size - positive_n
cache_suffix = f"_pos_{positive_n}_neg_{negative_n}"

config = Config(
    cache_dir=cache_dir,
    pos_seed=pos_seed,
    neg_seed=neg_seed,
    dataset_name=dataset_name,
    nodes_dataset_name=nodes_dataset_name,
    edges_dataset_name=edges_dataset_name,
    categories=categories,
    model_names=model_names,
    total_sample_size=total_sample_size,
    positive_ratio=positive_ratio,
    positive_n=positive_n,
    negative_n=negative_n,
    cache_suffix=cache_suffix,
    use_llm_enhancement=False  # Set to True to enable LLM enhancement
)
logging.info("Configuration variables are set.")

cache_manager = CacheManager(config.cache_dir)

In [None]:
categories, positive_datasets, negative_datasets, nodes_df = load_data(config, cache_manager)

In [None]:
import pandas as pd
import numpy as np

print(nodes_df.head())

# 1. Count missing values per column
missing_counts = nodes_df.isnull().sum()

print("\nMissing values per column:")
print(missing_counts)

# 2. Identify columns with unhashable types
def is_column_unhashable(col):
    try:
        # Attempt to hash the first non-null entry
        sample = col.dropna().iloc[0]
        hash(sample)
        return False
    except TypeError:
        return True
    except IndexError:
        # Column is entirely NaN
        return False

# Identify unhashable columns
unhashable_columns = [col for col in nodes_df.columns if is_column_unhashable(nodes_df[col])]

print("\nColumns with unhashable types:", unhashable_columns)

# 3. Handle unhashable columns by converting them to tuples
for col in unhashable_columns:
    nodes_df[col] = nodes_df[col].apply(lambda x: tuple(x) if isinstance(x, (list, np.ndarray)) else x)

print("\nConverted unhashable columns to tuples.")

# 4. Now, count unique rows
unique_rows_count = nodes_df.drop_duplicates().shape[0]
print(f"\nTotal number of unique rows: {unique_rows_count}")

# 5. Number of rows per category
rows_per_category = nodes_df['category'].value_counts(dropna=False)
print("\nNumber of rows per category:")
print(rows_per_category)

# 6. Missing values per column for each category
missing_counts_per_category = nodes_df.groupby('category').apply(lambda x: x.isnull().sum())

print("\nMissing values per column for each category:")
print(missing_counts_per_category)

# (Optional) Improved readability
# for category, group in nodes_df.groupby('category'):
#     print(f"\nCategory: {category}")
#     missing = group.isnull().sum()
#     print(missing)

In [None]:
if config.use_llm_enhancement:
    text_fields = ['llm_enhanced_text']
else:
    text_fields = ['name', 'category', 'labels', 'all_categories']

embedding_generator = EmbeddingGenerator(config, cache_manager)
embeddings_dict_all_models = embedding_generator.process_models(
    model_names=config.model_names,
    positive_datasets=positive_datasets,
    negative_datasets=negative_datasets,
    seed=config.neg_seed,
    text_fields=text_fields,
    label_generation_func=None,
    dataset_name=config.dataset_name,
    use_ontogpt=False,
    cache_suffix=config.cache_suffix,
    use_combinations=False,
    combine_fields=False,
)
logging.info("Embeddings for models processed successfully using process_models().")
embeddings_dict = embeddings_dict_all_models