Notebook to add better categories to RSO data

In [13]:
import json
import asyncio
from groq import Groq
import os
from typing import Dict, List, Optional
import time
from tqdm import tqdm
from dotenv import load_dotenv
from datetime import datetime, timedelta
import tiktoken
import random


In [14]:
load_dotenv()  # This loads the variables from .env

client = Groq(
    api_key=os.environ["GROQ_API_KEY"],
)
token_bucket = TokenBucket(tokens_per_minute=4500)  # Using 4500 to be conservative


In [12]:
class TokenBucket:
    def __init__(self, tokens_per_minute: int = 5000):
        self.max_tokens = tokens_per_minute
        self.tokens = tokens_per_minute
        self.last_update = datetime.now()
        self.tokens_per_minute = tokens_per_minute

    def update_tokens(self):
        now = datetime.now()
        time_passed = now - self.last_update
        self.tokens = min(
            self.max_tokens,
            self.tokens + (time_passed.total_seconds() * self.tokens_per_minute / 60)
        )
        self.last_update = now

    def consume(self, tokens: int) -> float:
        self.update_tokens()
        if self.tokens < tokens:
            wait_time = (tokens - self.tokens) * 60 / self.tokens_per_minute
            return wait_time
        self.tokens -= tokens
        return 0

In [15]:
def count_tokens(text: str) -> int:
    """Estimate token count using tiktoken"""
    # Using cl100k_base as an approximation
    encoding = tiktoken.get_encoding("cl100k_base")
    return len(encoding.encode(text))

def exponential_backoff(attempt: int, base_delay: float = 1) -> float:
    """Calculate delay with jitter for exponential backoff"""
    delay = min(300, base_delay * (2 ** attempt))  # Cap at 5 minutes
    jitter = random.uniform(0, 0.1 * delay)
    return delay + jitter

In [19]:
VALID_CATEGORIES = [
    "Biology", "Business", "Chemistry", "Physics", "Mathematics", "Computer Science", 
    "Data Science", "Economics", "Psychology", "Sociology", "Political Science",
    "History", "Philosophy", "Literature", "Languages", "Law", "Medicine",
    "Nursing", "Public Health", "Engineering", "Environmental Science",
    "Finance", "Investment", "Quantitative Trading", "Private Equity",
    "Venture Capital", "Consulting", "Marketing", "Entrepreneurship",
    "Real Estate", "Technology", "Software Development", "Product Management",
    "Healthcare", "Legal", "Research", "Journalism", "Media Production",
    "Visual Arts", "Painting", "Photography", "Digital Art", "Music",
    "Band", "Choir", "A Cappella", "Theater", "Dance", "Film",
    "Creative Writing", "Design", "Cultural", "International", "Religious",
    "LGBTQ+", "Gender & Sexuality", "Social Justice", "Political", "Activism",
    "Community Service/volunteering", "Mentorship", "Environmental", "Sports"
    "Team Sports", "Individual Sports", "Gaming",
    "Debate", "Model UN", "Food & Cooking", "Travel", "Outdoor Activities",
    "Student Government", "Publications", "Journalism", "Mental Health", "Wellness",
    "Career Development", "Academic Support", "Leadership", "Greek Life"
]

In [6]:
FEW_SHOT_EXAMPLES = [
    {
        "name": "Investment Banking Group",
        "description": "A professional organization dedicated to educating members about investment banking, private equity, and financial markets. We host networking events, technical workshops, and mock interviews to prepare students for careers in finance.",
        "ideal_categories": ["Finance", "Investment", "Career Development"],
        "explanation": "This RSO focuses on finance education and career preparation, warranting multiple related financial categories."
    },
    {
        "name": "Data Science for Social Good",
        "description": "We apply data science and machine learning techniques to tackle social issues in healthcare, education, and environmental sustainability. Members work on real-world projects while learning technical skills.",
        "ideal_categories": ["Data Science", "Computer Science", "Community Service/volunteering"],
        "explanation": "Combines technical data science work with social impact, deserving both technical and service categories."
    },
    {
        "name": "Mental Health Alliance",
        "description": "A student organization focused on promoting mental health awareness, reducing stigma, and connecting students with resources. We organize wellness workshops, peer support groups, and educational events.",
        "ideal_categories": ["Mental Health", "Wellness", "Student Life"],
        "explanation": "Focuses on mental health and wellness within student life context."
    }
]

In [None]:
def create_prompt(rso: Dict) -> str:
    """Create a prompt for the LLM to categorize an RSO."""
    few_shot_text = "\n---\n".join([
        f"""
Name: {ex['name']}
Description: {ex['description']}
Categories: {', '.join(ex['ideal_categories'])}
Explanation: {ex['explanation']}""" 
        for ex in FEW_SHOT_EXAMPLES
    ])
    
    prompt = f"""You are an expert at categorizing university student organizations. Given an RSO's name and description, assign it relevant categories from the provided list. Each RSO can have multiple categories if appropriate.

Valid categories: {', '.join(VALID_CATEGORIES)}

Here are some examples:
{few_shot_text}

For the following RSO, please provide:
1. A list of relevant categories (can be multiple)
2. A confidence score (0-100) for each category
3. A brief explanation of your categorization

Name: {rso['name']}
Description: {rso.get('full_description', '') or rso.get('description_preview', '')}

Response should be in JSON format:
{{
  "categories": [
    {{"name": "category_name", "confidence": 95}},
    {{"name": "another_category", "confidence": 85}}
  ],
  "explanation": "Brief explanation of categorization"
}}"""
    return prompt

In [16]:
def categorize_rso(rso: Dict, attempt: int = 0) -> Optional[Dict]:
    """Categorize a single RSO using the Groq API with rate limiting."""
    try:
        prompt = create_prompt(rso)
        estimated_tokens = count_tokens(prompt) + 500  # Add buffer for response
        
        # Check token bucket
        wait_time = token_bucket.consume(estimated_tokens)
        if wait_time > 0:
            print(f"\nRate limit approaching, waiting {wait_time:.2f} seconds...")
            time.sleep(wait_time)
        
        completion = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="mixtral-8x7b-32768",
            temperature=0.3,
            max_tokens=1000
        )
        
        response = json.loads(completion.choices[0].message.content)
        
        # Validate categories against our list
        response['categories'] = [
            cat for cat in response['categories'] 
            if cat['name'] in VALID_CATEGORIES
        ]
        
        return response
        
    except Exception as e:
        if "rate_limit" in str(e).lower():
            if attempt < 5:  # Max 5 retries
                delay = exponential_backoff(attempt)
                print(f"\nRate limit hit for {rso['name']}, waiting {delay:.2f} seconds...")
                time.sleep(delay)
                return categorize_rso(rso, attempt + 1)
            else:
                print(f"\nMax retries reached for {rso['name']}")
        print(f"Error categorizing RSO {rso['name']}: {str(e)}")
        return None

In [17]:
def process_rsos(input_file: str = None, output_file: str = None, rsos: List[Dict] = None):
    """Process all RSOs from input file or list and save results to output file."""
    try:
        # Read input JSON if file provided, otherwise use provided list
        if input_file:
            with open(input_file, 'r') as f:
                rsos = json.load(f)
        
        if not rsos:
            raise ValueError("No RSOs provided")
            
        results = []
        batch_size = 3  # Reduced batch size
        
        # Process RSOs in batches with progress bar
        for i in tqdm(range(0, len(rsos), batch_size)):
            batch = rsos[i:i + batch_size]
            batch_results = []
            
            # Process each RSO in batch
            for rso in batch:
                categorization = categorize_rso(rso)
                if categorization:
                    rso['ai_categories'] = categorization['categories']
                    rso['categorization_explanation'] = categorization['explanation']
                batch_results.append(rso)
            
            results.extend(batch_results)
            
            # Save intermediate results every batch
            if output_file:
                with open(f"{output_file}.partial", 'w') as f:
                    json.dump(results, f, indent=2)
            
            # Add delay between batches
            time.sleep(2)  # Conservative delay between batches
        
        # Write final results to output file if provided
        if output_file:
            with open(output_file, 'w') as f:
                json.dump(results, f, indent=2)
            print(f'\nCategorization complete! Results saved to {output_file}')
            
        return results
        
    except Exception as e:
        print(f'\nError processing RSOs: {str(e)}')
        # Save partial results if available
        if results and output_file:
            with open(f"{output_file}.error_partial", 'w') as f:
                json.dump(results, f, indent=2)
            print(f'Partial results saved to {output_file}.error_partial')
        return results

In [None]:
results = process_rsos('rso_data_detailed.json', 'categorized_rsos.json')