# Olympic Medal Prediction - Data Download and Preparation

## Setup Instructions

### 1. Install Required Packages
```bash
pip install kaggle wbdata pandas numpy requests openpyxl pycountry
```

### 2. Configure Kaggle API
1. Go to https://www.kaggle.com/account
2. Click "Create New API Token" - downloads `kaggle.json`
3. Place it at:
   - **Linux/Mac**: `~/.kaggle/kaggle.json`
   - **Windows**: `C:\Users\<username>\.kaggle\kaggle.json`
4. Set permissions (Linux/Mac): `chmod 600 ~/.kaggle/kaggle.json`

### 3. Run All Cells
This notebook will:
- Download Olympic medal data from Kaggle
- Pull World Bank development indicators
- Download UNDP HDI data
- Create country ISO3 mappings
- Merge everything into `data/processed/olympics_merged.csv`

In [2]:
import os
import sys
import pandas as pd
import numpy as np
import requests
import zipfile
import wbdata
import pycountry
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Create directory structure
Path('data/raw/olympic').mkdir(parents=True, exist_ok=True)
Path('data/raw/worldbank').mkdir(parents=True, exist_ok=True)
Path('data/raw/hdi').mkdir(parents=True, exist_ok=True)
Path('data/processed').mkdir(parents=True, exist_ok=True)

print("✓ Directory structure created")

✓ Directory structure created


## Step 1: Download Olympic Medal Data from Kaggle

We'll use the `olympic-games-medal-tally-by-country` dataset which provides clean country × games medal totals.

In [3]:
import os
import pandas as pd

def download_kaggle_olympics_new():
    """Download the alternative Olympic dataset from Kaggle."""
    try:
        import kaggle

        dataset = 'andreinovikov/olympic-games'  # new dataset
        dest_path = 'data/raw/olympic'
        os.makedirs(dest_path, exist_ok=True)

        print(f"Downloading {dataset} from Kaggle...")
        kaggle.api.dataset_download_files(
            dataset,
            path=dest_path,
            unzip=True
        )

        # List downloaded files
        files = os.listdir(dest_path)
        print(f"✓ Downloaded files: {files}")

        return True

    except Exception as e:
        print(f"❌ Error downloading from Kaggle: {e}")
        print("\nTroubleshooting:")
        print("1. Check kaggle.json is in correct location (~/.kaggle/kaggle.json)")
        print("2. Make sure you've accepted dataset terms on Kaggle website")
        print("3. Try downloading manually if API fails")
        return False

# Run the download
download_kaggle_olympics_new()


Downloading andreinovikov/olympic-games from Kaggle...
Dataset URL: https://www.kaggle.com/datasets/andreinovikov/olympic-games
✓ Downloaded files: ['medals_processed.csv', 'olympic_games.csv']


True

In [4]:
def load_and_prepare_olympic_data_new():
    """Load Olympic data from the new Kaggle dataset."""
    olympic_files = os.listdir('data/raw/olympic')
    print(f"Available files: {olympic_files}")

    # Usually the main CSV is named something like 'olympic_games.csv'
    medals_file = [f for f in olympic_files if f.endswith('.csv')][0]
    df = pd.read_csv(f'data/raw/olympic/{medals_file}')

    # Rename columns to match your old workflow
    df.rename(columns={
        'country': 'country_name',
        'year': 'year',
        'games_type': 'games_type',
        'gold': 'gold',
        'silver': 'silver',
        'bronze': 'bronze'
    }, inplace=True)

    # Compute total medals
    df['total_medals'] = df['gold'] + df['silver'] + df['bronze']

    # Save processed CSV
    df.to_csv('data/raw/olympic/medals_processed.csv', index=False)
    print(f"✓ Processed Olympic data saved: {df.shape}")

    return df

# Load the dataset
olympic_df = load_and_prepare_olympic_data_new()


Available files: ['medals_processed.csv', 'olympic_games.csv']
✓ Processed Olympic data saved: (1781, 12)


## Step 2: Create Country ISO3 Mapping

In [5]:
def create_country_mapping():
    """Create mapping from country names to ISO3 codes, including special Olympic cases"""
    
    # ✅ Extended manual mappings to cover all Olympic variations and historical countries
    manual_mappings = {
        # Standard corrections
        'United States': 'USA',
        'United States of America': 'USA',
        'Great Britain': 'GBR',
        'Russia': 'RUS',
        'Russian Federation': 'RUS',
        'ROC': 'RUS',  # Russian Olympic Committee
        'Soviet Union': 'RUS',
        'China': 'CHN',
        "People's Republic of China": 'CHN',
        'South Korea': 'KOR',
        'Korea': 'KOR',
        'Chinese Taipei': 'TWN',
        'Hong Kong': 'HKG',
        'Hong Kong, China': 'HKG',
        'Iran': 'IRN',
        'Netherlands': 'NLD',
        'Czech Republic': 'CZE',
        'Czechia': 'CZE',
        'North Korea': 'PRK',
        'Vietnam': 'VNM',
        'Venezuela': 'VEN',
        'Syria': 'SYR',
        'Tanzania': 'TZA',
        'Bahamas': 'BHS',
        'Philippines': 'PHL',
        'Moldova': 'MDA',
        'Ivory Coast': 'CIV',
        "Côte d'Ivoire": 'CIV',
        'Turkey': 'TUR',
        'Türkiye': 'TUR',

        # 🏛 Historical / special teams
        'Independent Olympic Athletes': None,
        'Serbia and Montenegro': 'SRB',
        'Czechoslovakia': 'CZE',
        'Unified Team': 'RUS',
        'German Democratic Republic (Germany)': 'DEU',
        'Netherlands Antilles': 'NLD',
        'USSR': 'RUS',
        'Virgin Islands, US': 'VIR',
        'Yugoslavia': 'SRB',
        'Mixed team': None,
    }

    def get_iso3(country_name):
        """Get ISO3 code for a country name"""
        # Manual override first
        if country_name in manual_mappings:
            return manual_mappings[country_name]
        
        # Try pycountry fuzzy search
        try:
            result = pycountry.countries.search_fuzzy(country_name)
            return result[0].alpha_3
        except:
            return None
    

    # Create mapping from Olympic data
    if 'olympic_df' in globals() and olympic_df is not None and 'country_name' in olympic_df.columns:
        unique_countries = olympic_df['country_name'].unique()
        
        mapping_data = []
        for country in unique_countries:
            iso3 = get_iso3(country)
            mapping_data.append({
                'country_name': country,
                'iso3': iso3
            })
        
        mapping_df = pd.DataFrame(mapping_data)
        
        # Report unmapped countries
        unmapped = mapping_df[mapping_df['iso3'].isnull()]
        if len(unmapped) > 0:
            print(f"⚠️  Unmapped countries ({len(unmapped)}):")
            print(unmapped['country_name'].tolist())
        else:
            print("✅ All countries successfully mapped!")
        
        # Save mapping
        os.makedirs('data/raw', exist_ok=True)
        mapping_df.to_csv('data/raw/country_iso3_mapping.csv', index=False)
        print(f"\n✓ Country mapping created: {len(mapping_df)} countries")
        print(f"  Mapped: {len(mapping_df[mapping_df['iso3'].notnull()])}")
        
        return mapping_df
    
    else:
        print("❌ olympic_df not found or missing 'country_name' column.")
        return None

# Run mapping
country_mapping = create_country_mapping()
# Clean up special non-national Olympic entries
invalid_countries = ['Independent Olympic Athletes', 'Mixed team']

before = len(olympic_df)
olympic_df = olympic_df[~olympic_df['country_name'].isin(invalid_countries)]
after = len(olympic_df)

print(f"✅ Removed {before - after} non-national entries ({invalid_countries})")
print(f"Remaining rows: {after}")
print(f"Unique countries now: {olympic_df['country_name'].nunique()}")


⚠️  Unmapped countries (2):
['Independent Olympic Athletes', 'Mixed team']

✓ Country mapping created: 152 countries
  Mapped: 150
✅ Removed 6 non-national entries (['Independent Olympic Athletes', 'Mixed team'])
Remaining rows: 1775
Unique countries now: 150


## Step 3: Download World Bank Indicators

In [6]:
import wbdata
import pandas as pd
from datetime import datetime
import pycountry
import os

def download_world_bank_data(start_year=1960, end_year=2024):
    """Download World Bank indicators using wbdata (compatible with older versions)"""

    indicators = {
        'SP.POP.TOTL': 'population',
        'NY.GDP.MKTP.CD': 'gdp_current_usd',
        'NY.GDP.PCAP.CD': 'gdp_per_capita',
        'SP.DYN.LE00.IN': 'life_expectancy',
        'SE.ADT.LITR.ZS': 'adult_literacy_rate',
        'SL.UEM.TOTL.ZS': 'unemployment_rate'
    }

    print(f"Downloading World Bank data from {start_year} to {end_year}...")

    os.makedirs('data/raw/worldbank', exist_ok=True)
    all_data = []

    for wb_code, var_name in indicators.items():
        print(f"  Downloading {var_name} ({wb_code})...", end=' ')
        try:
            # Download without any extra arguments
            data = wbdata.get_dataframe({wb_code: var_name})
            data = data.reset_index()
            
            # Convert date column to year
            data['year'] = pd.to_datetime(data['date']).dt.year
            
            # Filter by year range
            data = data[(data['year'] >= start_year) & (data['year'] <= end_year)]
            
            data = data[['country', 'year', var_name]]
            
            # Save CSV
            data.to_csv(f'data/raw/worldbank/{var_name}.csv', index=False)

            all_data.append(data)
            print(f"✓ ({len(data)} rows)")
        except Exception as e:
            print(f"❌ Error: {e}")

    # Merge all indicators
    if all_data:
        wb_combined = all_data[0]
        for df in all_data[1:]:
            wb_combined = wb_combined.merge(df, on=['country', 'year'], how='outer')

        # Add ISO3 codes
        def get_iso3_wb(country_name):
            try:
                result = pycountry.countries.search_fuzzy(country_name)
                return result[0].alpha_3
            except:
                return None

        # Manual corrections for common World Bank country name mismatches
        manual_iso_map = {
            "Egypt, Arab Rep.": "EGY",
            "Iran, Islamic Rep.": "IRN",
            "Korea, Rep.": "KOR",
            "Korea, Dem. People's Rep.": "PRK",
            "Russian Federation": "RUS",
            "Syrian Arab Republic": "SYR",
            "Venezuela, RB": "VEN",
            "Vietnam": "VNM",
            "Yemen, Rep.": "YEM",
            "Congo, Rep.": "COG",
            "Congo, Dem. Rep.": "COD",
            "Gambia, The": "GMB",
            "Hong Kong SAR, China": "HKG",
            "Lao PDR": "LAO",
            "Macao SAR, China": "MAC",
            "Slovak Republic": "SVK",
            "United States": "USA",
            "United Kingdom": "GBR"
        }

        wb_combined['iso3'] = wb_combined['country'].map(manual_iso_map).fillna(
            wb_combined['country'].apply(get_iso3_wb)
        )

        # Report countries still not mapped
        unmapped = wb_combined[wb_combined['iso3'].isnull()]['country'].unique()
        if len(unmapped) > 0:
            print(f"\n⚠️  Countries still not mapped to ISO3 ({len(unmapped)}):")
            print(list(unmapped))
        else:
            print("\n✓ All countries successfully mapped to ISO3")

        wb_combined.to_csv('data/raw/worldbank_combined.csv', index=False)
        print(f"\n✓ World Bank combined data saved: {wb_combined.shape}")
        print(f"  Countries with ISO3: {wb_combined['iso3'].notnull().sum() / len(wb_combined) * 100:.1f}%")

        return wb_combined

    return None

# Run it
wb_data = download_world_bank_data()


Downloading World Bank data from 1960 to 2024...
  Downloading population (SP.POP.TOTL)... ✓ (17290 rows)
  Downloading gdp_current_usd (NY.GDP.MKTP.CD)... ✓ (17290 rows)
  Downloading gdp_per_capita (NY.GDP.PCAP.CD)... ✓ (17290 rows)
  Downloading life_expectancy (SP.DYN.LE00.IN)... ✓ (17290 rows)
  Downloading adult_literacy_rate (SE.ADT.LITR.ZS)... ✓ (17290 rows)
  Downloading unemployment_rate (SL.UEM.TOTL.ZS)... ✓ (17290 rows)

⚠️  Countries still not mapped to ISO3 (60):
['Africa Eastern and Southern', 'Africa Western and Central', 'Arab World', 'Bahamas, The', 'Caribbean small states', 'Central Europe and the Baltics', 'Channel Islands', 'Early-demographic dividend', 'East Asia & Pacific', 'East Asia & Pacific (IDA & IBRD countries)', 'East Asia & Pacific (excluding high income)', 'Euro area', 'Europe & Central Asia', 'Europe & Central Asia (IDA & IBRD countries)', 'Europe & Central Asia (excluding high income)', 'European Union', 'Fragile and conflict affected situations', 'Hea

## Step 4: Download UNDP HDI Data

In [7]:
def download_hdi_data():
    """Download UNDP Human Development Index data"""
    
    print("Downloading HDI data...")
    
    # UNDP HDI data URL (latest available)
    url = 'https://hdr.undp.org/sites/default/files/2021-22_HDR/HDR21-22_Composite_indices_complete_time_series.csv'
    
    try:
        hdi_df = pd.read_csv(url)
        
        print(f"✓ HDI data downloaded: {hdi_df.shape}")
        print(f"Columns: {hdi_df.columns.tolist()[:10]}...")  # Show first 10 columns
        
        # Reshape from wide to long format
        # Find year columns (typically 1990-2021)
        year_cols = [col for col in hdi_df.columns if col.isdigit()]
        
        if year_cols:
            id_cols = [col for col in hdi_df.columns if col not in year_cols]
            hdi_long = hdi_df.melt(
                id_vars=id_cols,
                value_vars=year_cols,
                var_name='year',
                value_name='hdi'
            )
            hdi_long['year'] = hdi_long['year'].astype(int)
            
            # Keep relevant columns
            keep_cols = ['country', 'iso3', 'year', 'hdi']
            available_cols = [c for c in keep_cols if c in hdi_long.columns]
            
            if 'iso3' not in available_cols and 'country' in available_cols:
                # Add ISO3 codes
                def get_iso3_hdi(country_name):
                    try:
                        result = pycountry.countries.search_fuzzy(str(country_name))
                        return result[0].alpha_3
                    except:
                        return None
                hdi_long['iso3'] = hdi_long['country'].apply(get_iso3_hdi)
                available_cols.append('iso3')
            
            hdi_clean = hdi_long[available_cols].copy()
            
            # Remove missing HDI values
            hdi_clean = hdi_clean[hdi_clean['hdi'].notnull()]
            
            # Save
            hdi_clean.to_csv('data/raw/hdi/hdi_data.csv', index=False)
            print(f"✓ HDI data processed and saved: {hdi_clean.shape}")
            
            return hdi_clean
        
    except Exception as e:
        print(f"❌ Error downloading HDI data: {e}")
        print("Continuing without HDI data...")
        return None

hdi_data = download_hdi_data()

Downloading HDI data...
✓ HDI data downloaded: (206, 1008)
Columns: ['iso3', 'country', 'hdicode', 'region', 'hdi_rank_2021', 'hdi_1990', 'hdi_1991', 'hdi_1992', 'hdi_1993', 'hdi_1994']...


## Step 5: Merge All Datasets

In [8]:
def merge_all_data_full(olympic_df, country_mapping, wb_data, hdi_data):
    """
    Merge Olympic, World Bank, and HDI data, including all participating countries.
    Creates 0-medal rows for countries that didn't win.
    """
    
    print("\n" + "="*60)
    print("MERGING ALL DATASETS (INCLUDE ZERO-MEDALS)")
    print("="*60)
    
    if olympic_df is None:
        print("❌ No Olympic data available")
        return None
    
    # Add ISO3 codes first
    if country_mapping is not None:
        olympic_df = olympic_df.merge(
            country_mapping[['country_name', 'iso3']],
            on='country_name',
            how='left'
        )
        print(f"After adding ISO3: {olympic_df.shape}")
        print(f"  Missing ISO3: {olympic_df['iso3'].isnull().sum()}")
    
    # Get list of all countries and games
    all_countries = country_mapping['iso3'].dropna().unique()
    all_games = olympic_df[['year','games_type']].drop_duplicates()
    
    # Create full country × games table
    full_rows = []
    for year, game in all_games.values:
        for iso3 in all_countries:
            full_rows.append([iso3, year, game])
    
    full_df = pd.DataFrame(full_rows, columns=['iso3','year','games_type'])
    
    # Merge medal info (left join so missing medals become NaN)
    full_df = full_df.merge(
        olympic_df[['iso3','year','games_type','gold','silver','bronze','total_medals']],
        on=['iso3','year','games_type'],
        how='left'
    )
    
    # Fill missing medals with 0
    for col in ['gold','silver','bronze','total_medals']:
        full_df[col] = full_df[col].fillna(0)
    
    # Classification target
    full_df['has_medal'] = (full_df['total_medals'] > 0).astype(int)
    
    print(f"\n✅ Medal coverage after adding zero-medal countries:")
    print(full_df['has_medal'].value_counts())
    
    # Merge World Bank indicators
    if wb_data is not None:
        wb_subset = wb_data[wb_data['iso3'].notnull()].copy()
        full_df = full_df.merge(
            wb_subset,
            on=['iso3','year'],
            how='left',
            suffixes=('', '_wb')
        )
        
        wb_indicators = ['population', 'gdp_current_usd', 'gdp_per_capita', 
                         'life_expectancy', 'adult_literacy_rate', 'unemployment_rate']
        for indicator in wb_indicators:
            if indicator in full_df.columns:
                full_df[indicator] = full_df.groupby('iso3')[indicator].fillna(method='ffill')
    
    # Merge HDI
    if hdi_data is not None:
        hdi_subset = hdi_data[hdi_data['iso3'].notnull()][['iso3','year','hdi']].copy()
        full_df = full_df.merge(
            hdi_subset,
            on=['iso3','year'],
            how='left'
        )
        full_df['hdi'] = full_df.groupby('iso3')['hdi'].fillna(method='ffill')
    
    # Reorder columns
    base_cols = ['iso3','year','games_type']
    medal_cols = ['gold','silver','bronze','total_medals','has_medal']
    other_cols = [c for c in full_df.columns if c not in base_cols + medal_cols]
    full_df = full_df[base_cols + medal_cols + other_cols]
    
    # Save final dataset
    full_df.to_csv('data/processed/olympics_merged.csv', index=False)
    
    print(f"\n🎯 FINAL MERGED DATA: {full_df.shape}")
    print(f"Columns: {full_df.columns.tolist()}")
    print(f"Data coverage: years {full_df['year'].min()}-{full_df['year'].max()}, countries {full_df['iso3'].nunique()}")
    
    return full_df


In [None]:
final_data = merge_all_data_full(olympic_df, country_mapping, wb_data, hdi_data)

# Quick check
print("\nSample of merged data (including zero-medals):")
print(final_data.head(10))

# Check class balance
print("\nClass balance in 'has_medal':")
print(final_data['has_medal'].value_counts())


MERGING ALL DATASETS (INCLUDE ZERO-MEDALS)
After adding ISO3: (1775, 13)
  Missing ISO3: 0

✅ Medal coverage after adding zero-medal countries:
has_medal
0    5511
1    1775
Name: count, dtype: int64

🎯 FINAL MERGED DATA: (7390, 15)
Columns: ['iso3', 'year', 'games_type', 'gold', 'silver', 'bronze', 'total_medals', 'has_medal', 'country', 'population', 'gdp_current_usd', 'gdp_per_capita', 'life_expectancy', 'adult_literacy_rate', 'unemployment_rate']
Data coverage: years 1896-2022, countries 137

Sample of merged data (including zero-medals):
  iso3  year games_type  gold  silver  bronze  total_medals  has_medal  \
0  AUS  2022     Winter   1.0     2.0     1.0           4.0          1   
1  AUT  2022     Winter   7.0     7.0     4.0          18.0          1   
2  BLR  2022     Winter   0.0     2.0     0.0           2.0          1   
3  BEL  2022     Winter   1.0     0.0     1.0           2.0          1   
4  CAN  2022     Winter   4.0     8.0    14.0          26.0          1   
5  CZE

## Summary

### Files Created:
- `data/raw/olympic/` - Raw Olympic medal data
- `data/raw/worldbank/` - Individual World Bank indicator files
- `data/raw/worldbank_combined.csv` - Combined World Bank data
- `data/raw/hdi/hdi_data.csv` - UNDP HDI data
- `data/raw/country_iso3_mapping.csv` - Country name to ISO3 mapping
- `data/processed/olympics_merged.csv` - **Final merged dataset**

### Next Steps:
1. Run `train_two_stage_model.ipynb` to train the prediction model
2. Use `evaluate_and_report.ipynb` to evaluate and generate reports
3. Launch `streamlit run app.py` for interactive predictions