In [11]:
import os
import re
import pandas as pd
import jsonlines
import csv
import requests
import time
import pickle
import signal
import logging
import pycountry
from sqlalchemy import create_engine
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Retrieve environment variables
GOOGLE_API_KEY = "AIzaSyA6OZVGmmr256dTC2JBrawCKF2UTr-EsMY"
DATABASE_USER = "jcervantez"
DATABASE_PASSWORD = "Cervantez12"
DATABASE_HOST = os.getenv('DB_HOST', 'localhost')
DATABASE_PORT = os.getenv('DB_PORT', '5432')
DATABASE_NAME = os.getenv('DB_NAME', 'podcast_episodes')

# Validate environment variables
missing_vars = []
if not GOOGLE_API_KEY:
    missing_vars.append('GOOGLE_API_KEY')
if not DATABASE_USER:
    missing_vars.append('DB_USER')
if not DATABASE_PASSWORD:
    missing_vars.append('DB_PASSWORD')

if missing_vars:
    logging.error(f"Missing environment variables: {', '.join(missing_vars)}. "
                  "Please set them before running the script.")
    exit(1)

# Construct the database connection string
DATABASE_CONNECTION_STRING = (
    f"postgresql+psycopg2://{DATABASE_USER}:{DATABASE_PASSWORD}"
    f"@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_NAME}"
)

# Categories to investigate
CATEGORIES = [
    'Authors and Writers',
    'Female Voices',
    'Underrepresented Minority Voices',
    'Media Professionals'
]
TARGET_GENRES = ['Business', 'Society & Culture', 'News']
MAX_GUESTS_PER_CATEGORY = 100  # Desired number of guests per category

# Global variables
current_idx = 0
categorized_guests = {category: [] for category in CATEGORIES}
api_cache = {}
output_data = []
cache_lock = Lock()

# ------------------- Signal Handling for Graceful Interruption -------------------

def signal_handler(signum, frame):
    print("\nScript interrupted. Saving progress before exiting...")
    save_progress()
    exit(1)

signal.signal(signal.SIGINT, signal_handler)

# ------------------- Progress Saving and Loading -------------------

def save_progress():
    global current_idx, categorized_guests, api_cache, output_data
    try:
        with open('progress.pkl', 'wb') as f:
            pickle.dump({
                'current_idx': current_idx,
                'categorized_guests': categorized_guests,
                'api_cache': api_cache,
                'output_data': output_data
            }, f)
        # Save output_data to CSV
        if output_data:
            output_df = pd.DataFrame(output_data)
            output_df.to_csv('guest_recommendations.csv', index=False)
            logging.info("Partial 'guest_recommendations.csv' has been saved.")
        logging.info("Progress saved.")
    except Exception as e:
        logging.error(f"Error saving progress: {e}")

def load_progress():
    global current_idx, categorized_guests, api_cache, output_data
    if os.path.exists('progress.pkl'):
        try:
            with open('progress.pkl', 'rb') as f:
                progress = pickle.load(f)
            logging.info("Progress loaded.")
            current_idx = progress.get('current_idx', 0)
            categorized_guests = progress.get('categorized_guests', {cat: [] for cat in CATEGORIES})
            api_cache = progress.get('api_cache', {})
            output_data = progress.get('output_data', [])
        except (EOFError, pickle.UnpicklingError) as e:
            logging.warning("Progress file is corrupted or empty. Starting from scratch.")
            os.remove('progress.pkl')
            current_idx = 0
            categorized_guests = {cat: [] for cat in CATEGORIES}
            api_cache = {}
            output_data = []
    else:
        current_idx = 0
        categorized_guests = {cat: [] for cat in CATEGORIES}
        api_cache = {}
        output_data = []

# ------------------- Google Knowledge Graph API Interaction -------------------

def get_guest_info(name, service):
    """
    Fetch guest information from Google Knowledge Graph API.
    Implements caching to avoid redundant API calls.
    """
    with cache_lock:
        if name in api_cache:
            return api_cache[name]
    
    try:
        response = service.entities().search(query=name, limit=1).execute()
        time.sleep(0.1)  # Small delay to respect rate limits
        if 'itemListElement' in response and len(response['itemListElement']) > 0:
            entity = response['itemListElement'][0]['result']
            description = entity.get('description', '')
            detailed_description = entity.get('detailedDescription', {}).get('articleBody', '')
            info = (description, detailed_description)
            with cache_lock:
                api_cache[name] = info
            return info
    except HttpError as e:
        if e.resp.status == 429:
            logging.error("Quota exceeded for Google Knowledge Graph API.")
            raise e
        else:
            logging.error(f"HTTP Error fetching data for {name}: {e}")
    except Exception as e:
        logging.error(f"Unexpected error fetching data for {name}: {e}")
    return ('', '')

def categorize_guest(guest, service):
    """
    Categorize a guest into one or more predefined categories.
    Utilizes Google Knowledge Graph API for additional information.
    """
    categories = set()

    # Basic attribute checks
    if guest.get('gender', '').upper() == 'F':
        categories.add('Female Voices')
    if guest.get('African-American', False):
        categories.add('Underrepresented Minority Voices')

    # Fetch additional info from API
    description, detailed_description = get_guest_info(guest['guest_name'], service)
    combined_info = f"{description} {detailed_description}".lower()

    # Media Professionals
    media_keywords = [
        'journalist', 'reporter', 'correspondent', 'editor', 'media',
        'broadcaster', 'anchor', 'columnist'
    ]
    if any(keyword in combined_info for keyword in media_keywords):
        categories.add('Media Professionals')

    # Authors and Writers
    if 'author' in combined_info or 'writer' in combined_info:
        categories.add('Authors and Writers')

    return categories

# ------------------- Parallel Processing Function -------------------

def process_guest(guest, service):
    """
    Worker function to process a single guest.
    Returns a list of categorized guest dictionaries.
    """
    try:
        guest_categories = categorize_guest(guest, service)
        categorized_entries = []
        for category in guest_categories:
            with cache_lock:
                if len(categorized_guests[category]) >= MAX_GUESTS_PER_CATEGORY:
                    continue
                entry = {
                    'category': category,
                    'guest_name': guest['guest_name'],
                    'gender': guest['gender'],
                    'African-American': guest['African-American'],
                    'podcast_id': guest['podcast_id'],
                    'episode_id': guest['episode_id'],
                    'podcast_title': guest['podcast_title'],
                    'episode_title': guest['episode_title'],
                    'episode_description': guest['episode_description']
                }
                categorized_guests[category].append(entry)
                output_data.append(entry)
        return categorized_entries
    except Exception as e:
        logging.error(f"Error processing guest {guest['guest_name']}: {e}")
        return []

# ------------------- Main Function -------------------

def main():
    global current_idx, categorized_guests, api_cache, output_data

    # Load progress if any
    load_progress()

    # Initialize Google Knowledge Graph API service
    try:
        service = build('kgsearch', 'v1', developerKey=GOOGLE_API_KEY)
    except Exception as e:
        logging.error(f"Error initializing Google API service: {e}")
        return

    # Initialize database engine
    logging.info("Connecting to the PostgreSQL database...")
    try:
        engine = create_engine(DATABASE_CONNECTION_STRING)
    except Exception as e:
        logging.error(f"Error connecting to the database: {e}")
        return

    # Define the SQL query to fetch episodes_recent data
    episodes_query = """
    SELECT podcast_id, episode_id, episode_title, episode_description
    FROM episodes_recent
    """

    # Load guests data from 'guests-extract.jsonl'
    logging.info("Loading guests data from 'guests-extract.jsonl'...")
    guests_data = []
    try:
        with open('guests-extract.jsonl', 'r', encoding='utf-8') as f:
            line_number = 0
            for line in f:
                line_number += 1
                line = line.strip()
                if not line:
                    logging.warning(f"Empty line at line {line_number}. Skipping.")
                    continue
                try:
                    obj = jsonlines.Reader([line]).read()
                except jsonlines.InvalidLineError as e:
                    logging.error(f"Invalid JSON at line {line_number}: {e}. Skipping.")
                    continue
                except Exception as e:
                    logging.error(f"Unexpected error at line {line_number}: {e}. Skipping.")
                    continue

                podcast_id = obj.get('podcast_id')
                episode_id = obj.get('episode_id')
                guests = obj.get('guests', [])
                for guest in guests:
                    name = guest.get('name', '').strip()
                    if not name or len(name.split()) < 2:
                        continue  # Skip guests without at least first and last names
                    guest_record = {
                        'podcast_id': podcast_id,
                        'episode_id': episode_id,
                        'guest_name': name,
                        'gender': guest.get('gender', ''),
                        'African-American': guest.get('African-American', False)
                    }
                    guests_data.append(guest_record)
    except FileNotFoundError:
        logging.error("File 'guests-extract.jsonl' not found.")
        return
    except Exception as e:
        logging.error(f"Error reading 'guests-extract.jsonl': {e}")
        return

    guests_df = pd.DataFrame(guests_data)
    logging.info(f"Total guests loaded: {len(guests_df)}")

    # Merge guests with episodes_recent
    logging.info("Merging guests with episodes_recent data...")
    try:
        episodes_iter = pd.read_sql_query(episodes_query, engine, chunksize=100000)
        merged_chunks = []
        for chunk_number, chunk in enumerate(episodes_iter, start=1):
            merged_chunk = pd.merge(
                guests_df,
                chunk,
                on=['podcast_id', 'episode_id'],
                how='inner'
            )
            merged_chunks.append(merged_chunk)
            logging.info(f"Processed chunk {chunk_number}: matched {len(merged_chunk)} guests.")
        merged_df = pd.concat(merged_chunks, ignore_index=True)
    except Exception as e:
        logging.error(f"Error merging guests with episodes_recent: {e}")
        return

    logging.info(f"Total guests after merging: {len(merged_df)}")

    # Load podcast data from 'podcasts_sample.csv'
    logging.info("Loading podcast data from 'podcasts_sample.csv'...")
    try:
        # Replace 'sep' based on your CSV delimiter
        # For comma-separated values
        podcasts_df = pd.read_csv(
            'podcasts_sample.csv',
            sep=',',  # Adjust the separator based on your file
            encoding='utf-8',
            dtype={'podcast_id': str},
            on_bad_lines='skip',
            quoting=csv.QUOTE_ALL,  # Adjust quoting based on your data
            escapechar='\\'
        )
        logging.info(f"Columns found in podcast data: {podcasts_df.columns.tolist()}")
    except Exception as e:
        logging.error(f"Error reading 'podcasts_sample.csv': {e}")
        return

    # Rename 'title' to 'podcast_title' for consistency
    if 'title' in podcasts_df.columns:
        podcasts_df.rename(columns={'title': 'podcast_title'}, inplace=True)
    else:
        logging.error(f"Missing 'title' column in 'podcasts_sample.csv'")
        return

    # Verify that required columns are present
    required_columns = ['podcast_id', 'podcast_title', 'primary_genre']
    missing_columns = [col for col in required_columns if col not in podcasts_df.columns]
    if missing_columns:
        logging.error(f"Missing columns in 'podcasts_sample.csv': {missing_columns}")
        return

    # Merge with podcast titles and genres
    logging.info("Merging with podcast titles and genres...")
    try:
        merged_df = pd.merge(
            merged_df,
            podcasts_df[['podcast_id', 'podcast_title', 'primary_genre']],
            on='podcast_id',
            how='inner'
        )
        logging.info("Merge completed.")
    except Exception as e:
        logging.error(f"Error merging with 'podcasts_sample.csv': {e}")
        return

    logging.info(f"Total guests after merging with podcasts: {len(merged_df)}")

    # Filter for target genres
    logging.info(f"Filtering for target genres: {TARGET_GENRES}")
    merged_df = merged_df[merged_df['primary_genre'].isin(TARGET_GENRES)].copy()
    logging.info(f"Total guests after genre filtering: {len(merged_df)}")

    # Remove duplicate guests based on 'guest_name'
    merged_df = merged_df.drop_duplicates(subset=['guest_name'])
    logging.info(f"Total unique guests after removing duplicates: {len(merged_df)}")

    # Convert merged_df to list of dictionaries for processing
    guests_list = merged_df.to_dict(orient='records')
    total_guests = len(guests_list)
    logging.info(f"Starting categorization of {total_guests} guests...")

    # Initialize ThreadPoolExecutor for parallel processing
    MAX_WORKERS = 20  # Adjust based on your system's capability and API rate limits
    categorized_guest_count = {category: len(categorized_guests[category]) for category in CATEGORIES}

    # Function to determine if we've reached desired counts
    def reached_desired_counts():
        return all(count >= MAX_GUESTS_PER_CATEGORY for count in categorized_guest_count.values())

    # Initialize ThreadPoolExecutor
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Create future to guest mapping
        future_to_guest = {
            executor.submit(process_guest, guest, service): guest for guest in guests_list
        }

        # Iterate over completed futures as they finish
        for future in tqdm(as_completed(future_to_guest), total=total_guests, desc="Categorizing Guests"):
            guest = future_to_guest[future]
            try:
                categorized_entries = future.result()
                for entry in categorized_entries:
                    with cache_lock:
                        categorized_guest_count[entry['category']] += 1
                # Check if desired counts reached
                if reached_desired_counts():
                    logging.info("Desired guest counts per category reached. Stopping categorization.")
                    break
            except Exception as e:
                logging.error(f"Error in processing guest {guest['guest_name']}: {e}")

    # Convert output_data to DataFrame
    output_df = pd.DataFrame(output_data)

    # Ensure each category has exactly MAX_GUESTS_PER_CATEGORY guests
    for category in CATEGORIES:
        current_count = len(output_df[output_df['category'] == category])
        if current_count < MAX_GUESTS_PER_CATEGORY:
            logging.warning(f"Category '{category}' has only {current_count} guests.")
        elif current_count > MAX_GUESTS_PER_CATEGORY:
            # Randomly sample if more than MAX_GUESTS_PER_CATEGORY
            sampled_df = output_df[output_df['category'] == category].sample(
                n=MAX_GUESTS_PER_CATEGORY, random_state=42
            )
            output_df = output_df[~output_df.index.isin(sampled_df.index)]
            output_df = pd.concat([output_df, sampled_df], ignore_index=True)
            logging.info(f"Category '{category}' trimmed to {MAX_GUESTS_PER_CATEGORY} guests.")

    # Save to CSV
    try:
        output_df.to_csv('guest_recommendations.csv', index=False)
        logging.info("CSV file 'guest_recommendations.csv' has been created.")
    except Exception as e:
        logging.error(f"Error saving 'guest_recommendations.csv': {e}")

    # Remove progress file if exists
    if os.path.exists('progress.pkl'):
        try:
            os.remove('progress.pkl')
            logging.info("Progress file removed.")
        except Exception as e:
            logging.error(f"Error removing progress file: {e}")

    logging.info("Script completed successfully.")

if __name__ == '__main__':
    main()

2025-01-14 12:08:48,448 - INFO - Connecting to the PostgreSQL database...
2025-01-14 12:08:48,450 - INFO - Loading guests data from 'guests-extract.jsonl'...
2025-01-14 12:08:50,470 - ERROR - Invalid JSON at line 617899: line contains invalid json: Expecting value: line 1 column 1 (char 0) (line 1). Skipping.
2025-01-14 12:09:20,413 - ERROR - Invalid JSON at line 10142058: line contains invalid json: Expecting value: line 1 column 1 (char 0) (line 1). Skipping.
2025-01-14 12:09:25,995 - INFO - Total guests loaded: 7320672
2025-01-14 12:09:25,995 - INFO - Merging guests with episodes_recent data...
2025-01-14 12:12:18,749 - INFO - Processed chunk 1: matched 90828 guests.
2025-01-14 12:12:22,591 - INFO - Processed chunk 2: matched 76705 guests.
2025-01-14 12:12:26,279 - INFO - Processed chunk 3: matched 59018 guests.
2025-01-14 12:12:30,044 - INFO - Processed chunk 4: matched 99989 guests.
2025-01-14 12:12:33,860 - INFO - Processed chunk 5: matched 88191 guests.
2025-01-14 12:12:37,642 -