In [20]:
import requests
import psycopg2
import time
from datetime import datetime

# PostgreSQL connection parameters
PG_HOST = "localhost"
PG_DATABASE = "sustainability_db"
PG_USER = "postgres"
PG_PASSWORD = "***"  
PG_PORT = "5432"

# Configure API request parameters
base_url = "http://api.worldbank.org/v2"
format_param = "format=json"
per_page = "per_page=1000"  # Maximum allowed by the API

# List of indicator codes to extract (focused on environmental decline and sustainability)
indicators = [
    # Population
    "SP.POP.TOTL", # total population
    "SP.POP.GROW", # Population growth (annual %)
    
    # Economic growth through agriculture 
    "NV.AGR.TOTL.KD.ZG", # Agriculture, value added (annual % growth)  
    "NV.AGR.TOTL.ZS", # Agriculture, value added (% of GDP)
    
    # Environment - Original
    "AG.LND.AGRI.ZS", # Agriculture land (% of land area)
    "AG.LND.FRST.ZS", # % of forest area
    "NY.GDP.TOTL.RT.ZS", # Total natural resources rents (% of GDP)
    "EG.ELC.ACCS.UR.ZS", # Access to electricity(% of urban population)
    
    # Infrastructure
    "ER.H2O.FWTL.ZS", # Annual Freshwater withdrawals, total (% of internal resources)
    "ER.H2O.FWTL.K3", # Annual Freshwater withdrawals, total (billion cubic meters)
    "EG.ELC.FOSL.ZS", # Electricity production from oil, gas and coal sources (% of total)
    "EG.ELC.RNWX.ZS", # Electricity production from renewable sources, excluding hydroelectric (% of total)
    
    # Environment - New indicators
    "EN.GHG.ALL.PC.CE.AR5", # Total greenhouse gas emissions excluding LULUCF per capita (t CO2e/capita)
    "EN.GHG.ALL.MT.CE.AR5", # Total greenhouse gas emissions excluding LULUCF (Mt CO2e)
    "EN.GHG.CO2.PC.CE.AR5", # Carbon dioxide (CO2) emissions excluding LULUCF per capita(t CO2e/capita)
    "EN.GHG.CO2.MT.CE.AR5", # Carbon dioxide (CO2) emissions (total) excluding LULUCF (Mt CO2e)
    "EN.GHG.CH4.MT.CE.AR5", # Methane (CH4) emissions (total) excluding LULUCF (Mt CO2e)
    "EN.GHG.N2O.MT.CE.AR5", # Nitrous oxide (N2O) emissions (total) excluding LULUCF (Mt CO2e)
    "AG.LND.TOTL.K2", # Land area (sq.km)
    
    # Energy
    "EG.FEC.RNEW.ZS", # Renewable energy consumption (% of total final energy consumption)
    "EG.USE.COMM.FO.ZS", # Fossil fuel energy consumption (% of total)
    
    # Climate change
    "AG.LND.PRCP.MM", # Average precipitation in depth (mm per year)
    
    # Economy
    "NY.ADJ.DCO2.CD", # Adjusted savings: carbon dioxide damage (current US$)
    "NY.ADJ.DNGY.CD", # Adjusted savings: energy depletion (current US$)
    "NY.ADJ.NNTY.CD", # Adjusted net national income (current US$)
    "NY.ADJ.NNTY.KD.ZG", # Adjusted net national income (annual % growth)
]

# Define time period
start_year = 1990
end_year = 2020
date_range = f"{start_year}:{end_year}"

# Special aggregate region codes that appear in World Bank data
special_codes = ['ZH', 'T5', '1W', 'Z4', 'Z7', 'ZF', 'ZG', 'ZJ', 'ZQ', 'EU', 'OE', 
                'XC', 'XD', 'XE', 'XF', 'XG', 'XH', 'XI', 'XJ', 'XL', 'XM', 'XN', 
                'XO', 'XP', 'XQ', 'XT', 'XU']

# Create database connection
def get_db_connection():
    conn = psycopg2.connect(
        host=PG_HOST,
        database=PG_DATABASE,
        user=PG_USER,
        password=PG_PASSWORD,
        port=PG_PORT
    )
    conn.autocommit = False
    return conn

# Function to add a country to the dimension table
def add_country_to_dimension(cursor, country_code, conn):
    try:
        country_url = f"{base_url}/country/{country_code}?{format_param}"
        response = requests.get(country_url)
        if response.status_code == 200:
            data = response.json()
            if len(data) > 1 and len(data[1]) > 0:
                country = data[1][0]
                
                cursor.execute('''
                INSERT INTO dim_country VALUES (%s, %s, %s, %s)
                ON CONFLICT (country_code) DO UPDATE 
                SET country_name = EXCLUDED.country_name,
                    region = EXCLUDED.region,
                    income_level = EXCLUDED.income_level
                ''', (
                    country_code,
                    country['name'],
                    country.get('region', {}).get('value') if country.get('region') else None,
                    country.get('incomeLevel', {}).get('value') if country.get('incomeLevel') else None
                ))
                conn.commit()
                print(f"  Added country: {country_code} - {country.get('name', 'Unknown')}")
                return True
            else:
                print(f"  No data returned for country code: {country_code}")
                return False
        else:
            print(f"  Failed to get data for country code: {country_code}, status: {response.status_code}")
            return False
    except Exception as e:
        conn.rollback()
        print(f"  Error adding country {country_code}: {e}")
        return False

# Create dimension tables
def create_dimension_tables():
    global indicators
    
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        # Countries dimension
        print("Creating country dimension table...")
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS dim_country (
            country_code TEXT PRIMARY KEY,
            country_name TEXT,
            region TEXT,
            income_level TEXT
        )
        ''')
        
        # Get standard countries - using a better approach with paging
        page = 1
        total_pages = 1
        country_count = 0
        
        while page <= total_pages:
            country_url = f"{base_url}/country?{format_param}&{per_page}&page={page}"
            print(f"Fetching countries, page {page}...")
            response = requests.get(country_url)
            
            if response.status_code == 200:
                data = response.json()
                
                # Update total pages if this is the first page
                if page == 1 and len(data) > 0 and isinstance(data[0], dict):
                    if 'pages' in data[0]:
                        total_pages = data[0]['pages']
                        print(f"Total pages of country data: {total_pages}")
                
                # Process country data
                if len(data) > 1:
                    country_data = data[1]
                    for country in country_data:
                        cursor.execute('''
                        INSERT INTO dim_country VALUES (%s, %s, %s, %s)
                        ON CONFLICT (country_code) DO UPDATE 
                        SET country_name = EXCLUDED.country_name,
                            region = EXCLUDED.region,
                            income_level = EXCLUDED.income_level
                        ''', (
                            country['id'],
                            country['name'],
                            country.get('region', {}).get('value') if country.get('region') else None,
                            country.get('incomeLevel', {}).get('value') if country.get('incomeLevel') else None
                        ))
                        country_count += 1
            else:
                print(f"Error fetching country page {page}: Status code {response.status_code}")
            
            page += 1
            time.sleep(0.5)  # Rate limiting
        
        conn.commit()
        print(f"Added {country_count} countries from standard API")
        
        # Add special aggregate codes
        print("Adding special aggregate region codes...")
        for code in special_codes:
            add_country_to_dimension(cursor, code, conn)
            time.sleep(0.5)  # Rate limiting
        
        # Time dimension
        print("Creating time dimension table...")
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS dim_time (
            year INTEGER PRIMARY KEY,
            decade INTEGER
        )
        ''')
        
        # Insert years
        for year in range(start_year, end_year + 1):
            decade = (year // 10) * 10
            cursor.execute('''
            INSERT INTO dim_time VALUES (%s, %s)
            ON CONFLICT (year) DO UPDATE SET decade = EXCLUDED.decade
            ''', (year, decade))
        
        # Indicator dimension
        print("Creating indicator dimension table...")
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS dim_indicator (
            indicator_code TEXT PRIMARY KEY,
            indicator_name TEXT,
            category TEXT
        )
        ''')
        
        # Updated indicator categories with all indicators
        indicator_categories = {
            # Population
            "SP.POP.TOTL": "population", # total population
            "SP.POP.GROW": "population", # Population growth (annual %)
            
            # Economic growth through agriculture 
            "NV.AGR.TOTL.KD.ZG": "economic_growth", # Agriculture, value added (annual % growth)  
            "NV.AGR.TOTL.ZS": "economic_growth", # Agriculture, value added (% of GDP)
            
            # Environment - Original
            "AG.LND.AGRI.ZS": "environment", # Agriculture land (% of land area)
            "AG.LND.FRST.ZS": "environment", # % of forest area
            "NY.GDP.TOTL.RT.ZS": "environment", # Total natural resources rents (% of GDP)
            "EG.ELC.ACCS.UR.ZS": "environment", # Access to electricity,(% of urban population)
            
            # Infrastructure
            "ER.H2O.FWTL.ZS": "infrastructure", # Annual Freshwater withdrawals, total (% of internal resources)
            "ER.H2O.FWTL.K3": "infrastructure", # Annual Freshwater withdrawals, total (billion cubic meters)
            "EG.ELC.FOSL.ZS": "infrastructure", # Electricity production from oil, gas and coal sources (% of total)
            "EG.ELC.RNWX.ZS": "infrastructure", # Electricity production from renewable sources, excluding hydroelectric (% of total)
            
            # Environment - New indicators
            "EN.GHG.ALL.PC.CE.AR5": "environment", # Total greenhouse gas emissions excluding LULUCF per capita (t CO2e/capita)
            "EN.GHG.ALL.MT.CE.AR5": "environment", # Total greenhouse gas emissions excluding LULUCF (Mt CO2e)
            "EN.GHG.CO2.PC.CE.AR5": "environment", # Carbon dioxide (CO2) emissions excluding LULUCF per capita(t CO2e/capita)
            "EN.GHG.CO2.MT.CE.AR5": "environment", # Carbon dioxide (CO2) emissions (total) excluding LULUCF (Mt CO2e)
            "EN.GHG.CH4.MT.CE.AR5": "environment", # Methane (CH4) emissions (total) excluding LULUCF (Mt CO2e)
            "EN.GHG.N2O.MT.CE.AR5": "environment", # Nitrous oxide (N2O) emissions (total) excluding LULUCF (Mt CO2e)
            "AG.LND.TOTL.K2": "environment", # Land area (sq.km)
            
            # Energy
            "EG.FEC.RNEW.ZS": "energy", # Renewable energy consumption (% of total final energy consumption)
            "EG.USE.COMM.FO.ZS": "energy", # Fossil fuel energy consumption (% of total)
            
            # Climate change
            "AG.LND.PRCP.MM": "climate_change", # Average precipitation in depth (mm per year)
            
            # Economy
            "NY.ADJ.DCO2.CD": "economy", # Adjusted savings: carbon dioxide damage (current US$)
            "NY.ADJ.DNGY.CD": "economy", # Adjusted savings: energy depletion (current US$)
            "NY.ADJ.NNTY.CD": "economy", # Adjusted net national income (current US$)
            "NY.ADJ.NNTY.KD.ZG": "economy", # Adjusted net national income (annual % growth)
        }
        
        # Check indicator existence before adding to db
        valid_indicators = []
        for indicator in indicators:
            try:
                indicator_url = f"{base_url}/indicator/{indicator}?{format_param}"
                response = requests.get(indicator_url)
                
                if response.status_code == 200:
                    data = response.json()
                    if len(data) > 1 and len(data[1]) > 0:
                        indicator_info = data[1][0]
                        cursor.execute('''
                        INSERT INTO dim_indicator VALUES (%s, %s, %s)
                        ON CONFLICT (indicator_code) DO UPDATE 
                        SET indicator_name = EXCLUDED.indicator_name,
                            category = EXCLUDED.category
                        ''', (
                            indicator, 
                            indicator_info.get('name'), 
                            indicator_categories.get(indicator, "Other")
                        ))
                        valid_indicators.append(indicator)
                        print(f"Added indicator: {indicator} - {indicator_info.get('name')}")
                    else:
                        print(f"Warning: Indicator {indicator} not found in World Bank API")
                else:
                    print(f"Warning: Failed to fetch indicator {indicator}, status: {response.status_code}")
            except Exception as e:
                print(f"Error processing indicator {indicator}: {e}")
            
            time.sleep(0.5)  # Rate limiting
        
        # Update the global indicators list to only include valid ones
        indicators = valid_indicators
        
        conn.commit()
        print("Dimension tables created successfully.")
    
    except Exception as e:
        conn.rollback()
        print(f"Error creating dimension tables: {e}")
    finally:
        cursor.close()
        conn.close()

# Create fact table
def create_fact_table():
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        print("Creating fact table...")
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS fact_sustainability (
            id SERIAL PRIMARY KEY,
            country_code TEXT,
            year INTEGER,
            indicator_code TEXT,
            indicator_value REAL,
            load_timestamp TIMESTAMP,
            FOREIGN KEY (country_code) REFERENCES dim_country(country_code),
            FOREIGN KEY (year) REFERENCES dim_time(year),
            FOREIGN KEY (indicator_code) REFERENCES dim_indicator(indicator_code)
        )
        ''')
        conn.commit()
        print("Fact table created successfully.")
    
    except Exception as e:
        conn.rollback()
        print(f"Error creating fact table: {e}")
    finally:
        cursor.close()
        conn.close()

# Extract data for each indicator and load into fact table
def extract_and_load_data():
    print("Extracting and loading data...")
    missing_country_codes = set()
    
    for indicator in indicators:
        print(f"Processing indicator: {indicator}")
        
        # Construct API URL
        indicator_url = f"{base_url}/country/all/indicator/{indicator}?{format_param}&{per_page}&date={date_range}"
        
        # Initialize for pagination
        page = 1
        total_pages = 1
        skipped_count = 0
        inserted_count = 0
        
        # Get all valid country codes from dimension table
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute("SELECT country_code FROM dim_country")
        valid_country_codes = {row[0] for row in cursor.fetchall()}
        cursor.close()
        conn.close()
        
        while page <= total_pages:
            page_url = f"{indicator_url}&page={page}"
            print(f"  Requesting page {page}...")
            
            # Rate limiting to avoid API throttling
            time.sleep(1)
            
            # Create a new connection for each page to avoid transaction issues
            conn = get_db_connection()
            cursor = conn.cursor()
            
            try:
                response = requests.get(page_url)
                
                if response.status_code != 200:
                    print(f"Error: API returned status code {response.status_code}")
                    cursor.close()
                    conn.close()
                    continue
                
                try:
                    data = response.json()
                except Exception as e:
                    print(f"Error parsing JSON response: {e}")
                    cursor.close()
                    conn.close()
                    continue
                
                # Check if we have pagination info
                if len(data) > 0 and isinstance(data[0], dict):
                    if 'pages' in data[0]:
                        total_pages = data[0]['pages']
                
                # Process data if available
                if len(data) > 1:
                    batch_size = 0
                    for entry in data[1]:
                        # Skip if no value or country
                        if 'value' not in entry or entry['value'] is None:
                            continue
                        
                        # Get country code
                        country_code = entry['country']['id']
                        
                        # Check if country code exists in dimension table
                        if country_code not in valid_country_codes:
                            # Add the missing country code to our set
                            missing_country_codes.add(country_code)
                            skipped_count += 1
                            continue
                        
                        try:
                            # Parse value safely
                            value = None
                            if entry['value'] is not None:
                                try:
                                    value = float(entry['value'])
                                except (ValueError, TypeError):
                                    print(f"    Warning: Could not convert value '{entry['value']}' to float")
                                    continue
                            
                            # Parse year safely
                            year = None
                            if entry['date'] is not None:
                                try:
                                    year = int(entry['date'])
                                    # Check year is in our time dimension
                                    if year < start_year or year > end_year:
                                        continue
                                except (ValueError, TypeError):
                                    print(f"    Warning: Could not convert date '{entry['date']}' to integer")
                                    continue
                            
                            cursor.execute('''
                            INSERT INTO fact_sustainability 
                            (country_code, year, indicator_code, indicator_value, load_timestamp)
                            VALUES (%s, %s, %s, %s, %s)
                            ''', (
                                country_code,
                                year,
                                indicator,
                                value,
                                datetime.now()
                            ))
                            inserted_count += 1
                            batch_size += 1
                            
                            # Commit in smaller batches to avoid long transactions
                            if batch_size >= 100:
                                conn.commit()
                                batch_size = 0
                                
                        except Exception as row_error:
                            # Rollback on error to avoid "transaction is aborted" cascading failures
                            conn.rollback()
                            print(f"Error inserting row: {row_error}")
                            # Create a new connection to recover from transaction errors
                            cursor.close()
                            conn.close()
                            conn = get_db_connection()
                            cursor = conn.cursor()
                    
                    # Final commit for this page
                    try:
                        conn.commit()
                    except Exception as commit_error:
                        conn.rollback()
                        print(f"Error committing batch: {commit_error}")
                
                page += 1
                    
            except Exception as e:
                conn.rollback()
                print(f"Error processing {indicator} page {page}: {e}")
                page += 1
            finally:
                # Always close the cursor and connection for this page
                cursor.close()
                conn.close()
            
        print(f"  Completed {indicator}: Inserted {inserted_count} records, skipped {skipped_count} records")
    
    # Handle missing country codes
    if missing_country_codes:
        print(f"Found {len(missing_country_codes)} missing country codes. Adding them to dimension table...")
        handle_missing_country_codes(missing_country_codes)
        
        # We won't retry data loading automatically to avoid potential loops
        print("You may want to run the script again to load data for newly added countries.")

# Create indexes for better query performance
def create_indexes():
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        print("Creating indexes...")
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_fact_country ON fact_sustainability(country_code)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_fact_year ON fact_sustainability(year)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_fact_indicator ON fact_sustainability(indicator_code)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_fact_country_year ON fact_sustainability(country_code, year)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_fact_indicator_year ON fact_sustainability(indicator_code, year)')
        conn.commit()
        print("Indexes created successfully.")
    
    except Exception as e:
        conn.rollback()
        print(f"Error creating indexes: {e}")
    finally:
        cursor.close()
        conn.close()

# Handle missing country codes
def handle_missing_country_codes(missing_codes):
    conn = get_db_connection()
    cursor = conn.cursor()
    
    added_count = 0
    for code in missing_codes:
        if add_country_to_dimension(cursor, code, conn):
            added_count += 1
        time.sleep(0.5)  # Rate limiting
    
    print(f"Added {added_count} out of {len(missing_codes)} missing country codes")
    
    cursor.close()
    conn.close()

# Main execution
def main():
    try:
        # Create tables
        create_dimension_tables()
        create_fact_table()
        
        # Extract and load data
        extract_and_load_data()
        
        # Create indexes
        create_indexes()
        
        print("World Bank environmental sustainability data extraction completed successfully!")
        
    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    main()

Creating country dimension table...
Fetching countries, page 1...
Total pages of country data: 1
Added 296 countries from standard API
Adding special aggregate region codes...
  Added country: ZH - Africa Eastern and Southern
  Added country: T5 - South Asia (IDA & IBRD)
  Added country: 1W - World
  Added country: Z4 - East Asia & Pacific
  Added country: Z7 - Europe & Central Asia
  Added country: ZF - Sub-Saharan Africa (excluding high income)
  Added country: ZG - Sub-Saharan Africa 
  Added country: ZJ - Latin America & Caribbean 
  Added country: ZQ - Middle East & North Africa
  Added country: EU - European Union
  Added country: OE - OECD members
  Added country: XC - Euro area
  Added country: XD - High income
  Added country: XE - Heavily indebted poor countries (HIPC)
  Added country: XF - IBRD only
  Added country: XG - IDA total
  Added country: XH - IDA blend
  Added country: XI - IDA only
  Added country: XJ - Latin America & Caribbean (excluding high income)
  Added cou