In [1]:
from un_classification import un_classification
un_classification = un_classification

In [20]:
import logging
import os
import pandas as pd
from typing import List, Optional, Dict
from pydantic import BaseModel, Field
from openai import OpenAI
from dotenv import load_dotenv


In [21]:
#using gpt40 api

# load api key
load_dotenv()
api_key = os.getenv("API_KEY")

# set up logging
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# define pydantic BaseModel
class TagClassification(BaseModel):
    tag: str = Field(..., description="Main category tag")
    subtag1: Optional[str] = Field(None, description="Subcategory of the main tag")
    subtag2: Optional[str] = Field(None, description="Specific item within subcategory")

class ResolutionClassification(BaseModel):
    classifications: List[TagClassification] = Field(..., description="List of relevant classifications for this resolution")

# New models for staged classification
class MainTagClassification(BaseModel):
    main_tags: List[str] = Field(..., description="List of relevant main category tags")

class SubTag1Classification(BaseModel):
    subtag1s: List[str] = Field(..., description="List of relevant subcategories for the main tag")

class SubTag2Classification(BaseModel):
    subtag2s: List[str] = Field(..., description="List of relevant specific items for the subcategory")

def call_api_staged(Title: str, stage: int, previous_tags: Optional[Dict] = None):
    """
    Analyzes a UN resolution text in stages.
    
    Args:
        Title: Title of the resolution to analyze
        stage: 1 for main tag, 2 for subtag1, 3 for subtag2
        previous_tags: Results from previous stages
        
    Returns:
        Structured classification results
    """
    # Initialize OpenAI client
    client = OpenAI(api_key=api_key)
    
    if stage == 1:
        # First stage: identify main tag categories
        main_tag = list(un_classification.keys())
        system_prompt = f"""You are a UN document classification assistant. Your task is to analyze UN resolutions given their Title.
Classify the resolution according to the following valid main categories (select only values from the list):
        
{main_tag}

Rules:
1. Identify ALL relevant main categories from the list.
2. Return only valid category names as a list.
3. If none of the categories apply, return an empty list.
"""
        try:
            logger.info("Calling OpenAI API for main tag classification.")
            response = client.beta.chat.completions.parse(
                model="gpt-4o-mini",
                temperature=0.3,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": f"Resolution text: {Title}"}
                ],
                max_tokens=1000,
                response_format=MainTagClassification,
            )
            
            main_tags_result = response.choices[0].message.parsed
            logger.info("Main tag API call successful.")
            return main_tags_result
            
        except Exception as e:
            logger.error(f"Error during main tag API call: {e}")
            return MainTagClassification(main_tags=[])
        
    elif stage == 2:
        # Second stage: identify subtag1 based on main tags
        main_tag = previous_tags["main_tag"]
        subcategories = list(un_classification[main_tag].keys())
        
        system_prompt = f"""You are a UN document classification assistant. Your task is to analyze UN resolutions given their Title.
For a resolution categorized in the main category '{main_tag}', select the relevant subcategories from the following valid list:
        
{subcategories}

Rules:
1. Select only unique, valid subcategories from the list above.
2. If none of the listed subcategories apply, return an empty string.
3. Return only the valid subcategory names as a list.
"""
        try:
            logger.info(f"Calling OpenAI API for subtag1 classification for {main_tag}.")
            response = client.beta.chat.completions.parse(
                model="gpt-4o-mini",
                temperature=0.3,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": f"Resolution text: {Title}"}
                ],
                max_tokens=1000,
                response_format=SubTag1Classification,
            )
            
            subtag1_result = response.choices[0].message.parsed
            logger.info(f"Subtag1 API call for {main_tag} successful.")
            return subtag1_result
            
        except Exception as e:
            logger.error(f"Error during subtag1 API call for {main_tag}: {e}")
            return SubTag1Classification(subtag1s=[])
        
    elif stage == 3:
        # Third stage: identify subtag2 based on main tag and subtag1
        main_tag = previous_tags["main_tag"]
        subtag1 = previous_tags["subtag1"]
        specific_items = un_classification[main_tag][subtag1]
        
        system_prompt = f"""You are a UN document classification assistant. Your task is to analyze UN resolutions given their Title.
For a resolution categorized as '{main_tag}' > '{subtag1}', choose the single most relevant specific item from the following valid options:
        
{specific_items}

Rules:
1. Select only one item from the above list.
2. If none of the specific items are applicable, or if no valid options exist, return an empty string.
3. Return only the single selected item as a string.
"""
        try:
            logger.info(f"Calling OpenAI API for subtag2 classification for {main_tag} > {subtag1}.")
            response = client.beta.chat.completions.parse(
                model="gpt-4o-mini",
                temperature=0.3,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": f"Resolution text: {Title}"}
                ],
                max_tokens=1000,
                response_format=SubTag2Classification,
            )
            
            subtag2_result = response.choices[0].message.parsed
            logger.info(f"Subtag2 API call for {main_tag} > {subtag1} successful.")
            return subtag2_result
            
        except Exception as e:
            logger.error(f"Error during subtag2 API call for {main_tag} > {subtag1}: {e}")
            return SubTag2Classification(subtag2s=[])

def get_tags_iterative(Title: str) -> List[List]:
    """
    Gets classification tags iteratively in three stages.
    
    Returns:
        List of lists containing [tag, subtag1, subtag2]
    """
    final_results = []
    
    # Stage 1: Get main tags
    main_tags_result = call_api_staged(Title, stage=1)
    if not main_tags_result.main_tags:
        logger.warning("No main tags found.")
        return []
    
    # For each main tag, get subtag1
    for main_tag in main_tags_result.main_tags:
        subtag1_result = call_api_staged(Title, stage=2, previous_tags={"main_tag": main_tag})
        if not subtag1_result.subtag1s:
            logger.warning(f"No subtag1s found for main tag: {main_tag}")
            continue
            
        # For each subtag1, get subtag2
        for subtag1 in subtag1_result.subtag1s:
            subtag2_result = call_api_staged(Title, stage=3, previous_tags={"main_tag": main_tag, "subtag1": subtag1})
            if subtag2_result.subtag2s:
                for subtag2 in subtag2_result.subtag2s:
                    final_results.append([main_tag, subtag1, subtag2])
            else:
                logger.warning(f"No subtag2s found for {main_tag} > {subtag1}")
    
    return final_results


In [22]:
#create sample dataframe
random_seed = 42
df2= pd.read_csv('data/UN_VOTING_DATA_RAW.csv')
df_sample = df2.sample(20,random_state=random_seed)

  df2= pd.read_csv('data/UN_VOTING_DATA_RAW.csv')


In [23]:

#add new column to dataframe with tags
logger.info("Adding tags to sample dataframe.")
df_sample['tags'] = df_sample.apply(lambda row: get_tags_iterative(row['Title']), axis=1)
logger.info("Tags added to sample dataframe.")

2025-03-13 14:49:52,228 - INFO - Adding tags to sample dataframe.
2025-03-13 14:49:52,568 - INFO - Calling OpenAI API for main tag classification.
2025-03-13 14:49:53,788 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-03-13 14:49:53,802 - INFO - Main tag API call successful.
2025-03-13 14:49:54,288 - INFO - Calling OpenAI API for subtag1 classification for POLITICAL AND LEGAL QUESTIONS.
2025-03-13 14:49:55,466 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-03-13 14:49:55,483 - INFO - Subtag1 API call for POLITICAL AND LEGAL QUESTIONS successful.
2025-03-13 14:49:55,926 - INFO - Calling OpenAI API for subtag2 classification for POLITICAL AND LEGAL QUESTIONS > MAINTENANCE OF PEACE AND SECURITY.
2025-03-13 14:49:57,381 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-03-13 14:49:57,396 - INFO - Subtag2 API call for POLITICAL AND LEGAL QUESTIONS > MAINT

In [24]:
#save df_sample to csv
df_sample.to_csv('data/fixed_sample_gpt4o-mini_UN_VOTING_DATA_RAW_with_tags.csv', index=False)
logger.info("Sample dataframe with tags saved to CSV.")

2025-03-13 14:53:07,183 - INFO - Sample dataframe with tags saved to CSV.
