# **Exploration of *Amorpha* spp. GBIF distributions**

## Step 0: Import libraries and set directory paths

In [None]:
import os
# Import libraries
import pygbif.occurrences as occ
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from matplotlib.patheffects import withStroke # For text effects
import matplotlib.cm as cm # For colormaps
import numpy as np
import time # adds delay between gbif api calls
from shapely.geometry import box # To create a bounding box
import rasterio # For raster data
from rasterio.mask import mask
from rasterio.plot import show as raster_show # For easier plotting with transform
from rasterio.sample import sample_gen # For extracting raster values at points
import seaborn as sns # For a nicer heatmap visualization
from scipy.stats import pearsonr # For R-squared calculation (pearsonr returns r, so r^2)
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

# Data directories
# area of interest
aoi_dir = r'data\aoi'
# GBIF observations
obs_dir = r'data\species'
# Bioclim variables
bio_vars_dir = r'data/bioclim/wc2.1_2.5m_bio'
# Elevation raster
elev_file_path = r'data/bioclim/wc2.1_2.5m_elev/wc2.1_2.5m_elev.tif'

# Say when you're ready!
print(f"Current Working Directory: {os.getcwd()}") # Good to check!
print(f"AOI directory will be: {os.path.abspath(aoi_dir)}")
print(f"Observations directory will be: {os.path.abspath(obs_dir)}")

print("Libraries imported!")


## Step 1: Read the species list from the CSV

In [None]:
# --- Step 1: Read the species list from the CSV ---
filename = 'species_gbif.csv'
species_csv_path = os.path.join(obs_dir, filename)
try:
    df_species_list = pd.read_csv(species_csv_path)
    if 'simpSciName' not in df_species_list.columns:
        print(f"Error: Column 'species' not found in {species_csv_path}. Please check the CSV.")
        species_list = []
    else:
        # Ensure we only get actual species names, not the header if it was misread.
        # Filter out any row where 'species' column might be the header 'species' itself.
        df_species_list_filtered = df_species_list[df_species_list['simpSciName'] != 'simpSciName']
        species_list = df_species_list_filtered['simpSciName'].dropna().unique().tolist()
    
    if not species_list: # If list is empty after filtering
        print(f"No valid species names found in {species_csv_path} after filtering. Please check the CSV content.")
    else:
        print(f"Found {len(species_list)} unique species in the CSV: {species_list}")

except FileNotFoundError:
    print(f"Error: Species CSV file not found at '{species_csv_path}'.")
    species_list = []
except Exception as e:
    print(f"Error reading species CSV file '{species_csv_path}': {e}")
    species_list = []

## Step 2: GBIF API Query
GBIF occurence data will either be loaded in from a previous query csv or queried fresh.

In [None]:
# --- Step 2: Fetch or Load Occurrence Data ---
output_csv_filename = 'df_all_occurrences.csv'
output_csv_path = os.path.join(obs_dir, output_csv_filename)

# Check if the processed CSV file already exists
if os.path.exists(output_csv_path):
    print(f"Found existing processed data file: '{output_csv_path}'. Loading data from CSV...")
    try:
        df_all_occurrences = pd.read_csv(output_csv_path)
        print(f"Successfully loaded {len(df_all_occurrences)} records from CSV.")
        if not df_all_occurrences.empty:
            print("Sample of loaded data (first 5 rows):")
            # Display some relevant columns, adjust as needed
            display_cols = ['scientificName', 'decimalLatitude', 'decimalLongitude', 'key', 'issues']
            existing_cols = [col for col in display_cols if col in df_all_occurrences.columns]
            print(df_all_occurrences[existing_cols].head())
    except Exception as e:
        print(f"Error loading data from '{output_csv_path}': {e}")
        print("Proceeding to fetch data from GBIF.")
        df_all_occurrences = pd.DataFrame() # Ensure it's an empty DataFrame
else:
    print(f"Processed data file not found at '{output_csv_path}'. Fetching data from GBIF...")
    all_occurrences_list = []
    
    # Ensure species_list is defined from Step 1
    # If not, you might need to load it or define it here. For example:
    # species_list = [...] # or load from species_gbif.csv as in your Step 1

    if not 'species_list' in locals() or not species_list:
        print("Error: 'species_list' is not defined or is empty. Cannot fetch GBIF data.")
        print("Please ensure Step 1 (reading species list) has been executed successfully.")
        df_all_occurrences = pd.DataFrame() # Initialize as empty
    else:
        print("\nFetching GBIF occurrence data for each species...")
        PAGE_LIMIT = 300 # Max records per page GBIF allows

        for species_name in species_list:
            if not species_name or pd.isna(species_name):
                print(f"Skipping invalid species name: {species_name}")
                continue

            print(f"\nSearching for all occurrences of '{species_name}' in North America with coordinates...")
            
            current_offset = 0
            species_records = []
            records_fetched_for_species = 0

            while True:
                params = {
                    'scientificName': species_name,
                    'hasCoordinate': True,
                    'continent': 'north_america',
                    'limit': PAGE_LIMIT,
                    'offset': current_offset
                }
                
                try:
                    response = occ.search(**params)
                    results_batch = response.get('results', [])
                    
                    if results_batch:
                        species_records.extend(results_batch)
                        records_fetched_for_species += len(results_batch)
                        print(f"  Fetched {len(results_batch)} records (offset {current_offset}). Total for '{species_name}': {records_fetched_for_species}")
                    
                    if response.get('endOfRecords', True) or not results_batch:
                        print(f"  End of records for '{species_name}'. Total found: {records_fetched_for_species}")
                        break 
                    
                    current_offset += PAGE_LIMIT
                    time.sleep(0.2)

                except Exception as e:
                    print(f"  An error occurred while fetching data for '{species_name}' (offset {current_offset}): {e}")
                    print("  Continuing to the next species or offset if possible, but this batch might be lost.")
                    break 

            if species_records:
                df_single_species = pd.DataFrame(species_records)
                all_occurrences_list.append(df_single_species)
            else:
                if records_fetched_for_species == 0:
                     print(f"No occurrences ultimately processed for '{species_name}' with the specified criteria after pagination attempt.")
            
            time.sleep(0.5)

        # Concatenate all occurrences into a single DataFrame
        if all_occurrences_list:
            df_all_occurrences = pd.concat(all_occurrences_list, ignore_index=True)
            print(f"\nTotal occurrences fetched for all species: {len(df_all_occurrences)}")
            if not df_all_occurrences.empty:
                print("Sample of combined data (first 5 rows):")
                print(df_all_occurrences[['scientificName', 'decimalLatitude', 'decimalLongitude', 'key']].head())
                
                # --- Write to CSV ---
                try:
                    # Ensure the directory exists before writing
                    os.makedirs(obs_dir, exist_ok=True)
                    df_all_occurrences.to_csv(output_csv_path, index=False)
                    print(f"\nSuccessfully saved all occurrences to '{output_csv_path}'")
                except Exception as e:
                    print(f"\nError saving data to '{output_csv_path}': {e}")
            else:
                print("\nNo occurrences were fetched, so no CSV file was saved.")
        else:
            print("\nNo occurrences were fetched for any species. No CSV file was saved.")
            df_all_occurrences = pd.DataFrame() # Ensure it's an empty DataFrame if nothing was fetched

# Final check and display if df_all_occurrences is still empty
if df_all_occurrences.empty:
    print("\nNo occurrence data available (either not found in CSV or not fetched).")


## Step 3a: Check on GBIF issues

In [None]:
# --- Step 3a: Check for and handle GBIF issues ---
if 'df_all_occurrences' not in globals() or df_all_occurrences.empty:
    print("Error: The DataFrame 'df_all_occurrences' is not defined or is empty.")
    print("Please ensure it's populated with the concatenated GBIF data before running this analysis.")
else:
    print(f"Total records in df_all_occurrences: {len(df_all_occurrences)}\n")

    if 'issues' not in df_all_occurrences.columns:
        print("The 'issues' column was not found in the DataFrame.")
        print("This column is expected from the pygbif download and contains GBIF data quality flags.")
        print(f"Available columns: {df_all_occurrences.columns.tolist()}")
    else:
        print("--- Debugging 'issues' column ---")
        non_null_issues = df_all_occurrences['issues'].dropna()
        if non_null_issues.empty:
            print("Debug: The 'issues' column contains only NaN values.")
        else:
            print(f"Debug: First 5 non-null entries in 'issues' column and their types:")
            for i in range(min(5, len(non_null_issues))):
                entry = non_null_issues.iloc[i]
                print(f"  Entry {i+1}: {entry} (Type: {type(entry)})")
        print("--- End Debugging ---")


        # Extract all individual issue strings from the 'issues' column
        all_individual_issues = []
        for index, row in df_all_occurrences.iterrows(): # Consider using .itertuples() for minor speedup if needed
            issue_list_for_row = row['issues']
            
            # <<< POTENTIAL FIX AREA STARTS HERE >>>
            # If 'issues' are strings like "['ISSUE1', 'ISSUE2']", convert them
            if isinstance(issue_list_for_row, str):
                try:
                    # Attempt to evaluate the string as a Python literal (list)
                    import ast
                    evaluated_list = ast.literal_eval(issue_list_for_row)
                    if isinstance(evaluated_list, list):
                        issue_list_for_row = evaluated_list
                    else:
                        # It was a string but didn't evaluate to a list
                        # Might be a single issue string not in list format, or something else.
                        # For now, if it's not a list after eval, we'll skip it, 
                        # or you could decide to wrap single issue strings into a list:
                        # if isinstance(evaluated_list, str): issue_list_for_row = [evaluated_list] else: issue_list_for_row = []
                        pass # Keep as original string, will be skipped by 'isinstance(list)' below
                except (ValueError, SyntaxError):
                    # Not a string representation of a list (e.g., just "ZERO_COORDINATE")
                    # If you expect single issues as plain strings, you might want to wrap them:
                    # issue_list_for_row = [issue_list_for_row] 
                    # For now, we'll assume GBIF issues are list-like if stringified
                    pass # Keep as original string, will be skipped 
            # <<< POTENTIAL FIX AREA ENDS HERE >>>

            if isinstance(issue_list_for_row, list): # Check if it's a list
                all_individual_issues.extend(issue_list_for_row)
            # If issue_list_for_row is NaN or not a list (even after trying to convert string), it's skipped.

        if not all_individual_issues:
            print("No issues were successfully extracted into 'all_individual_issues'.")
            print("This could be due to the 'issues' column format (e.g., strings not lists, all NaNs, or unexpected structure).")
            print("Please check the debug output above for the format of entries in the 'issues' column.")
        else:
            all_issues_series = pd.Series(all_individual_issues)
            issue_counts = all_issues_series.value_counts()
            
            print("\nCounts of records per GBIF Issue Category:") # Added newline for better spacing
            # ... (rest of your original code for printing counts)
            print("------------------------------------------------------------------------------------")
            print("Geospatial Issues (in approximate order from the GBIF blog post):")

            ORDERED_GEOSPATIAL_ISSUES = {
                "ZERO_COORDINATE": "Zero coordinate (0,0 - often null)",
                "COUNTRY_COORDINATE_MISMATCH": "Country coordinate mismatch",
                "COORDINATE_INVALID": "Coordinate invalid (uninterpretable)",
                "COORDINATE_OUT_OF_RANGE": "Coordinate out of range (-90/90, -180/180)",
                "GEODETIC_DATUM_ASSUMED_WGS84": "Geodetic datum assumed WGS84 (Info: datum was null)",
                "GEODETIC_DATUM_INVALID": "Geodetic datum invalid",
                "COUNTRY_MISMATCH": "Country mismatch (interpreted name vs. code)",
                "COUNTRY_DERIVED_FROM_COORDINATES": "Country derived from coordinates (Info: country was null)",
                "COUNTRY_INVALID": "Country invalid (uninterpretable name/code)",
                "CONTINENT_INVALID": "Continent invalid",
                "COORDINATE_ROUNDED": "Coordinate rounded by GBIF (Info: standard processing)",
                "COORDINATE_REPROJECTED": "Coordinate reprojected to WGS84 (Info: successful)",
                "COORDINATE_REPROJECTION_SUSPICIOUS": "Coordinate reprojection suspicious (large shift)",
                "COORDINATE_REPROJECTION_FAILED": "Coordinate reprojection failed",
                "COORDINATE_UNCERTAINTY_METERS_INVALID": "Coordinate uncertainty meters invalid",
                "COORDINATE_PRECISION_INVALID": "Coordinate precision invalid",
                "PRESUMED_NEGATED_LONGITUDE": "Presumed negated longitude (to match country)",
                "PRESUMED_NEGATED_LATITUDE": "Presumed negated latitude (to match country)"
            }
            
            issues_printed_explicitly = set()

            for issue_code, description in ORDERED_GEOSPATIAL_ISSUES.items():
                count = issue_counts.get(issue_code, 0)
                print(f"- {description} ({issue_code}): {count} records")
                issues_printed_explicitly.add(issue_code)

            print("\nOther Issues Found (including other categories like taxonomic, date-related, etc.):")
            print("----------------------------------------------------------------------------------")
            
            other_issues_found_count = 0
            for issue_code, count in issue_counts.items():
                if issue_code not in issues_printed_explicitly:
                    description = ORDERED_GEOSPATIAL_ISSUES.get(issue_code, issue_code) 
                    print(f"- {description}: {count} records")
                    other_issues_found_count += 1
            
            if other_issues_found_count == 0:
                print("No other significant issue types found beyond the primary geospatial list, or they had zero counts.")
                
            records_with_any_issue = 0
            for i_list in df_all_occurrences['issues']: # Re-check logic here based on actual format
                processed_list = i_list
                if isinstance(i_list, str):
                    try:
                        import ast
                        evaluated = ast.literal_eval(i_list)
                        if isinstance(evaluated, list):
                            processed_list = evaluated
                        else: # was a string but not a list, treat as no issue for this summary
                            processed_list = [] 
                    except (ValueError, SyntaxError): # not a string list
                        processed_list = [] # treat as no issue for this summary
                
                if isinstance(processed_list, list) and len(processed_list) > 0:
                    records_with_any_issue +=1
            
            print(f"\nSummary:")
            print(f"- Total records with at least one listed issue: {records_with_any_issue}")
            records_without_issues = len(df_all_occurrences) - records_with_any_issue
            print(f"- Total records with no listed issues: {records_without_issues}")

## Step 3b: Filter problematic observations
- I think we'll still want to do some more filtering
- Obs that have coords associated with collection institutions
- Obs that are in water
- Obs that are a countries centroid

In [None]:
# --- Step 3b: Filter out problematic records ---
import ast # Make sure ast is imported

if 'df_all_occurrences' not in globals() or df_all_occurrences.empty:
    print("Error: The DataFrame 'df_all_occurrences' is not defined or is empty.")
    print("Please ensure it's populated with the concatenated GBIF data before running this filtering step.")
    df_all_occurrences_filtered = pd.DataFrame() 
else:
    print(f"\nStarting filtering based on critical GBIF issues...")
    print(f"Original number of records: {len(df_all_occurrences)}")

    CRITICAL_ISSUES_TO_FILTER = {
        "ZERO_COORDINATE",
        "COORDINATE_INVALID",
        "COORDINATE_OUT_OF_RANGE",
        "COUNTRY_COORDINATE_MISMATCH", # This was one you asked to filter by earlier
        "COORDINATE_REPROJECTION_FAILED",
        "COORDINATE_UNCERTAINTY_METERS_INVALID"
    }
    print(f"Filtering out records with any of the following issues: {CRITICAL_ISSUES_TO_FILTER}")

    if 'issues' not in df_all_occurrences.columns:
        print("Warning: 'issues' column not found in df_all_occurrences. Cannot perform issue-based filtering.")
        df_all_occurrences_filtered = df_all_occurrences.copy()
    else:
        rows_to_keep_mask = []
        actually_had_issues_to_check = 0 # Counter for debugging

        for entry in df_all_occurrences['issues']: # Iterate directly over the series
            current_issue_list = [] # Default to an empty list (no issues)
            
            if isinstance(entry, str):
                try:
                    evaluated = ast.literal_eval(entry)
                    if isinstance(evaluated, list):
                        current_issue_list = evaluated
                        actually_had_issues_to_check +=1 
                    # If it's a string but not a list (e.g. single issue name), 
                    # ast.literal_eval might return it as a string.
                    # For simplicity here, if not a list after eval, treat as empty for filtering.
                    # Or, you could decide: if isinstance(evaluated, str): current_issue_list = [evaluated]

                except (ValueError, SyntaxError):
                    # String was not a valid Python literal representation of a list
                    # print(f"Debug: Could not parse string: {entry}") # Optional debug
                    pass # current_issue_list remains empty
            elif isinstance(entry, list): # Already a list (e.g., if data not from CSV this run)
                current_issue_list = entry
                actually_had_issues_to_check +=1
            # If entry is NaN (float type usually), it's not a string or list, current_issue_list remains empty.

            # Now, current_issue_list is either the actual list of issues or an empty list.
            if not CRITICAL_ISSUES_TO_FILTER.intersection(set(current_issue_list)):
                rows_to_keep_mask.append(True)  # Keep if no critical issues found in this record's list
            else:
                rows_to_keep_mask.append(False) # Exclude if a critical issue was found

        print(f"Debug: Number of rows where an actual list of issues was processed: {actually_had_issues_to_check}")
        df_all_occurrences_filtered = df_all_occurrences[rows_to_keep_mask].copy()

        num_records_removed = len(df_all_occurrences) - len(df_all_occurrences_filtered)
        print(f"Number of records removed due to critical issues: {num_records_removed}")
        print(f"Number of records remaining after filtering: {len(df_all_occurrences_filtered)}")


## Step 4: Provide simplified scientific names based on speciesKey
Species exception handling for *Amorpha herbacea var. floridana* due to it being considered a variety of *Amorpha herbacea* in GBIF.

In [None]:
# --- Step 4: Add/update simplified scientific names based on BNM provided list ---

if 'df_all_occurrences_filtered' not in globals() or df_all_occurrences_filtered.empty:
    print("Error: The DataFrame 'df_all_occurrences_filtered' is not defined or is empty.")
    print("Please ensure it's populated before attempting to add/update simplified scientific names.")
elif 'df_species_list' not in globals() or df_species_list.empty:
    print("Error: The DataFrame 'df_species_list' (your species map) is not defined or is empty.")
    print("Please ensure Step 1 (loading species list/map) has been executed successfully.")
else:
    print(f"\nAttempting to add/update 'simpSciName' in df_all_occurrences_filtered using the pre-loaded 'df_species_list' and specific rules...")
    
    # df_species_map now refers to your already loaded df_species_list
    df_species_map = df_species_list.copy() # Use a copy to avoid modifying the original df_species_list

    # --- Part 1: Merge based on speciesKey from df_species_map (which is df_species_list) ---
    if 'speciesKey' not in df_species_map.columns or 'simpSciName' not in df_species_map.columns:
        print(f"Error: The pre-loaded 'df_species_list' (used as map) must contain 'speciesKey' and 'simpSciName' columns.")
        print(f"Available columns in 'df_species_list': {df_species_map.columns.tolist()}")
    elif 'speciesKey' not in df_all_occurrences_filtered.columns:
        print(f"Error: 'speciesKey' column not found in df_all_occurrences_filtered.")
        print(f"Available columns in df_all_occurrences_filtered: {df_all_occurrences_filtered.columns.tolist()}")
    else:
        original_occurrence_keys_type = df_all_occurrences_filtered['speciesKey'].dtype
        original_map_keys_type = df_species_map['speciesKey'].dtype
        
        temp_occurrences_df = df_all_occurrences_filtered.copy()
        # temp_map_df is already a copy (df_species_map)

        try:
            temp_occurrences_df['speciesKey_for_merge'] = pd.to_numeric(temp_occurrences_df['speciesKey'], errors='coerce').astype('Int64')
            df_species_map['speciesKey_for_merge'] = pd.to_numeric(df_species_map['speciesKey'], errors='coerce').astype('Int64')
            
            temp_occurrences_df.dropna(subset=['speciesKey_for_merge'], inplace=True)
            df_species_map.dropna(subset=['speciesKey_for_merge'], inplace=True)
            print(f"  Converted 'speciesKey' in occurrences data (was {original_occurrence_keys_type}) and mapping file (was {original_map_keys_type}) to 'Int64' for merging.")
        except Exception as e_conv:
            print(f"  Warning: Error during 'speciesKey' type conversion: {e_conv}. Merge might be inaccurate. Using original types.")
            temp_occurrences_df['speciesKey_for_merge'] = temp_occurrences_df['speciesKey']
            df_species_map['speciesKey_for_merge'] = df_species_map['speciesKey']

        original_simp_sci_name_exists = 'simpSciName' in temp_occurrences_df.columns
        if original_simp_sci_name_exists:
             temp_occurrences_df.rename(columns={'simpSciName': 'simpSciName_before_key_merge'}, inplace=True)
             print("  Temporarily renamed existing 'simpSciName' to 'simpSciName_before_key_merge'.")

        df_all_occurrences_filtered_merged = pd.merge(
            temp_occurrences_df,
            df_species_map[['speciesKey_for_merge', 'simpSciName']], # Use df_species_map here
            on='speciesKey_for_merge',
            how='left'
        )
        
        if 'simpSciName_before_key_merge' in df_all_occurrences_filtered_merged.columns:
            df_all_occurrences_filtered_merged['simpSciName'] = df_all_occurrences_filtered_merged['simpSciName'].fillna(
                df_all_occurrences_filtered_merged['simpSciName_before_key_merge']
            )
            df_all_occurrences_filtered_merged.drop(columns=['simpSciName_before_key_merge'], inplace=True)
        
        df_all_occurrences_filtered_merged.drop(columns=['speciesKey_for_merge'], inplace=True, errors='ignore')
        df_all_occurrences_filtered = df_all_occurrences_filtered_merged

        num_mapped = df_all_occurrences_filtered['simpSciName'].notna().sum()
        num_total = len(df_all_occurrences_filtered)
        print(f"  Merge based on speciesKey complete. 'simpSciName' column processed.")
        print(f"  {num_mapped} out of {num_total} records now have a non-null 'simpSciName'.")

    # --- Part 2: Apply specific rule for 'Amorpha herbacea var. floridana (Rydb.) Wilbur' ---
    target_scientific_name = 'Amorpha herbacea var. floridana (Rydb.) Wilbur'
    new_simp_sci_name_for_target = 'Amorpha herbacea var. floridana'

    if 'scientificName' not in df_all_occurrences_filtered.columns:
        print(f"\nWarning: 'scientificName' column not found in df_all_occurrences_filtered. Cannot apply specific rule for '{target_scientific_name}'.")
    else:
        mask_specific_species = (df_all_occurrences_filtered['scientificName'] == target_scientific_name)
        num_specific_records = mask_specific_species.sum()

        if num_specific_records > 0:
            if 'simpSciName' not in df_all_occurrences_filtered.columns:
                df_all_occurrences_filtered['simpSciName'] = pd.NA 
            df_all_occurrences_filtered.loc[mask_specific_species, 'simpSciName'] = new_simp_sci_name_for_target
            print(f"\nApplied specific rule: Set 'simpSciName' to '{new_simp_sci_name_for_target}' for {num_specific_records} records where scientificName was '{target_scientific_name}'.")
        else:
            print(f"\nSpecific rule: No records found with scientificName '{target_scientific_name}'. No changes made by this rule.")

    print("\nFinal first 5 rows of df_all_occurrences_filtered with 'simpSciName' column after all processing:")
    if 'simpSciName' in df_all_occurrences_filtered.columns:
        print(df_all_occurrences_filtered[['scientificName', 'speciesKey', 'simpSciName']].head())
    else:
        print(df_all_occurrences_filtered[['scientificName', 'speciesKey']].head(), "(simpSciName column not present)")


## Step 5/6/7: Basic plotting data processing

In [None]:
# --- Step 5: Data Cleaning (from df_all_occurrences_filtered) and Initial GeoDataFrame creation (unprojected) ---
gdf_all_occurrences_unprojected = gpd.GeoDataFrame() # Initialize an empty GeoDataFrame

if 'df_all_occurrences_filtered' not in globals() or df_all_occurrences_filtered.empty:
    print("Error: df_all_occurrences_filtered is not defined or is empty. Cannot proceed with mapping preparation.")
else:
    print(f"\nProcessing df_all_occurrences_filtered (contains {len(df_all_occurrences_filtered)} records) for mapping...")
    if 'decimalLatitude' in df_all_occurrences_filtered.columns and 'decimalLongitude' in df_all_occurrences_filtered.columns:
        # Create df_all_occurrences_clean from df_all_occurrences_filtered
        df_all_occurrences_clean = df_all_occurrences_filtered.dropna(
            subset=['decimalLatitude', 'decimalLongitude']
        ).copy() # .copy() is good practice
        
        if not df_all_occurrences_clean.empty:
            print(f"  {len(df_all_occurrences_clean)} records remaining after dropping rows with NA coordinates.")
            try:
                gdf_all_occurrences_unprojected = gpd.GeoDataFrame(
                    df_all_occurrences_clean,
                    geometry=gpd.points_from_xy(
                        df_all_occurrences_clean.decimalLongitude, 
                        df_all_occurrences_clean.decimalLatitude
                    ),
                    crs="EPSG:4326" 
                )
                print("  Successfully created gdf_all_occurrences_unprojected from cleaned, filtered data.")
            except Exception as e_gdf:
                print(f"  Error creating GeoDataFrame from cleaned data: {e_gdf}")
        else:
            print("  DataFrame is empty after dropping NA coordinates from df_all_occurrences_filtered.")
    else:
        print("  Error: df_all_occurrences_filtered does not have 'decimalLatitude' or 'decimalLongitude' columns.")

if gdf_all_occurrences_unprojected.empty:
     print("  gdf_all_occurrences_unprojected is empty. Further mapping steps might not produce output.")


# --- Step 6: Prepare Base Map (North America, Orthographic Projection) ---
# This part remains largely the same, as it prepares the background map independently 
# of the occurrence data, other than needing to be projected to the same target CRS.
north_america_map_proj = None
ortho_crs = "+proj=ortho +lat_0=45 +lon_0=-100 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs +type=crs"
world_shapefile_path = 'data/AOI/ne_110m_admin_0_countries/ne_110m_admin_0_countries.shp'

try:
    print(f"\nLoading base map from '{world_shapefile_path}'...")
    world = gpd.read_file(world_shapefile_path)
    if world.crs is None:
        print("  Warning: Shapefile has no CRS defined. Assuming EPSG:4326 (WGS84).")
        world = world.set_crs("EPSG:4326", allow_override=True)

    country_name_column = None
    possible_name_cols = ['NAME', 'ADMIN', 'SOVEREIGNT', 'SOV_A3', 'ADM0_A3', 'NAME_EN', 'NAME_LONG', 'admin', 'name_long', 'name']
    for col in possible_name_cols:
        if col in world.columns:
            country_name_column = col
            break
    
    north_america_map_unprojected = world # Default to whole world if filtering fails
    if country_name_column:
        print(f"  Using column '{country_name_column}' from shapefile for country names for base map.")
        north_america_countries_filter = ['United States of America', 'Canada', 'Mexico', 'United States']
        filtered_map = world[world[country_name_column].isin(north_america_countries_filter)]
        if not filtered_map.empty:
            north_america_map_unprojected = filtered_map
            print(f"  Filtered base map to {len(north_america_map_unprojected)} North American polygons.")
        else:
            print(f"  Warning: No countries matched the filter for base map. Using all {len(world)} countries from shapefile.")
    else:
        print("  Warning: Could not find a standard country name column for base map. Using all countries from shapefile.")

    if north_america_map_unprojected is not None and not north_america_map_unprojected.empty and north_america_map_unprojected.crs:
        north_america_map_proj = north_america_map_unprojected.to_crs(ortho_crs)
        print(f"  Base map projected successfully to Orthographic CRS (centered at lat {ortho_crs.split('lat_0=')[1].split(' ')[0]}, lon {ortho_crs.split('lon_0=')[1].split(' ')[0]}).")
    elif north_america_map_unprojected is None or north_america_map_unprojected.empty:
        print("  Could not project base map because the unprojected map is empty or None.")
    else: # north_america_map_unprojected exists but has no CRS
         print("  Could not project base map because it has no CRS defined.")

except Exception as e:
    print(f"Error processing base map shapefile: {e}")


# --- Step 7: Generate Species Colors (using 'simpSciName' from gdf_all_occurrences_unprojected) ---
unique_species_in_data = [] # Initialize
species_colors = {} # Initialize

if not gdf_all_occurrences_unprojected.empty:
    if 'simpSciName' in gdf_all_occurrences_unprojected.columns:
        unique_species_in_data = gdf_all_occurrences_unprojected['simpSciName'].dropna().unique().tolist()
        print(f"\nFound {len(unique_species_in_data)} unique 'simpSciName' values in the data for color generation.")
        
        if unique_species_in_data:
            num_unique_species = len(unique_species_in_data)
            # Using 'tab20' colormap. For more species, colors might repeat.
            # Consider 'viridis' or other perceptually uniform colormaps if distinctness for many categories is key.
            if num_unique_species <= 20:
                colors_cmap = cm.get_cmap('tab20', num_unique_species)
                species_colors = {species: colors_cmap(i) for i, species in enumerate(unique_species_in_data)}
            else: 
                print(f"  Warning: More than 20 unique simplified species names ({num_unique_species}). Colors will repeat from 'tab20'.")
                tab20_colors_cmap = cm.get_cmap('tab20', 20) # Get 20 distinct colors
                species_colors = {species: tab20_colors_cmap(i % 20) for i, species in enumerate(unique_species_in_data)}
            print(f"  Generated colors for {len(species_colors)} simplified species names.")
        else:
            print("  No unique 'simpSciName' values found after dropping NA, so no species colors generated.")
    else:
        print("  'simpSciName' column not found in gdf_all_occurrences_unprojected. Cannot generate species-specific colors based on it.")
        print("  Consider using 'scientificName' or another column if 'simpSciName' was not added correctly.")
else:
    print("\ngdf_all_occurrences_unprojected is empty. Cannot generate species colors.")

## Step 9: Plot individual species distributions

In [None]:
# --- Step 9: Plot individual maps ---

# 1. Occurrence Data (gdf_all_occurrences_unprojected in EPSG:4326)
if 'gdf_all_occurrences_unprojected' not in globals() or \
   ('empty' in dir(gdf_all_occurrences_unprojected) and gdf_all_occurrences_unprojected.empty):
    print("Setup: 'gdf_all_occurrences_unprojected' not found/empty. Creating placeholder.")
    placeholder_occurrence_data = pd.DataFrame({
        'decimalLatitude': [34.0, 36.5, 38.0, 33.5, 40.0, 42.5, 28.0, 30.0, 35.0], 
        'decimalLongitude': [-118.2, -120.0, -119.5, -117.0, -75.0, -73.0, -82.0, -81.0, -95.0],
        'simpSciName': ['Amorpha californica', 'Amorpha californica', 'Amorpha californica', 'Amorpha californica', 
                        'Species B (East)', 'Species B (East)', 'Species C (South)', 'Species C (South)', 'Species D (Central)'],
        'key': range(9)
    })
    gdf_all_occurrences_unprojected = gpd.GeoDataFrame(
        placeholder_occurrence_data,
        geometry=gpd.points_from_xy(placeholder_occurrence_data.decimalLongitude, 
                                     placeholder_occurrence_data.decimalLatitude),
        crs="EPSG:4326"
    )
else:
    print("Setup: Using existing 'gdf_all_occurrences_unprojected'.")
    if not isinstance(gdf_all_occurrences_unprojected, gpd.GeoDataFrame):
        try:
            gdf_all_occurrences_unprojected = gpd.GeoDataFrame(
                gdf_all_occurrences_unprojected,
                geometry=gpd.points_from_xy(gdf_all_occurrences_unprojected.decimalLongitude, gdf_all_occurrences_unprojected.decimalLatitude),
                crs="EPSG:4326")
        except Exception as e: print(f"Error converting to GDF: {e}")
    elif gdf_all_occurrences_unprojected.crs is None or gdf_all_occurrences_unprojected.crs.to_string() != "EPSG:4326":
        gdf_all_occurrences_unprojected = gdf_all_occurrences_unprojected.to_crs("EPSG:4326")

# 2. North America Basemap (north_america_map_unprojected in EPSG:4326)
if 'north_america_map_unprojected' not in globals() or \
   north_america_map_unprojected is None or \
   ('empty' in dir(north_america_map_unprojected) and north_america_map_unprojected.empty):
    print("Setup: 'north_america_map_unprojected' not found/empty.")
    shapefile_path_actual = 'data/AOI/ne_110m_admin_0_countries/ne_110m_admin_0_countries.shp' 
    try:
        if os.path.exists(shapefile_path_actual):
            world_temp = gpd.read_file(shapefile_path_actual)
            if world_temp.crs is None: world_temp = world_temp.set_crs("EPSG:4326", allow_override=True)
            north_america_map_unprojected = world_temp[world_temp['ADMIN'].isin(['United States of America', 'Canada', 'Mexico'])]
            if north_america_map_unprojected.empty: 
                 north_america_map_unprojected = world_temp[world_temp['ADMIN'].isin(['United States', 'Canada', 'Mexico'])] # Try alternate name
            if north_america_map_unprojected.empty: 
                north_america_map_unprojected = gpd.GeoDataFrame({'name': ['Placeholder NA Base Map']}, geometry=[box(-179, 10, -50, 85)], crs="EPSG:4326")
            else: north_america_map_unprojected = north_america_map_unprojected.to_crs("EPSG:4326")
        else: raise FileNotFoundError 
    except Exception as e:
        print(f"Setup Error for basemap: {e}. Using placeholder.")
        north_america_map_unprojected = gpd.GeoDataFrame({'name': ['Placeholder NA Base Map']}, geometry=[box(-179, 10, -50, 85)], crs="EPSG:4326")
else:
    print("Setup: Using existing 'north_america_map_unprojected'.")
    if north_america_map_unprojected.crs is None or north_america_map_unprojected.crs.to_string() != "EPSG:4326":
        north_america_map_unprojected = north_america_map_unprojected.to_crs("EPSG:4326")

# 3. Unique Species List & 4. Colors
if 'unique_species_in_data' not in globals() or not unique_species_in_data:
    if not gdf_all_occurrences_unprojected.empty and 'simpSciName' in gdf_all_occurrences_unprojected.columns:
        unique_species_in_data = sorted(list(gdf_all_occurrences_unprojected['simpSciName'].dropna().unique()))
    else: unique_species_in_data = ['Amorpha californica', 'Species B (East)', 'Species C (South)', 'Species D (Central)'] 
if 'species_colors' not in globals() or not species_colors or not all(s in species_colors for s in unique_species_in_data):
    try: # Use seaborn for nicer palettes if available
        import seaborn as sns; palette_gen = sns.color_palette("husl", len(unique_species_in_data) if unique_species_in_data else 1)
    except ImportError: cmap_gen = plt.cm.get_cmap('tab10', len(unique_species_in_data) if unique_species_in_data else 1); palette_gen = [cmap_gen(i) for i in range(len(unique_species_in_data) if unique_species_in_data else 1)]
    species_colors = {name: palette_gen[i] for i, name in enumerate(unique_species_in_data)}
# --- End of Setup ---


# --- Determine Overall Map Extent for "Flat" Maps (EPSG:4326 degrees) ---
overall_map_lim_min_lon = -125.0 
overall_map_lim_max_lon = -65.0
overall_map_lim_min_lat = 20.0
overall_map_lim_max_lat = 55.0 # Default values

if not gdf_all_occurrences_unprojected.empty:
    # Get bounds from the geometry column of the GeoDataFrame
    min_lon_all_data, min_lat_all_data, max_lon_all_data, max_lat_all_data = gdf_all_occurrences_unprojected.total_bounds
    
    lon_span = max_lon_all_data - min_lon_all_data
    lat_span = max_lat_all_data - min_lat_all_data

    # Add a 10% buffer, with a minimum of 1 degree
    lon_buffer_overall = max(lon_span * 0.10, 1.0) 
    lat_buffer_overall = max(lat_span * 0.10, 1.0) 
    
    overall_map_lim_min_lon = min_lon_all_data - lon_buffer_overall
    overall_map_lim_max_lon = max_lon_all_data + lon_buffer_overall
    overall_map_lim_min_lat = min_lat_all_data - lat_buffer_overall
    overall_map_lim_max_lat = max_lat_all_data + lat_buffer_overall
    print(f"Calculated Overall Map Extent (EPSG:4326): Lon [{overall_map_lim_min_lon:.2f}, {overall_map_lim_max_lon:.2f}], Lat [{overall_map_lim_min_lat:.2f}, {overall_map_lim_max_lat:.2f}]")
else:
    print(f"Warning: Occurrence data is empty. Using default map extent: Lon [{overall_map_lim_min_lon:.2f}, {overall_map_lim_max_lon:.2f}], Lat [{overall_map_lim_min_lat:.2f}, {overall_map_lim_max_lat:.2f}]")

# Define Ticks based on the new overall extent (optional, but good for consistency)
lon_tick_step = 10.0 # Adjust step as needed
min_lon_tick = np.ceil(overall_map_lim_min_lon / lon_tick_step) * lon_tick_step
max_lon_tick = np.floor(overall_map_lim_max_lon / lon_tick_step) * lon_tick_step
lon_ticks = np.arange(min_lon_tick, max_lon_tick + lon_tick_step * 0.1, lon_tick_step)

lat_tick_step = 5.0 # Adjust step as needed
min_lat_tick = np.ceil(overall_map_lim_min_lat / lat_tick_step) * lat_tick_step
max_lat_tick = np.floor(overall_map_lim_max_lat / lat_tick_step) * lat_tick_step
lat_ticks = np.arange(min_lat_tick, max_lat_tick + lat_tick_step * 0.1, lat_tick_step)
# --- End of Extent and Tick Calculation ---


# --- Plotting Individual "Flat" Maps with Common Extent ---
if gdf_all_occurrences_unprojected.empty or north_america_map_unprojected.empty or not species_colors or not unique_species_in_data:
    print("Skipping flat maps due to missing data, basemap, colors, or species list after setup.")
else:
    print("\nGenerating individual 'flat' maps with common extent for each species...")
    new_marker_size = 30

    for species_name_to_plot in unique_species_in_data:
        print(f"  Plotting map for: {species_name_to_plot}")
        
        gdf_species_subset = gdf_all_occurrences_unprojected[
            gdf_all_occurrences_unprojected['simpSciName'] == species_name_to_plot
        ].copy() # Use .copy()

        # Get the count of observations for the current species
        observation_count = len(gdf_species_subset)
        legend_label = f"{species_name_to_plot} (n={observation_count})"

        fig, ax = plt.subplots(1, 1, figsize=(12, 9)) 
        
        north_america_map_unprojected.plot(ax=ax, color='lightgray', edgecolor='black', linewidth=0.5, zorder=1)
        
        if not gdf_species_subset.empty:
            gdf_species_subset.plot(
                ax=ax, marker='o', 
                color=species_colors.get(species_name_to_plot, 'purple'), 
                markersize=new_marker_size, 
                alpha=0.7, 
                label=legend_label, # Use the new label with count
                zorder=2
            )
        
        ax.set_title(f'Distribution of {species_name_to_plot}', fontsize=14)
        ax.set_xlabel("Longitude (degrees)", fontsize=10)
        ax.set_ylabel("Latitude (degrees)", fontsize=10)
        
        ax.set_xlim(overall_map_lim_min_lon, overall_map_lim_max_lon)
        ax.set_ylim(overall_map_lim_min_lat, overall_map_lim_max_lat)
        ax.set_xticks(lon_ticks)
        ax.set_yticks(lat_ticks)
        
        ax.grid(True, linestyle='--', alpha=0.7)
        if not gdf_species_subset.empty : 
            ax.legend(loc='best', markerscale=0.8)

        plt.show()
print("\nIndividual flat species mapping process complete.")


## Step 10: Plot all *Amorpha* spp. GBIF occurrence data

In [None]:
# --- Step 10: Amoprha spp. distirbution map ---

# Check if essential variables are present
missing_vars_final_map = []
if 'gdf_all_occurrences_unprojected' not in globals() or gdf_all_occurrences_unprojected.empty:
    missing_vars_final_map.append('gdf_all_occurrences_unprojected')
if 'north_america_map_unprojected' not in globals() or north_america_map_unprojected.empty:
    missing_vars_final_map.append('north_america_map_unprojected')
if 'unique_species_in_data' not in globals() or not unique_species_in_data:
    missing_vars_final_map.append('unique_species_in_data')
if 'species_colors' not in globals() or not species_colors:
    missing_vars_final_map.append('species_colors')
if 'overall_map_lim_min_lon' not in globals(): # Check one of the extent vars
    missing_vars_final_map.append('overall map limit variables (e.g., overall_map_lim_min_lon)')
if 'new_marker_size' not in globals():
    print("Warning: 'new_marker_size' not defined, defaulting to 48 for the final map.")
    new_marker_size = 48


if missing_vars_final_map:
    print(f"Skipping final combined map due to missing prerequisite variables: {', '.join(missing_vars_final_map)}")
else:
    print("\nGenerating final 'flat' map with all species and a color key...")

    fig, ax = plt.subplots(1, 1, figsize=(14, 10)) # Adjusted figsize for map with legend

    # 1. Plot the unprojected basemap
    north_america_map_unprojected.plot(ax=ax, color='lightgray', edgecolor='black', linewidth=0.5, zorder=1)

    # 2. Plot each species
    for species_name_to_plot in unique_species_in_data:
        gdf_species_subset = gdf_all_occurrences_unprojected[
            gdf_all_occurrences_unprojected['simpSciName'] == species_name_to_plot
        ]
        
        if not gdf_species_subset.empty:
            gdf_species_subset.plot(
                ax=ax, 
                marker='o', 
                color=species_colors.get(species_name_to_plot, 'grey'), # Use grey as fallback
                markersize=new_marker_size, 
                alpha=0.7, 
                label=species_name_to_plot, # This label is used by ax.legend()
                zorder=2
            )

    # 3. Set map title, labels, extent, and ticks
    ax.set_title('Combined Distribution of All Amorpha Species', fontsize=16)
    ax.set_xlabel("Longitude (degrees)", fontsize=10)
    ax.set_ylabel("Latitude (degrees)", fontsize=10)
    
    ax.set_xlim(overall_map_lim_min_lon, overall_map_lim_max_lon)
    ax.set_ylim(overall_map_lim_min_lat, overall_map_lim_max_lat)
    
    # Ensure lon_ticks and lat_ticks are defined from the previous step's extent calculation
    if 'lon_ticks' in globals() and 'lat_ticks' in globals():
        ax.set_xticks(lon_ticks)
        ax.set_yticks(lat_ticks)
    else:
        print("Warning: lon_ticks or lat_ticks not defined. Using default ticks for the final map.")
        
    ax.grid(True, linestyle='--', alpha=0.7)

    # 4. Add Legend (Color Key)
    # Adjust legend properties for better placement and appearance
    legend = ax.legend(
        title="Species",
        loc='best', # 'best' tries to find a good spot, or specify e.g. 'upper right'
        markerscale=0.8, 
        fontsize='small',
        title_fontsize='medium',
        ncol=1 # Adjust number of columns if many species
    )
    if legend: # Ensure legend was created
        legend.get_frame().set_alpha(0.9) # Make legend background slightly transparent

    # Optional: Set aspect ratio for the flat map
    # ax.set_aspect('equal', adjustable='box') # Uncomment if you want degrees to be visually ~square

    plt.show()
    print("\nFinal combined species map generation complete.")


## Step 11: Observation thinning to reduce spatial autocorrelation

In [None]:
# --- Step 11: Spatial thinning to reduce spatial autocorrelation ---
print("\nPerforming spatial thinning for each species (0.25-degree grid)...")
gdf_thinned_results_list = []
grid_resolution = 0.25 # degrees

if 'gdf_all_occurrences_unprojected' in globals() and not gdf_all_occurrences_unprojected.empty:
    for species_name_to_thin in unique_species_in_data:
        species_gdf_original = gdf_all_occurrences_unprojected[
            gdf_all_occurrences_unprojected['simpSciName'] == species_name_to_thin
        ].copy()

        if species_gdf_original.empty:
            # If a species had no original data, add it as is (it will be empty)
            # or ensure it has the 'thinning_status' column if needed for schema consistency later
            species_gdf_original['thinning_status'] = pd.NA # Or 'kept' if it must be non-null
            gdf_thinned_results_list.append(species_gdf_original)
            continue

        # Calculate grid cell IDs based on geometry
        species_gdf_original['grid_lon_id'] = np.floor(species_gdf_original.geometry.x / grid_resolution)
        species_gdf_original['grid_lat_id'] = np.floor(species_gdf_original.geometry.y / grid_resolution)
        species_gdf_original['grid_cell_unique_id'] = species_gdf_original['grid_lon_id'].astype(str) + '_' + species_gdf_original['grid_lat_id'].astype(str)

        # Randomly sample one point per grid cell for this species
        # Using .index to ensure we are selecting from the original DataFrame's index
        try:
            kept_indices_for_species = species_gdf_original.groupby('grid_cell_unique_id', group_keys=False).apply(lambda x: x.sample(1, random_state=42)).index
        except ValueError: # Happens if a group is empty, though sample(1) on a group of 1 should be fine
            kept_indices_for_species = pd.Index([]) # No points kept if sampling fails (e.g. all groups were problematic)

        species_gdf_original['thinning_status'] = 'dropped' # Default to dropped
        species_gdf_original.loc[kept_indices_for_species, 'thinning_status'] = 'kept'
        
        gdf_thinned_results_list.append(species_gdf_original)

    if gdf_thinned_results_list:
        gdf_all_occurrences_thinned_status = pd.concat(gdf_thinned_results_list).reset_index(drop=True)
        # Ensure it's still a GeoDataFrame if concatenation changed its type (unlikely with GeoPandas concat)
        if not isinstance(gdf_all_occurrences_thinned_status, gpd.GeoDataFrame):
             gdf_all_occurrences_thinned_status = gpd.GeoDataFrame(gdf_all_occurrences_thinned_status, geometry='geometry', crs="EPSG:4326")
        print("Spatial thinning complete.")
    else:
        print("No data to thin or no species found. Creating an empty thinned DataFrame.")
        # Create an empty GDF with the expected columns if gdf_all_occurrences_unprojected was empty or no species
        columns_if_empty = list(gdf_all_occurrences_unprojected.columns) + ['thinning_status'] if 'gdf_all_occurrences_unprojected' in globals() else ['geometry', 'simpSciName', 'thinning_status']
        gdf_all_occurrences_thinned_status = gpd.GeoDataFrame(columns=columns_if_empty, geometry='geometry', crs="EPSG:4326")

else:
    print("Skipping thinning: 'gdf_all_occurrences_unprojected' is not defined or is empty.")
    # Ensure gdf_all_occurrences_thinned_status exists for subsequent plotting code, even if empty
    gdf_all_occurrences_thinned_status = gpd.GeoDataFrame(columns=['geometry', 'simpSciName', 'thinning_status'], geometry='geometry', crs="EPSG:4326")


# --- Plotting Individual "Flat" Maps with Thinned Data & SPECIES-SPECIFIC ZOOM ---
if gdf_all_occurrences_thinned_status.empty or north_america_map_unprojected.empty or not species_colors or not unique_species_in_data:
    print("Skipping thinned maps due to missing thinned data, basemap, colors, or species list.")
else:
    print("\nGenerating individual 'flat' maps with thinned data (zoomed to species extent) for each species...")
    kept_marker_size = 35
    dropped_marker_size = 20
    minimum_extent_buffer_degrees = 0.5 # Minimum buffer in degrees around data

    for species_name_to_plot in unique_species_in_data:
        print(f"  Plotting thinned map for: {species_name_to_plot}")
        
        species_subset_with_status = gdf_all_occurrences_thinned_status[
            gdf_all_occurrences_thinned_status['simpSciName'] == species_name_to_plot
        ].copy()

        fig, ax = plt.subplots(1, 1, figsize=(10, 8)) # Adjust figsize as needed for potentially varying aspects
        
        # Plot the unprojected basemap
        north_america_map_unprojected.plot(ax=ax, color='lightgray', edgecolor='black', linewidth=0.5, zorder=1)
        
        current_species_color = species_colors.get(species_name_to_plot, 'purple') # Default color

        if species_subset_with_status.empty:
            print(f"    No data (kept or dropped) for {species_name_to_plot} to determine zoom. Using overall extent.")
            # Fallback to overall extent if no points for this species
            ax.set_xlim(overall_map_lim_min_lon, overall_map_lim_max_lon)
            ax.set_ylim(overall_map_lim_min_lat, overall_map_lim_max_lat)
            ax.text(0.5, 0.5, 'No occurrence data for this species', 
                    horizontalalignment='center', verticalalignment='center', 
                    transform=ax.transAxes, fontsize=12, color='gray')
        else:
            gdf_kept = species_subset_with_status[species_subset_with_status['thinning_status'] == 'kept']
            gdf_dropped = species_subset_with_status[species_subset_with_status['thinning_status'] == 'dropped']

            count_kept = len(gdf_kept)
            count_dropped = len(gdf_dropped)
            label_kept = f"Kept (n={count_kept})"
            label_dropped = f"Dropped (n={count_dropped})"

            # Plot dropped points first
            if not gdf_dropped.empty:
                gdf_dropped.plot(
                    ax=ax, marker='o', facecolors='none', edgecolors=current_species_color,
                    linewidth=0.8, markersize=dropped_marker_size, alpha=0.4,
                    label=label_dropped, zorder=2
                )
            
            # Plot kept points on top
            if not gdf_kept.empty:
                gdf_kept.plot(
                    ax=ax, marker='o', color=current_species_color, edgecolor='black',
                    linewidth=0.3, markersize=kept_marker_size, alpha=0.8,
                    label=label_kept, zorder=3
                )

            # Calculate species-specific extent for zooming
            min_lon_species, min_lat_species, max_lon_species, max_lat_species = species_subset_with_status.total_bounds
            
            lon_span_species = max_lon_species - min_lon_species
            lat_span_species = max_lat_species - min_lat_species
            
            # Dynamic buffer: 10% of span, or the minimum_extent_buffer_degrees, whichever is larger
            lon_buffer_species = max(lon_span_species * 0.10, minimum_extent_buffer_degrees)
            lat_buffer_species = max(lat_span_species * 0.10, minimum_extent_buffer_degrees)
            
            # Apply species-specific zoomed extent
            ax.set_xlim(min_lon_species - lon_buffer_species, max_lon_species + lon_buffer_species)
            ax.set_ylim(min_lat_species - lat_buffer_species, max_lat_species + lat_buffer_species)
            
            # Add legend
            handles, labels = ax.get_legend_handles_labels()
            if handles:
                ax.legend(handles, labels, loc='best', title=f"{species_name_to_plot}", markerscale=0.8, fontsize=9, title_fontsize=10)

        ax.set_title(f'Spatially Thinned Distribution of {species_name_to_plot} (0.25° grid, Zoomed)', fontsize=14)
        ax.set_xlabel("Longitude (degrees)", fontsize=10)
        ax.set_ylabel("Latitude (degrees)", fontsize=10)
        
        ax.grid(True, linestyle='--', alpha=0.7) # Grid will adapt to auto-ticks

        plt.show()
print("\nIndividual thinned species mapping process (zoomed) complete.")


## Step 12: Import WorldClim 2.1 data (elevation and 19 BioClim)
- https://www.worldclim.org/data/worldclim21.html


In [None]:
# --- Step 12: Import WorldClim raster data ---
output_cropped_rasters = {} # To store cropped raster data and transforms

# Define the bounding box from your species occurrence data
# Ensure 'gdf_all_occurrences_thinned_status' is loaded and is a GeoDataFrame in EPSG:4326
# For placeholder if the GDF isn't loaded yet:
if 'gdf_all_occurrences_thinned_status' not in globals() or gdf_all_occurrences_thinned_status.empty:
    print("Warning: 'gdf_all_occurrences_thinned_status' not found or empty.")
    print("Using a placeholder extent for North America for demonstration.")
    # Placeholder extent (roughly North America). Replace with your actual data bounds.
    min_lon, min_lat, max_lon, max_lat = -170, 15, -50, 80 
    # Create a placeholder GDF if needed for the code structure to run
    placeholder_occurrence_data = pd.DataFrame({
        'decimalLatitude': [min_lat, max_lat], 
        'decimalLongitude': [min_lon, max_lon],
        'simpSciName': ['PlaceholderSpecies1', 'PlaceholderSpecies2'],
        'thinning_status': ['kept', 'kept']
    })
    gdf_all_occurrences_thinned_status = gpd.GeoDataFrame(
        placeholder_occurrence_data,
        geometry=gpd.points_from_xy(placeholder_occurrence_data.decimalLongitude, 
                                     placeholder_occurrence_data.decimalLatitude),
        crs="EPSG:4326"
    )
else:
    if not isinstance(gdf_all_occurrences_thinned_status, gpd.GeoDataFrame):
         print("Error: gdf_all_occurrences_thinned_status is not a GeoDataFrame.")
         # Handle error appropriately - perhaps load it or stop
    elif gdf_all_occurrences_thinned_status.crs.to_string().upper() != 'EPSG:4326':
        print(f"Warning: gdf_all_occurrences_thinned_status CRS is {gdf_all_occurrences_thinned_status.crs}. WorldClim is EPSG:4326. Ensure they match or transform.")
        # It's crucial that occurrence data CRS matches raster CRS for extraction.
        # WorldClim default is WGS84 (EPSG:4326).
    
    min_lon, min_lat, max_lon, max_lat = gdf_all_occurrences_thinned_status.total_bounds

# Add a buffer to the bounding box (e.g., 0.5 degrees)
buffer = 0.5 
bbox_for_cropping = box(min_lon - buffer, min_lat - buffer, max_lon + buffer, max_lat + buffer)
print(f"Bounding box for cropping (with {buffer}° buffer): {bbox_for_cropping.bounds}")

# --- Function to load, crop, and return raster data ---
def load_and_crop_raster(raster_path, geometry_to_crop, variable_name="raster"):
    """
    Loads a raster, crops it to the given geometry, and returns the data and transform.
    """
    if not os.path.exists(raster_path):
        print(f"Error: Raster file not found at {raster_path}")
        return None, None
        
    try:
        with rasterio.open(raster_path) as src:
            # Ensure the geometry is in the same CRS as the raster
            # WorldClim is typically EPSG:4326 (WGS84). Assume geometry_to_crop is also.
            if src.crs.to_string().upper() != 'EPSG:4326': # A basic check
                 print(f"Warning: Raster {variable_name} CRS ({src.crs}) may not be EPSG:4326. Cropping assumes matching CRSs.")

            cropped_data, cropped_transform = mask(src, [geometry_to_crop], crop=True)
            
            # Handle NoData: rasterio's mask sets nodata areas outside the geometry to 0 if crop=True.
            # If the original raster has a specific NoData value, we might want to preserve it as NaN.
            # For WorldClim, NoData is often a large negative number like -3.4e+38 or -9999.
            # The `mask` function with `crop=True` should handle this by not including those areas,
            # but within the valid crop, original NoData values need to be handled.
            
            nodata_val = src.nodata
            if nodata_val is not None:
                # Masked areas outside the shape are 0. Original nodata inside the shape needs to be NaN.
                # Create a boolean mask for original NoData values
                is_nodata = (cropped_data == nodata_val)
                # Also treat the fill value from mask (0 for areas outside geometry if not originally 0) as NaN if appropriate
                # This depends on whether 0 is a valid data value. For WorldClim temps, 0 C is valid.
                # For precipitation, 0 mm is valid. So only original NoData should become NaN.
                cropped_data = np.where(is_nodata, np.nan, cropped_data.astype(float)) # Convert to float for NaNs
            else:
                cropped_data = cropped_data.astype(float) # Ensure float for consistency

            # Remove the first dimension if it's 1 (mask returns a 3D array)
            if cropped_data.shape[0] == 1:
                cropped_data = cropped_data.squeeze(axis=0)

            print(f"  Successfully loaded and cropped: {variable_name} ({os.path.basename(raster_path)})")
            print(f"    Cropped shape: {cropped_data.shape}")
            print(f"    Cropped min: {np.nanmin(cropped_data):.2f}, max: {np.nanmax(cropped_data):.2f} (after NaN conversion)")
            return cropped_data, cropped_transform
            
    except Exception as e:
        print(f"Error processing {variable_name} ({raster_path}): {e}")
        return None, None

# --- Process Bioclimatic Variables ---
print(f"\nProcessing Bioclimatic variables from: {bio_vars_dir}")
if os.path.exists(bio_vars_dir):
    for i in range(1, 20): # BIO1 to BIO19
        var_name = f'bio{i}'
        file_name = f'wc2.1_2.5m_bio_{i}.tif'
        raster_path = os.path.join(bio_vars_dir, file_name)
        
        cropped_raster_data, cropped_raster_transform = load_and_crop_raster(raster_path, bbox_for_cropping, var_name)
        
        if cropped_raster_data is not None:
            output_cropped_rasters[var_name] = {
                'data': cropped_raster_data,
                'transform': cropped_raster_transform
            }
else:
    print(f"Error: Bioclim directory not found at {bio_vars_dir}")

# --- Process Elevation ---
print(f"\nProcessing Elevation data from: {elev_file_path}")
if os.path.exists(elev_file_path):
    var_name = 'elev'
    cropped_raster_data, cropped_raster_transform = load_and_crop_raster(elev_file_path, bbox_for_cropping, var_name)
    
    if cropped_raster_data is not None:
        output_cropped_rasters[var_name] = {
            'data': cropped_raster_data,
            'transform': cropped_raster_transform
        }
else:
    print(f"Error: Elevation file not found at {elev_file_path}")

print("\n--- Summary of Cropped Rasters ---")
if output_cropped_rasters:
    for var_name, details in output_cropped_rasters.items():
        print(f"Variable: {var_name}, Shape: {details['data'].shape if details['data'] is not None else 'N/A'}")
else:
    print("No rasters were processed or loaded.")

## Step 13a: Correlation matrix to reduce highly correlated (R^2 > 0.70) bioclim variables

In [None]:
# --- Step 13a: Correlation matrix to reduce highly correlated (R^2 > 0.70) bioclim variables ---

# --- Bioclim Variable Names Mapping ---
bioclim_names = {
    'bio1': 'BIO1 Ann Mean Temp',
    'bio2': 'BIO2 Mean Diurnal Range',
    'bio3': 'BIO3 Isothermality',
    'bio4': 'BIO4 Temp Seasonality',
    'bio5': 'BIO5 Max Temp Warm Mth',
    'bio6': 'BIO6 Min Temp Cold Mth',
    'bio7': 'BIO7 Temp Ann Range',
    'bio8': 'BIO8 Mean Temp Wet Qtr',
    'bio9': 'BIO9 Mean Temp Dry Qtr',
    'bio10': 'BIO10 Mean Temp Warm Qtr',
    'bio11': 'BIO11 Mean Temp Cold Qtr',
    'bio12': 'BIO12 Ann Precip',
    'bio13': 'BIO13 Precip Wet Mth',
    'bio14': 'BIO14 Precip Dry Mth',
    'bio15': 'BIO15 Precip Seasonality',
    'bio16': 'BIO16 Precip Wet Qtr',
    'bio17': 'BIO17 Precip Dry Qtr',
    'bio18': 'BIO18 Precip Warm Qtr',
    'bio19': 'BIO19 Precip Cold Qtr',
    'elev': 'Elevation'
}

# --- 1. Extract and Prepare Data for Analysis ---
data_for_analysis = {}
reference_shape = None
reference_var_name = None

print("Preparing bioclimatic and elevation data for analysis...")

if 'output_cropped_rasters' not in globals() or not output_cropped_rasters:
    print("Error: 'output_cropped_rasters' is not defined or is empty.")
    output_cropped_rasters = {}

analysis_keys = sorted([key for key in output_cropped_rasters.keys() if key.startswith('bio') or key == 'elev'])

if not analysis_keys:
    print("No bioclimatic or elevation variables found in 'output_cropped_rasters'.")
else:
    for var_name_short in analysis_keys: 
        if var_name_short in output_cropped_rasters and \
           'data' in output_cropped_rasters[var_name_short] and \
           output_cropped_rasters[var_name_short]['data'] is not None:
            
            current_data = output_cropped_rasters[var_name_short]['data']
            if reference_shape is None:
                reference_shape = current_data.shape
                reference_var_name = var_name_short
                print(f"  Using shape of '{bioclim_names.get(var_name_short, var_name_short)}' ({reference_shape}) as reference.")
            
            if current_data.shape != reference_shape:
                print(f"  Warning: Shape of '{bioclim_names.get(var_name_short, var_name_short)}' ({current_data.shape}) "
                      f"does not match reference shape ({reference_shape} from '{bioclim_names.get(reference_var_name, reference_var_name)}').")
                print(f"  Skipping '{bioclim_names.get(var_name_short, var_name_short)}' for analysis.")
                continue
            
            descriptive_name = bioclim_names.get(var_name_short, var_name_short)
            data_for_analysis[descriptive_name] = current_data.flatten()
        else:
            print(f"  Data for '{bioclim_names.get(var_name_short, var_name_short)}' is missing or None. Skipping.")

# --- 2. Create a DataFrame with Descriptive Names ---
if not data_for_analysis:
    print("No valid data to build DataFrame for analysis. Aborting.")
else:
    df_analysis = pd.DataFrame(data_for_analysis)
    ordered_descriptive_names = [bioclim_names.get(k, k) for k in analysis_keys if bioclim_names.get(k,k) in df_analysis.columns]
    df_analysis = df_analysis[ordered_descriptive_names]

    print(f"\nDataFrame created with {len(df_analysis.columns)} variables and {len(df_analysis)} pixels (rows).")

    # --- 3. Handle NoData/NaN values ---
    original_pixel_count = len(df_analysis)
    df_analysis_cleaned = df_analysis.dropna()
    cleaned_pixel_count = len(df_analysis_cleaned)
    print(f"  Removed {original_pixel_count - cleaned_pixel_count} pixels with NaN values.")
    print(f"  {cleaned_pixel_count} pixels remain for analysis.")

    if cleaned_pixel_count < 2:
        print("  Error: Not enough data points after NaN removal for analysis.")
    elif len(df_analysis_cleaned.columns) < 2:
        print(f"  Error: Only {len(df_analysis_cleaned.columns)} variable(s) remain. Need at least 2 for pair plots.")
    else:
        # --- 4. Correlation Matrix (Heatmap) ---
        # (Heatmap code remains the same as previous correct version)
        print("\nCalculating correlation matrix...")
        correlation_matrix_final = df_analysis_cleaned.corr(method='pearson')

        print("\n--- Correlation Matrix (Descriptive Names) ---")
        with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 200):
            print(correlation_matrix_final)

        print("\nVisualizing correlation matrix (heatmap)...")
        plt.figure(figsize=(20, 16))
        sns.heatmap(correlation_matrix_final, annot=True, cmap='coolwarm', fmt=".2f", 
                    linewidths=.5, vmin=-1, vmax=1,
                    cbar_kws={'label': 'Pearson Correlation Coefficient'})
        plt.title('Correlation Matrix of Cropped Climate Variables', fontsize=18)
        plt.xticks(rotation=60, ha='right', fontsize=9)
        plt.yticks(rotation=0, fontsize=9)
        plt.tight_layout()
        plt.show()

        # --- 5. Scatter Plot Matrix with Regression Lines and R-squared ---
        print("\nGenerating scatter plot matrix (this might take time)...")
        
        sample_size_for_pairplot = min(1000, len(df_analysis_cleaned))
        df_for_pairplot = df_analysis_cleaned.sample(n=sample_size_for_pairplot, random_state=42) if len(df_analysis_cleaned) > sample_size_for_pairplot else df_analysis_cleaned
        print(f"  Using {len(df_for_pairplot)} sampled pixels for the pairplot.")

        def annotate_r_squared(x, y, ax=None, **kwargs):
            if not isinstance(x, pd.Series) or not isinstance(y, pd.Series):
                return
            valid_indices = ~x.isnull() & ~y.isnull()
            x_valid = x[valid_indices]
            y_valid = y[valid_indices]
            if len(x_valid) < 2 or len(y_valid) < 2:
                r_squared_text = "R²: N/A"
            else:
                r, _ = pearsonr(x_valid, y_valid)
                r_squared = r**2
                r_squared_text = f'$R^2={r_squared:.2f}$'
            ax = ax or plt.gca()
            ax.text(0.05, 0.9, r_squared_text, transform=ax.transAxes, 
                    fontsize=8, ha='left', va='top', 
                    bbox=dict(boxstyle='round,pad=0.3', fc='wheat', alpha=0.5))

        # MODIFIED HERE: Removed plot_kws from the main pairplot call
        pair_plot = sns.pairplot(df_for_pairplot, 
                                 diag_kind='kde',
                                 diag_kws={'fill': True, 'alpha':0.6}) 
        
        # Apply regression plot to lower triangle and R-squared to upper
        pair_plot.map_lower(sns.regplot, 
                            scatter_kws={'alpha':0.3, 's':10},  # Style for regplot's scatter
                            line_kws={'color':'red', 'linewidth':1}) # Style for regplot's line
        pair_plot.map_upper(annotate_r_squared)
        
        for ax_row in pair_plot.axes:
            for ax_col in ax_row:
                ax_col.tick_params(axis='both', which='major', labelsize=7)
                ax_col.xaxis.label.set_size(8) 
                ax_col.yaxis.label.set_size(8)

        pair_plot.fig.suptitle('Scatter Plot Matrix of Climate Variables (Regression & R²)', y=1.02, fontsize=18)
        plt.tight_layout(rect=[0, 0, 1, 0.98])
        plt.show()
        
        print("\nAnalysis complete.")

## Step 13b: Plot un-correlated bioclim variables

In [None]:
# --- Step 13b: Plot un-correlated bioclim variables ---

# --- Bioclim Variable Names Mapping (ensure it's available) ---
bioclim_names = {
    'bio1': 'BIO1 Ann Mean Temp', 'bio2': 'BIO2 Mean Diurnal Range',
    'bio3': 'BIO3 Isothermality', 'bio4': 'BIO4 Temp Seasonality',
    'bio5': 'BIO5 Max Temp Warm Mth', 'bio6': 'BIO6 Min Temp Cold Mth',
    'bio7': 'BIO7 Temp Ann Range', 'bio8': 'BIO8 Mean Temp Wet Qtr',
    'bio9': 'BIO9 Mean Temp Dry Qtr', 'bio10': 'BIO10 Mean Temp Warm Qtr',
    'bio11': 'BIO11 Mean Temp Cold Qtr', 'bio12': 'BIO12 Ann Precip',
    'bio13': 'BIO13 Precip Wet Mth', 'bio14': 'BIO14 Precip Dry Mth',
    'bio15': 'BIO15 Precip Seasonality', 'bio16': 'BIO16 Precip Wet Qtr',
    'bio17': 'BIO17 Precip Dry Qtr', 'bio18': 'BIO18 Precip Warm Qtr',
    'bio19': 'BIO19 Precip Cold Qtr', 'elev': 'Elevation'
}

# Ensure df_analysis_cleaned is defined from the previous step
if 'df_analysis_cleaned' not in globals() or df_analysis_cleaned.empty:
    print("Error: 'df_analysis_cleaned' is not defined or is empty. Please run the previous cell.")
    # As a fallback for standalone execution, let's try to recreate it if possible
    # This assumes 'output_cropped_rasters' is available
    if 'output_cropped_rasters' in globals() and output_cropped_rasters:
        data_for_analysis = {}
        analysis_keys = sorted([key for key in output_cropped_rasters.keys() if key.startswith('bio') or key == 'elev'])
        reference_shape = None
        if analysis_keys:
            for var_name_short in analysis_keys:
                if var_name_short in output_cropped_rasters and output_cropped_rasters[var_name_short]['data'] is not None:
                    current_data = output_cropped_rasters[var_name_short]['data']
                    if reference_shape is None: reference_shape = current_data.shape
                    if current_data.shape == reference_shape:
                         data_for_analysis[bioclim_names.get(var_name_short, var_name_short)] = current_data.flatten()
            df_analysis = pd.DataFrame(data_for_analysis)
            ordered_descriptive_names = [bioclim_names.get(k, k) for k in analysis_keys if bioclim_names.get(k,k) in df_analysis.columns]
            df_analysis = df_analysis[ordered_descriptive_names]
            df_analysis_cleaned = df_analysis.dropna()
            if df_analysis_cleaned.empty: print("Recreated df_analysis_cleaned is empty.")
        else: print("No keys found to recreate df_analysis_cleaned.")
    else:
        print("Cannot recreate df_analysis_cleaned as output_cropped_rasters is also missing.")
        # You might want to raise an error or exit here if it cannot be recreated.
        # For now, this script might fail if it's not available.


if 'df_analysis_cleaned' in globals() and not df_analysis_cleaned.empty and len(df_analysis_cleaned.columns) >=2 :
    print(f"Starting with {len(df_analysis_cleaned.columns)} variables.")

    # --- 1. Calculate the Full Correlation Matrix ---
    correlation_matrix_full = df_analysis_cleaned.corr(method='pearson')
    
    # --- 2. Iteratively Identify and Mark Variables for Removal ---
    variables_to_remove = set()
    threshold = 0.70
    
    # Create a copy of column names to iterate over, as we might modify the set of considered columns indirectly
    columns_to_consider = list(df_analysis_cleaned.columns) 

    # Calculate sum of absolute correlations for each variable with all others (in the original full set)
    # This helps decide which of a correlated pair is "more connected" overall
    sum_abs_corr = correlation_matrix_full.abs().sum(axis=1) - 1 # Subtract 1 for self-correlation

    # Iterate through the upper triangle of the correlation matrix
    for i in range(len(correlation_matrix_full.columns)):
        for j in range(i + 1, len(correlation_matrix_full.columns)):
            var_i = correlation_matrix_full.columns[i]
            var_j = correlation_matrix_full.columns[j]
            
            # If one of the variables is already marked for removal, skip this pair
            if var_i in variables_to_remove or var_j in variables_to_remove:
                continue
                
            correlation_value = correlation_matrix_full.loc[var_i, var_j]
            
            if abs(correlation_value) > threshold:
                print(f"  Found correlation > {threshold}: '{var_i}' and '{var_j}' ({correlation_value:.2f})")
                
                # Decide which variable to remove: the one with higher sum_abs_corr
                if sum_abs_corr[var_i] > sum_abs_corr[var_j]:
                    print(f"    Marking '{var_i}' for removal (higher overall abs corr: {sum_abs_corr[var_i]:.2f} vs {sum_abs_corr[var_j]:.2f} for '{var_j}').")
                    variables_to_remove.add(var_i)
                elif sum_abs_corr[var_j] > sum_abs_corr[var_i]:
                    print(f"    Marking '{var_j}' for removal (higher overall abs corr: {sum_abs_corr[var_j]:.2f} vs {sum_abs_corr[var_i]:.2f} for '{var_i}').")
                    variables_to_remove.add(var_j)
                else:
                    # Tie-breaking: remove the one that comes later alphabetically or by original order
                    # For simplicity, let's remove the one later in the current column list
                    # This is arbitrary but consistent.
                    if columns_to_consider.index(var_i) > columns_to_consider.index(var_j):
                         print(f"    Tie in overall abs corr. Marking '{var_i}' for removal (appears later).")
                         variables_to_remove.add(var_i)
                    else:
                         print(f"    Tie in overall abs corr. Marking '{var_j}' for removal (appears later).")
                         variables_to_remove.add(var_j)


    print(f"\nVariables marked for removal: {variables_to_remove if variables_to_remove else 'None'}")

    # --- 3. Create DataFrame with Reduced Set of Variables ---
    variables_to_keep = [col for col in df_analysis_cleaned.columns if col not in variables_to_remove]
    df_analysis_reduced = df_analysis_cleaned[variables_to_keep]
    
    print(f"\nReduced set has {len(df_analysis_reduced.columns)} variables: {df_analysis_reduced.columns.tolist()}")

    if len(df_analysis_reduced.columns) < 2:
        print("  Error: After removing correlated variables, less than 2 variables remain. Cannot create correlation matrix.")
    else:
        # --- 4. Calculate and Display Correlation Matrix for Reduced Set ---
        print("\nCalculating correlation matrix for the reduced set of variables...")
        correlation_matrix_reduced = df_analysis_reduced.corr(method='pearson')

        print("\n--- Correlation Matrix (Reduced Set) ---")
        with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 200):
            print(correlation_matrix_reduced)

        print("\nVisualizing correlation matrix (heatmap) for the reduced set...")
        plt.figure(figsize=(max(10, len(df_analysis_reduced.columns)*0.8), max(8, len(df_analysis_reduced.columns)*0.6)))
        sns.heatmap(correlation_matrix_reduced, annot=True, cmap='coolwarm', fmt=".2f", 
                    linewidths=.5, vmin=-1, vmax=1,
                    cbar_kws={'label': 'Pearson Correlation Coefficient'})
        plt.title('Correlation Matrix of Reduced Climate Variables (Threshold > 0.70)', fontsize=16)
        plt.xticks(rotation=60, ha='right', fontsize=9)
        plt.yticks(rotation=0, fontsize=9)
        plt.tight_layout()
        plt.show()

        # --- 5. Scatter Plot Matrix for Reduced Set (Optional, can be slow) ---
        # Consider if you still need this for the reduced set.
        # If df_analysis_reduced.columns is small (e.g., < 8), it might be quick.
        if len(df_analysis_reduced.columns) > 1 and len(df_analysis_reduced.columns) <= 10: # Limit for practical pairplot
            print("\nGenerating scatter plot matrix for the reduced set (this might take some time)...")
            
            sample_size_for_pairplot_reduced = min(1000, len(df_analysis_reduced))
            df_for_pairplot_reduced = df_analysis_reduced.sample(n=sample_size_for_pairplot_reduced, random_state=42) if len(df_analysis_reduced) > sample_size_for_pairplot_reduced else df_analysis_reduced
            print(f"  Using {len(df_for_pairplot_reduced)} sampled pixels for the reduced pairplot.")

            def annotate_r_squared_reduced(x, y, ax=None, **kwargs): # Copied from previous
                if not isinstance(x, pd.Series) or not isinstance(y, pd.Series): return
                valid_indices = ~x.isnull() & ~y.isnull()
                x_valid, y_valid = x[valid_indices], y[valid_indices]
                if len(x_valid) < 2 or len(y_valid) < 2: r_squared_text = "R²: N/A"
                else: r, _ = pearsonr(x_valid, y_valid); r_squared = r**2; r_squared_text = f'$R^2={r_squared:.2f}$'
                ax = ax or plt.gca()
                ax.text(0.05, 0.9, r_squared_text, transform=ax.transAxes, fontsize=8, ha='left', va='top', bbox=dict(boxstyle='round,pad=0.3', fc='wheat', alpha=0.5))

            pair_plot_reduced = sns.pairplot(df_for_pairplot_reduced, diag_kind='kde', diag_kws={'fill': True, 'alpha':0.6})
            pair_plot_reduced.map_lower(sns.regplot, scatter_kws={'alpha':0.3, 's':10}, line_kws={'color':'red', 'linewidth':1})
            pair_plot_reduced.map_upper(annotate_r_squared_reduced)

            for ax_r in pair_plot_reduced.axes:
                for ax_c in ax_r:
                    ax_c.tick_params(axis='both', which='major', labelsize=7)
                    ax_c.xaxis.label.set_size(8); ax_c.yaxis.label.set_size(8)
            
            pair_plot_reduced.fig.suptitle('Scatter Plot Matrix of Reduced Climate Variables (Regression & R²)', y=1.02, fontsize=16)
            plt.tight_layout(rect=[0, 0, 1, 0.98])
            plt.show()
        elif len(df_analysis_reduced.columns) > 10:
            print("\nSkipping pairplot for the reduced set as it still has many variables (>10). Heatmap should suffice.")
        
    print("\nVariable reduction process complete.")
else:
    print("Skipping variable reduction as 'df_analysis_cleaned' is not suitable (e.g. empty or too few columns).")

## Step 13c: Plot reduced BioClim variables

In [None]:
# --- Step 13c: Plot reduced BioClim variables ---
# Invert the bioclim_names for easy lookup from descriptive to short key
short_names_map = {v: k for k, v in bioclim_names.items()}

# --- Units for variables (as provided by user) ---
variable_units = {
    'BIO10 Mean Temp Warm Qtr': '°C', # bio10
    'BIO13 Precip Wet Mth': 'mm',     # bio13
    'BIO15 Precip Seasonality': 'CV', # bio15 (Coefficient of Variation)
    'BIO18 Precip Warm Qtr': 'mm',    # bio18
    'BIO2 Mean Diurnal Range': '°C',  # bio2
    'BIO8 Mean Temp Wet Qtr': '°C',   # bio8
    'BIO9 Mean Temp Dry Qtr': '°C',   # bio9
    'Elevation': 'meters'             # elev
}
# Add any other variables that might be in df_analysis_reduced with their units
# If a variable from df_analysis_reduced is not in variable_units, it will plot without unit in cbar label.


# Check if df_analysis_reduced exists
if 'df_analysis_reduced' not in globals() or df_analysis_reduced.empty:
    print("Error: 'df_analysis_reduced' (DataFrame with selected variables) not found or is empty.")
    print("Please ensure the variable reduction step has been run successfully.")
    # Create a placeholder if needed for the code to run partially
    class PlaceholderDF:
        columns = []
    df_analysis_reduced = PlaceholderDF()


if 'output_cropped_rasters' not in globals() or not output_cropped_rasters:
    print("Error: 'output_cropped_rasters' not found or is empty.")
    # Create a placeholder
    output_cropped_rasters = {}

if 'north_america_map_unprojected' not in globals():
    print("Warning: 'north_america_map_unprojected' not found. Basemap will not be plotted.")
    north_america_map_unprojected = None # Avoid error later

if 'gdf_all_occurrences_thinned_status' not in globals():
    print("Warning: 'gdf_all_occurrences_thinned_status' not found. Occurrence points will not be plotted.")
    gdf_all_occurrences_thinned_status = None


# --- Plotting Selected Variables ---
final_selected_variables_desc = list(df_analysis_reduced.columns)

if not final_selected_variables_desc:
    print("No variables in the reduced set to plot.")
else:
    print(f"\nPlotting {len(final_selected_variables_desc)} selected climate variables over their cropped extent...")
    
    # Determine overall extent for consistent plotting limits across these variables
    # This should be based on the extent used for cropping them originally.
    # Re-derive from one of the cropped rasters if not stored, or use `bbox_for_cropping`
    # from the raster processing step if it's still in scope.
    
    # For simplicity, let's try to get extent from the first available raster
    # or fallback to the bbox_for_cropping if defined.
    plot_extent = None 
    first_valid_short_key = None

    for desc_name in final_selected_variables_desc:
        short_key = short_names_map.get(desc_name)
        
        if not short_key or short_key not in output_cropped_rasters or \
           output_cropped_rasters[short_key]['data'] is None:
            print(f"  Skipping '{desc_name}': Data not found in output_cropped_rasters or is None.")
            continue

        raster_data = output_cropped_rasters[short_key]['data']
        # raster_transform is still needed for plotting_extent, but not for imshow's transform kwarg
        raster_transform_for_extent = output_cropped_rasters[short_key]['transform'] 
        unit = variable_units.get(desc_name, '')
        
        fig, ax = plt.subplots(1, 1, figsize=(10, 7))
        
        current_cmap_name = 'viridis' # Default
        if short_key == 'elev':
            current_cmap_name = 'terrain' # Specific colormap for elevation
        elif 'precip' in desc_name.lower() or 'bio12' <= short_key <= 'bio19': # Precipitation related
            current_cmap_name = 'Blues' # Or 'GnBu', 'YlGnBu'
        elif 'temp' in desc_name.lower() or 'bio1' <= short_key <= 'bio11': # Temperature related
            current_cmap_name = 'coolwarm' # Or 'RdYlBu_r' for diverging

        cmap = plt.get_cmap(current_cmap_name) # Use plt.get_cmap to get the colormap object
        cmap.set_bad(color='lightgray')

        # MODIFIED HERE: Removed 'transform=raster_transform'
        im = ax.imshow(raster_data, 
                       cmap=cmap, 
                       extent=rasterio.plot.plotting_extent(raster_data, raster_transform_for_extent) if raster_data is not None else None
                      )

        # Plot basemap
        if north_america_map_unprojected is not None and plot_extent_xmin is not None:
            try:
                north_america_map_unprojected.plot(ax=ax, color='lightgoldenrodyellow', edgecolor='gray', linewidth=0.5, zorder=1, alpha=0.7)
            except Exception as e:
                print(f"    Could not plot/clip basemap for {desc_name}: {e}")
        
        # Overlay occurrence points
        if gdf_all_occurrences_thinned_status is not None and not gdf_all_occurrences_thinned_status.empty:
            gdf_all_occurrences_thinned_status.plot(ax=ax, marker='.', color='red', markersize=5, zorder=3, alpha=0.6, label='Occurrences')

        # Set plot limits
        if plot_extent_xmin is not None:
            ax.set_xlim(plot_extent_xmin, plot_extent_xmax)
            ax.set_ylim(plot_extent_ymin, plot_extent_ymax)

        ax.set_title(f'{desc_name}', fontsize=14)
        ax.set_xlabel("Longitude", fontsize=10)
        ax.set_ylabel("Latitude", fontsize=10)
        ax.grid(True, linestyle='--', alpha=0.5)
        
        cbar = fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
        cbar.set_label(f'{unit}', fontsize=10)
        
        plt.tight_layout()
        plt.show()
        
print("\nFinished plotting selected climate variables.")


## Step 14a: Extract climate vars for all occurrence data and visualize

In [None]:
# --- Step 14: Extract climate vars for all occurrence data and visualize ---

# --- Define species_colors (ensure this is done consistently, e.g., from an earlier cell) ---
# This should be defined based on the species present in your main occurrence dataset
# gdf_all_occurrences_thinned_status or a similar definitive list.
if 'species_colors' not in globals():
    print("Warning: 'species_colors' not defined. Attempting to create it.")
    if 'gdf_all_occurrences_thinned_status' in globals() and \
       'simpSciName' in gdf_all_occurrences_thinned_status.columns and \
       not gdf_all_occurrences_thinned_status.empty:
        
        unique_species_for_colors = sorted(list(gdf_all_occurrences_thinned_status['simpSciName'].dropna().unique()))
        if unique_species_for_colors:
            try:
                palette_gen = sns.color_palette("husl", len(unique_species_for_colors))
            except Exception: # Fallback if seaborn fails or husl changes
                cmap_gen = plt.cm.get_cmap('tab10', max(1,len(unique_species_for_colors))) # Ensure at least 1 color
                palette_gen = [cmap_gen(i) for i in range(len(unique_species_for_colors))]
            species_colors = {name: palette_gen[i] for i, name in enumerate(unique_species_for_colors)}
            print(f"Defined 'species_colors' for {len(species_colors)} species.")
        else:
            print("Could not define 'species_colors': No unique species names found in thinned data.")
            species_colors = {}
    else:
        print("Could not define 'species_colors': 'gdf_all_occurrences_thinned_status' or 'simpSciName' column missing/empty.")
        species_colors = {}
# --- End of species_colors definition ---


# --- Placeholder checks for other prerequisites (if running standalone) ---
if 'gdf_all_occurrences_thinned_status' not in globals() or gdf_all_occurrences_thinned_status.empty:
    print("Error: gdf_all_occurrences_thinned_status is not defined or empty.")
    gdf_all_occurrences_thinned_status = gpd.GeoDataFrame(
        {'simpSciName': ['Amorpha test'], 'geometry': [box(-100, 30, -99, 31)]}, crs="EPSG:4326"
    )
if 'output_cropped_rasters' not in globals() or not output_cropped_rasters:
    print("Error: output_cropped_rasters is not defined or empty.")
    output_cropped_rasters = {}
if 'df_analysis_reduced' not in globals() or df_analysis_reduced.empty:
    print("Error: df_analysis_reduced (selected climate variables) is not defined or empty.")
    class PlaceholderReducedDF: columns = []
    df_analysis_reduced = PlaceholderReducedDF()
if 'bioclim_names' not in globals(): 
    bioclim_names = {f'bio{i}': f'BIO{i}' for i in range(1,20)}; bioclim_names['elev'] = 'Elevation'
if 'short_names_map' not in globals(): 
    short_names_map = {v: k for k, v in bioclim_names.items()}
if 'variable_units' not in globals(): 
    variable_units = {name: 'units' for name in bioclim_names.values()}


# --- Part 1: Extract Climate Data for Species Occurrence Points ---
print("\n--- Part 1: Extracting Climate Data for Occurrence Points ---")
gdf_occurrences_with_climate = gdf_all_occurrences_thinned_status.copy()
selected_vars_descriptive = list(df_analysis_reduced.columns)

if not selected_vars_descriptive:
    print("No selected climate variables found in df_analysis_reduced. Skipping extraction and subsequent analyses.")
    if 'gdf_occurrences_with_climate' not in globals(): gdf_occurrences_with_climate = pd.DataFrame()
else:
    if 'geometry' not in gdf_occurrences_with_climate.columns or not isinstance(gdf_occurrences_with_climate.geometry, gpd.GeoSeries):
        print("Error: 'geometry' column missing or invalid in gdf_all_occurrences_thinned_status.")
        selected_vars_descriptive = [] 
    elif gdf_occurrences_with_climate.crs is None or gdf_occurrences_with_climate.crs.to_string().upper() != 'EPSG:4326':
        print(f"Warning: Occurrence CRS is {gdf_occurrences_with_climate.crs}. Attempting reproject to EPSG:4326.")
        try:
            gdf_occurrences_with_climate = gdf_occurrences_with_climate.to_crs("EPSG:4326")
        except Exception as e:
            print(f"Error reprojecting occurrences: {e}. Extraction may fail."); selected_vars_descriptive = []
    
    if selected_vars_descriptive:
        coords = [(p.x, p.y) for p in gdf_occurrences_with_climate.geometry]
        for desc_name in selected_vars_descriptive:
            short_key = short_names_map.get(desc_name)
            if not short_key or short_key not in output_cropped_rasters or output_cropped_rasters[short_key]['data'] is None:
                print(f"  Skipping extraction for '{desc_name}': Data not found."); gdf_occurrences_with_climate[desc_name] = np.nan; continue
            raster_data, raster_transform = output_cropped_rasters[short_key]['data'], output_cropped_rasters[short_key]['transform']
            print(f"  Extracting values for: {desc_name} ({short_key})")
            try:
                raster_data_for_sampling = raster_data[np.newaxis, :, :] if raster_data.ndim == 2 else raster_data
                with rasterio.MemoryFile() as memfile:
                    with memfile.open(driver='GTiff', height=raster_data_for_sampling.shape[1], width=raster_data_for_sampling.shape[2], 
                                      count=raster_data_for_sampling.shape[0], dtype=str(raster_data_for_sampling.dtype), 
                                      crs='EPSG:4326', transform=raster_transform) as dataset:
                        dataset.write(raster_data_for_sampling)
                    with memfile.open() as dataset_to_sample:
                        extracted_values = [val[0] if val is not None and len(val)>0 else np.nan for val in dataset_to_sample.sample(coords, masked=True)]
                gdf_occurrences_with_climate[desc_name] = extracted_values
                print(f"    Successfully extracted. Min: {np.nanmin(extracted_values):.2f}, Max: {np.nanmax(extracted_values):.2f}")
            except Exception as e:
                print(f"    Error extracting values for {desc_name}: {e}"); gdf_occurrences_with_climate[desc_name] = np.nan

        print("\nClimate data extraction complete.")
        valid_climate_cols = [col for col in selected_vars_descriptive if col in gdf_occurrences_with_climate.columns]
        print("Sample of occurrences with extracted climate data:")
        print(gdf_occurrences_with_climate[['simpSciName'] + valid_climate_cols].head())
        
        # --- Part 2a: Overall Climate Niche (Box Plots) ---
        print("\n--- Part 2a: Overall Climate Niche (All Species Together) ---")
        if valid_climate_cols and not gdf_occurrences_with_climate[valid_climate_cols].empty:
            num_vars = len(valid_climate_cols); plots_per_row = 4; num_rows = (num_vars + plots_per_row - 1) // plots_per_row
            if num_vars > 0 and num_rows > 0:
                plt.figure(figsize=(max(15, plots_per_row * 4), num_rows * 4))
                for i, var_name in enumerate(valid_climate_cols):
                    plt.subplot(num_rows, plots_per_row, i + 1)
                    sns.boxplot(y=gdf_occurrences_with_climate[var_name].dropna())
                    unit = variable_units.get(var_name, ''); plt.title(var_name, fontsize=10)
                    plt.ylabel(f"Value ({unit})" if unit else "Value", fontsize=9); plt.xticks([])
                    plt.tick_params(axis='y', labelsize=8)
                plt.suptitle('Overall Climatic Niche for Amorpha spp.', fontsize=16, y=1.02 if num_rows >1 else 1.05)
                plt.tight_layout(rect=[0, 0, 1, 0.98 if num_rows >1 else 0.95]); plt.show()
            else: print("No valid climate variables to plot for overall niche.")
        else: print("No climate data extracted to plot overall niche.")

        # --- Part 2b: Species-Specific Climate Niches (Box Plots) ---
        print("\n--- Part 2b: Species-Specific Climate Niches ---")
        if 'simpSciName' in gdf_occurrences_with_climate.columns and valid_climate_cols and not gdf_occurrences_with_climate[valid_climate_cols].empty:
            unique_species_in_data = sorted(gdf_occurrences_with_climate['simpSciName'].dropna().unique())
            # Ensure palette uses the main species_colors, defaulting if a species isn't in the map
            species_specific_palette = {s: species_colors.get(s, '#808080') for s in unique_species_in_data}

            for var_name in valid_climate_cols:
                if gdf_occurrences_with_climate[var_name].notna().sum() < 1:
                    print(f"Skipping boxplot for {var_name} as it contains all NaN values."); continue
                plt.figure(figsize=(max(10, len(unique_species_in_data) * 0.8), 6))
                sns.boxplot(x='simpSciName', y=var_name, 
                            data=gdf_occurrences_with_climate.dropna(subset=[var_name]), 
                            order=unique_species_in_data, # Use sorted list
                            palette=species_specific_palette) # MODIFIED: Use consistent species_colors
                unit = variable_units.get(var_name, ''); plt.title(f'Climatic Niche for: {var_name}', fontsize=14)
                plt.xlabel('Species', fontsize=10); plt.ylabel(f"Value ({unit})" if unit else "Value", fontsize=10)
                plt.xticks(rotation=45, ha='right', fontsize=8); plt.yticks(fontsize=9)
                plt.tight_layout(); plt.show()
        else: print("Insufficient data for species-specific boxplots.")

        # --- Part 2c: PCA ---
        print("\n--- Part 2c: Principal Component Analysis (PCA) ---")
        pca_data = gdf_occurrences_with_climate[valid_climate_cols].dropna()
        if pca_data.shape[0] < 2 or pca_data.shape[1] < 2:
            print(f"Skipping PCA: Not enough data after NaN removal (Samples: {pca_data.shape[0]}, Features: {pca_data.shape[1]})")
        else:
            print(f"Performing PCA on {pca_data.shape[0]} samples and {pca_data.shape[1]} climate variables.")
            scaler = StandardScaler(); scaled_data = scaler.fit_transform(pca_data)
            pca = PCA(n_components=None); pca_transformed_data = pca.fit_transform(scaled_data)
            explained_variance_ratio = pca.explained_variance_ratio_
            print(f"\nExplained variance by each component: {np.round(explained_variance_ratio, 3)}")
            print(f"Cumulative explained variance: {np.round(np.cumsum(explained_variance_ratio), 3)}")
            print(f"Variance explained by PC1: {explained_variance_ratio[0]*100:.2f}%")
            if len(explained_variance_ratio) > 1:
                print(f"Variance explained by PC2: {explained_variance_ratio[1]*100:.2f}%")
                print(f"Variance explained by PC1 & PC2: {(explained_variance_ratio[0] + explained_variance_ratio[1])*100:.2f}%")
            else: print("Only one principal component was computed.")

            if pca_transformed_data.shape[1] >= 2:
                pca_df = pd.DataFrame(data=pca_transformed_data[:, :2], columns=['PC1', 'PC2'])
                pca_df['simpSciName'] = gdf_occurrences_with_climate.loc[pca_data.index, 'simpSciName'].values
                
                unique_species_in_pca = sorted(pca_df['simpSciName'].unique())
                pca_palette = {s: species_colors.get(s, '#808080') for s in unique_species_in_pca}

                plt.figure(figsize=(12, 10))
                sns.scatterplot(x='PC1', y='PC2', hue='simpSciName', data=pca_df, 
                                palette=pca_palette, # MODIFIED: Use consistent species_colors
                                hue_order=unique_species_in_pca, # Consistent legend order
                                s=20, alpha=0.7)
                loadings = pca.components_[:2, :].T
                pc1_max_abs, pc2_max_abs = 0,0
                if not pca_df.empty:
                    if 'PC1' in pca_df.columns and pca_df['PC1'].notna().any(): pc1_max_abs = np.abs(pca_df['PC1']).max()
                    if 'PC2' in pca_df.columns and pca_df['PC2'].notna().any(): pc2_max_abs = np.abs(pca_df['PC2']).max()
                arrow_scale_base = max(pc1_max_abs, pc2_max_abs) if (pc1_max_abs > 0 or pc2_max_abs > 0) else 1.0
                arrow_scale = np.sqrt(arrow_scale_base) * 1.2
                if pd.isna(arrow_scale) or np.isinf(arrow_scale) : arrow_scale = 1.5
                for i, var_name in enumerate(pca_data.columns):
                    plt.arrow(0,0,loadings[i,0]*arrow_scale,loadings[i,1]*arrow_scale,color='black',alpha=0.7,head_width=0.08,head_length=0.1,overhang=0.3)
                    plt.text(loadings[i,0]*arrow_scale*1.15,loadings[i,1]*arrow_scale*1.15,var_name,color='dimgray',ha='center',va='center',fontsize=9)
                plt.xlabel(f'PC1 ({explained_variance_ratio[0]*100:.1f}%)',fontsize=12); plt.ylabel(f'PC2 ({explained_variance_ratio[1]*100:.1f}%)',fontsize=12)
                plt.title('PCA of Climate Variables for Amorpha Occurrences',fontsize=16); plt.axhline(0,color='grey',ls='--',lw=0.7,alpha=0.5); plt.axvline(0,color='grey',ls='--',lw=0.7,alpha=0.5)
                plt.grid(True,ls=':',alpha=0.4); plt.legend(title='Species',bbox_to_anchor=(1.05,1),loc='upper left',fontsize=9,title_fontsize=10)
                plt.tight_layout(rect=[0,0,0.85,1]); plt.show()
                print("\n--- Loadings for Principal Components ---")
                loadings_df = pd.DataFrame(pca.components_[:2,:].T,columns=['PC1','PC2'],index=pca_data.columns)
                print(loadings_df.sort_values(by='PC1',ascending=False,key=abs))
                print("\nVariables contributing most to PC1 (absolute value):"); print(loadings_df['PC1'].abs().sort_values(ascending=False).head())
                print("\nVariables contributing most to PC2 (absolute value):"); print(loadings_df['PC2'].abs().sort_values(ascending=False).head())
            elif pca_transformed_data.shape[1] == 1:
                print("Only one principal component available. Biplot cannot be generated."); pca_df = pd.DataFrame(data=pca_transformed_data[:,:1],columns=['PC1'])
                pca_df['simpSciName'] = gdf_occurrences_with_climate.loc[pca_data.index,'simpSciName'].values
                unique_species_in_pca_pc1 = sorted(pca_df['simpSciName'].unique())
                pca_palette_pc1 = {s: species_colors.get(s, '#808080') for s in unique_species_in_pca_pc1}
                plt.figure(figsize=(max(10,len(unique_species_in_pca_pc1)*0.6),6))
                sns.stripplot(x='simpSciName',y='PC1',data=pca_df,jitter=True,palette=pca_palette_pc1, order=unique_species_in_pca_pc1)
                plt.title('PC1 Scores by Species'); plt.ylabel(f'PC1 ({explained_variance_ratio[0]*100:.1f}%)'); plt.xticks(rotation=45,ha='right'); plt.tight_layout(); plt.show()
                print("\n--- Loadings for Principal Component 1 ---"); loadings_df = pd.DataFrame(pca.components_[:1,:].T,columns=['PC1'],index=pca_data.columns)
                print(loadings_df.sort_values(by='PC1',ascending=False,key=abs))
            else: print("No principal components with significant variance found for biplot.")
# This is the end of the main 'else' block
print("\n--- All analyses complete ---")

## Step 14b: Additonal PCA visualizations

In [None]:
# --- Step 14b: Additonal PCA visualizations ---

# Placeholder for pca_df if not defined
if 'pca_df' not in globals():
    print("Warning: 'pca_df' not found. Creating a placeholder for visualization structure.")
    # Try to get species names if gdf_occurrences_with_climate exists
    species_names_for_placeholder = ['Species A', 'Species B']
    if 'gdf_occurrences_with_climate' in globals() and 'simpSciName' in gdf_occurrences_with_climate.columns:
        unique_spp = gdf_occurrences_with_climate['simpSciName'].dropna().unique()
        if len(unique_spp) > 0:
            species_names_for_placeholder = np.random.choice(unique_spp, size=100, replace=True)
        else: # if no species, use default placeholder names
            species_names_for_placeholder = np.random.choice(species_names_for_placeholder, size=100, replace=True)


    pca_df = pd.DataFrame({
        'PC1': np.random.randn(100), 
        'PC2': np.random.randn(100),
        'simpSciName': species_names_for_placeholder
    })

# Placeholder for loadings, pca_data, explained_variance_ratio, loadings_df, pca object
if 'pca_data' not in globals(): # pca_data has descriptive column names
    print("Warning: 'pca_data' not found. Creating placeholder.")
    # Try to get selected_vars_descriptive if available
    if 'selected_vars_descriptive' in globals() and selected_vars_descriptive:
        pca_data_cols = selected_vars_descriptive[:max(2,len(selected_vars_descriptive)//2)] # Take a subset
    else:
        pca_data_cols = ['BIO1 Ann Mean Temp', 'BIO12 Ann Precip', 'Elevation']
    pca_data = pd.DataFrame(np.random.rand(100, len(pca_data_cols)), columns=pca_data_cols)


if 'loadings' not in globals() or 'pca' not in globals(): # loadings (numpy array), pca (sklearn object)
    print("Warning: 'loadings' array or 'pca' object not found. Creating placeholders.")
    _num_vars = len(pca_data.columns)
    loadings = np.random.rand(_num_vars, 2) # Dummy loadings for PC1, PC2
    class PlaceholderPCA:
        explained_variance_ratio_ = np.array([0.4, 0.3] + [0.1] * (_num_vars-2) if _num_vars > 2 else [0.4,0.3])
        explained_variance_ratio_ = explained_variance_ratio_[explained_variance_ratio_ > 0]
        if len(explained_variance_ratio_) > _num_vars: explained_variance_ratio_ = explained_variance_ratio_[:_num_vars]
        if sum(explained_variance_ratio_) > 1 : explained_variance_ratio_ = explained_variance_ratio_/sum(explained_variance_ratio_)


    pca = PlaceholderPCA()

if 'explained_variance_ratio' not in globals(): # Should be pca.explained_variance_ratio_
    explained_variance_ratio = pca.explained_variance_ratio_


if 'loadings_df' not in globals():
    print("Warning: 'loadings_df' not found. Creating placeholder.")
    loadings_df = pd.DataFrame(loadings[:,:2] if loadings.shape[1] >=2 else loadings, 
                               columns=['PC1', 'PC2'][:loadings.shape[1]], 
                               index=pca_data.columns)
# --- End Prerequisite Checks ---


# --- Visualization 1: 2D Density Plot (KDE) of PC Scores ---
print("\n--- PCA Viz 1: 2D Density Plot (KDE) ---")
if 'pca_df' in globals() and not pca_df.empty and 'PC1' in pca_df.columns and 'PC2' in pca_df.columns \
   and len(pca_df) > 2: # KDE needs a few points
    plt.figure(figsize=(10, 8))
    sns.kdeplot(x='PC1', y='PC2', data=pca_df, 
                fill=True, cmap="viridis", thresh=0.05, alpha=0.7, levels=10)
    
    if 'loadings' in globals() and 'pca_data' in globals() and loadings.shape[0] == len(pca_data.columns) and loadings.shape[1] >=2:
        arrow_scale_base = max(np.abs(pca_df['PC1']).max(), np.abs(pca_df['PC2']).max()) if not pca_df.empty and pca_df['PC1'].notna().any() and pca_df['PC2'].notna().any() else 1.0
        arrow_scale = np.sqrt(arrow_scale_base) * 1.2 if arrow_scale_base > 0 else 1.2
        if pd.isna(arrow_scale) or np.isinf(arrow_scale) : arrow_scale = 1.5
        for i, var_name in enumerate(pca_data.columns):
            plt.arrow(0,0,loadings[i,0]*arrow_scale,loadings[i,1]*arrow_scale,color='red',alpha=0.8,head_width=0.08,head_length=0.1,overhang=0.3)
            plt.text(loadings[i,0]*arrow_scale*1.15,loadings[i,1]*arrow_scale*1.15,var_name,color='black',ha='center',va='center',fontsize=9, weight='bold')

    plt.xlabel(f'PC1 ({explained_variance_ratio[0]*100:.1f}%)', fontsize=12)
    plt.ylabel(f'PC2 ({explained_variance_ratio[1]*100:.1f}%)', fontsize=12)
    plt.title('2D Density of Occurrence Points in PC Space', fontsize=16)
    plt.axhline(0, color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
    plt.axvline(0, color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
    plt.grid(True, linestyle=':', alpha=0.4); plt.tight_layout(); plt.show()
else:
    print("Skipping 2D Density Plot: pca_df or required columns/data missing or insufficient.")


# --- Visualization 2: Heatmap of PC Loadings ---
print("\n--- PCA Viz 3: Heatmap of Loadings ---")
if 'loadings_df' in globals() and not loadings_df.empty and ('PC1' in loadings_df.columns and 'PC2' in loadings_df.columns):
    plt.figure(figsize=(8, max(6, len(loadings_df) * 0.45)))
    sns.heatmap(loadings_df[['PC1', 'PC2']], annot=True, cmap="coolwarm", center=0, fmt=".2f", linewidths=.5, cbar_kws={'label': 'Loading Value'})
    plt.title('Heatmap of PCA Loadings for PC1 & PC2', fontsize=14)
    plt.ylabel('Original Climate Variables', fontsize=10); plt.xlabel('Principal Components', fontsize=10)
    plt.xticks(rotation=0); plt.yticks(rotation=0, fontsize=9); plt.tight_layout(); plt.show()
else:
    print("Skipping Loadings Heatmap: loadings_df or required PCs not available.")


# --- Visualization 3: Scree Plot (Explained Variance) ---
print("\n--- PCA Viz 4: Scree Plot ---")
if 'pca' in globals() and hasattr(pca, 'explained_variance_ratio_') and len(pca.explained_variance_ratio_) > 0:
    explained_variance = pca.explained_variance_ratio_
    cumulative_explained_variance = np.cumsum(explained_variance)
    num_components = len(explained_variance)

    plt.figure(figsize=(10, 6))
    plt.bar(range(1, num_components + 1), explained_variance, alpha=0.7, align='center', label='Individual explained variance', color='skyblue')
    plt.step(range(1, num_components + 1), cumulative_explained_variance, where='mid', label='Cumulative explained variance', color='red', linestyle='--')
    plt.ylabel('Explained Variance Ratio', fontsize=12); plt.xlabel('Principal Component Index', fontsize=12)
    plt.title('Scree Plot - Explained Variance by Principal Component', fontsize=14)
    plt.xticks(ticks=range(1, num_components + 1)); plt.yticks(np.arange(0, 1.1, 0.1))
    plt.legend(loc='best'); plt.grid(True, linestyle=':', alpha=0.6); plt.tight_layout()
    for i, cum_var in enumerate(cumulative_explained_variance):
        if i < 5 or (i > 0 and cum_var > 0.9 and cumulative_explained_variance[i-1] < 0.9) or i == num_components -1 :
            plt.text(i + 1.1, cum_var + 0.02, f"{cum_var*100:.1f}%", fontsize=8, color='red')
    plt.show()
else:
    print("Skipping Scree Plot: PCA object or explained_variance_ratio_ not available or empty.")

print("\n--- All PCA visualizations attempted ---")


## Step 14c: PCA by species

In [None]:
# --- Step 14c: PCA by species ---

# Placeholder checks (assuming these are defined from your PCA cell)
if 'pca_df' not in globals() or pca_df.empty:
    print("Error: 'pca_df' not found or empty. Cannot generate species-specific KDEs.")
    pca_df = pd.DataFrame({'PC1': [], 'PC2': [], 'simpSciName': []}) 
if 'loadings' not in globals(): loadings = np.array([[0,0],[0,0]]) 
if 'pca_data' not in globals(): pca_data = pd.DataFrame(columns=['Var1', 'Var2']) 
if 'explained_variance_ratio' not in globals(): explained_variance_ratio = np.array([0.5, 0.5])
if 'species_colors' not in globals(): species_colors = {} 

# --- Static Individual 2D KDE Plots per Species with FIXED AXES ---
print("\n--- PCA Viz: Individual 2D Density Plots (KDE) per Species (Fixed Axes) ---")

if 'pca_df' in globals() and not pca_df.empty and \
   'PC1' in pca_df.columns and 'PC2' in pca_df.columns and \
   'simpSciName' in pca_df.columns and len(pca_df) > 2 :

    unique_species = sorted(pca_df['simpSciName'].dropna().unique())

    if not unique_species:
        print("No unique species found in pca_df to plot.")
    else:
        # --- Determine overall PC score ranges for fixed axes ---
        pc1_min_overall = pca_df['PC1'].min()
        pc1_max_overall = pca_df['PC1'].max()
        pc2_min_overall = pca_df['PC2'].min()
        pc2_max_overall = pca_df['PC2'].max()

        # Add a small buffer (e.g., 5-10% of the range)
        pc1_range = pc1_max_overall - pc1_min_overall
        pc2_range = pc2_max_overall - pc2_min_overall
        
        # Handle case where range is zero (e.g. all points at the same PC value)
        buffer_pc1 = pc1_range * 0.05 if pc1_range > 1e-6 else 0.5 
        buffer_pc2 = pc2_range * 0.05 if pc2_range > 1e-6 else 0.5
        
        fixed_xlim = (pc1_min_overall - buffer_pc1, pc1_max_overall + buffer_pc1)
        fixed_ylim = (pc2_min_overall - buffer_pc2, pc2_max_overall + buffer_pc2)
        print(f"  Using fixed axes for all species plots: Xlim={fixed_xlim}, Ylim={fixed_ylim}")

        # Determine a consistent arrow scale based on overall PC score ranges
        overall_arrow_scale_base = max(abs(fixed_xlim[0]), abs(fixed_xlim[1]), abs(fixed_ylim[0]), abs(fixed_ylim[1]))
        overall_arrow_scale = np.sqrt(overall_arrow_scale_base) * 1.0 if overall_arrow_scale_base > 0 else 1.0
        if pd.isna(overall_arrow_scale) or np.isinf(overall_arrow_scale) or overall_arrow_scale == 0: overall_arrow_scale = 1.0


        for species_name in unique_species:
            species_data_pca = pca_df[pca_df['simpSciName'] == species_name]

            if len(species_data_pca) < 3: 
                print(f"  Skipping KDE for '{species_name}': Insufficient data points ({len(species_data_pca)}).")
                continue
            
            plt.figure(figsize=(10, 8))
            species_color = species_colors.get(species_name, sns.color_palette()[0])

            sns.kdeplot(x='PC1', y='PC2', data=species_data_pca, 
                        fill=True, color=species_color, alpha=0.6,
                        thresh=0.05, levels=8)
            
            plt.scatter(species_data_pca['PC1'], species_data_pca['PC2'], 
                        color=species_color, s=10, alpha=0.3)

            # Add Loading arrows (using overall_arrow_scale for consistency)
            if 'loadings' in globals() and 'pca_data' in globals() and \
               loadings.shape[0] == len(pca_data.columns) and loadings.shape[1] >=2 :
                for i, var_name in enumerate(pca_data.columns):
                    plt.arrow(0,0,loadings[i,0]*overall_arrow_scale,loadings[i,1]*overall_arrow_scale,color='black',alpha=0.7,head_width=0.05,head_length=0.08,overhang=0.3,zorder=5)
                    plt.text(loadings[i,0]*overall_arrow_scale*1.15,loadings[i,1]*overall_arrow_scale*1.15,var_name,color='dimgray',ha='center',va='center',fontsize=8,zorder=5,
                             path_effects=[withStroke(linewidth=0.5, foreground='white')])

            # Apply fixed axes limits
            plt.xlim(fixed_xlim)
            plt.ylim(fixed_ylim)

            plt.xlabel(f'PC1 ({explained_variance_ratio[0]*100:.1f}%)', fontsize=12)
            plt.ylabel(f'PC2 ({explained_variance_ratio[1]*100:.1f}%)', fontsize=12)
            plt.title(f'2D Density (KDE) in PC Space for: {species_name}', fontsize=14)
            plt.axhline(0, color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
            plt.axvline(0, color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
            plt.grid(True, linestyle=':', alpha=0.4)
            plt.tight_layout()
            plt.show()
else:
    print("Skipping Individual Species KDE Plots: pca_df or required columns/data missing or insufficient.")
