In [None]:
import openai
import os
import pandas as pd
import numpy as np
import logging
import json

with open('config.json', 'r') as f:
    config = json.load(f)

path = config['working_dir']
output_dir = os.path.join(path,'output_embedding_openai') # Directory di output
os.makedirs(output_dir, exist_ok=True)

logging.basicConfig(filename=os.path.join(output_dir,'output_openai.log'), 
                    level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')

def get_embedding(text, model="text-embedding-3-small"):
    # print(text)
    max_length = 8191
    text_chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)] if isinstance(text, str) else []
    embeddings = []

    if not text_chunks:
        # evito di fare chiamate inutili se la sequenza è vuota
        return None

    for chunk in text_chunks:
        try:
            response = openai.embeddings.create(
                input=chunk,
                model=model,
                # dimensions=100 #TODO: imposta la dimensione a piacimento
            )
            embedding = response.data[0].embedding
            embeddings.append(embedding)
        except Exception as e:
            print(f"Errore nel calcolo di un frammento della sequenza: {e}")
            return None

    if not embeddings:
        logging.info(f"Errore nel calcolo degli embedding!!!")
        return None

    # Calculate the mean of the embeddings
    mean_embedding = np.mean(embeddings, axis=0).tolist()
    return mean_embedding

openai.api_key = config['openai_api_key']
print(openai.api_key)

# CARICA IL DATASET
file_path = config['nodes_file_path']
df = pd.read_csv(file_path, sep="\t")
print(df.shape)
logging.info(df.shape)

# FILTRA PER TIPO
df_text = df[df["type"].isin(["Phenotype", "Disease", "Genomic feature"])]
print(df_text.shape)
logging.info(df_text.shape)
df_sequence = df[df["type"].isin(["Gene", "miRNA"])]
print(df_sequence.shape)
logging.info(df_sequence.shape)

print(df_text.shape[0]+df_sequence.shape[0])
logging.info(df_text.shape[0]+df_sequence.shape[0])

df_text.index = df_text["name"]
df_sequence.index = df_sequence["name"]

# Funzione per processare il DataFrame in blocchi
def process_in_batches(df, embedding_column, get_embedding_func, output_file, batch_size=5):
    # Carica il checkpoint se esiste
    if os.path.exists(output_file):
        checkpoint_df = pd.read_csv(output_file, sep="\t", index_col=0)
        print(f"Checkpoint trovato. Riprendo da riga {len(checkpoint_df)}")
        logging.info(f"Checkpoint trovato. Riprendo da riga {len(checkpoint_df)}")
        start_index = len(checkpoint_df)
    else:
        print("Nessun checkpoint trovato. Inizio dall'inizio.")
        logging.info("Nessun checkpoint trovato. Inizio dall'inizio.")
        start_index = 0
        # Crea il file di output con l'intestazione se non esiste
        with open(output_file, 'w') as f:
            f.write("name\ttype\tlen_seq\tembedding\n")

    processed_count=0
    # Itera sul DataFrame in blocchi
    for i in range(start_index, len(df), batch_size):
        batch = df.iloc[i:i + batch_size]
        print(f"Processando righe da {i} a {i + batch_size - 1}...")
        logging.info(f"Processando righe da {i} a {i + batch_size - 1}...")

        with open(output_file, 'a') as f:
            for idx, row in batch.iterrows():
                sequence_name = row['name']
                sequence_type = row['type']
                sequence_len = len(row[embedding_column]) if isinstance(row[embedding_column], str) else 0

                processed_count+=1
                if sequence_len == 0:
                    logging.warning(f"Processing {processed_count}/{len(df)} ({sequence_type}) len {sequence_len} - {sequence_name} -- Embedding non calcolato per la sequenza {sequence_name} di lunghezza {str(sequence_len)}.")
                else:
                    logging.info(f"Processing {processed_count}/{len(df)} ({sequence_type}) len {sequence_len} - {sequence_name}")

                embedding = get_embedding_func(row[embedding_column])

                if sequence_len != 0 and embedding is None:
                    logging.warning(f"Embedding non calcolato per la sequenza {sequence_name} di lunghezza {str(sequence_len)}.")
                    continue

                # Scrivi l'embedding direttamente nel file
                f.write(f"{sequence_name}\t{sequence_type}\t{sequence_len}\t{embedding}\n")

    print("Processing completato.")
    logging.info("Processing completato.")


# GENERA GLI EMBEDDING IN BATCH
process_in_batches(df_text, "Description", get_embedding, os.path.join(output_dir, "embedded_text.tsv"))
process_in_batches(df_sequence, "Sequence", get_embedding, os.path.join(output_dir,"embedded_sequence.tsv"))