In [1]:
import json
import os
from src.data_processing.indexing.embeddings_handler import (
    get_embeddings,
    get_model_info,
    save_embeddings_details_to_json,
    load_embeddings_details_from_json,
    find_most_similar,
    save_to_csv
)
import numpy as np
from typing import List, Union, Dict, Any

In [None]:
def generate_embedding_details(
        chunk_dir: str, 
        annotations_models: List[str], 
        embedding_models: List[str],
        output_file_path: str
) -> List[Dict[str, Any]]:
    annotations_models = [
        "bioformer",
        "pubmedbert"
    ]

    embedding_models = [
        "bio_bert",
        "bio_gpt",
        "longformer",
        "sci_bert"
    ]

    all_embedding_detials = []

    chunk_dir = "../../data/PMC_7614604_chunks"
    #chunk_dir = "../../data/test/test"

    # Read the chunks and create embeddings:
    for annotation_model in annotations_models:
        print("Processing for Annotation Model: ", annotation_model)
        for cur_file in os.listdir(chunk_dir):
            if cur_file.endswith(".json") and annotation_model in cur_file:
                input_file_path = f"{chunk_dir}/{cur_file}"

                with open(f"{input_file_path}", "r") as f:
                    print(f"Processing {input_file_path}")
                    chunks = json.load(f)
                    merged_texts_with_sum = [
                        f"Summary:\n{article_summary}\nText:\n{chunk["merged_text"]}"
                        for chunk in chunks
                    ]
                    #print(merged_texts_with_sum)
                    merged_texts_without_sum = [chunk["merged_text"] for chunk in chunks]

                    # Creating Embeddings with Summary
                    for embedding_model in embedding_models:
                        print(f"Processing for Embedding Model: {embedding_model} with article summary")
                        model_info = get_model_info(embedding_model)
                        embeddings = get_embeddings(
                            model_name=model_info[0],
                            token_limit=model_info[1],
                            texts=merged_texts_with_sum
                        )

                        embeddings_details = {
                            "file": cur_file,
                            "chunks_count": len(merged_texts_with_sum),
                            "annotation_model": annotation_model,
                            "embeddings_model": embedding_model,
                            "embeddings_model_token_limit": model_info[1],
                            "contains_summary": True,
                            "embeddings": embeddings
                        }

                        all_embedding_detials.append(embeddings_details)

                    # Creating Embeddings without Summary
                    for embedding_model in embedding_models:
                        print(f"Processing for Embedding Model: {embedding_model} without article summary")
                        model_info = get_model_info(embedding_model)
                        embeddings = get_embeddings(
                            model_name=model_info[0],
                            token_limit=model_info[1],
                            texts=merged_texts_without_sum
                        )

                        embeddings_details = {
                            "file": cur_file,
                            "chunks_count": len(merged_texts_without_sum),
                            "annotation_model": annotation_model,
                            "embeddings_model": embedding_model,
                            "embeddings_model_token_limit": model_info[1],
                            "contains_summary": False,
                            "embeddings": embeddings
                        }

                        all_embedding_detials.append(embeddings_details)

    # Write the Embeddings to a file:
    file_path = "../../data/PMC_7614604_chunks/embeddings/PMC_7614604_embeddings.json"
    save_embeddings_details_to_json(all_embedding_detials, file_path)
