In [None]:
import os, json, time, pickle
import numpy as np
import random
import hashlib, uuid
from llama_index.core import Document
from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser

class SyntheticDataGenerator:
    """
    This class generates synthetic data from a list of text documents and query prompts.
    The synthetic data is generated by querying a language model (LLM) with a prompt that includes the text from the document.
    The LLM generates a response, which is then parsed to extract the synthetic data.
    The synthetic data is saved in a json file.
    """

    #TODO: add splitter_arguments and possibly option for sentence_splitter
    def __init__(self, embed_model, llm, query_list, max_gen_attempts=2, output_dir="./synth_dataset"):
        self.embed_model = embed_model
        self.llm = llm
        self.query_list = query_list # different types of queries you want to generate for each text chunk
        self.output_dir = output_dir
        self.parser = SentenceSplitter(chunk_size=500, chunk_overlap=50) if embed_model is None else SemanticSplitterNodeParser(buffer_size=1, embed_model=embed_model)
        self._ensure_output_dir_exists()
        self.max_gen_attempts = max_gen_attempts # number of times to try generating and successfully parsing a response for a text chunk
        self.parse_attempt_counters = {} # keeps track of the number of times we've tried to parse a response for a text chunk

    def _ensure_output_dir_exists(self):
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

    def generate_random_hash(self):
        random_uuid = uuid.uuid4()
        hash_object = hashlib.sha256(random_uuid.bytes)
        return hash_object.hexdigest()

    def generate_synthetic_data(self, documents, num_generate=1, file_name="nodes.pkl"):
        """
        documents is a list of text. Each text is a document.
        num_generate is the number of synthetic examples to generate for text chunk.
        """

        print("Parsing documents...")
        # the corpus will be the same for all tasks, only need to save it once
        if os.path.exists(os.path.join(self.output_dir, file_name)):
            with open(os.path.join(self.output_dir, file_name), 'rb') as f:
                nodes = pickle.load(f)
        else:
            documents = [Document(doc_id=i, text=doc) for i, doc in enumerate(documents)]
            nodes = self.parser.get_nodes_from_documents(documents, show_progress=False)
            with open(os.path.join(self.output_dir, file_name), 'wb') as f:
                pickle.dump(nodes, f)
        
        # corpus.json is the first of two files needed for embedding fine tuning
        if not os.path.exists(os.path.join(self.output_dir, f"corpus.json")):
            # node.node_id is the connection between the node and the generated data
            corpus = {node.node_id: node.text for node in nodes}
            with open(os.path.join(self.output_dir, f"corpus.json"), 'w', encoding='utf-8') as f:
                json.dump(corpus, f, ensure_ascii=False, indent=2)
        else:
            with open(os.path.join(self.output_dir, f"corpus.json"), 'r', encoding='utf-8') as f:
                corpus = json.load(f)
        print(f"Number of nodes: {len(nodes)}")

        # for each task in the query list, generate synthetic data from each node
        # a task is a specific type of question or instruction, e.g. "Generate a query from the text in the form of a socratic question"
        # the number of tasks depends on your final goal for the LLM
        for task in self.query_list:
            self._process_task(task, nodes, num_generate, corpus)

    def _process_task(self, task, nodes, num_generate, corpus):
        print(f"  Processing {task}...")
        collection = {}
        # this is the second type of two files needed for embedding training
        # you will get a json file for each task
        dataset_output_path = os.path.join(self.output_dir, f"dataset_{task}.json")
        
        # process each node for one task
        for node_counter, node in enumerate(nodes):
            self._process_one_node(task, node, num_generate, corpus, collection)
            # periodically save the generated data to file, and the last batch
            if node_counter % max(5, int(len(nodes) / 10)) == 0 or node_counter == len(nodes) - 1:
                self._save_to_file(dataset_output_path, collection)
                # print(f"    Processed {node_counter+1}/{len(nodes)} nodes...")
                # reset collection for next batch to save memory
                collection = {}

    def task_parser(self, response):
        # this method parses the response from the llm, and returns a list of strings
        # the len of the list should be equal to the number of examples you want to generate per text chunk
        raise NotImplementedError("task_parser method must be implemented in a subclass")

    def _process_one_node(self, task, node, num_generate, corpus, collection):
        # Reset counter for new node
        node_key = f"{task}_{node.node_id}"
        self.parse_attempt_counters[node_key] = 0
        
        instruct_prompt = task
        # num_generate is a parameter in the prompt to tell the LLM how many examples from the text chunk to generate
        # This is so that the LLM doesn't generate the same response num_generate times for each text chunk
        prompt = instruct_prompt.format(text=node.text, num_generate=num_generate)

        # Loop until we get a successful parsing of the response or reach max attempts
        # this is a safe guard agains the LLM not generating a response that can be parsed
        while True:
            response = self.llm.complete(prompt)
            clean_response = self.task_parser(response.text)
            if clean_response == []:
                if self._handle_failed_parse(task, node):
                    break
                continue
            else:
                self._handle_successful_parse(clean_response, node, task, corpus, collection)
                break

    def _handle_failed_parse(self, task, node):
        node_key = f"{task}_{node.node_id}"
        self.parse_attempt_counters[node_key] += 1
        
        if self.parse_attempt_counters[node_key] == self.max_gen_attempts:
            print(f"    Failed to parse response for node id {node.node_id}...")
            return True
        return False

    def _handle_successful_parse(self, clean_response, node, task, corpus, collection):
        if isinstance(clean_response, str):
            clean_response = [clean_response]

        # each generated example will have a unique hash
        hashes = [self.generate_random_hash() for _ in range(len(clean_response))]
        
        corpus[node.node_id] = node.text
        for i in range(len(hashes)):
            collection[hashes[i]] = {
                'response': clean_response[i],
                'hash_id': hashes[i],
                'relevant_doc': node.node_id, # this is the connection between the generated data and the corpus.json
                # 'task': task # this may not be important
            }

    def _save_to_file(self, file_path, data):
        # Load existing data if file exists
        existing_data = {}
        if os.path.exists(file_path):
            with open(file_path, 'r', encoding='utf-8') as f:
                existing_data = json.load(f)
        
        # append new data to existing data
        merged_data = {**existing_data, **data}
        
        # Save merged data back to file
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(merged_data, f, ensure_ascii=False, indent=2)

