In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

import os
import helper as hp
from configparser import ConfigParser
from ppo_refinement import PPORefinement

from kinetics.jacobian_solver import check_jacobian
import concurrent.futures
from functools import partial
import threading

# Define PPO

In [2]:
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# PPO Algorithm Components
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dims=(256, 512, 1024)):
        super(Actor, self).__init__()
        self.hidden_dims = hidden_dims 
        layers = []
        input_d = state_dim
        for hidden_d in hidden_dims:
            layers.append(nn.Linear(input_d, hidden_d))
            layers.append(nn.ReLU())
            input_d = hidden_d
        
        self.network = nn.Sequential(*layers)
        self.mean_layer = nn.Linear(hidden_dims[-1], action_dim)
        self.log_std = nn.Parameter(torch.full((action_dim,), -0.5, dtype=torch.float32))

    def forward(self, state):
        x = self.network(state)
        mean = self.mean_layer(x)
        std = torch.exp(self.log_std)
        if mean.ndim > 1 and std.ndim == 1 and std.shape[0] == mean.shape[1]:
             std = std.unsqueeze(0).expand_as(mean)
        return mean, std

class Critic(nn.Module):
    def __init__(self, state_dim, hidden_dims=(256, 512, 1024)):
        super(Critic, self).__init__()
        self.hidden_dims = hidden_dims 
        layers = []
        input_d = state_dim
        for hidden_d in hidden_dims:
            layers.append(nn.Linear(input_d, hidden_d))
            layers.append(nn.ReLU())
            input_d = hidden_d
        
        self.network = nn.Sequential(*layers)
        self.value_layer = nn.Linear(hidden_dims[-1], 1)

    def forward(self, state):
        x = self.network(state)
        value = self.value_layer(x)
        return value



In [3]:
class PPORefinement:
    def __init__(self, param_dim, latent_dim, min_x_bounds, max_x_bounds, 
                 names_km_full, chk_jcbn,
                 actor_hidden_dims=(256, 512, 1024), critic_hidden_dims=(256, 512, 1024),
                 p0_init_std=1, actor_lr=1e-4, critic_lr=1e-4, 
                 gamma=0.99, epsilon=0.2, gae_lambda=0.95,
                 ppo_epochs=10, num_episodes_per_update=64, 
                 T_horizon=5, k_reward_steepness=1.0,
                 action_clip_range=(-0.1, 0.1),
                 entropy_coeff=0.01, max_grad_norm=0.5,
                 num_threads=8):  # Added num_threads parameter
        
        self.param_dim = param_dim
        self.latent_dim = latent_dim  # For z in state
        self.min_x_bounds = min_x_bounds
        self.max_x_bounds = max_x_bounds
        self.p0_init_std = p0_init_std
        self.num_threads = num_threads  # Store the number of threads to use

        self.names_km_full = names_km_full 
        self.chk_jcbn = chk_jcbn  # Store the jacobian checker instance
        
        # Thread-local storage for chk_jcbn instances to avoid race conditions
        self.thread_local = threading.local()

        # State dim: p_t (param_dim) + z (latent_dim) + lambda_max (1) + t (1)
        self.state_dim = self.param_dim + self.latent_dim + 1 + 1
        self.action_dim = self.param_dim  # Actor outputs updates to p_t

        self.actor = Actor(self.state_dim, self.action_dim, hidden_dims=actor_hidden_dims)
        self.critic = Critic(self.state_dim, hidden_dims=critic_hidden_dims)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.gamma = gamma
        self.epsilon = epsilon
        self.gae_lambda = gae_lambda
        self.ppo_epochs = ppo_epochs
        self.num_episodes_per_update = num_episodes_per_update
        self.T_horizon = T_horizon
        self.k_reward_steepness = k_reward_steepness 
        self.action_clip_range = action_clip_range 
        self.entropy_coeff = entropy_coeff
        self.max_grad_norm = max_grad_norm
        
        self.eig_partition_final_reward = -2.5 
        
        # Initialize the thread pool executor
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads)

    def _get_thread_local_chk_jcbn(self):
        """Get or create a thread-local copy of chk_jcbn to avoid race conditions"""
        if not hasattr(self.thread_local, 'chk_jcbn'):
            # Clone the chk_jcbn object for this thread
            # This assumes chk_jcbn has a copy or clone method, adjust as needed
            # If no clone method exists, you might need to re-initialize the object
            # or implement thread safety in another way
            self.thread_local.chk_jcbn = self.chk_jcbn  # Replace with proper cloning if needed
        return self.thread_local.chk_jcbn

    def _get_lambda_max_single(self, p_tensor_single):
        """Single-threaded version of lambda_max calculation"""
        p_numpy = p_tensor_single.detach().cpu().numpy()
        # Use thread-local chk_jcbn instance
        chk_jcbn_local = self._get_thread_local_chk_jcbn()
        chk_jcbn_local._prepare_parameters([p_numpy], self.names_km_full) 
        max_eig_list = chk_jcbn_local.calc_eigenvalues_recal_vmax()
        return max_eig_list[0]

    def _get_lambda_max_batch(self, p_tensor_batch):
        """Calculate lambda_max for a batch of parameters in parallel"""
        # Convert tensor batch to list of numpy arrays
        p_numpy_list = [p.detach().cpu().numpy() for p in p_tensor_batch]
        
        # Use ThreadPoolExecutor to parallelize calculations
        with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor:
            results = list(executor.map(self._get_lambda_max_single, p_tensor_batch))
        
        return results

    def _get_lambda_max(self, p_tensor_single):
        """Single parameter version - this maintains the original API"""
        return self._get_lambda_max_single(p_tensor_single)

    def _compute_reward(self, lambda_max_val):
        intermediate_r = 1.0 / (1.0 + np.exp(self.k_reward_steepness * (lambda_max_val - (self.eig_partition_final_reward))))
        # TODO: Right now, we are not using the Incidence part of the reward.
        return intermediate_r

    def _collect_trajectories(self):
        batch_states = []
        batch_actions = []
        batch_log_probs_actions = []
        batch_rewards = []
        batch_next_states = []
        batch_dones = []
        
        all_episode_total_rewards = []
        
        # Collect episode data in parallel
        episode_futures = []
        
        for i_episode in range(self.num_episodes_per_update):
            episode_future = self.executor.submit(self._collect_single_episode, i_episode)
            episode_futures.append(episode_future)
        
        # Gather results
        for future in concurrent.futures.as_completed(episode_futures):
            ep_result = future.result()
            if ep_result:  # Check if the episode returned valid data
                (ep_states, ep_actions, ep_log_probs, ep_rewards, 
                 ep_next_states, ep_dones, ep_total_reward) = ep_result
                
                batch_states.extend(ep_states)
                batch_actions.extend(ep_actions)
                batch_log_probs_actions.extend(ep_log_probs)
                batch_rewards.extend(ep_rewards)
                batch_next_states.extend(ep_next_states)
                batch_dones.extend(ep_dones)
                all_episode_total_rewards.append(ep_total_reward)
        
        # Concatenate collected data into batch tensors
        final_batch_states = torch.stack(batch_states) if batch_states else torch.empty(0, self.state_dim)
        final_batch_actions = torch.stack(batch_actions) if batch_actions else torch.empty(0, self.action_dim)
        final_batch_log_probs = torch.stack(batch_log_probs_actions) if batch_log_probs_actions else torch.empty(0,1)
        final_batch_rewards = torch.stack(batch_rewards) if batch_rewards else torch.empty(0,1)
        final_batch_next_states = torch.stack(batch_next_states) if batch_next_states else torch.empty(0, self.state_dim)
        final_batch_dones = torch.stack(batch_dones) if batch_dones else torch.empty(0,1)
        
        avg_episode_reward_val = np.mean(all_episode_total_rewards) if all_episode_total_rewards else 0

        return (final_batch_states, final_batch_actions, final_batch_log_probs, 
                final_batch_rewards, final_batch_next_states, final_batch_dones,
                avg_episode_reward_val)

    def _collect_single_episode(self, i_episode):
        """Collect data for a single episode - to be run in parallel"""
        print(f"Collecting episode data {i_episode + 1}/{self.num_episodes_per_update}...")
        
        ep_states = []
        ep_actions = []
        ep_log_probs = []
        ep_rewards = []
        ep_next_states = []
        ep_dones = []
        
        # Episode initialization
        # Generate p0: small random values, then clamp. No parameter fixing.
        p_curr_np = np.random.normal(0, self.p0_init_std, size=self.param_dim)
        p_curr_np = np.clip(p_curr_np, self.min_x_bounds, self.max_x_bounds)
        
        # Generate z for state
        z_curr_np = np.random.normal(0, 1, size=self.latent_dim)

        p_curr_torch = torch.tensor(p_curr_np, dtype=torch.float32)
        z_torch_ep = torch.tensor(z_curr_np, dtype=torch.float32)

        episode_total_reward = 0

        for t_s in range(self.T_horizon):
            # Use instance methods for lambda_max and reward
            lambda_max_pt_val = self._get_lambda_max(p_curr_torch) 
            
            state_torch_flat = torch.cat((
                p_curr_torch, z_torch_ep,
                torch.tensor([lambda_max_pt_val], dtype=torch.float32),
                torch.tensor([t_s], dtype=torch.float32)
            ))

            with torch.no_grad():
                action_mean, action_std = self.actor(state_torch_flat.unsqueeze(0))
                dist = Normal(action_mean, action_std)
                action = dist.sample()
                log_prob_action = dist.log_prob(action).sum(dim=-1)

            ep_states.append(state_torch_flat)
            ep_actions.append(action.squeeze(0))
            ep_log_probs.append(log_prob_action)

            action_clipped = torch.clamp(action.squeeze(0), self.action_clip_range[0], self.action_clip_range[1])
            p_next_torch = p_curr_torch + action_clipped
            p_next_torch = torch.clamp(p_next_torch, self.min_x_bounds, self.max_x_bounds)
            
            lambda_max_p_next_val = self._get_lambda_max(p_next_torch)
            is_final_step = (t_s == self.T_horizon - 1)
            reward_val = self._compute_reward(lambda_max_p_next_val)

            ep_rewards.append(torch.tensor([reward_val], dtype=torch.float32))
            episode_total_reward += reward_val

            next_state_torch_flat = torch.cat((
                p_next_torch, z_torch_ep,
                torch.tensor([lambda_max_p_next_val], dtype=torch.float32),
                torch.tensor([t_s + 1], dtype=torch.float32)
            ))
            ep_next_states.append(next_state_torch_flat)
            ep_dones.append(torch.tensor([1.0 if is_final_step else 0.0], dtype=torch.float32))
            p_curr_torch = p_next_torch
        
        return (ep_states, ep_actions, ep_log_probs, ep_rewards, 
                ep_next_states, ep_dones, episode_total_reward)

    def _compute_gae(self, rewards, values, next_values, dones):
        advantages = torch.zeros_like(rewards)
        last_advantage = 0
        if rewards.nelement() == 0: 
             return torch.zeros_like(rewards), torch.zeros_like(values) 

        for t in reversed(range(len(rewards))): 
            is_terminal_transition = dones[t].item() > 0.5 
            delta = rewards[t] + self.gamma * next_values[t] * (1.0 - dones[t]) - values[t]
            advantages[t] = last_advantage = delta + self.gamma * self.gae_lambda * (1.0 - dones[t]) * last_advantage
        returns = advantages + values
        return advantages, returns

    def update(self, trajectories_data):
        states, actions, log_probs_old, rewards, next_states, dones, _ = trajectories_data
        
        if states.nelement() == 0: 
            return 0.0, 0.0

        with torch.no_grad():
            values = self.critic(states)          
            next_values = self.critic(next_states) 

        advantages, returns = self._compute_gae(rewards, values, next_values, dones)
        
        if advantages.nelement() == 0: 
             return 0.0, 0.0

        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        actor_total_loss_epoch = 0
        critic_total_loss_epoch = 0

        for _ in range(self.ppo_epochs):
            current_pi_mean, current_pi_std = self.actor(states)
            dist_new = Normal(current_pi_mean, current_pi_std)
            log_probs_new = dist_new.log_prob(actions).sum(dim=-1, keepdim=True)
            entropy = dist_new.entropy().mean()

            ratios = torch.exp(log_probs_new - log_probs_old.detach()) 
            
            surr1 = ratios * advantages.detach() 
            surr2 = torch.clamp(ratios, 1.0 - self.epsilon, 1.0 + self.epsilon) * advantages.detach()
            actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coeff * entropy

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.actor_optimizer.step()
            actor_total_loss_epoch += actor_loss.item()

            values_pred = self.critic(states) 
            critic_loss = (returns.detach() - values_pred).pow(2).mean()

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
            self.critic_optimizer.step()
            critic_total_loss_epoch += critic_loss.item()
        
        avg_actor_loss = actor_total_loss_epoch / self.ppo_epochs
        avg_critic_loss = critic_total_loss_epoch / self.ppo_epochs
        return avg_actor_loss, avg_critic_loss

    def train(self, num_iterations, output_path_base="ppo_training_output"):
        import os
        os.makedirs(output_path_base, exist_ok=True)
        
        all_iter_avg_rewards = []
        print(f"Starting PPO training for {num_iterations} iterations with {self.num_threads} threads.") 
        print(f"State dim: {self.state_dim}, Action dim: {self.action_dim}, Latent (z) dim: {self.latent_dim}")
        print(f"Num episodes per update: {self.num_episodes_per_update}, Horizon T: {self.T_horizon}")
        print(f"p0 initialized with N(0, {self.p0_init_std**2}) and clamped to [{self.min_x_bounds}, {self.max_x_bounds}]")

        for iteration in range(num_iterations):
            trajectories_data = self._collect_trajectories()
            avg_episode_reward = trajectories_data[-1] 
            
            if trajectories_data[0].nelement() == 0 and self.num_episodes_per_update > 0:
                print(f"Iter {iteration:04d}: No trajectories collected. Skipping update. Avg Ep Reward: {avg_episode_reward:.4f}")
                all_iter_avg_rewards.append(avg_episode_reward) 
                continue 

            actor_loss, critic_loss = self.update(trajectories_data)
            all_iter_avg_rewards.append(avg_episode_reward)
        
            print(f"Iter {iteration:04d}: Avg Ep Reward: {avg_episode_reward:.4f}, "
                    f"Actor Loss: {actor_loss:.4f}, Critic Loss: {critic_loss:.4f}")
            
            if iteration % 50 == 0 or iteration == num_iterations - 1: 
                actor_path = os.path.join(output_path_base, f"actor_iter_{iteration}.pth")
                critic_path = os.path.join(output_path_base, f"critic_iter_{iteration}.pth")
                torch.save(self.actor.state_dict(), actor_path)
                torch.save(self.critic.state_dict(), critic_path)
        
        # Shutdown the executor
        self.executor.shutdown()
        print("Training finished.")
        return all_iter_avg_rewards

# Train PPO agent

In [4]:
#Parse arguments from configfile
configs = ConfigParser()
configs.read('configfile.ini')

n_samples = int(configs['MLP']['n_samples']) # Used by MLP for its internal sampling if any, and for p0 generation.

lnminkm = float(configs['CONSTRAINTS']['min_km'])
lnmaxkm = float(configs['CONSTRAINTS']['max_km'])

repeats = int(configs['EVOSTRAT']['repeats'])
generations = int(configs['EVOSTRAT']['generations']) # Will be used as num_iterations for PPO
ss_idx = int(configs['EVOSTRAT']['ss_idx'])
# n_threads = int(configs['EVOSTRAT']['n_threads']) # PPO collection is currently single-threaded

output_path = configs['PATHS']['output_path']
met_model = configs['PATHS']['met_model']
names_km_config = hp.load_pkl(f'models/{met_model}/parameter_names_km_fdp1.pkl') # Full list of param names

# Parameters needed directly by PPORefinement
param_dim_config = int(configs['MLP']['no_kms'])
latent_dim_config = int(configs['MLP']['latent_dim']) # For z vector in state


# Call solvers from SKimPy (Used only for initial messages now)
chk_jcbn = check_jacobian()

In [5]:
# Integrate data
print('---- Load kinetic and thermodynamic data')
chk_jcbn._load_ktmodels(met_model, 'fdp1')           ## Load kinetic and thermodynamic data
print('---- Load steady state data')
chk_jcbn._load_ssprofile(met_model, 'fdp1', ss_idx)  ## Integrate steady state information

---- Load kinetic and thermodynamic data


2025-05-12 17:18:13,194 - thermomodel_new - INFO - # Model initialized with units kcal/mol and temperature 298.15 K


---- Load steady state data


In [6]:
print('--- Begin PPO refinement strategy')
for rep in range(repeats):
    this_savepath = f'{output_path}/ppo_repeat_{rep}/' 
    os.makedirs(this_savepath, exist_ok=True)

    # Instantiate PPORefinement agent with direct parameters for serial execution
    ppo_agent = PPORefinement(
        param_dim=param_dim_config,
        latent_dim=latent_dim_config,
        min_x_bounds=lnminkm,
        max_x_bounds=lnmaxkm,
        names_km_full=names_km_config, # Pass the full list of names
        chk_jcbn=chk_jcbn,             # Pass the jacobian checker instance
        p0_init_std=1, # Default is 0.01
        num_threads=32
    )
    
    print(f"Repeat {rep}: Starting PPO training for {generations} iterations (serial execution).")
    ppo_iteration_rewards = ppo_agent.train(
        num_iterations=generations, # Use 'generations' from config as PPO iterations
        output_path_base=this_savepath
    )
    
    hp.save_pkl(f'{this_savepath}/ppo_iteration_rewards.pkl', ppo_iteration_rewards)
    print(f"Repeat {rep}: PPO training finished. Rewards log saved to {this_savepath}")


--- Begin PPO refinement strategy
Repeat 0: Starting PPO training for 25 iterations (serial execution).
Starting PPO training for 25 iterations with 32 threads.
State dim: 485, Action dim: 384, Latent (z) dim: 99
Num episodes per update: 64, Horizon T: 5
p0 initialized with N(0, 1) and clamped to [-25.0, 3.0]
Collecting episode data 1/64...
Collecting episode data 2/64...Collecting episode data 3/64...

Collecting episode data 4/64...
Collecting episode data 5/64...
Collecting episode data 6/64...
Collecting episode data 7/64...
Collecting episode data 8/64...
Collecting episode data 9/64...
Collecting episode data 10/64...
Collecting episode data 11/64...
Collecting episode data 12/64...
Collecting episode data 13/64...
Collecting episode data 14/64...
Collecting episode data 15/64...
Collecting episode data 16/64...Collecting episode data 17/64...

Collecting episode data 18/64...
Collecting episode data 19/64...
Collecting episode data 20/64...
Collecting episode data 21/64...
Colle

ValueError: array must not contain infs or NaNs