In [21]:
import os,sys
sys.path.insert(0,'../../libs')
import openai
from llm_utils import BSAgent
from data_utils import train_val_test_split,load_split_climate_data
from utils import donload_hf_model
import pandas as pd
import re,json,copy
from tqdm import tqdm
from prompts import short_cot_pt,short_cot_pt_2label,long_cot_pt,long_cot_pt_2label,long_fewshotcot_pt_2label
import pprint

In [22]:
from pydantic import BaseModel
from typing import Literal

In [23]:
from dotenv import load_dotenv
env_path = '../../../.env'
load_dotenv(dotenv_path=env_path)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("OPENAI_API_KEY not found in environment variables. Please check your .env file.")


In [24]:
data_folder = '~/data/Fund/Climate'
# data_path = os.path.join(data_folder,'Climate training paragraphs.csv')
# ds = load_split_climate_data(data_path,merge_neutral=True,verbose=True)
# ds['test'].to_csv(data_folder+'/test.csv')
# ds['validation'].to_csv(data_folder+'/validation.csv')
# ds['train'].to_csv(data_folder+'/train.csv')

test_data = pd.read_csv(data_folder+'/test.csv')
val_data = pd.read_csv(data_folder+'/validation.csv')
train_data = pd.read_csv(data_folder+'/train.csv')

In [25]:
class ClimateClassification(BaseModel):
    justification: str
    classification: Literal["favorable", "unfavorable", "neutral"]

class ClimateClassification_2label(BaseModel):
    justification: str
    classification: Literal["favorable", "unfavorable"]


In [26]:
## download models
# donload_hf_model('meta-llama/Llama-3.1-8B-Instruct', '/home/xiong/data/hf_cache/llama-3.1-8B-Instruct',hf_token=os.getenv('huggingface_token'))
# donload_hf_model('Qwen/Qwen2.5-7B-Instruct', '/home/xiong/data/hf_cache/Qwen2.5-7B-Instruct',hf_token=os.getenv('huggingface_token'))
# donload_hf_model('deepseek-ai/DeepSeek-V2-Lite-Chat', '/home/xiong/data/hf_cache/DeepSeek-V2-Lite-Chat',hf_token=os.getenv('huggingface_token'))
# donload_hf_model('Qwen/Qwen2.5-14B-Instruct', '/home/xiong/data/hf_cache/Qwen2.5-14B-Instruct',hf_token=os.getenv('huggingface_token'))
# donload_hf_model('microsoft/phi-4', '/home/xiong/data/hf_cache/phi-4',hf_token=os.getenv('huggingface_token'))

In [27]:
# use openap modesl 
# agent = BSAgent(model='gpt-4o-mini')
# Try other opens rousce modesl 
# python -m vllm.entrypoints.openai.api_server --model /home/xiong/data/hf_cache/llama-3.1-8B-Instruct --dtype auto --servered_model_name llama-3.1-8b-Instruct
agent = BSAgent(base_url='http://localhost:8000/v1',api_key=None)
agent.model = agent.client.models.list().data[0].id
print(agent.model) 

llama-3.1-8b-Instruct


In [28]:
def get_climate_classifications(agent, dataset, prompt_template):
    results = []
    for i in tqdm(range(len(dataset))):
        structured_prompt = copy.deepcopy(prompt_template)
        structured_prompt['user'] = structured_prompt['user'].format(PARAGRAPH=dataset.iloc[i].paragraph)
        try:
            response = agent.get_response_content(prompt_template=structured_prompt, response_format=ClimateClassification)
            results.append({
                'paragraph': dataset.iloc[i].paragraph,
                'true_label': dataset.iloc[i].label,
                'predicted_label': response.classification,
                'justification': response.justification
            })
        except Exception as e:
            print(f"Error processing row {i}: {str(e)}")
            results.append({
                'paragraph': dataset.iloc[i].paragraph,
                'true_label': dataset.iloc[i].label,
                'predicted_label': None,
                'justification': f"Error: {str(e)}"
            })
    return pd.DataFrame(results)

In [29]:
# Get predictions for validation and test sets
val_results = get_climate_classifications(agent, val_data, long_fewshotcot_pt_2label)
val_results.to_csv(data_folder+'/val_results.csv')

print("\nValidation Results:")
print(f"Total samples: {len(val_results)}")
print(f"Successfully processed: {len(val_results[val_results.predicted_label.notna()])}")
val_accuracy = (val_results['true_label'] == val_results['predicted_label']).mean()
print(f"Validation Accuracy: {val_accuracy:.2%}")


100%|██████████| 108/108 [00:59<00:00,  1.80it/s]


Validation Results:
Total samples: 108
Successfully processed: 108
Validation Accuracy: 80.56%





#### Try run with asyc clent

In [11]:
# Run the test
import nest_asyncio
import asyncio
nest_asyncio.apply()
from llm_utils_async import AsyncBSAgent


In [12]:
agent = AsyncBSAgent(model='llama-3.1-8b-Instruct',base_url='http://localhost:8000/v1',api_key=None)
print(agent.model) 

llama-3.1-8b-Instruct


In [15]:
async def get_climate_classifications(agent, dataset, prompt_template):
    async def process_row(i):
        structured_prompt = copy.deepcopy(prompt_template)
        structured_prompt['user'] = structured_prompt['user'].format(PARAGRAPH=dataset.iloc[i].paragraph)
        try:
            response = await agent.get_response_content(prompt_template=structured_prompt, response_format=ClimateClassification)
            return {
                'paragraph': dataset.iloc[i].paragraph,
                'true_label': dataset.iloc[i].label,
                'predicted_label': response.classification,
                'justification': response.justification
            }
        except Exception as e:
            print(f"Error processing row {i}: {str(e)}")
            return {
                'paragraph': dataset.iloc[i].paragraph,
                'true_label': dataset.iloc[i].label,
                'predicted_label': None,
                'justification': f"Error: {str(e)}"
            }

    tasks = [process_row(i) for i in range(len(dataset))]
    results = await asyncio.gather(*tasks)
    return pd.DataFrame(results)

In [20]:
val_results = asyncio.run(get_climate_classifications(agent, train_data, long_fewshotcot_pt_2label))
print("\nValidation Results:")
print(f"Total samples: {len(val_results)}")
print(f"Successfully processed: {len(val_results[val_results.predicted_label.notna()])}")
val_accuracy = (val_results['true_label'] == val_results['predicted_label']).mean()
print(f"Validation Accuracy: {val_accuracy:.2%}")


Validation Results:
Total samples: 504
Successfully processed: 504
Validation Accuracy: 74.40%
