In [None]:
import os
import argparse

import pandas as pd
from tqdm import tqdm
from dotenv import load_dotenv
from openai import OpenAI
from prompts import baseline_prompt
from promptTemplate import process
import re

def extract_answer(text):
    # [Answer]: 뒤에 오는 줄의 맨 앞~줄 끝까지 추출
    match = re.search(r"\[Answer\]:\s*(.*)", text)
    if match:
        return match.group(1).strip()
    return ""


def main():

    args = {"input":"data/train_dataset.csv","model":"solar-pro2","output":"eval_submission.csv"}
    

    # Load environment variables
    load_dotenv()
    
    # Load data
    df = pd.read_csv(args['input'])
    
    if "err_sentence" not in df.columns:
        raise ValueError("Input CSV must contain 'err_sentence' column")

    # Setup Upstage client
    api_key = os.getenv("UPSTAGE_API_KEY")
    if not api_key:
        raise ValueError("UPSTAGE_API_KEY not found in environment variables")
    
    print(f"Model: {args['model']}")
    print(f"Output: {args['output']}")

    err_sentences = []
    cor_sentences = []
    
    # Process each sentence
    for row in tqdm(df.itertuples(), total=len(df), desc="Generating"):

        text = row.err_sentence
        
        err_sentences.append(text)
        
        try:
            resp = process(text)
            corrected = extract_answer(resp)
            cor_sentences.append(corrected)
            
        except Exception as e:
            print(f"Error processing: {text[:50]}... - {e}")
            cor_sentences.append(text)  # fallback to original

    # Save results with required column names
    out_df = pd.DataFrame({"err_sentence": err_sentences, "cor_sentence": cor_sentences})
    out_df.to_csv(args['output'], index=False)
    print(f"Wrote {len(out_df)} rows to {args['output']}")


In [5]:

if __name__ == "__main__":

   
    main()


Model: solar-pro2
Output: eval_submission.csv


Generating: 100%|██████████| 254/254 [10:49<00:00,  2.56s/it]

Wrote 254 rows to eval_submission.csv



