In [None]:
import concurrent.futures
import csv
import json

import IPython.display
import ipywidgets
import openai
import tqdm


csv_file = ipywidgets.widgets.Text(
    value="input.csv",
    description="CSV File: ",
    layout=ipywidgets.Layout(width="auto"),
)

from_columns = ipywidgets.widgets.Text(
    value="question,context",
    description="From Columns: ",
    layout=ipywidgets.Layout(width="auto"),
)

to_column = ipywidgets.widgets.Text(
    value="answer",
    description="To Column: ",
)

prompt_file = ipywidgets.widgets.Text(
    value="",
    description="Prompt File: ",
    placeholder="(Optional) Path to prompt template file",
    layout=ipywidgets.Layout(width="auto"),
)

model = ipywidgets.widgets.Dropdown(
    options=["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"],
    value="gpt-4o",
    description="Model: ",
)

concurrency = ipywidgets.widgets.IntSlider(
    min=1,
    max=20,
    step=1,
    value=5,
    description="Concurrency: ",
)

overwrite = ipywidgets.widgets.Checkbox(
    value=False,
    description="Overwrite existing values",
)

append_by = ipywidgets.widgets.Text(
    value="",
    description="Append By: ",
    placeholder="(Optional) Delimiter to append results",
)

exclude_unresolved = ipywidgets.widgets.Checkbox(
    value=False,
    description="Exclude unresolved results",
)

output_file = ipywidgets.widgets.Text(
    value="output.csv",
    description="Output File: ",
)

output = ipywidgets.widgets.Output()

button = ipywidgets.widgets.Button(
    description="Process CSV",
)


def prompt_by_openai(client, model_name, prompt):
    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {
                "role": "user",
                "content": prompt,
            }
        ],
        response_format={
            "type": "json_schema",
            "json_schema": {
                "name": "structured_result",
                "description": "Response with result and resolved status",
                "schema": {
                    "type": "object",
                    "properties": {
                        "result": {
                            "type": "string",
                            "description": "The generated result",
                        },
                        "resolved": {
                            "type": "boolean",
                            "description": "Whether the generated result is resolved user task or not",
                        },
                    },
                    "required": ["result", "resolved"],
                    "additionalProperties": False,
                },
                "strict": True,
            },
        },
    )
    
    return json.loads(response.choices[0].message.content)


def process_record(client, record, headers, from_cols, to_col, prompt_template, model_name, should_overwrite, append_delimiter, should_exclude_unresolved):
    if not should_overwrite and record.get(to_col, "") != "":
        return record
    
    parameters = [record.get(col, "") for col in from_cols]
    prompt = prompt_template % tuple(parameters)
    
    try:
        llm_response = prompt_by_openai(client, model_name, prompt)
        
        if should_exclude_unresolved and not llm_response.get("resolved", False):
            return record
        
        if append_delimiter == "":
            record[to_col] = llm_response["result"]
        else:
            existing = record.get(to_col, "")
            record[to_col] = append_delimiter.join([existing, llm_response["result"]])
        
    except Exception as e:
        print(f"Error processing record: {e}")
    
    return record


@output.capture()
def process_csv(b):
    b.disabled = True
    
    try:
        client = openai.OpenAI(
            base_url="http://cortex-api.cortex-api.svc.cluster.local:8080/v1",
            api_key=open("/var/run/secrets/kubernetes.io/serviceaccount/token", "r", encoding="utf-8").read(),
        )
        
        with open(csv_file.value, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            headers = reader.fieldnames
            records = list(reader)
        
        from_cols = [col.strip() for col in from_columns.value.split(",")]
        to_col = to_column.value.strip()
        
        if prompt_file.value != "":
            with open(prompt_file.value, "r", encoding="utf-8") as f:
                prompt_template = f.read()
        else:
            prompt_template = f"You are an AI assistant. Given `{', '.join(from_cols)}`, your task is to generate `{to_col}`.\n"
            for col in from_cols:
                prompt_template += f"{col}: %s\n"
            prompt_template += f"{to_col}:"
        
        results = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency.value) as executor:
            futures = [
                executor.submit(
                    process_record,
                    client,
                    record,
                    headers,
                    from_cols,
                    to_col,
                    prompt_template,
                    model.value,
                    overwrite.value,
                    append_by.value,
                    exclude_unresolved.value,
                )
                for record in records
            ]
            
            for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(records), desc="Processing records"):
                results.append(future.result())
        
        with open(output_file.value, "w", encoding="utf-8", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=headers)
            writer.writeheader()
            writer.writerows(results)
        
        print(f"Successfully processed {len(results)} records and saved to {output_file.value}")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        b.disabled = False


button.on_click(process_csv)

IPython.display.display(csv_file, from_columns, to_column, prompt_file, model, concurrency, overwrite, append_by, exclude_unresolved, output_file, button, output)