Optimize LGWR local bandwidths using Reinforcement learning approach (Proximal Policy Optimization, PPO)

In [1]:
from stable_baselines3 import PPO
import pandas as pd

from src.optimizer.reinforce.lgwr_optimizer import LgwrOptimizerRL
from src.dataset.interfaces.spatial_dataset import IFieldInfo
from src.optimizer.reinforce.callback import EpisodeTracker
from src.dataset.spatial_dataset import SpatialDataset
from src.kernel.lgwr_kernel import LgwrKernel
from src.log.lgwr_logger import LgwrLogger
from src.model.lgwr import LGWR

Create a logger to record the LGWR model's information.

In [2]:
logger = LgwrLogger()

Load the Georgia dataset and create a spatial dataset.

In [3]:
georgia_data = pd.read_csv(r'./data/GData_utm.csv')
spatialDataset = SpatialDataset(
	georgia_data,
	IFieldInfo(
		predictor_fields=['PctFB', 'PctBlack', 'PctRural'],
		response_field='PctBach',
		coordinate_x_field='Longitud',
		coordinate_y_field='Latitude'
	),
	logger,
	isSpherical=True
)

{'2025-03-15 04:53:49': 'SpatialDataset : Data schema matchs with the data.'}
{'2025-03-15 04:53:50': 'SpatialDataset : Data points created.'}


Create a LGWR kernel and LGWR model.

In [4]:
kernel = LgwrKernel(
	spatialDataset,
	logger,
	kernel_type='bisquare',
	kernel_bandwidth_type='adaptive'
)
lgwr = LGWR(spatialDataset, kernel, logger)

{'2025-03-15 04:53:50': 'LgwrKernel : Kernel is initialized.'}
{'2025-03-15 04:53:50': 'LGWR : LGWR model is initialized.'}


Initialize lgwr gym environment

In [5]:
env = LgwrOptimizerRL(
	lgwr,
	logger,
	min_bandwidth=10,
	max_bandwidth=spatialDataset.x_matrix.shape[0],
	min_action=-10,
	max_action=10,
	max_steps=500,
	reward_threshold=0.75
)

{'2025-03-15 04:53:50': 'LgwrOptimizerRL: LgwrOptimizerRL environment is initialized.'}


Using PPO to optimize the bandwidth vector
(local bandwidths for each location)

In [6]:
TOTAL_TIMESTEPS = 5000
episodeTracker = EpisodeTracker(
	logger,
	total_timesteps=TOTAL_TIMESTEPS
)
model = PPO(
	"MlpPolicy",
	env,
	verbose=1,
	device='cpu'
)
model.learn(
	total_timesteps=TOTAL_TIMESTEPS,
	callback=episodeTracker
)
logger.append_info("PPO: PPO finished training.")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
{'2025-03-15 04:53:51': 'episode 1 reached the reward threshold: R2 0.7503555064972859, bandwidth vector: [40 36 40 41 47 34 34 41 44 45 39 41 41 39 41 44 42 40 43 36 49 41 46 44\n 39 46 43 46 40 43 39 48 43 47 42 42 44 43 39 41 42 41 39 43 40 38 44 43\n 38 42 39 41 41 45 41 40 39 37 42 43 41 47 38 44 39 40 38 43 41 40 42 42\n 45 48 41 42 36 39 43 43 41 42 40 37 44 47 42 43 48 41 45 41 36 42 37 45\n 40 45 42 44 44 43 42 41 45 45 47 38 39 49 42 37 35 42 41 46 43 43 40 42\n 45 40 49 39 44 36 44 41 41 41 39 31 44 38 40 43 42 37 46 45 42 39 38 49\n 44 40 41 34 39 39 43 41 41 46 46 42 44 40 48]'}
{'2025-03-15 04:53:51': 'Episode 1 ends, total steps: 9, remaining steps: 4991'}
{'2025-03-15 04:54:04': 'Episode 2 ends, total steps: 500, remaining steps: 4491'}
{'2025-03-15 04:54:14': 'Episode 3 ends, total steps: 500, remaining steps: 3991'}
{'2025-03-15 04:54:24': 'Episode 4 ends, total steps: 500, r

Test the model

In [7]:
obs, _ = env.reset()
for _ in range(1000):
	action, _ = model.predict(obs)
	obs, reward, done, truncated, _ = env.step(action)
	logger.append_info(
		f"Bandwidth: {obs}, Reward (R2): {reward}"
	)
	if done or truncated:
		break

{'2025-03-15 04:56:03': 'Bandwidth: [120 117 120 119 122 119 120 118 119 119 117 120 121 118 118 118 119 119\n 119 119 121 122 118 118 118 118 118 119 119 120 118 121 119 119 116 119\n 121 119 120 119 119 118 118 120 120 120 119 120 119 120 119 117 118 119\n 119 121 117 120 121 118 118 119 119 120 122 119 119 120 121 119 118 117\n 118 120 120 119 117 120 118 120 119 120 120 117 119 120 119 118 119 120\n 119 117 119 121 119 119 120 118 119 119 119 119 121 120 118 120 118 120\n 118 120 120 119 121 118 118 118 119 117 120 119 117 119 118 120 119 119\n 120 118 118 119 120 121 121 119 118 118 117 118 119 119 119 118 118 117\n 118 118 119 118 118 120 119 119 118 121 119 118 117 119 119], Reward (R2): 0.6751007733372907'}
{'2025-03-15 04:56:03': 'Bandwidth: [122 116 119 119 122 120 120 118 119 120 116 121 122 117 119 116 118 119\n 121 119 119 124 116 116 119 117 116 118 119 122 117 122 120 119 115 119\n 121 118 119 121 119 118 118 120 119 121 118 122 119 119 120 118 118 120\n 121 120 117 122 

Save the log

In [8]:
logger.save_model_info_json()