In [None]:
import sys
import os
current_dir = os.getcwd()

sys.path.append(os.path.dirname(current_dir))

from utils import setup_api_key

setup_api_key(file_path='../../config.json')

In [None]:
import os
from datasets import load_from_disk
from tqdm import tqdm

            
from typing import List
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
import json

def get_model_id(model_type, run_name, project_name, checkpoint_id):
    return os.path.join(model_type, "model_output", project_name, run_name, checkpoint_id)

project_config = {
    "survey-json": {
        "project_name": "survey-json-model-inst",
        "train_dataset_path": "../datasets/survey_json_datasets_instruction_train",
        "test_dataset_path": "../datasets/survey_json_datasets_instruction_test",
    },
    "schema": {
        "project_name": "schema-model-inst",
        "train_dataset_path": "../datasets/schema_datasets/schema_data_train",
        "test_dataset_path": "../datasets/schema_datasets/schema_data_test"
    },
    "paraloq": {
        "project_name": "paraloq-model-inst",
        "train_dataset_path": "../datasets/paraloq/paraloq_data_train",
        "test_dataset_path": "../datasets/paraloq/paraloq_data_test"
    },
    "nous": {
        "project_name": "nous-model-inst",
        "train_dataset_path": "../datasets/nous/nous_data_train",
        "test_dataset_path": "../datasets/nous/nous_data_test"
    }
}

def load_project(project="schema"):
    test_dataset = load_from_disk(project_config[project]["test_dataset_path"])
    train_dataset = load_from_disk(project_config[project]["train_dataset_path"])
    return test_dataset, train_dataset

def get_result(project= "schema"):
    test_dataset, train_dataset = load_project(project=project)
    example_size = 3

    examples = []
    for data in train_dataset:
        instruction, response = data["text"].split("[/INST]")
        instruction = instruction.replace("<s>[INST]", "").strip()
        if project == "schema":
            instruction = instruction.replace("Convert the raw data to ld+json format.", "").strip()
        example_size -= 1
        examples.append({"instruction": instruction, "response": response})
        if example_size == 0:
            break

    task = ""
    if project == "schema":
        task = "Convert the raw data to ld+json format."
    elif project == "survey-json":
        task = "Convert the question list to survey json."
    elif project == "paraloq":
        task = "Generate the structured response for the given query."
    elif project == "nous":
        task = "Generate the structured response for the given query."
        
    pre_prompt = f"Given few examples of instructions and responses from the training dataset. The task is to generate a response for the given instruction. {task}\n\n"

    for i, example in enumerate(examples):
        pre_prompt += f"Example {i+1}:\nInstruction: {example['instruction']}\nResponse: {example['response']}\n\n"

    test_data = []

    for data in tqdm(test_dataset):
        instruction, response = data["text"].split("[/INST]")
        instruction = instruction.replace("<s>[INST]", "").strip()
        prompt = f"{pre_prompt}\nInstruction: {instruction}\nResponse:"
        test_data.append({"prompt": prompt, "response": response})

    chain_model = ChatOpenAI(temperature=0)

    generated_responses = []
    actual_responses = []

    for data in tqdm(test_data):
        response = chain_model.invoke(data['prompt'])
        generated_responses.append(response.content)
        actual_responses.append(data['response'].strip())

    export_date = {
        "generated_responses": generated_responses,
        "actual_responses": actual_responses
    }

    # write to json file
    with open(f'./{project}_instruction_generation.json', 'w') as f:
        json.dump(export_date, f)

    return data

import re
import json

def extract_json(sample):
    match = re.search(r'<script type="application/ld\+json">\s*(.*?)\s*</script>', sample, re.DOTALL)
    if match:
        json_str = match.group(1)
        return json.loads(json_str)
    else:
        return None