Optimize GWR bandwidth using Reinforcement learning approach (Proximal Policy Optimization, PPO)

In [1]:
from stable_baselines3 import PPO
import pandas as pd

from src.optimizer.reinforce.gwr_optimizer import GwrOptimizerRL
from src.dataset.interfaces.spatial_dataset import IFieldInfo
from src.optimizer.reinforce.callback import EpisodeTracker
from src.dataset.spatial_dataset import SpatialDataset
from src.kernel.gwr_kernel import GwrKernel
from src.log.gwr_logger import GwrLogger
from src.model.gwr import GWR

Create a logger to record the GWR model's information.

In [2]:
logger = GwrLogger()

Load the Georgia dataset and create a spatial dataset.

In [3]:
georgia_data = pd.read_csv(r'./data/GData_utm.csv')
spatialDataset = SpatialDataset(
	georgia_data,
	IFieldInfo(
		predictor_fields=['PctFB', 'PctBlack', 'PctRural'],
		response_field='PctBach',
		coordinate_x_field='Longitud',
		coordinate_y_field='Latitude'
	),
	logger,
	isSpherical=True
)

{'2025-03-15 04:58:04': 'SpatialDataset : Data schema matchs with the data.'}
{'2025-03-15 04:58:04': 'SpatialDataset : Data points created.'}


Create a GWR kernel and GWR model.

In [4]:
kernel = GwrKernel(
	spatialDataset,
	logger,
	kernel_type='bisquare',
	kernel_bandwidth_type='adaptive'
)
gwr = GWR(spatialDataset, kernel, logger)

{'2025-03-15 04:58:04': 'GwrKernel : Kernel is initialized.'}
{'2025-03-15 04:58:04': 'GWR : GWR model is initialized.'}


Initialize gwr gym environment

In [5]:
env = GwrOptimizerRL(
	gwr,
	logger,
	min_bandwidth=10,
	max_bandwidth=spatialDataset.x_matrix.shape[0],
	min_action=-10,
	max_action=10
)

{'2025-03-15 04:58:04': 'GwrOptimizerRL: GwrOptimizerRL environment is initialized.'}


Using PPO to optimize the bandwidth

In [7]:
TOTAL_TIMESTEPS = 5000
episodeTracker = EpisodeTracker(
  logger,
  total_timesteps=TOTAL_TIMESTEPS
)
model = PPO(
  "MlpPolicy", 
  env, 
  verbose=1, 
  device='cpu'
)
model.learn(
  total_timesteps=TOTAL_TIMESTEPS, 
  callback=episodeTracker
)
logger.append_info("PPO: PPO finished training.")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'2025-03-15 05:01:19': 'Episode 1 ends, total steps: 100, remaining steps: 4900'}
{'2025-03-15 05:01:21': 'Episode 2 ends, total steps: 100, remaining steps: 4800'}
{'2025-03-15 05:01:22': 'Episode 3 ends, total steps: 100, remaining steps: 4700'}
{'2025-03-15 05:01:24': 'Episode 4 ends, total steps: 100, remaining steps: 4600'}
{'2025-03-15 05:01:24': 'episode 90 reached the reward threshold: R2 0.7629490710812805, bandwidth: 36'}
{'2025-03-15 05:01:24': 'Episode 5 ends, total steps: 1, remaining steps: 4599'}
{'2025-03-15 05:01:24': 'episode 91 reached the reward threshold: R2 0.8255832052225135, bandwidth: 24'}
{'2025-03-15 05:01:24': 'Episode 6 ends, total steps: 1, remaining steps: 4598'}
{'2025-03-15 05:01:26': 'Episode 7 ends, total steps: 100, remaining steps: 4498'}
{'2025-03-15 05:01:28': 'Episode 8 ends, total steps: 100, remaining steps: 4398'}
{'2025-03-15 05:01:30': 'Episode 9 e

Test the model

In [None]:
obs, _ = env.reset()
for _ in range(100):
	action, _ = model.predict(obs)
	obs, reward, done, truncated, _ = env.step(action)
	logger.append_info(
		f"Bandwidth: {obs}, Reward (R2): {reward}"
	)
	if done or truncated:
		break

logger.save_model_info_json()