# Imports

In [None]:
# Standard libraries
import os
import joblib

# Data handling
import pandas as pd
import numpy as np

# Machine learning
from xgboost import XGBRegressor
from sklearn.metrics import mean_squared_error, r2_score

# Geospatial
import geopandas as gpd
from shapely.geometry import Polygon
from shapely import wkt
import contextily as ctx  # for background basemap tiles

# Visualization
import matplotlib.pyplot as plt
import folium
import branca.colormap as cm
from IPython.display import HTML

# Explainable AI
import shap

# Data preprocessing

### Prepare earthquake data

In [None]:
col_names = ['Date', 'Time', 'Event Type', 'GT', 'Magnitude', 'Magnitude Type', 'Latitude', 'Longitude', 'Depth', 'Quality', 'Event ID', 'Number of Picked Phases', 'Ngrams']

all_eq = []

for year in range(1977, 2019): # Exclude years not in wells data
    filename = f'data/raw/SCEC_DC/{year}.catalog'

    eq = pd.read_csv(filename, sep=r'\s+', comment='#', header=None, names=col_names)

    eq['Datetime'] = pd.to_datetime(eq['Date'] + ' ' + eq['Time'], errors='coerce') # Invalid entries set as NaT
    eq['Year'] = year

    all_eq.append(eq)

eqs = pd.concat(all_eq, ignore_index=True)
eqs.to_csv('data/processed/earthquakes.csv', index=False)

In [None]:
eqs = pd.read_csv('data/processed/earthquakes.csv')
eqs.head()

### Prepare well injections data

In [None]:
# Obtain the API numbers to loop through
api_filepath = 'data/raw/dataAllFields/WellsAPInumber.dat'

with open(api_filepath, 'r') as f:
    wells_api = [line.strip() for line in f]

all_injs = []

for api in wells_api:
    filepath = f'data/raw/dataAllFields/Well_Injection_API_{api}.xlsx'
    well_metadata = pd.read_excel(filepath, nrows=1)
    injection_data = pd.read_excel(filepath, skiprows=3)

    # Drop columns in injection_data containing 'Unnamed'
    drop_cols = injection_data.columns[injection_data.columns.str.contains('Unnamed')]
    injection_data.drop(drop_cols, axis=1, inplace=True)

    # Handle case when there are no rows in injection_data - set all column values as NA
    if injection_data.empty:
        injection_data = pd.DataFrame({col: pd.Series([pd.NA], dtype='object') for col in injection_data.columns})

    # Duplicate well_metadata rows to match number of rows of injection_data
    well_md = pd.concat([well_metadata]*len(injection_data), ignore_index=True)

    # Final concatenation
    injection = pd.concat([well_md, injection_data], axis=1)

    all_injs.append(injection)

injs = pd.concat(all_injs, ignore_index=True)

# Convert columns into correct dtypes
injs = injs.convert_dtypes()

injs.to_csv('data/processed/well_injections.csv', index=False)

In [None]:
inj = pd.read_csv('data/processed/well_injections.csv', low_memory=False)
inj.head()

### Prepare well production data

In [None]:
# Obtain the API numbers to loop through
api_filepath = 'data/raw/dataAllFields/WellsAPInumber.dat'

with open(api_filepath, 'r') as f:
    wells_api = [line.strip() for line in f]

all_prod = []

for api in wells_api:
    filepath = f'data/raw/dataAllFields/Well_Production_API_{api}.xlsx'
    well_metadata = pd.read_excel(filepath, nrows=1)
    production_data = pd.read_excel(filepath, skiprows=3)

    # Drop columns in production_data containing 'Unnamed'
    drop_cols = production_data.columns[production_data.columns.str.contains('Unnamed')]
    production_data.drop(drop_cols, axis=1, inplace=True)

    # Handle case when there are no rows in production_data - set all column values as NA
    if production_data.empty:
        production_data = pd.DataFrame({col: pd.Series([pd.NA], dtype='object') for col in production_data.columns})

    # Duplicate well_metadata rows to match number of rows of production_data
    well_md = pd.concat([well_metadata]*len(production_data), ignore_index=True)

    # Final concatenation
    production = pd.concat([well_md, production_data], axis=1)

    all_prod.append(production)

prod = pd.concat(all_prod, ignore_index=True)

# Convert columns into correct dtypes
prod = prod.convert_dtypes()

prod.to_csv('data/processed/well_productions.csv', index=False)

In [None]:
prod = pd.read_csv('data/processed/well_productions.csv', low_memory=False)
prod.head()

# Plot distribution of well locations with their corresponding:   
- Total water/steam injected  
- Total gas/air injected  
- Total water produced  
- Total gas produced  
- Total oil produced  

## Total water/steam injected

### Prepare total injections data

In [None]:
inj.columns

In [None]:
# Only use rows containing annual total
inj_total = inj[inj['Injection Date'].str.contains('Total', case=False, na=False)].copy()

# Sum all annual total, grouped by 'Well #'
inj_total['Total Water or Steam Injected (bbl)'] = inj_total.groupby('Well #')['Water or Steam Injected (bbl)'].transform('sum')
inj_total['Total Gas or Air Injected (Mcf)'] = inj_total.groupby('Well #')['Gas or Air Injected (Mcf)'].transform('sum')

# Keep only unique wells
inj_totals = inj_total.drop_duplicates(subset='Well #', keep='first').copy()

# Drop invalid coordinates - latitude 0, longitude 0
inj_totals = inj_totals[(inj_totals['Latitude'] != 0) & (inj_totals['Longitude'] != 0)]

inj_totals.head()

In [None]:
inj_totals.columns

In [None]:
# Set colour scale based on total water/steam injected
min_val = np.log1p(inj_totals['Total Water or Steam Injected (bbl)'].min())
max_val = np.log1p(inj_totals['Total Water or Steam Injected (bbl)'].max())

colormap = cm.linear.RdBu_10.scale(min_val, max_val).to_step(10)

colormap.caption = 'Logarithmic Total Water/Steam Injected (bbl)'

# Create base map
wells_water_inj_map = folium.Map(location=[34.05, -118.25], zoom_start=10, tiles='OpenStreetMap')

for _, row in inj_totals.iterrows():
    vol = np.log1p(row['Total Water or Steam Injected (bbl)']) # Scale the values down
    color = colormap(vol)
    folium.CircleMarker(
        location=[row['Latitude'], row['Longitude']],
        radius=5,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=1,
        opacity=0
    ).add_to(wells_water_inj_map)

# Add legend
colormap.add_to(wells_water_inj_map)

wells_water_inj_map.save("vizs/distr_maps/wells_water_inj_map.html")

## Total gas/air injected

In [None]:
# Set colour scale based on total gas/air injected
min_val = np.log1p(inj_totals['Total Gas or Air Injected (Mcf)'].min())
max_val = np.log1p(inj_totals['Total Gas or Air Injected (Mcf)'].max())

colormap = cm.linear.RdBu_10.scale(min_val, max_val).to_step(10)

colormap.caption = 'Logarithmic Total Gas/Air Injected (Mcf)'

# Create base map
wells_gas_inj_map = folium.Map(location=[34.05, -118.25], zoom_start=10, tiles='OpenStreetMap')

for _, row in inj_totals.iterrows():
    vol = np.log1p(row['Total Gas or Air Injected (Mcf)'])
    color = colormap(vol)
    folium.CircleMarker(
        location=[row['Latitude'], row['Longitude']],
        radius=5,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=1,
        opacity=0
    ).add_to(wells_gas_inj_map)

# Add legend
colormap.add_to(wells_gas_inj_map)

wells_gas_inj_map.save("vizs/distr_maps/wells_gas_inj_map.html")

#### Side-by-side view of both injection maps

In [None]:
water_html = wells_water_inj_map._repr_html_()
gas_html = wells_gas_inj_map._repr_html_()

HTML(f"""
<div style="display: flex; justify-content: space-between;">

  <div style="width: 49.5%;">
    <h3 style="text-align: center;">Distribution of Wells by Total Water/Steam Injected</h3>
    {water_html}
  </div>

  <div style="width: 49.5%;">
    <h3 style="text-align: center;">Distribution of Wells by Total Gas/Air Injected</h3>
    {gas_html}
  </div>

</div>
""")

In [None]:
# Save maps - download both well injection html maps first before the combined, then make sure all 3 are in the same local path before opening the combined html
with open("vizs/distr_maps/combined_well_inj_maps.html", "w") as f:
    f.write(f"""
    <html>
    <head><title>Well Injection Maps</title></head>
    <body>

    <div style="display: flex; justify-content: space-between;">

      <div style="width: 49.5%;">
        <h3 style="text-align: center;">Distribution of Wells by Total Water/Steam Injected</h3>
        <iframe src="wells_water_inj_map.html" width="100%" height="500" style="border:none;"></iframe>
      </div>

      <div style="width: 49.5%;">
        <h3 style="text-align: center;">Distribution of Wells by Total Gas/Air Injected</h3>
        <iframe src="wells_gas_inj_map.html" width="100%" height="500" style="border:none;"></iframe>
      </div>

    </div>

    </body>
    </html>
    """)

## Total water produced

### Prepare total production data

In [None]:
prod.columns

In [None]:
# Only use rows containing annual total
prod_total = prod[prod['Production Date'].str.contains('Total', case=False, na=False)].copy()

# Sum all annual total, grouped by 'Well #'
prod_total['Total Water Produced (bbl)'] = prod_total.groupby('Well #')['Water Produced (bbl)'].transform('sum')
prod_total['Total Gas Produced (Mcf)'] = prod_total.groupby('Well #')['Gas Produced (Mcf)'].transform('sum')
prod_total['Total Oil Produced (bbl)'] = prod_total.groupby('Well #')['Oil Produced (bbl)'].transform('sum')

# Keep only unique wells
prod_totals = prod_total.drop_duplicates(subset='Well #', keep='first').copy()

# Drop invalid coordinates - latitude 0, longitude 0
prod_totals = prod_totals[(prod_totals['Latitude'] != 0) & (prod_totals['Longitude'] != 0)]

prod_totals.head()

In [None]:
# Set colour scale based on total water produced
min_val = np.log1p(prod_totals['Total Water Produced (bbl)'].min())
max_val = np.log1p(prod_totals['Total Water Produced (bbl)'].max())

colormap = cm.linear.RdBu_10.scale(min_val, max_val).to_step(10)

colormap.caption = 'Logarithmic Total Water Produced (bbl)'

# Create base map
wells_water_prod_map = folium.Map(location=[34.05, -118.25], zoom_start=10, tiles='OpenStreetMap')

for _, row in prod_totals.iterrows():
    vol = np.log1p(row['Total Water Produced (bbl)']) # Scale the values down
    color = colormap(vol)
    folium.CircleMarker(
        location=[row['Latitude'], row['Longitude']],
        radius=5,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=1,
        opacity=0
    ).add_to(wells_water_prod_map)

# Add legend
colormap.add_to(wells_water_prod_map)

wells_water_prod_map.save("vizs/distr_maps/wells_water_prod_map.html")

## Total gas produced

In [None]:
# Set colour scale based on total gas produced
min_val = np.log1p(prod_totals['Total Gas Produced (Mcf)'].min())
max_val = np.log1p(prod_totals['Total Gas Produced (Mcf)'].max())

colormap = cm.linear.RdBu_10.scale(min_val, max_val).to_step(10)

colormap.caption = 'Logarithmic Total Gas Produced (Mcf)'

# Create base map
wells_gas_prod_map = folium.Map(location=[34.05, -118.25], zoom_start=10, tiles='OpenStreetMap')

for _, row in prod_totals.iterrows():
    vol = np.log1p(row['Total Gas Produced (Mcf)']) # Scale the values down
    color = colormap(vol)
    folium.CircleMarker(
        location=[row['Latitude'], row['Longitude']],
        radius=5,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=1,
        opacity=0
    ).add_to(wells_gas_prod_map)

# Add legend
colormap.add_to(wells_gas_prod_map)

wells_gas_prod_map.save("vizs/distr_maps/wells_gas_prod_map.html")

## Total oil produced

In [None]:
# Set colour scale based on total oil produced
min_val = np.log1p(prod_totals['Total Oil Produced (bbl)'].min())
max_val = np.log1p(prod_totals['Total Oil Produced (bbl)'].max())

colormap = cm.linear.RdBu_10.scale(min_val, max_val).to_step(10)

colormap.caption = 'Logarithmic Total Oil Produced (bbl)'

# Create base map
wells_oil_prod_map = folium.Map(location=[34.05, -118.25], zoom_start=10, tiles='OpenStreetMap')

for _, row in prod_totals.iterrows():
    vol = np.log1p(row['Total Oil Produced (bbl)']) # Scale the values down
    color = colormap(vol)
    folium.CircleMarker(
        location=[row['Latitude'], row['Longitude']],
        radius=5,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=1,
        opacity=0
    ).add_to(wells_oil_prod_map)

# Add legend
colormap.add_to(wells_oil_prod_map)

wells_oil_prod_map.save("vizs/distr_maps/wells_oil_prod_map.html")

#### Side-by-side view of all production maps

In [None]:
water_html = wells_water_prod_map._repr_html_()
gas_html = wells_gas_prod_map._repr_html_()
oil_html = wells_oil_prod_map._repr_html_()

HTML(f"""
<div style="display: flex; justify-content: space-between;">

  <div style="width: 33%;">
    <h3 style="text-align: center;">Distribution of Wells by Total Water Produced</h3>
    {water_html}
  </div>

  <div style="width: 33%;">
    <h3 style="text-align: center;">Distribution of Wells by Total Gas Produced</h3>
    {gas_html}
  </div>

  <div style="width: 33%;">
    <h3 style="text-align: center;">Distribution of Wells by Total Oil Produced</h3>
    {gas_html}
  </div>

</div>
""")

In [None]:
# Save maps - download all well production html maps first before the combined, then make sure all 4 are in the same local path before opening the combined html
with open("vizs/distr_maps/combined_well_prod_maps.html", "w") as f:
    f.write(f"""
    <html>
    <head><title>Well Production Maps</title></head>
    <body>

    <div style="display: flex; justify-content: space-between;">

      <div style="width: 33%;">
        <h3 style="text-align: center;">Distribution of Wells by Total Water Produced</h3>
        <iframe src="wells_water_prod_map.html" width="100%" height="500" style="border:none;"></iframe>
      </div>

      <div style="width: 33%;">
        <h3 style="text-align: center;">Distribution of Wells by Total Gas Produced</h3>
        <iframe src="wells_gas_prod_map.html" width="100%" height="500" style="border:none;"></iframe>
      </div>

      <div style="width: 33%;">
        <h3 style="text-align: center;">Distribution of Wells by Total Oil Produced</h3>
        <iframe src="wells_oil_prod_map.html" width="100%" height="500" style="border:none;"></iframe>
      </div>

    </div>

    </body>
    </html>
    """)

# XGBoost model

## Further data preprocessing and feature engineering

In [None]:
# Drop irrelevant columns
drop_cols = ['API #', 'Operator Name', 'County Name', 'Field Name', 'Lease Name',
                'Area Name', 'Area Code', 'District #', 'Section', 'Township',
                'Range', 'Base Meridian', 'API Number','PWT Status', 'Status', 'Pool Code',
             'Event Type', 'GT', 'Event ID', 'Number of Picked Phases', 'Ngrams']

# Create a new DataFrame with those columns dropped
inj_copy = inj.drop(columns=drop_cols, errors='ignore').copy()
prod_copy = prod.drop(columns=drop_cols, errors='ignore').copy()
eqs_copy = eqs.drop(columns=drop_cols, errors='ignore').copy()

In [None]:
# Make sure all date cols are in Datetime
inj_copy1 = inj_copy.copy()
prod_copy1 = prod_copy.copy()
eqs_copy1 = eqs_copy.copy()

inj_copy1['Injection Date'] = pd.to_datetime(inj_copy1['Injection Date'], errors='coerce')
prod_copy1['Production Date'] = pd.to_datetime(prod_copy1['Production Date'], errors='coerce')
eqs_copy1['Date'] = pd.to_datetime(eqs_copy1['Date'], errors='coerce')

# Create YearMonth col for merging
inj_copy1['YearMonth'] = inj_copy1['Injection Date'].dt.to_period('M')
prod_copy1['YearMonth'] = prod_copy1['Production Date'].dt.to_period('M')
eqs_copy1['YearMonth'] = eqs_copy1['Date'].dt.to_period('M')

In [None]:
# Merge inj and prod
inj_prod = pd.merge(inj_copy1, prod_copy1,
                    on=['Latitude', 'Longitude', 'YearMonth'],
                    how='inner')

inj_prod.head()

In [None]:
# Add Month and Geometry cols into inj_prod
# Convert DataFrame to GeoDataFrame
geometry = gpd.points_from_xy(inj_prod['Longitude'], inj_prod['Latitude'])
gdf = gpd.GeoDataFrame(inj_prod, geometry=geometry, crs="EPSG:4326")

# Get bounding box of all points
minx, miny, maxx, maxy = gdf.total_bounds

grid_size = 0.05  # ~5km

cols = list(np.arange(minx, maxx + grid_size, grid_size))
rows = list(np.arange(miny, maxy + grid_size, grid_size))

# Create polygons for the grid
polygons = []
ids = []
for i, x in enumerate(cols[:-1]):
    for j, y in enumerate(rows[:-1]):
        polygons.append(
            Polygon([
                (x, y),
                (x + grid_size, y),
                (x + grid_size, y + grid_size),
                (x, y + grid_size)
            ])
        )
        ids.append(f'{i}_{j}')

grid = gpd.GeoDataFrame({'grid_id': ids, 'grid_poly': polygons}, crs="EPSG:4326", geometry=polygons)

# Spatial join to attach grid polygon to each point
inj_prod_gdf = gpd.sjoin(gdf, grid, how='inner', predicate='intersects')

# Keep polygon geometry in a new column instead of replacing point geometry
inj_prod_gdf = inj_prod_gdf.rename(columns={'grid_poly': 'grid_geometry'})

# Format YearMonth string
inj_prod_gdf['YearMonth'] = inj_prod_gdf['YearMonth'].dt.strftime('%Y-%m')

In [None]:
inj_prod_gdf.head()

In [None]:
# Add Month and Geometry cols into eqs_copy1, using inj_prod_gdf bounding box
# Convert DataFrame to GeoDataFrame
geometry = gpd.points_from_xy(eqs_copy1['Longitude'], eqs_copy1['Latitude'])
gdf = gpd.GeoDataFrame(eqs_copy1, geometry=geometry, crs="EPSG:4326")

# Spatial join to attach grid polygon to each point using inj_prod_gdf's grid
eqs_gdf = gpd.sjoin(gdf, grid, how='inner', predicate='intersects')

# Keep polygon geometry in a new column instead of replacing point geometry
eqs_gdf = eqs_gdf.rename(columns={'grid_poly': 'grid_geometry'})

# Format YearMonth string
eqs_gdf['YearMonth'] = eqs_gdf['YearMonth'].dt.strftime('%Y-%m')

In [None]:
eqs_gdf.head()

In [None]:
print(inj_prod_gdf.columns)
print(eqs_gdf.columns)

In [None]:
# Merge injections & productions table with earthquake data
overall = pd.merge(inj_prod_gdf, eqs_gdf,
                   on=['grid_id', 'grid_geometry', 'YearMonth'], how='inner')

overall.head()

In [None]:
overall.columns

In [None]:
# Create a copy of overall df, with more focused, fewer columns
overall_df = overall.copy()
drop_cols = ['Well #_x', 'Latitude_x', 'Longitude_x', 'Injection Date','Well Type_x', 'Reported Date_x',
             'Well #_y', 'Well Type_y', 'Reported Date_y', 'geometry_x', 'index_right_x', 'Date', 'Time', 
             'Latitude_y', 'Longitude_y', 'Datetime', 'Year', 'geometry_y', 'index_right_y']
overall_df = overall_df.drop(columns=drop_cols, errors='ignore')

overall_df.head()

In [None]:
overall_df.columns

In [None]:
overall_df = overall_df.sort_values(['grid_id'])

# Calculate lags and rolling averages
value_cols = ['Water or Steam Injected (bbl)', 'Gas or Air Injected (Mcf)', 
              'Oil Produced (bbl)', 'Water Produced (bbl)', 'Gas Produced (Mcf)'] 

time_windows = {'3m': 3, '1y': 12, '5y': 60}

grouped = overall_df.groupby(['grid_id'])

for col in value_cols:
    for label, window in time_windows.items():

        # Lag (value from previous month)
        overall_df[f'{col}_lag_{label}'] = grouped[col].transform(lambda x: x.shift(1))

        # Rolling mean
        overall_df[f'{col}_roll_mean_{label}'] = grouped[col].transform(lambda x: x.rolling(window, min_periods=1).mean())

        # Rolling sum
        overall_df[f'{col}_roll_sum_{label}'] = grouped[col].transform(lambda x: x.rolling(window, min_periods=1).sum())

        # Rolling std
        overall_df[f'{col}_roll_std_{label}'] = grouped[col].transform(lambda x: x.rolling(window, min_periods=1).std())

# Ratios
overall_df['water_gas_inj_ratio'] = overall_df['Water or Steam Injected (bbl)'] / (overall_df['Gas or Air Injected (Mcf)'] + 1e-6)

overall_df['water_gas_prod_ratio'] = overall_df['Water Produced (bbl)'] / (overall_df['Gas Produced (Mcf)'] + 1e-6)
overall_df['water_oil_prod_ratio'] = overall_df['Water Produced (bbl)'] / (overall_df['Oil Produced (bbl)'] + 1e-6)
overall_df['gas_oil_prod_ratio'] = overall_df['Gas Produced (Mcf)'] / (overall_df['Oil Produced (bbl)'] + 1e-6) 

# Cumulative totals
overall_df['cum_water_inj'] = grouped['Water or Steam Injected (bbl)'].cumsum()
overall_df['cum_gas_inj'] = grouped['Gas or Air Injected (Mcf)'].cumsum()

overall_df['cum_water_prod'] = grouped['Water Produced (bbl)'].cumsum()
overall_df['cum_gas_prod'] = grouped['Gas Produced (Mcf)'].cumsum()
overall_df['cum_oil_prod'] = grouped['Oil Produced (bbl)'].cumsum()

# Additional features
overall_df['inj_intensity'] = (overall_df['Water or Steam Injected (bbl)'] / (overall_df['Days Well Injected'] + 1e-6))
overall_df['prod_efficiency'] = (overall_df['Oil Produced (bbl)'] / (overall_df['Oil Produced (bbl)'] + overall_df['Water Produced (bbl)'] + 1e-6))
overall_df['pressure_diff'] = overall_df['Tubing Pressure'] - overall_df['Casing Pressure']
overall_df['depth_norm_inj'] = (overall_df['Water or Steam Injected (bbl)'] / (overall_df['Depth'] + 1e-6))
overall_df['btu_norm_gas'] = (overall_df['Gas Produced (Mcf)'] * overall_df['BTU'])

# Seasonality
overall_df['YearMonth'] = overall_df['YearMonth'].astype('period[M]')
overall_df['Month'] = overall_df['YearMonth'].dt.month

overall_df['month_sin'] = np.sin(2*np.pi*overall_df['Month']/12)
overall_df['month_cos'] = np.cos(2*np.pi*overall_df['Month']/12)

In [None]:
overall_df.columns

In [None]:
# Create target cols
overall_df['Total Earthquakes'] = (overall_df.groupby(['grid_id', 'YearMonth'])['Magnitude'].transform('count'))

for label, window in time_windows.items():
    overall_df[f'earthquake_count_{label}'] = (grouped['Total Earthquakes'].transform(lambda x: x.rolling(window, min_periods=1).sum()))

    overall_df[f'earthquake_avg_mag_{label}'] = (grouped['Magnitude'].transform(lambda x: x.rolling(window, min_periods=1).mean()))

In [None]:
overall_df.to_csv('data/final/final_dataset.csv', index=False)

In [None]:
df = pd.read_csv('data/final/final_dataset.csv')

df.head()

In [None]:
# Train-val-test split: 80-10-10
df = df.sort_values("YearMonth").reset_index(drop=True) # To prevent leakage
months = df['YearMonth'].dropna().sort_values().unique()
n = len(months)
train_months = months[:int(0.8*n)]
val_months   = months[int(0.8*n):int(0.9*n)]
test_months  = months[int(0.9*n):]

train_df = df[df['YearMonth'].isin(train_months)]
val_df   = df[df['YearMonth'].isin(val_months)]
test_df  = df[df['YearMonth'].isin(test_months)]

In [None]:
df_cols = list(df.columns)
feature_cols = [c for c in df_cols[:-6] if c not in ['Production Date', 'YearMonth', 'grid_id', 'grid_geometry', 'Total Earthquakes', 'Gravity of Oil']]
target_cols = df_cols[-6:]

X_train = train_df[feature_cols].copy()
X_val = val_df[feature_cols].copy()
X_test = test_df[feature_cols].copy()

# Handle categorical cols
cat_cols = ['Source of Water', 'Kind of Water', 'Magnitude Type', 'Quality']
X_train[cat_cols] = X_train[cat_cols].astype('category')
X_val[cat_cols] = X_val[cat_cols].astype('category')
X_test[cat_cols] = X_test[cat_cols].astype('category')

count_targets = target_cols[::2]  # earthquake_count_3m,1y,5y
mag_targets = target_cols[1::2] # earthquake_avg_mag_3m,1y,5y

# earthquake_count_3m,1y,5y
y_rate_train = np.log1p(train_df[count_targets])  
y_rate_train = y_rate_train.clip(upper=y_rate_train.quantile(0.99), axis=1) # Clip outliers
y_rate_val = np.log1p(val_df[count_targets])
y_rate_test = np.log1p(test_df[count_targets])

# earthquake_avg_mag_3m,1y,5y
y_mag_train = train_df[target_cols[1::2]]  
y_mag_val = val_df[target_cols[1::2]]
y_mag_test = test_df[target_cols[1::2]]

## Model training, prediction and evaluation on test set

In [None]:
# Train, eval and save results of the models
os.makedirs('xgb_model', exist_ok=True)
os.makedirs('xgb_model/results', exist_ok=True)
os.makedirs('xgb_model/results/splits', exist_ok=True)
os.makedirs('xgb_model/indiv', exist_ok=True)

models = {}
predictions = {}

eval_file = 'xgb_model/results/final_eval.txt'
with open(eval_file, 'w') as f:
    f.write("Final Evaluation Metrics\n")

all_targets = count_targets + mag_targets

for target in all_targets:
    # Determine if log-transform is needed
    is_count = target in count_targets
    y_train = y_rate_train[target] if is_count else y_mag_train[target]
    y_val = y_rate_val[target] if is_count else y_mag_val[target]
    
    # Apply log-transform if count target
    if is_count:
        y_train_trans = np.log1p(y_train)
        y_val_trans = np.log1p(y_val)
    else:
        y_train_trans = y_train
        y_val_trans = y_val

    # Initialize and fit model
    model = XGBRegressor(
        n_estimators=400,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        objective="reg:squarederror",
        tree_method="hist",
        enable_categorical=True,
        random_state=10
    )
    model.fit(X_train, y_train_trans, eval_set=[(X_val, y_val_trans)], verbose=0)
    
    # Predict
    y_train_pred = np.expm1(model.predict(X_train)) if is_count else model.predict(X_train)
    y_val_pred   = np.expm1(model.predict(X_val)) if is_count else model.predict(X_val)
    y_test_pred  = np.expm1(model.predict(X_test)) if is_count else model.predict(X_test)
    
    # Store models and predictions
    models[target] = model
    predictions[target] = {'train': y_train_pred, 'val': y_val_pred, 'test': y_test_pred}
    
    # Save model
    joblib.dump(model, f'xgb_model/indiv/{target}.pkl')
    
    # Save predictions
    pd.DataFrame({'pred': y_train_pred}).to_csv(f'xgb_model/results/splits/{target}_train.csv', index=False)
    pd.DataFrame({'pred': y_val_pred}).to_csv(f'xgb_model/results/splits/{target}_val.csv', index=False)
    pd.DataFrame({'pred': y_test_pred}).to_csv(f'xgb_model/results/splits/{target}_test.csv', index=False)
        
    # Evaluate
    rmse_val = mean_squared_error(y_val, y_val_pred)
    r2_val = r2_score(y_val, y_val_pred)
    print(f"{target} RMSE: {rmse_val:.3f}, R^2: {r2_val:.3f}")
    
    # Append evaluation metrics to txt file
    with open(eval_file, 'a') as f:
        f.write(f"{target} RMSE: {rmse_val:.3f}, R^2: {r2_val:.3f}\n")

## SHAP analysis

In [None]:
os.makedirs("vizs/shap", exist_ok=True)
os.makedirs("xgb_model/results", exist_ok=True)
os.makedirs("xgb_model/results/top10_features", exist_ok=True)

# Convert cat cols to cat codes
X_train_shap = X_train.copy()
for c in cat_cols:
    X_train_shap[c] = X_train_shap[c].cat.codes

top_features = {}

for target, model in models.items():
    print(f"\n=== SHAP Analysis for {target} ===")
    
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_train_shap)
    
    # Top 10 features
    mean_abs_shap = pd.DataFrame({'feature': X_train_shap.columns,
                                  'mean_abs_shap': np.abs(shap_values).mean(axis=0)
                                }).sort_values(by='mean_abs_shap', ascending=False)
    
    top_10 = mean_abs_shap.head(10)
    top_features[target] = top_10['feature'].tolist()
    top_features_lst = top_10['feature'].tolist()
    shap_values_top10 = shap_values[:, [X_train_shap.columns.get_loc(f) for f in top_features_lst]]
    X_top10 = X_train_shap[top_features_lst]

    # Save top 10 features to txt
    top10_file = f"xgb_model/results/top10_features/top10_{target}.txt"
    with open(top10_file, 'w') as f:
        f.write(f"Top 10 features for {target}:\n")
        for feat in top_features[target]:
            f.write(f"{feat}\n")    
    
    # Barplot
    plt.figure(figsize=(10,5))
    plt.barh(top_10['feature'][::-1], top_10['mean_abs_shap'][::-1], color='skyblue')
    plt.xlabel("Mean |SHAP value|")
    plt.title(f"Top 10 Features: {target}")
    plt.tight_layout()
    plt.savefig(f"vizs/shap/{target}_barplot.png")
    plt.show()
    plt.close()
    
    # Beeswarm for top 10 features
    shap.summary_plot(shap_values_top10, X_top10, plot_type="dot", show=False)
    plt.title(f"SHAP Beeswarm: {target}")
    plt.tight_layout()
    plt.savefig(f"vizs/shap/{target}_beeswarm.png")
    plt.show()
    plt.close()
    
    print(f"Top 10 features for {target}:")
    print(top_features[target])

## Predict on full dataset to create prediction maps

In [None]:
# Prediction
os.makedirs("xgb_model", exist_ok=True)
os.makedirs("xgb_model/results", exist_ok=True)

pred_df = df.copy()
# Prepare full feature set 
X_full = df[feature_cols].copy()

for c in X_full.select_dtypes(include='object').columns:
    X_full[c] = X_full[c].astype('category')

full_preds = {}

# Earthquake rate preds
for target in count_targets:
    model = models[target]
    
    y_pred_log = model.predict(X_full)
    y_pred = np.expm1(y_pred_log)
    
    full_preds[target] = y_pred
    pred_df['pred_' + target] = y_pred

# Average earthquake magnitude preds
for target in mag_targets:
    model = models[target]

    y_pred = model.predict(X_full)

    full_preds[target] = y_pred
    pred_df['pred_' + target] = y_pred

# Save the combined predictions 
pred_df.to_csv("xgb_model/results/full_dataset_preds.csv", index=False)

# Save the full dataset model 
joblib.dump(models, "xgb_model/full_dataset_xgb.pkl")

In [None]:
# Prediction maps
os.makedirs("vizs/pred_maps", exist_ok=True)

# Convert grid_geometry to shapely only if still string
pred_df["grid_geometry"] = pred_df["grid_geometry"].apply(lambda x: wkt.loads(x) if isinstance(x, str) else x)

# Convert to GeoDataFrame
gdf = gpd.GeoDataFrame(pred_df, geometry="grid_geometry", crs="EPSG:4326")

# Aggregate preds by grid_id 
agg_gdf = gdf.groupby("grid_id").agg({"grid_geometry": "first",
                                      **{'pred_'+col: 'mean' for col in count_targets + mag_targets}
                                    }).reset_index()

agg_gdf = gpd.GeoDataFrame(agg_gdf, geometry='grid_geometry', crs='EPSG:4326')
agg_gdf = agg_gdf.to_crs(epsg=3857)  # For basemap

# Plot with LA basemap 
fig, ax = plt.subplots(2, 3, figsize=(20, 12))

agg_gdf.plot(column='pred_'+count_targets[0], cmap='Reds', legend=True, alpha=0.6, edgecolor='k', ax=ax[0,0])
ctx.add_basemap(ax[0,0], source=ctx.providers.CartoDB.Positron)
ax[0,0].set_title(f'Predicted Earthquake Rate (3M)')

agg_gdf.plot(column='pred_'+count_targets[1], cmap='Reds', legend=True, alpha=0.6, edgecolor='k', ax=ax[0,1])
ctx.add_basemap(ax[0,1], source=ctx.providers.CartoDB.Positron)
ax[0,1].set_title(f'Predicted Earthquake Rate (1Y)')

agg_gdf.plot(column='pred_'+count_targets[2], cmap='Reds', legend=True, alpha=0.6, edgecolor='k', ax=ax[0,2])
ctx.add_basemap(ax[0,2], source=ctx.providers.CartoDB.Positron)
ax[0,2].set_title(f'Predicted Earthquake Rate (5Y)')

agg_gdf.plot(column='pred_'+mag_targets[0], cmap='Blues', legend=True, alpha=0.6, edgecolor='k', ax=ax[1,0])
ctx.add_basemap(ax[1,0], source=ctx.providers.CartoDB.Positron)
ax[1,0].set_title(f'Predicted Average Earthquake Magnitude (3M)')

agg_gdf.plot(column='pred_'+mag_targets[1], cmap='Blues', legend=True, alpha=0.6, edgecolor='k', ax=ax[1,1])
ctx.add_basemap(ax[1,1], source=ctx.providers.CartoDB.Positron)
ax[1,1].set_title(f'Predicted Average Earthquake Magnitude (1Y)')

agg_gdf.plot(column='pred_'+mag_targets[2], cmap='Blues', legend=True, alpha=0.6, edgecolor='k', ax=ax[1,2])
ctx.add_basemap(ax[1,2], source=ctx.providers.CartoDB.Positron)
ax[1,2].set_title(f'Predicted Average Earthquake Magnitude (5Y)')

plt.tight_layout()
plt.savefig('vizs/pred_maps/final_preds.png')
plt.show()