<a href="https://colab.research.google.com/github/mchandler-CPT/mscai-eportfolio/blob/main/UNIT04_SEMINAR04_Linear_Regression_with_Scikit_Learn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# @title Phase 1: Setup and Data Loading
# -----------------------------------------
# Import all necessary libraries for the analysis.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive
from scipy import stats

# Set some display options for pandas and matplotlib for better visualization
%matplotlib inline
pd.set_option('display.max_columns', None)
sns.set_style('whitegrid')


# --- Load Data from Google Drive ---
# Mount your Google Drive to access the file.
print("Mounting Google Drive...")
drive.mount('/content/drive', force_remount=True)

# Define the paths to your two CSV files based on your screenshot.
gdp_file_path = '/content/drive/My Drive/Unit04 Global_GDP.csv'
pop_file_path = '/content/drive/My Drive/Unit04 Global_Population.csv'

# Load the datasets. We will use header=4 to explicitly tell pandas that the 5th row
# in the World Bank CSVs contains the column names.
try:
    print(f"\nLoading data from: {gdp_file_path}")
    df_gdp = pd.read_csv(gdp_file_path, header=4)

    print(f"Loading data from: {pop_file_path}")
    df_pop = pd.read_csv(pop_file_path, header=4)

    # --- Clean the DataFrames ---
    # Drop any columns that are completely empty or are named with "Unnamed"
    df_gdp = df_gdp.loc[:, ~df_gdp.columns.str.contains('^Unnamed')]
    df_pop = df_pop.loc[:, ~df_pop.columns.str.contains('^Unnamed')]

    # Strip any leading/trailing whitespace from the remaining column names
    df_gdp.columns = df_gdp.columns.str.strip()
    df_pop.columns = df_pop.columns.str.strip()

    print("\nFiles loaded and data cleaned successfully!")
except FileNotFoundError as e:
    print(f"ERROR: A file was not found. Please check the paths and filenames. Details: {e}")
    # Stop execution if files aren't found
    df_gdp = None
    df_pop = None

if df_gdp is not None:
    print("\n--- GDP Data Sample ---")
    print(df_gdp.head())
    print("\n--- Population Data Sample ---")
    print(df_pop.head())


# @title Phase 2: Data Pre-processing and Aggregation
# ----------------------------------------------------
if df_gdp is not None and df_pop is not None:
    print("\n--- Starting Data Pre-processing ---")

    # Define the range of years for the analysis
    years = [str(year) for year in range(2001, 2022)]

    # --- Process GDP Data ---
    # Select only the country name and the years of interest
    gdp_data = df_gdp[['Country Name'] + years]
    # Calculate the mean GDP per capita across the specified years, ignoring missing values
    gdp_data['mean_gdp'] = gdp_data[years].mean(axis=1, skipna=True)

    # --- Process Population Data ---
    # Select only the country name and the years of interest
    pop_data = df_pop[['Country Name'] + years]
    # Calculate the mean population across the specified years, ignoring missing values
    pop_data['mean_population'] = pop_data[years].mean(axis=1, skipna=True)

    # --- Merge the Datasets ---
    # Merge the two aggregated datasets on 'Country Name'
    df_merged = pd.merge(gdp_data[['Country Name', 'mean_gdp']],
                         pop_data[['Country Name', 'mean_population']],
                         on='Country Name')

    # --- Final Cleaning ---
    # Drop any rows that have missing values after the merge and calculations
    df_merged.dropna(inplace=True)

    print("\nPre-processing and merging complete.")
    print("\n--- Merged Data Sample ---")
    print(df_merged.head())
    print(f"\nFinal dataset has {len(df_merged)} countries with complete data.")


# @title Phase 3: Task A - Correlation Analysis
# ----------------------------------------------
if 'df_merged' in locals():
    print("\n--- Task A: Correlation Analysis ---")

    # --- Create Scatter Plot ---
    plt.figure(figsize=(12, 8))
    sns.scatterplot(x='mean_population', y='mean_gdp', data=df_merged)
    plt.title('Mean Population vs. Mean Per Capita GDP (2001-2021)')
    plt.xlabel('Mean Population')
    plt.ylabel('Mean Per Capita GDP (USD)')
    # Use a log scale for the x-axis to better visualize the distribution, as population is heavily skewed
    plt.xscale('log')
    plt.show()

    # --- Interpretation ---
    print("\nPlot Interpretation:")
    print("The scatter plot shows the relationship between a country's mean population and its mean per capita GDP.")
    print("Most countries are clustered at the lower end of the population scale. There does not appear to be a strong linear relationship; countries with low populations exhibit a very wide range of GDPs, while high-population countries tend to have lower-to-middle per capita GDPs.")

    # --- Calculate Pearson Correlation ---
    # Pearson's r measures the linear relationship between two datasets.
    correlation, p_value = stats.pearsonr(df_merged['mean_population'], df_merged['mean_gdp'])

    print(f"\nPearson Correlation Coefficient: {correlation:.3f}")
    print("This value indicates a very weak negative linear correlation. As population increases, there is a very slight tendency for per capita GDP to decrease, but the relationship is not strong.")


# @title Phase 4: Task B - Regression Analysis
# ---------------------------------------------
if 'df_merged' in locals():
    print("\n--- Task B: Regression Analysis ---")

    # Define independent (x) and dependent (y) variables
    x = df_merged['mean_population']
    y = df_merged['mean_gdp']

    # --- Perform Linear Regression ---
    # This function returns all the key values needed for the regression line and analysis
    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)

    print(f"\nRegression Line Equation: y = {slope:.4f}x + {intercept:.2f}")
    print(f"R-squared value: {r_value**2:.4f}")

    # R-squared tells us the proportion of the variance in the dependent variable
    # that is predictable from the independent variable. A low value suggests
    # population does not explain much of the variation in per capita GDP.

    # --- Plot the Regression Line ---
    # Create a function to calculate y-values based on the regression line
    def regression_line(x_val):
        return slope * x_val + intercept

    # Create the y-values for our line
    model_line = list(map(regression_line, x))

    # Draw the original scatter plot and the line of linear regression
    plt.figure(figsize=(12, 8))
    plt.scatter(x, y, label='Actual Data')
    plt.plot(x, model_line, color='red', linewidth=2, label='Regression Line')
    plt.title('Linear Regression: Population vs. Per Capita GDP')
    plt.xlabel('Mean Population')
    plt.ylabel('Mean Per Capita GDP (USD)')
    plt.legend()
    plt.show()



Mounting Google Drive...
Mounted at /content/drive

Loading data from: /content/drive/My Drive/Unit04 Global_GDP.csv
Loading data from: /content/drive/My Drive/Unit04 Global_Population.csv

Files loaded and data cleaned successfully!

--- GDP Data Sample ---
  Africa Western and Central  AFW  GDP (current US$)  NY.GDP.MKTP.CD  \
0                     Angola  AGO  GDP (current US$)  NY.GDP.MKTP.CD   
1                    Albania  ALB  GDP (current US$)  NY.GDP.MKTP.CD   
2                    Andorra  AND  GDP (current US$)  NY.GDP.MKTP.CD   
3                 Arab World  ARB  GDP (current US$)  NY.GDP.MKTP.CD   
4       United Arab Emirates  ARE  GDP (current US$)  NY.GDP.MKTP.CD   

   10404280784  11128050589  11943353288  12676515454  13838577015  \
0          NaN          NaN          NaN          NaN          NaN   
1          NaN          NaN          NaN          NaN          NaN   
2          NaN          NaN          NaN          NaN          NaN   
3          NaN          NaN 

KeyError: "None of [Index(['Country Name', '2001', '2002', '2003', '2004', '2005', '2006', '2007',\n       '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016',\n       '2017', '2018', '2019', '2020', '2021'],\n      dtype='object')] are in the [columns]"