In [1]:
import asyncio
import re
import os
import pandas as pd
from openai import AsyncOpenAI
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

class CoT:
    def __init__(self, api_key=None, base_url=None, model="gpt-4o-mini", system_prompt=None, temperature=0.2):
        # Use environment variables if no parameters are provided
        self.api_key = api_key or os.getenv("API_KEY")
        self.base_url = base_url or os.getenv("BASE_URL")
        self.model = model
        self.system_prompt = system_prompt or "Think step by step. Based on the user's query, generate Python code using `matplotlib.pyplot` and 'seaborn' to create the requested plot. Ensure the code is outputted within the Markdown format like ```python\n...```."
        self.temperature = temperature
        # Initialize OpenAI client
        self.client = AsyncOpenAI(
            api_key=self.api_key,
            base_url=self.base_url
        )

    def describe_data(self, data_path):

        return str(data_path)

    async def call_openai_api(self, user_query, data_description):
        # Retry indefinitely until successful
        while True:
            try:
                # Concatenate the data description and user query
                # full_query = f"Data Description: {data_description}\nUser Query: {user_query}"
                full_query = f"Data Path: {data_description}\nUser Query: {user_query}"
                
                # print(full_query)
                
                # Call OpenAI API to get the response
                response = await self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": full_query}
                    ],
                    temperature=self.temperature
                )
                prompt_tokens = response.usage.prompt_tokens
                completion_tokens = response.usage.completion_tokens
                total_tokens = response.usage.total_tokens
            

                # Extract the code content from the response
                response_text = response.choices[0].message.content
                match = re.search(r'```python\n(.*?)```', response_text, flags=re.DOTALL)
                print({
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "total_tokens": total_tokens
                })
                
                if match:
                    return match.group(1).strip()
                else:
                    return None
            except Exception as e:
                print(f"API call failed with error: {e}. Retrying...")

    async def get_code_content(self, user_query, data_path_list, img_file_path=None):
        # print(f'사용하는 모델 명 : {self.model}')        
        data_description = []
        # data_description = ''
        if data_path_list:
            for data_path in data_path_list:
                root_path = os.path.abspath(os.path.dirname(os.curdir))
                data_path = os.path.normpath(os.path.join(root_path, data_path))
            data_description.append(data_path)
        else:
            data_description =  None
        # Call API and return the generated code
        code_content = await self.call_openai_api(user_query, data_description)
        return code_content

In [2]:
dataset_path = r"..\..\dataset\matplotbench_data.csv"

dataset_df = pd.read_csv(dataset_path)

In [3]:
sample = [10, 20, 30, 40, 50]
query_list = dataset_df.loc[sample,'simple_instruction']

In [4]:
zs = CoT()
# user_query = "Scatter plot을 그리고 싶어. X축과 Y축 데이터를 받아서 시각화해줘."
for user_query in query_list:
    code_content = await zs.get_code_content(user_query, None)
# print("🔹 생성된 코드:\n", code_content)

{'prompt_tokens': 425, 'completion_tokens': 516, 'total_tokens': 941}
{'prompt_tokens': 201, 'completion_tokens': 414, 'total_tokens': 615}
{'prompt_tokens': 235, 'completion_tokens': 778, 'total_tokens': 1013}
{'prompt_tokens': 274, 'completion_tokens': 568, 'total_tokens': 842}
{'prompt_tokens': 314, 'completion_tokens': 618, 'total_tokens': 932}


prompt_tokens: 289.8

completion_tokens: 578.8

total_tokens: 868.6