In [3]:
import os
import pandas as pd

# This script processes CSV files containing neighbor entries for different attributes,
# extracts unique neighbors, and saves them to separate text files for each attribute.

# Directory containing the input CSV files
directory = 'splits'
# Directory where the output text files will be saved
output_directory = 'faiss_index'

# Ensure the output directory exists
os.makedirs(output_directory, exist_ok=True)

# List of attributes to process
attributes = ['bioActivity', 'collectionSite', 'collectionSpecie', 'collectionType', 'name']

for attribute in attributes:
    # Initialize a set to store unique neighbor entries
    unique_neighbors = set()

    # Iterate over all files in the directory
    for filename in os.listdir(directory):
        if filename.startswith(f"test_doi_{attribute}") or filename.startswith(f"train_doi_{attribute}"):
            filepath = os.path.join(directory, filename)
            # Read the CSV file
            df = pd.read_csv(filepath)
            # Add unique neighbor entries to the set
            unique_neighbors.update(df['neighbor'].unique())

    # Save the unique neighbors to a text file
    with open(os.path.join(output_directory, f'unique_{attribute}.txt'), 'w') as f:
        for neighbor in sorted(unique_neighbors):
            f.write(f"{neighbor}\n")


In [1]:
#create faiss indexes
import os
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Set your OpenAI API key
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

# Directory containing the text files
input_directory = 'faiss_index'
output_directory = 'faiss_index_trained'

# Ensure the output directory exists
os.makedirs(output_directory, exist_ok=True)

# Iterate over all text files in the input directory
for filename in os.listdir(input_directory):
    if filename.endswith('.txt'):
        filepath = os.path.join(input_directory, filename)
        
        # Read the text file and create Document objects
        entities = []
        with open(filepath, 'r') as f:
            for line in f:
                text = line.strip()
                doc = Document(page_content=text, metadata={'text': text})
                entities.append(doc)
        
        # Create FAISS index from documents
        faiss_index = FAISS.from_documents(entities, embeddings)
        
        # Save the FAISS index locally
        index_path = os.path.join(output_directory, f'{filename}.index')
        faiss_index.save_local(index_path)

print("FAISS indices have been created and saved.")

FAISS indices have been created and saved.


In [None]:
#test Faiss indexes
import os
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS

# Set your OpenAI API key
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

# Directory containing the FAISS indexes
index_directory = 'faiss_index_trained'

# Mapping from number to attribute
attribute_mapping = {
    1: 'collectionSpecie',
    2: 'collectionSite',
    3: 'bioActivity',
    4: 'name',
    5: 'collectionType'
}

def load_faiss_index(attribute_number):
    attribute = attribute_mapping.get(attribute_number)
    if not attribute:
        raise ValueError(f"Invalid attribute number: {attribute_number}")
    index_path = os.path.join(index_directory, f'unique_{attribute}.txt.index')
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"FAISS index for attribute '{attribute}' not found.")
    return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)

def similarity_search(attribute_number, query, top_k=5):
    faiss_index = load_faiss_index(attribute_number)
    docs_with_score = faiss_index.similarity_search_with_score(query, top_k=top_k)
    return docs_with_score

# Example usage
attribute_number = 2  # Change this to the attribute number you want to search (1 to 5)
query = "test"  # Change this to your query string

try:
    results = similarity_search(attribute_number, query)
    for doc, score in results:
        print(f"Document: {doc.page_content}, Score: {score}")
except (FileNotFoundError, ValueError) as e:
    print(e)

In [None]:
#Updated and Optimized workflow to account for multiple simsearch results being inserted to the file (for hits@k != 1)
import os
import ast
import csv
import re
import logging
import shutil
from tqdm import tqdm
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set your OpenAI API key
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

# Directory containing the FAISS indexes
index_directory = 'faiss_index_trained'

def load_faiss_index(attribute):
    index_path = os.path.join(index_directory, f'unique_{attribute}.txt.index')
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"FAISS index directory for attribute '{attribute}' not found.")
    return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)

def similarity_search(faiss_index, query, top_k):
    docs_with_score = faiss_index.similarity_search_with_score(query, k=top_k)
    return docs_with_score

def clean_restored_field(restored_str):
    # Remove any leading/trailing whitespace and ensure proper list format
    restored_str = restored_str.strip()
    if not (restored_str.startswith('[') and restored_str.endswith(']')):
        restored_str = f"[{restored_str}]"
    return restored_str

def process_row(row, faiss_index, attribute, top_k):
    try:
        restored_str = clean_restored_field(row['restored'])
        restored = ast.literal_eval(restored_str)
        if len(restored) > 1 and isinstance(restored[1], list):
            restored[1] = [str(item) for item in restored[1]]
            query = ' '.join(restored[1])
            docs_with_score = similarity_search(faiss_index, query, top_k)
            if docs_with_score:
                old_value = row['restored']
                similar_entries = ', '.join([doc.page_content for doc, _ in docs_with_score])
                new_restored_value = [restored[0], [similar_entries]]
                row['restored'] = str(new_restored_value)
                logging.info(f"changed {old_value} to {new_restored_value} in row {row['true']}")
    except (ValueError, SyntaxError, TypeError) as e:
        logging.error(f"Error parsing restored field in row {row}: {e}")
    return row

def update_restored_with_similarity_search(file_path, attribute, top_k):
    backup_file_path = file_path + '.bak'
    shutil.copy(file_path, backup_file_path)
    
    try:
        faiss_index = load_faiss_index(attribute)
    except FileNotFoundError as e:
        logging.error(e)
        return

    temp_file = file_path + '.temp'
    batch_size = 100  # Adjust this based on your available memory

    with open(file_path, 'r', newline='', encoding='utf-8') as infile, \
         open(temp_file, 'w', newline='', encoding='utf-8') as outfile:
        reader = csv.DictReader(infile)
        writer = csv.DictWriter(outfile, fieldnames=reader.fieldnames, quoting=csv.QUOTE_ALL)
        writer.writeheader()

        batch = []
        for row in tqdm(reader, desc=f"Processing rows in {os.path.basename(file_path)}"):
            batch.append(row)
            if len(batch) >= batch_size:
                processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
                writer.writerows(processed_batch)
                batch = []

        if batch:
            processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
            writer.writerows(processed_batch)

    os.replace(temp_file, file_path)

def process_files_for_similarity_search(directory, attribute, top_k):
    pattern = re.compile(rf'llm_results_gpt4_0.8_doi_{attribute}_\d+_\d+(?:st|nd|rd|th)\.csv')
    files = [f for f in os.listdir(directory) if pattern.match(f)]

    for filename in tqdm(files, desc="Processing files"):
        file_path = os.path.join(directory, filename)
        logging.info(f"Processing file: {filename} with attribute: {attribute}")
        update_restored_with_similarity_search(file_path, attribute, top_k)

# Process all files for each attribute
directory = 'results/LLM_corr_k_full/'
attributes_and_k = {
    #'collectionSpecie': 50,
    #'collectionSite': 20,
    #'bioActivity': 5,
    'name': 50,
    #'collectionType': 1
}

for attribute, top_k in attributes_and_k.items():
    logging.info(f"Processing attribute: {attribute} with top_k: {top_k}")
    process_files_for_similarity_search(directory, attribute, top_k)


In [2]:
# for finetuned model. collectionSite 0 1st
import os
import ast
import csv
import re
import logging
import shutil
from tqdm import tqdm
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set your OpenAI API key
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("OpenAI API key not found. Please set the 'OPENAI_API_KEY' environment variable.")

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

# Directory containing the FAISS indexes
index_directory = 'faiss_index_trained'

def load_faiss_index(attribute):
    """
    Load the FAISS index for a given attribute.

    Parameters:
    - attribute (str): The attribute name (e.g., 'collectionSite').

    Returns:
    - FAISS: The loaded FAISS index object.
    """
    index_path = os.path.join(index_directory, f'unique_{attribute}.txt.index')
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"FAISS index file for attribute '{attribute}' not found at {index_path}.")
    return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)

def similarity_search(faiss_index, query, top_k):
    """
    Perform a similarity search using the FAISS index.

    Parameters:
    - faiss_index (FAISS): The FAISS index object.
    - query (str): The query string.
    - top_k (int): Number of top similar documents to retrieve.

    Returns:
    - list of tuples: Each tuple contains (Document, score).
    """
    docs_with_score = faiss_index.similarity_search_with_score(query, k=top_k)
    return docs_with_score

def clean_restored_field(restored_str):
    """
    Clean and ensure the 'restored' field is in proper list format.

    Parameters:
    - restored_str (str): The string representation of the 'restored' field.

    Returns:
    - str: Cleaned string in list format.
    """
    restored_str = restored_str.strip()
    if not (restored_str.startswith('[') and restored_str.endswith(']')):
        restored_str = f"[{restored_str}]"
    return restored_str

def process_row(row, faiss_index, attribute, top_k):
    """
    Process a single row by performing similarity search and updating the 'restored' field.

    Parameters:
    - row (dict): The CSV row as a dictionary.
    - faiss_index (FAISS): The FAISS index object.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.

    Returns:
    - dict: The updated row.
    """
    try:
        restored_str = clean_restored_field(row['restored'])
        restored = ast.literal_eval(restored_str)
        # Check if 'restored' has the expected structure
        if isinstance(restored, list) and len(restored) == 2 and isinstance(restored[1], list):
            # Convert all predicted values to strings
            restored[1] = [str(item) for item in restored[1]]
            # Create a query by joining the predicted values
            query = ' '.join(restored[1])
            # Perform similarity search
            docs_with_score = similarity_search(faiss_index, query, top_k)
            if docs_with_score:
                # Extract the most similar entries
                similar_entries = [doc.page_content for doc, _ in docs_with_score]
                # Update the 'restored' field with similar entries
                new_restored_value = [restored[0], similar_entries]
                old_value = row['restored']
                row['restored'] = str(new_restored_value)
                logging.info(f"Updated 'restored' from {old_value} to {new_restored_value} for DOI {row['true']}")
        else:
            logging.warning(f"Unexpected 'restored' format in row: {row}")
    except (ValueError, SyntaxError, TypeError) as e:
        logging.error(f"Error parsing 'restored' field in row {row}: {e}")
    return row

def update_restored_with_similarity_search(file_path, attribute, top_k):
    """
    Update the 'restored' field in the CSV file using similarity search.

    Parameters:
    - file_path (str): Path to the input CSV file.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.
    """
    backup_file_path = file_path + '.bak'
    shutil.copy(file_path, backup_file_path)
    logging.info(f"Backup created at {backup_file_path}")

    try:
        faiss_index = load_faiss_index(attribute)
    except FileNotFoundError as e:
        logging.error(e)
        return

    temp_file = file_path + '.temp'
    batch_size = 100  # Adjust this based on your available memory

    with open(file_path, 'r', newline='', encoding='utf-8') as infile, \
         open(temp_file, 'w', newline='', encoding='utf-8') as outfile:
        reader = csv.DictReader(infile)
        writer = csv.DictWriter(outfile, fieldnames=reader.fieldnames, quoting=csv.QUOTE_ALL)
        writer.writeheader()

        batch = []
        for row in tqdm(reader, desc=f"Processing rows in {os.path.basename(file_path)}"):
            batch.append(row)
            if len(batch) >= batch_size:
                processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
                writer.writerows(processed_batch)
                batch = []

        if batch:
            processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
            writer.writerows(processed_batch)

    os.replace(temp_file, file_path)
    logging.info(f"Updated file saved at {file_path}")

def process_files_for_similarity_search(directory, attribute, top_k):
    """
    Process all relevant CSV files in the specified directory for similarity search.

    Parameters:
    - directory (str): Directory containing the CSV files.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.
    """
    # Updated pattern to match the new filename format
    pattern = re.compile(rf'llm_results_ft_4o_0\.8_doi_{re.escape(attribute)}_0_1st\.csv')
    files = [f for f in os.listdir(directory) if pattern.match(f)]

    if not files:
        logging.warning(f"No files found matching pattern for attribute '{attribute}' in directory '{directory}'.")
        return

    for filename in tqdm(files, desc="Processing files"):
        file_path = os.path.join(directory, filename)
        logging.info(f"Processing file: {filename} for attribute: {attribute}")
        update_restored_with_similarity_search(file_path, attribute, top_k)

# Process only the 'collectionSite' attribute with top_k=20
directory = 'llm_ft_results'
attributes_and_k = {
    'collectionSite': 20,
}

for attribute, top_k in attributes_and_k.items():
    logging.info(f"Starting similarity search processing for attribute: {attribute} with top_k: {top_k}")
    process_files_for_similarity_search(directory, attribute, top_k)

logging.info("Similarity search processing completed.")

2024-10-02 12:30:25,822 - INFO - Starting similarity search processing for attribute: collectionSite with top_k: 20
Processing files:   0%|          | 0/1 [00:00<?, ?it/s]2024-10-02 12:30:25,865 - INFO - Processing file: llm_results_ft_4o_0.8_doi_collectionSite_0_1st.csv for attribute: collectionSite
2024-10-02 12:30:25,867 - INFO - Backup created at llm_ft_results/llm_results_ft_4o_0.8_doi_collectionSite_0_1st.csv.bak
2024-10-02 12:30:26,354 - INFO - HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
2024-10-02 12:30:26,394 - INFO - Updated 'restored' from ['10.1002/cbdv.200590016', ['Fortaleza/CE']] to ['10.1002/cbdv.200590016', ['Recife/PE', 'Sao Paulo/SP', 'Rio De Janeiro/RJ', 'Sao Carlos/SP', 'Sao Sebastiao Do Passe/BA', 'Rio Claro/SP', 'Corumba/MS', 'Belem/PA', 'Manaus/AM', 'Araraquara/SP', 'Cunha/SP', 'Cuiaba/MT', 'Teodoro Sampaio/SP', 'Goiania/GO', 'Chapada Dos Guimaraes/MT', 'Campinas/SP', 'Piracicaba/SP', 'Ribeirao Preto/SP', 'Santarem/PA', 'Cordeiropol

In [7]:
#for finetuned model. name 0 1st
import os
import ast
import csv
import re
import logging
import shutil
from tqdm import tqdm
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set your OpenAI API key
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("OpenAI API key not found. Please set the 'OPENAI_API_KEY' environment variable.")

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

# Directory containing the FAISS indexes
index_directory = 'faiss_index_trained'

def load_faiss_index(attribute):
    """
    Load the FAISS index for a given attribute.

    Parameters:
    - attribute (str): The attribute name (e.g., 'name').

    Returns:
    - FAISS: The loaded FAISS index object.
    """
    index_path = os.path.join(index_directory, f'unique_{attribute}.txt.index')
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"FAISS index file for attribute '{attribute}' not found at {index_path}.")
    return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)

def similarity_search(faiss_index, query, top_k):
    """
    Perform a similarity search using the FAISS index.

    Parameters:
    - faiss_index (FAISS): The FAISS index object.
    - query (str): The query string.
    - top_k (int): Number of top similar documents to retrieve.

    Returns:
    - list of tuples: Each tuple contains (Document, score).
    """
    docs_with_score = faiss_index.similarity_search_with_score(query, k=top_k)
    return docs_with_score

def clean_restored_field(restored_str):
    """
    Clean and ensure the 'restored' field is in proper list format.

    Parameters:
    - restored_str (str): The string representation of the 'restored' field.

    Returns:
    - str: Cleaned string in list format.
    """
    restored_str = restored_str.strip()
    if not (restored_str.startswith('[') and restored_str.endswith(']')):
        restored_str = f"[{restored_str}]"
    return restored_str

def process_row(row, faiss_index, attribute, top_k):
    """
    Process a single row by performing similarity search and updating the 'restored' field.

    Parameters:
    - row (dict): The CSV row as a dictionary.
    - faiss_index (FAISS): The FAISS index object.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.

    Returns:
    - dict: The updated row.
    """
    try:
        restored_str = clean_restored_field(row['restored'])
        restored = ast.literal_eval(restored_str)
        # Check if 'restored' has the expected structure
        if isinstance(restored, list) and len(restored) == 2 and isinstance(restored[1], list):
            # Convert all predicted values to strings
            restored[1] = [str(item) for item in restored[1]]
            # Create a query by joining the predicted values
            query = ' '.join(restored[1])
            # Perform similarity search
            docs_with_score = similarity_search(faiss_index, query, top_k)
            if docs_with_score:
                # Extract the most similar entries
                similar_entries = [doc.page_content for doc, _ in docs_with_score]
                # Update the 'restored' field with similar entries
                new_restored_value = [restored[0], similar_entries]
                old_value = row['restored']
                row['restored'] = str(new_restored_value)
                logging.info(f"Updated 'restored' from {old_value} to {new_restored_value} for DOI {row['true']}")
        else:
            logging.warning(f"Unexpected 'restored' format in row: {row}")
    except (ValueError, SyntaxError, TypeError) as e:
        logging.error(f"Error parsing 'restored' field in row {row}: {e}")
    return row

def update_restored_with_similarity_search(file_path, attribute, top_k):
    """
    Update the 'restored' field in the CSV file using similarity search.

    Parameters:
    - file_path (str): Path to the input CSV file.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.
    """
    backup_file_path = file_path + '.bak'
    shutil.copy(file_path, backup_file_path)
    logging.info(f"Backup created at {backup_file_path}")

    try:
        faiss_index = load_faiss_index(attribute)
    except FileNotFoundError as e:
        logging.error(e)
        return

    temp_file = file_path + '.temp'
    batch_size = 100  # Adjust this based on your available memory

    with open(file_path, 'r', newline='', encoding='utf-8') as infile, \
         open(temp_file, 'w', newline='', encoding='utf-8') as outfile:
        reader = csv.DictReader(infile)
        writer = csv.DictWriter(outfile, fieldnames=reader.fieldnames, quoting=csv.QUOTE_ALL)
        writer.writeheader()

        batch = []
        for row in tqdm(reader, desc=f"Processing rows in {os.path.basename(file_path)}"):
            batch.append(row)
            if len(batch) >= batch_size:
                processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
                writer.writerows(processed_batch)
                batch = []

        if batch:
            processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
            writer.writerows(processed_batch)

    os.replace(temp_file, file_path)
    logging.info(f"Updated file saved at {file_path}")

def process_files_for_similarity_search(directory, attribute, top_k):
    """
    Process all relevant CSV files in the specified directory for similarity search.

    Parameters:
    - directory (str): Directory containing the CSV files.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.
    """
    # Updated pattern to match the new filename format for 'name'
    pattern = re.compile(rf'llm_results_ft_4o_0\.8_doi_{re.escape(attribute)}_0_1st\.csv$')
    files = [f for f in os.listdir(directory) if pattern.match(f)]

    if not files:
        logging.warning(f"No files found matching pattern for attribute '{attribute}' in directory '{directory}'.")
        return

    for filename in tqdm(files, desc="Processing files"):
        file_path = os.path.join(directory, filename)
        logging.info(f"Processing file: {filename} for attribute: {attribute}")
        update_restored_with_similarity_search(file_path, attribute, top_k)

# Process only the 'name' attribute with top_k=50
directory = 'llm_ft_results'
attributes_and_k = {
    'name': 50,
}

for attribute, top_k in attributes_and_k.items():
    logging.info(f"Starting similarity search processing for attribute: {attribute} with top_k: {top_k}")
    process_files_for_similarity_search(directory, attribute, top_k)

logging.info("Similarity search processing completed.")

2024-10-02 14:10:30,542 - INFO - Starting similarity search processing for attribute: name with top_k: 50
Processing files:   0%|          | 0/1 [00:00<?, ?it/s]2024-10-02 14:10:30,545 - INFO - Processing file: llm_results_ft_4o_0.8_doi_name_0_1st.csv for attribute: name
2024-10-02 14:10:30,548 - INFO - Backup created at llm_ft_results/llm_results_ft_4o_0.8_doi_name_0_1st.csv.bak
2024-10-02 14:10:31,053 - INFO - HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
2024-10-02 14:10:31,184 - INFO - Updated 'restored' from ['10.1002/cbdv.200590016', ['Atractylenolide III']] to ['10.1002/cbdv.200590016', ['3-Acetyl aleuritolic acid', 'Oleanolic acid 3-sophoroside', 'Kaempferol-3-O-rutinoside', '24-methylenecycloartan-3b-ol', "4-Hydroxy-3-(3',7'-dimethylocta-2'-E-6'-dienyl)benzoic acid", "3,5,6,7,3',4',5'-heptamethoxyflavonol", 'Procyanidin B-3', 'Ourateacatechin', "3',4'-methylenedioxy-5,5',6,7-tetramethoxyflavone", '3-O-β-D-quinovopyranosyl cincholic acid', 'Kaempfero

In [12]:
# for finetuned model. bioActivity 0 1st
import os
import ast
import csv
import re
import logging
import shutil
from tqdm import tqdm
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set your OpenAI API key
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("OpenAI API key not found. Please set the 'OPENAI_API_KEY' environment variable.")

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

# Directory containing the FAISS indexes
index_directory = 'faiss_index_trained'

def load_faiss_index(attribute):
    """
    Load the FAISS index for a given attribute.

    Parameters:
    - attribute (str): The attribute name (e.g., 'bioActivity').

    Returns:
    - FAISS: The loaded FAISS index object.
    """
    index_path = os.path.join(index_directory, f'unique_{attribute}.txt.index')
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"FAISS index file for attribute '{attribute}' not found at {index_path}.")
    return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)

def similarity_search(faiss_index, query, top_k):
    """
    Perform a similarity search using the FAISS index.

    Parameters:
    - faiss_index (FAISS): The FAISS index object.
    - query (str): The query string.
    - top_k (int): Number of top similar documents to retrieve.

    Returns:
    - list of tuples: Each tuple contains (Document, score).
    """
    docs_with_score = faiss_index.similarity_search_with_score(query, k=top_k)
    return docs_with_score

def clean_restored_field(restored_str):
    """
    Clean and ensure the 'restored' field is in proper list format.

    Parameters:
    - restored_str (str): The string representation of the 'restored' field.

    Returns:
    - str: Cleaned string in list format.
    """
    restored_str = restored_str.strip()
    if not (restored_str.startswith('[') and restored_str.endswith(']')):
        restored_str = f"[{restored_str}]"
    return restored_str

def process_row(row, faiss_index, attribute, top_k):
    """
    Process a single row by performing similarity search and updating the 'restored' field.

    Parameters:
    - row (dict): The CSV row as a dictionary.
    - faiss_index (FAISS): The FAISS index object.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.

    Returns:
    - dict: The updated row.
    """
    try:
        restored_str = clean_restored_field(row['restored'])
        restored = ast.literal_eval(restored_str)
        # Check if 'restored' has the expected structure
        if isinstance(restored, list) and len(restored) == 2 and isinstance(restored[1], list):
            # Convert all predicted values to strings
            restored[1] = [str(item) for item in restored[1]]
            # Create a query by joining the predicted values
            query = ' '.join(restored[1])
            # Perform similarity search
            docs_with_score = similarity_search(faiss_index, query, top_k)
            if docs_with_score:
                # Extract the most similar entries
                similar_entries = [doc.page_content for doc, _ in docs_with_score]
                # Update the 'restored' field with similar entries
                new_restored_value = [restored[0], similar_entries]
                old_value = row['restored']
                row['restored'] = str(new_restored_value)
                logging.info(f"Updated 'restored' from {old_value} to {new_restored_value} for DOI {row['true']}")
        else:
            logging.warning(f"Unexpected 'restored' format in row: {row}")
    except (ValueError, SyntaxError, TypeError) as e:
        logging.error(f"Error parsing 'restored' field in row {row}: {e}")
    return row

def update_restored_with_similarity_search(file_path, attribute, top_k):
    """
    Update the 'restored' field in the CSV file using similarity search.

    Parameters:
    - file_path (str): Path to the input CSV file.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.
    """
    backup_file_path = file_path + '.bak'
    shutil.copy(file_path, backup_file_path)
    logging.info(f"Backup created at {backup_file_path}")

    try:
        faiss_index = load_faiss_index(attribute)
    except FileNotFoundError as e:
        logging.error(e)
        return

    temp_file = file_path + '.temp'
    batch_size = 100  # Adjust this based on your available memory

    with open(file_path, 'r', newline='', encoding='utf-8') as infile, \
         open(temp_file, 'w', newline='', encoding='utf-8') as outfile:
        reader = csv.DictReader(infile)
        writer = csv.DictWriter(outfile, fieldnames=reader.fieldnames, quoting=csv.QUOTE_ALL)
        writer.writeheader()

        batch = []
        for row in tqdm(reader, desc=f"Processing rows in {os.path.basename(file_path)}"):
            batch.append(row)
            if len(batch) >= batch_size:
                processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
                writer.writerows(processed_batch)
                batch = []

        if batch:
            processed_batch = [process_row(r, faiss_index, attribute, top_k) for r in batch]
            writer.writerows(processed_batch)

    os.replace(temp_file, file_path)
    logging.info(f"Updated file saved at {file_path}")

def process_files_for_similarity_search(directory, attribute, top_k):
    """
    Process all relevant CSV files in the specified directory for similarity search.

    Parameters:
    - directory (str): Directory containing the CSV files.
    - attribute (str): The attribute being processed.
    - top_k (int): Number of top similar documents to retrieve.
    """
    # Updated pattern to match the new filename format
    pattern = re.compile(rf'llm_results_ft_4o_0\.8_doi_{re.escape(attribute)}_0_1st\.csv')
    files = [f for f in os.listdir(directory) if pattern.match(f)]

    if not files:
        logging.warning(f"No files found matching pattern for attribute '{attribute}' in directory '{directory}'.")
        return

    for filename in tqdm(files, desc="Processing files"):
        file_path = os.path.join(directory, filename)
        logging.info(f"Processing file: {filename} for attribute: {attribute}")
        update_restored_with_similarity_search(file_path, attribute, top_k)

# Process only the 'bioActivity' attribute with top_k=5
directory = 'llm_ft_results'
attributes_and_k = {
    'bioActivity': 5,
}

for attribute, top_k in attributes_and_k.items():
    logging.info(f"Starting similarity search processing for attribute: {attribute} with top_k: {top_k}")
    process_files_for_similarity_search(directory, attribute, top_k)

logging.info("Similarity search processing completed.")

2024-10-03 13:36:53,735 - INFO - Starting similarity search processing for attribute: bioActivity with top_k: 5
Processing files:   0%|          | 0/1 [00:00<?, ?it/s]2024-10-03 13:36:53,739 - INFO - Processing file: llm_results_ft_4o_0.8_doi_bioActivity_0_1st.csv for attribute: bioActivity
2024-10-03 13:36:53,743 - INFO - Backup created at llm_ft_results/llm_results_ft_4o_0.8_doi_bioActivity_0_1st.csv.bak
2024-10-03 13:36:54,169 - INFO - HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
2024-10-03 13:36:54,234 - INFO - Updated 'restored' from ['10.1002/cbdv.200800342', ['Cytotoxic']] to ['10.1002/cbdv.200800342', ['Cytotoxic', 'Genotoxic', 'Mutagenic', 'Anticancer', 'Antiviral']] for DOI ['10.1002/cbdv.200800342', 'Cytotoxic']
2024-10-03 13:36:54,509 - INFO - HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
2024-10-03 13:36:54,513 - INFO - Updated 'restored' from ['10.1002/hlca.200890147', ['Antifungal']] to ['10.1002/hlca.200890147', [

In [None]:
# Delete all files with _bak.csv as filename ending in the directory results/LLM_enhanced_full
backup_directory = 'results/LLM_corr_k_full'
for root, dirs, files in os.walk(backup_directory):
    for file in files:
        if file.endswith('.bak'):
            file_path = os.path.join(root, file)
            try:
                os.remove(file_path)
                logging.info(f"Deleted backup file: {file_path}")
            except Exception as e:
                logging.error(f"Error deleting file {file_path}: {e}")



In [None]:
#process doi_name files
import csv
import re
import os
import glob
import ast
import shutil

def process_csv(input_file, output_file):
    with open(input_file, 'r', newline='', encoding='utf-8') as infile, \
         open(output_file, 'w', newline='', encoding='utf-8') as outfile:
        reader = csv.reader(infile)
        writer = csv.writer(outfile, quoting=csv.QUOTE_ALL)

        for row in reader:
            if reader.line_num == 1:
                writer.writerow(row)
                continue

            true_col, restored_col, edge_type = row

            # Process the true column
            true_col = process_column(true_col)

            # Process the restored column
            restored_col = process_column(restored_col)

            writer.writerow([true_col, restored_col, edge_type])

    print(f"Processed file saved as: {output_file}")

def process_column(col):
    try:
        # Remove outer brackets and split by the first comma
        content = col.strip()[1:-1]
        doi, name = content.split(',', 1)
        
        # Clean up DOI and name
        doi = doi.strip().strip("'\"")
        name = name.strip().strip("'\"")
        
        # Escape double quotes in the name
        name = name.replace('"', '\\"')
        
        # Reconstruct the column with proper quoting
        return f'["{doi}", "{name}"]'
    except:
        # If we can't parse it, return the original
        return col

def process_all_files(input_directory, output_directory):
    # Clean the output directory
    if os.path.exists(output_directory):
        shutil.rmtree(output_directory)
    os.makedirs(output_directory)

    # Pattern for matching file names
    pattern = r'processed_llm_results_gpt4_0.8_doi_name_\d+_\d+(?:st|nd|rd|th)\.csv'
    
    # Get all matching files in the input directory
    for input_file in glob.glob(os.path.join(input_directory, '*.csv')):
        if re.match(pattern, os.path.basename(input_file)):
            output_file = os.path.join(output_directory, os.path.basename(input_file))
            print(f"Processing {input_file}...")
            process_csv(input_file, output_file)

# Usage
input_directory = 'results/LLM_corr_k_full/'
output_directory = 'results/LLM_corr_k_full/modified_names/'
process_all_files(input_directory, output_directory)

In [None]:
# wrap the results in ' '
import os
import re

def process_file(input_file, output_file):
    with open(input_file, 'r', encoding='utf-8') as infile:
        content = infile.read()
    
    def replace_function(match):
        doi = match.group(1)
        activities = match.group(2).split(', ')
        quoted_activities = []
        for i, activity in enumerate(activities):
            if i == 0:
                quoted_activities.append(f"{activity}'")
            elif i == len(activities) - 1:
                quoted_activities.append(f"'{activity}")
            else:
                quoted_activities.append(f"'{activity}'")
        return f"['{doi}', [{', '.join(quoted_activities)}]]"
    
    # Use regex to find and replace the specific part we want to change
    pattern = r"\['([^']+)', \[(.*?)\]\]"
    modified_content = re.sub(pattern, replace_function, content)
    
    with open(output_file, 'w', encoding='utf-8') as outfile:
        outfile.write(modified_content)

def process_folder(folder_path):
    for filename in os.listdir(folder_path):
        if filename.endswith('.csv'):
            input_file = os.path.join(folder_path, filename)
            output_file = os.path.join(folder_path, f"processed_{filename}")
            process_file(input_file, output_file)
            print(f"Processed {filename}")

# Specify the folder path containing the CSV files
folder_path = 'results/LLM_corr_k_full'

# Process all CSV files in the folder
process_folder(folder_path)

In [None]:
#EVAL from NatUKE
import pandas as pd
from ast import literal_eval
import numpy as np

def hits_at(k, true, list_pred):
    hits = []
    for index_t, t in enumerate(true):
        hit = False
        # get the list of predicteds that's on the second argument
        for index_lp, lp in enumerate(list_pred[index_t][1]):
            if index_lp >= k:
                break
            if t[1] == lp:
                hits.append(1)
                hit = True
                break
        if not(hit):
            hits.append(0)
    return np.mean(hits)

def mrr(true, list_pred):
    # using the first list pred to get how many there will be
    rrs = []
    for index_t, t in enumerate(true):
        # get the list of predicteds that's on the second argument
        for index_lp, lp in enumerate(list_pred[index_t][1]):
            if t[1] == lp:
                rrs.append(1/(index_lp + 1))
                break
    return np.mean(rrs)


path = 'results/LLM_corr_k_full'
file_name = "llm_results"
splits = [0.8]
#edge_groups = ['doi_name', 'doi_bioActivity', 'doi_collectionSpecie', 'doi_collectionSite', 'doi_collectionType']
edge_group = 'doi_collectionSpecie'
#algorithms = ['bert', 'deep_walk', 'node2vec', 'metapath2vec', 'regularization']
algorithms = ['gpt4']
    #'collectionSpecie': 50,
    #'collectionSite': 20,
    #'bioActivity': 5,
    #'name': 50, -- DOESNT WORK HERE
    #'collectionType': 1
k_at = [1]
dynamic_stages = ['1st', '2nd', '3rd', '4th']

# hits@k
hitsatk_df = {'k': [], 'algorithm': [], 'edge_group': [], 'split': [], 'dynamic_stage': [], 'value': []}
for algorithm in algorithms:
    for k in k_at:
        for split in splits:
            for iteration in range(10):
                for dynamic_stage in dynamic_stages:
                    restored_df = pd.read_csv("{}/processed_{}_{}_{}_{}_{}_{}.csv".format(path, file_name, algorithm, split, edge_group, iteration, dynamic_stage))
                    restored_df['true'] = restored_df['true'].apply(literal_eval)
                    restored_df['restored'] = restored_df['restored'].apply(literal_eval)
                    hitsatk_df['k'].append(k)
                    hitsatk_df['algorithm'].append(algorithm)
                    hitsatk_df['split'].append(split)
                    hitsatk_df['edge_group'].append(edge_group)
                    hitsatk_df['dynamic_stage'].append(dynamic_stage)
                    hitsatk_df['value'].append(hits_at(k, restored_df.true.to_list(), restored_df.restored.to_list()))
                        
hitsatk_df = pd.DataFrame(hitsatk_df)
hitsatk_df.to_csv('{}/metric_results/full_dynamic_hits@k_{}_{}.csv'.format(path, edge_group, file_name), index=False)
hitsatk_df_mean = hitsatk_df.groupby(by=['k', 'algorithm', 'split', 'edge_group', 'dynamic_stage'], as_index=False).mean()
hitsatk_df_std = hitsatk_df.groupby(by=['k', 'algorithm', 'split', 'edge_group', 'dynamic_stage'], as_index=False).std()
hitsatk_df_mean['std'] = hitsatk_df_std['value']
print(hitsatk_df_mean)

# mrr
mrr_df = {'algorithm': [], 'edge_group': [], 'split': [], 'dynamic_stage': [], 'value': []}
for algorithm in algorithms:
    for split in splits:
        for iteration in range(10):
            for dynamic_stage in dynamic_stages:
                restored_df = pd.read_csv("{}/processed_{}_{}_{}_{}_{}_{}.csv".format(path, file_name, algorithm, split, edge_group, iteration, dynamic_stage))
                restored_df['true'] = restored_df['true'].apply(literal_eval)
                restored_df['restored'] = restored_df['restored'].apply(literal_eval)
                mrr_df['algorithm'].append(algorithm)
                mrr_df['split'].append(split)
                mrr_df['edge_group'].append(edge_group)
                mrr_df['dynamic_stage'].append(dynamic_stage)
                mrr_df['value'].append(mrr(restored_df.true.to_list(), restored_df.restored.to_list()))
                        
mrr_df = pd.DataFrame(mrr_df)
mrr_df.to_csv('{}/metric_results/full_dynamic_mrr_{}_{}.csv'.format(path, edge_group, file_name), index=False)
mrr_df_mean = mrr_df.groupby(by=['algorithm', 'edge_group', 'split', 'dynamic_stage'], as_index=False).mean()
mrr_df_std = mrr_df.groupby(by=['algorithm', 'edge_group', 'split', 'dynamic_stage'], as_index=False).std()
mrr_df_mean['std'] = mrr_df_std['value']
print(mrr_df_mean)

# saving files
hitsatk_df_mean.to_csv('{}/metric_results/dynamic_hits@k_{}_{}.csv'.format(path, edge_group, file_name), index=False)
mrr_df_mean.to_csv('{}/metric_results/dynamic_mrr_{}_{}.csv'.format(path, edge_group, file_name), index=False)

In [None]:
#EVAL from NatUKE -- for NAME
import pandas as pd
import numpy as np
import ast
import re

def parse_list(s):
    try:
        return ast.literal_eval(s)
    except:
        # If literal_eval fails, try a regex-based approach
        match = re.match(r'\[([^]]+)\]', s)
        if match:
            return [item.strip().strip('"') for item in match.group(1).split(',')]
        return []

def hits_at(k, true, list_pred):
    hits = []
    for t, p in zip(true, list_pred):
        if isinstance(p, list) and len(p) > 1:
            pred_list = p[1]
            if isinstance(pred_list, str):
                pred_list = parse_list(pred_list)
            hit = int(t[1] in pred_list[:k])
        else:
            hit = 0
        hits.append(hit)
    return np.mean(hits)

def mrr(true, list_pred):
    rrs = []
    for t, p in zip(true, list_pred):
        if isinstance(p, list) and len(p) > 1:
            pred_list = p[1]
            if isinstance(pred_list, str):
                pred_list = parse_list(pred_list)
            try:
                rank = pred_list.index(t[1]) + 1
                rrs.append(1 / rank)
            except ValueError:
                rrs.append(0)
        else:
            rrs.append(0)
    return np.mean(rrs)

path = 'results/LLM_corr_k_full'
file_name = "llm_results"
splits = [0.8]
edge_group = 'doi_name'
algorithms = ['gpt4']
k_at = [1]
dynamic_stages = ['1st', '2nd', '3rd', '4th']

# hits@k
hitsatk_df = {'k': [], 'algorithm': [], 'edge_group': [], 'split': [], 'dynamic_stage': [], 'value': []}
# mrr
mrr_df = {'algorithm': [], 'edge_group': [], 'split': [], 'dynamic_stage': [], 'value': []}

for algorithm in algorithms:
    for k in k_at:
        for split in splits:
            for iteration in range(10):
                for dynamic_stage in dynamic_stages:
                    file_path = f"{path}/fixed_names/processed_{file_name}_{algorithm}_{split}_{edge_group}_{iteration}_{dynamic_stage}.csv"
                    print(f"Processing file: {file_path}")
                    try:
                        restored_df = pd.read_csv(file_path)
                        restored_df['true'] = restored_df['true'].apply(parse_list)
                        restored_df['restored'] = restored_df['restored'].apply(parse_list)
                        
                        hits_value = hits_at(k, restored_df.true.to_list(), restored_df.restored.to_list())
                        mrr_value = mrr(restored_df.true.to_list(), restored_df.restored.to_list())
                        
                        print(f"Hits@{k}: {hits_value}")
                        print(f"MRR: {mrr_value}")
                        
                        hitsatk_df['k'].append(k)
                        hitsatk_df['algorithm'].append(algorithm)
                        hitsatk_df['split'].append(split)
                        hitsatk_df['edge_group'].append(edge_group)
                        hitsatk_df['dynamic_stage'].append(dynamic_stage)
                        hitsatk_df['value'].append(hits_value)

                        mrr_df['algorithm'].append(algorithm)
                        mrr_df['split'].append(split)
                        mrr_df['edge_group'].append(edge_group)
                        mrr_df['dynamic_stage'].append(dynamic_stage)
                        mrr_df['value'].append(mrr_value)
                    except Exception as e:
                        print(f"Error processing file {file_path}: {e}")
                        
hitsatk_df = pd.DataFrame(hitsatk_df)
hitsatk_df.to_csv('{}/metric_results/full_dynamic_hits@k_{}_{}.csv'.format(path, edge_group, file_name), index=False)
hitsatk_df_mean = hitsatk_df.groupby(by=['k', 'algorithm', 'split', 'edge_group', 'dynamic_stage'], as_index=False).mean()
hitsatk_df_std = hitsatk_df.groupby(by=['k', 'algorithm', 'split', 'edge_group', 'dynamic_stage'], as_index=False).std()
hitsatk_df_mean['std'] = hitsatk_df_std['value']
print(hitsatk_df_mean)

mrr_df = pd.DataFrame(mrr_df)
mrr_df.to_csv('{}/metric_results/full_dynamic_mrr_{}_{}.csv'.format(path, edge_group, file_name), index=False)
mrr_df_mean = mrr_df.groupby(by=['algorithm', 'edge_group', 'split', 'dynamic_stage'], as_index=False).mean()
mrr_df_std = mrr_df.groupby(by=['algorithm', 'edge_group', 'split', 'dynamic_stage'], as_index=False).std()
mrr_df_mean['std'] = mrr_df_std['value']
print(mrr_df_mean)

# saving files
hitsatk_df_mean.to_csv('{}/metric_results/dynamic_hits@k_{}_{}.csv'.format(path, edge_group, file_name), index=False)
mrr_df_mean.to_csv('{}/metric_results/dynamic_mrr_{}_{}.csv'.format(path, edge_group, file_name), index=False)


In [8]:
#finetune name fix
import os
import csv
import re
import ast
import shutil
from glob import glob
from tqdm import tqdm
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def fix_quotes(s):
    """
    Replace single quotes with double quotes, but not within words.
    """
    return re.sub(r"(?<!\w)'|'(?!\w)", '"', s)

def process_true_restored_field(cell):
    """
    Process the 'true' or 'restored' field to ensure it's a properly formatted list of strings.
    """
    try:
        # Attempt to parse the cell as a Python list
        parsed = ast.literal_eval(cell)
        if isinstance(parsed, list):
            # Ensure all elements are strings
            fixed = [str(item) for item in parsed]
            return str(fixed)
        else:
            # If not a list, treat it as a single string
            return f'["{str(parsed)}"]'
    except (ValueError, SyntaxError):
        # If parsing fails, attempt to fix quotes and reformat
        fixed_str = fix_quotes(cell)
        # Remove any leading/trailing whitespace and ensure it starts and ends with brackets
        fixed_str = fixed_str.strip()
        if not (fixed_str.startswith('[') and fixed_str.endswith(']')):
            fixed_str = f'[{fixed_str}]'
        # Split by comma, assuming the first comma separates DOI and name
        try:
            content = fixed_str[1:-1]
            doi, name = content.split(',', 1)
            doi = doi.strip().strip('"')
            name = name.strip().strip('"')
            # Escape any existing double quotes in the name
            name = name.replace('"', '\\"')
            return f'["{doi}", "{name}"]'
        except Exception as e:
            logging.error(f"Failed to process cell '{cell}': {e}")
            return cell  # Return the original cell if all else fails

def process_csv(input_file, output_file):
    """
    Process a single CSV file, fixing the 'true' and 'restored' columns.
    """
    with open(input_file, 'r', newline='', encoding='utf-8') as infile, \
         open(output_file, 'w', newline='', encoding='utf-8') as outfile:
        reader = csv.reader(infile)
        writer = csv.writer(outfile, quoting=csv.QUOTE_ALL)
        
        for row in reader:
            if reader.line_num == 1:
                # Write header as is
                writer.writerow(row)
                continue
            
            try:
                true_col, restored_col, edge_type = row
            except ValueError:
                logging.warning(f"Unexpected number of columns in row {reader.line_num}: {row}")
                writer.writerow(row)
                continue
            
            # Process 'true' and 'restored' columns
            true_col_fixed = process_true_restored_field(true_col)
            restored_col_fixed = process_true_restored_field(restored_col)
            
            writer.writerow([true_col_fixed, restored_col_fixed, edge_type])
    
    logging.info(f"Processed file saved as: {output_file}")

def process_all_files(input_directory, output_directory):
    """
    Process all relevant CSV files in the input directory.
    """
    # Clean the output directory
    if os.path.exists(output_directory):
        shutil.rmtree(output_directory)
    os.makedirs(output_directory, exist_ok=True)
    
    # Define the pattern for 'name' CSV files
    pattern = r'llm_results_ft_4o_0\.8_doi_name_0_1st\.csv$'
    
    # Get all matching files in the input directory
    files = glob(os.path.join(input_directory, '*.csv'))
    name_files = [f for f in files if re.search(pattern, os.path.basename(f))]
    
    if not name_files:
        logging.warning(f"No 'name' CSV files found in directory '{input_directory}'.")
        return
    
    for input_file in tqdm(name_files, desc="Processing 'name' CSV files"):
        filename = os.path.basename(input_file)
        output_file = os.path.join(output_directory, filename)
        logging.info(f"Processing {filename}...")
        process_csv(input_file, output_file)

# Usage
if __name__ == "__main__":
    input_directory = 'llm_ft_results'  # Directory containing the original 'name' CSV files
    output_directory = 'llm_ft_results/processed_names'  # Directory to save the processed CSV files
    process_all_files(input_directory, output_directory)

Processing 'name' CSV files:   0%|          | 0/1 [00:00<?, ?it/s]2024-10-02 14:13:15,481 - INFO - Processing llm_results_ft_4o_0.8_doi_name_0_1st.csv...
2024-10-02 14:13:15,534 - INFO - Processed file saved as: llm_ft_results/processed_names/llm_results_ft_4o_0.8_doi_name_0_1st.csv
Processing 'name' CSV files: 100%|██████████| 1/1 [00:00<00:00, 18.56it/s]


In [10]:
# eval for name
import pandas as pd
import numpy as np
import ast
import re
import os
from tqdm import tqdm

def parse_list(s):
    """
    Parse a string representation of a list into an actual Python list.
    """
    try:
        return ast.literal_eval(s)
    except (ValueError, SyntaxError):
        # Attempt to fix common formatting issues
        s = s.strip()
        if not (s.startswith('[') and s.endswith(']')):
            s = f'[{s}]'
        try:
            return ast.literal_eval(s)
        except:
            # As a last resort, split by comma
            return [item.strip().strip('"') for item in s.strip('[]').split(',')]

def hits_at(k, true, list_pred):
    """
    Calculate the Hits@k metric.
    """
    hits = []
    for t, p in zip(true, list_pred):
        if isinstance(p, list) and len(p) >= 2:
            pred_list = p[1]
            if isinstance(pred_list, str):
                pred_list = parse_list(pred_list)
            hit = int(t[1] in pred_list[:k])
        else:
            hit = 0
        hits.append(hit)
    return np.mean(hits)

def mrr(true, list_pred):
    """
    Calculate the Mean Reciprocal Rank (MRR) metric.
    """
    rrs = []
    for t, p in zip(true, list_pred):
        if isinstance(p, list) and len(p) >= 2:
            pred_list = p[1]
            if isinstance(pred_list, str):
                pred_list = parse_list(pred_list)
            try:
                rank = pred_list.index(t[1]) + 1
                rrs.append(1 / rank)
            except ValueError:
                rrs.append(0)
        else:
            rrs.append(0)
    return np.mean(rrs)

def evaluate_attribute(input_file, output_dir, attribute):
    """
    Evaluate the specified attribute using Hits@k and MRR metrics.
    """
    logging.info(f"Starting evaluation for attribute: {attribute}")
    
    try:
        restored_df = pd.read_csv(input_file)
    except FileNotFoundError:
        logging.error(f"The file {input_file} does not exist. Please check the path.")
        return
    except Exception as e:
        logging.error(f"Error reading {input_file}: {e}")
        return

    # Ensure 'true' and 'restored' columns exist
    if 'true' not in restored_df.columns or 'restored' not in restored_df.columns:
        logging.error(f"The input CSV must contain 'true' and 'restored' columns.")
        return

    # Parse 'true' and 'restored' columns
    restored_df['true'] = restored_df['true'].apply(parse_list)
    restored_df['restored'] = restored_df['restored'].apply(parse_list)
    
    # Extract lists
    true_values = restored_df['true'].to_list()  # List of tuples: [(doi, true_value), ...]
    predicted_values = restored_df['restored'].to_list()  # List of lists: [[doi, [pred1, pred2, ...]], ...]
    
    # Evaluation parameters
    k_at = [1, 3, 5, 10,25,50]  # You can adjust or extend these values
    
    # Initialize metrics storage
    hitsatk_records = []
    mrr_records = []
    
    # Calculate metrics
    for k in k_at:
        mean_hits = hits_at(k, true_values, predicted_values)
        hitsatk_records.append({
            'k': k,
            'metric': 'hits@k',
            'value': mean_hits
        })
        logging.info(f"Hits@{k}: {mean_hits}")
    
    mean_mrr = mrr(true_values, predicted_values)
    mrr_records.append({
        'metric': 'mrr',
        'value': mean_mrr
    })
    logging.info(f"MRR: {mean_mrr}")
    
    # Convert to DataFrames
    hitsatk_df = pd.DataFrame(hitsatk_records)
    mrr_df = pd.DataFrame(mrr_records)
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save results
    hitsatk_output_path = os.path.join(output_dir, f'hits@k_{attribute}_evaluation.csv')
    mrr_output_path = os.path.join(output_dir, f'mrr_{attribute}_evaluation.csv')
    
    hitsatk_df.to_csv(hitsatk_output_path, index=False)
    mrr_df.to_csv(mrr_output_path, index=False)
    
    logging.info(f"Hits@k results saved to {hitsatk_output_path}")
    logging.info(f"MRR results saved to {mrr_output_path}")

# Usage
if __name__ == "__main__":
    import logging
    
    # Set up logging
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    
    # Define paths
    path = 'llm_ft_results'
    file_name = "llm_results_ft_4o_0.8_doi_name_0_1st.csv"
    name_csv_path = os.path.join(path, file_name)
    output_dir = 'ft_evaluation_results'
    attribute = 'name'  # Change to 'name' when processing 'name' attribute
    
    # Evaluate the attribute
    evaluate_attribute(name_csv_path, output_dir, attribute)

2024-10-02 14:15:47,622 - INFO - Starting evaluation for attribute: name
2024-10-02 14:15:47,675 - INFO - Hits@1: 0.13031161473087818
2024-10-02 14:15:47,696 - INFO - Hits@3: 0.2521246458923513
2024-10-02 14:15:47,719 - INFO - Hits@5: 0.3002832861189802
2024-10-02 14:15:47,744 - INFO - Hits@10: 0.36827195467422097
2024-10-02 14:15:47,766 - INFO - Hits@25: 0.42209631728045327
2024-10-02 14:15:47,787 - INFO - Hits@50: 0.45042492917847027
2024-10-02 14:15:47,806 - INFO - MRR: 0.20809734957332549
2024-10-02 14:15:47,808 - INFO - Hits@k results saved to evaluation_results/hits@k_name_evaluation.csv
2024-10-02 14:15:47,809 - INFO - MRR results saved to evaluation_results/mrr_name_evaluation.csv
