In [60]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

import os
import helper as hp
from configparser import ConfigParser
from ppo_refinement import PPORefinement

from kinetics.jacobian_solver import check_jacobian
import concurrent.futures
from functools import partial
import threading

# Define PPO

In [61]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dims=(64, 64)):
        super(Actor, self).__init__()
        self.hidden_dims = hidden_dims 
        layers = []
        input_d = state_dim
        for hidden_d in hidden_dims:
            layers.append(nn.Linear(input_d, hidden_d))
            layers.append(nn.ReLU())
            input_d = hidden_d
        
        self.network = nn.Sequential(*layers)
        self.mean_layer = nn.Linear(hidden_dims[-1], action_dim)
        self.log_std_layer = nn.Linear(hidden_dims[-1], action_dim)
        # Initialize log_std_layer to start with reasonable variance
        nn.init.constant_(self.log_std_layer.bias, -0.5)

    def forward(self, state):
        x = self.network(state)
        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)  # Prevent numerical instability
        std = torch.exp(log_std)
        return mean, std

class Critic(nn.Module):
    def __init__(self, state_dim, hidden_dims=(64, 64)):
        super(Critic, self).__init__()
        self.hidden_dims = hidden_dims 
        layers = []
        input_d = state_dim
        for hidden_d in hidden_dims:
            layers.append(nn.Linear(input_d, hidden_d))
            layers.append(nn.ReLU())
            input_d = hidden_d
        
        self.network = nn.Sequential(*layers)
        self.value_layer = nn.Linear(hidden_dims[-1], 1)

    def forward(self, state):
        x = self.network(state)
        value = self.value_layer(x)
        return value


In [81]:
class PPORefinement:
    def __init__(self, param_dim, latent_dim, min_x_bounds, max_x_bounds, 
                 names_km_full, chk_jcbn,
                 actor_hidden_dims=(256, 512, 1024), critic_hidden_dims=(256, 512, 1024),
                 p0_init_std=1, actor_lr=1e-4, critic_lr=1e-4, 
                 gamma=0.99, epsilon=0.2, gae_lambda=0.95,
                 ppo_epochs=10, num_episodes_per_update=32, 
                 T_horizon=5, k_reward_steepness=1.0,
                 action_clip_range=(-0.1, 0.1),
                 entropy_coeff=0.01, max_grad_norm=0.5,
                 reward_flag=0):
        
        self.param_dim = param_dim
        self.latent_dim = latent_dim # For z in state
        self.min_x_bounds = min_x_bounds
        self.max_x_bounds = max_x_bounds
        self.p0_init_std = p0_init_std

        self.names_km_full = names_km_full 
        self.chk_jcbn = chk_jcbn # Store the jacobian checker instance

        # State dim: p_t (param_dim) + z (latent_dim) + lambda_max (1) + t (1)
        self.state_dim = self.param_dim + self.latent_dim + 1 + 1
        self.action_dim = self.param_dim # Actor outputs updates to p_t

        self.actor = Actor(self.state_dim, self.action_dim, hidden_dims=actor_hidden_dims)
        self.critic = Critic(self.state_dim, hidden_dims=critic_hidden_dims)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.gamma = gamma
        self.epsilon = epsilon
        self.gae_lambda = gae_lambda
        self.ppo_epochs = ppo_epochs
        self.num_episodes_per_update = num_episodes_per_update
        self.T_horizon = T_horizon
        self.k_reward_steepness = k_reward_steepness 
        self.action_clip_range = action_clip_range 
        self.entropy_coeff = entropy_coeff
        self.max_grad_norm = max_grad_norm
        self.reward_flag = reward_flag
        
        self.eig_partition_final_reward = -2.5 

    def _get_lambda_max(self, p_tensor_single):
        p_numpy = p_tensor_single.detach().cpu().numpy()
        # Use the stored chk_jcbn instance
        self.chk_jcbn._prepare_parameters([p_numpy], self.names_km_full) 
        max_eig_list = self.chk_jcbn.calc_eigenvalues_recal_vmax()
        max_eig_list.sort()

        return max_eig_list

    def _compute_reward(self, lambdas_val, n_consider=10):
        if self.reward_flag == 0:
            lambda_max_val = lambdas_val[0]
            if lambda_max_val > 10:
                return 0.0
            r = 1.0 / (1.0 + np.exp(self.k_reward_steepness * (lambda_max_val - (self.eig_partition_final_reward))))
        else:
            considered_avg = sum(lambdas_val[:n_consider]) / n_consider
            r = np.exp(-0.1 * considered_avg) / 2
        # TODO: Right now, we are not using the Incidence part of the reward.

        return r

    def _collect_trajectories(self):
        batch_states = []
        batch_actions = []
        batch_log_probs_actions = []
        batch_rewards = []
        batch_next_states = []
        batch_dones = []
        
        all_episode_total_rewards = []

        for i_episode in range(self.num_episodes_per_update):
            print(f"Collecting episode data {i_episode + 1}/{self.num_episodes_per_update}...")
            # Episode initialization
            # Generate p0: small random values, then clamp. No parameter fixing.
            p_curr_np = np.random.normal(0, self.p0_init_std, size=self.param_dim)
            p_curr_np = (p_curr_np - p_curr_np.min()) / (p_curr_np.max() - p_curr_np.min())  # Normalize to [0, 1]
            p_curr_np = p_curr_np * (self.max_x_bounds - self.min_x_bounds) + self.min_x_bounds  # Scale to [min_x_bounds, max_x_bounds]
            
            # Generate z for state
            z_curr_np = np.random.normal(0, 1, size=self.latent_dim)

            p_curr_torch = torch.tensor(p_curr_np, dtype=torch.float32)
            z_torch_ep = torch.tensor(z_curr_np, dtype=torch.float32)

            episode_total_reward = 0

            for t_s in range(self.T_horizon):
                lambda_max_pt_val = self._get_lambda_max(p_curr_torch)[0]
                
                state_torch_flat = torch.cat((
                    p_curr_torch, z_torch_ep,
                    torch.tensor([lambda_max_pt_val], dtype=torch.float32),
                    torch.tensor([t_s], dtype=torch.float32)
                ))

                with torch.no_grad():
                    action_mean, action_std = self.actor(state_torch_flat.unsqueeze(0))
                    # Apply tanh and scale to [-0.1, 0.1]
                    action_raw = action_mean
                    action_scaled = torch.tanh(action_raw) * 0.1
                    # Adjust log probability for tanh transformation
                    dist = Normal(action_mean, action_std)
                    action_raw_sample = dist.sample()
                    log_prob_raw = dist.log_prob(action_raw_sample).sum(dim=-1)
                    # Approximate log prob for tanh-transformed action
                    action_scaled_sample = torch.tanh(action_raw_sample) * 0.1
                    log_prob_scaled = log_prob_raw - torch.sum(2 * (torch.log(torch.tensor(2.0)) - action_raw_sample - torch.nn.functional.softplus(-2 * action_raw_sample)), dim=-1)

                batch_states.append(state_torch_flat)
                batch_actions.append(action_scaled_sample.squeeze(0))
                batch_log_probs_actions.append(log_prob_scaled)

                p_next_torch = p_curr_torch + action_scaled_sample.squeeze(0)
                p_next_torch = torch.clamp(p_next_torch, self.min_x_bounds, self.max_x_bounds)
                
                lambdas_p_next_val = self._get_lambda_max(p_next_torch)
                is_final_step = (t_s == self.T_horizon - 1)
                reward_val = self._compute_reward(lambdas_p_next_val)

                batch_rewards.append(torch.tensor([reward_val], dtype=torch.float32))
                episode_total_reward += reward_val / self.T_horizon

                next_state_torch_flat = torch.cat((
                    p_next_torch, z_torch_ep,
                    torch.tensor([lambdas_p_next_val[0]], dtype=torch.float32),
                    torch.tensor([t_s + 1], dtype=torch.float32)
                ))
                batch_next_states.append(next_state_torch_flat)
                batch_dones.append(torch.tensor([1.0 if is_final_step else 0.0], dtype=torch.float32))
                p_curr_torch = p_next_torch
            
            all_episode_total_rewards.append(episode_total_reward)

        final_batch_states = torch.stack(batch_states) if batch_states else torch.empty(0, self.state_dim)
        final_batch_actions = torch.stack(batch_actions) if batch_actions else torch.empty(0, self.action_dim)
        final_batch_log_probs = torch.stack(batch_log_probs_actions) if batch_log_probs_actions else torch.empty(0,1)
        final_batch_rewards = torch.stack(batch_rewards) if batch_rewards else torch.empty(0,1)
        final_batch_next_states = torch.stack(batch_next_states) if batch_next_states else torch.empty(0, self.state_dim)
        final_batch_dones = torch.stack(batch_dones) if batch_dones else torch.empty(0,1)
        
        avg_episode_reward_val = np.mean(all_episode_total_rewards) if all_episode_total_rewards else 0

        return (final_batch_states, final_batch_actions, final_batch_log_probs, 
                final_batch_rewards, final_batch_next_states, final_batch_dones,
                avg_episode_reward_val)

    def _compute_gae(self, rewards, values, next_values, dones):
        advantages = torch.zeros_like(rewards)
        last_advantage = 0
        if rewards.nelement() == 0: 
             return torch.zeros_like(rewards), torch.zeros_like(values) 

        for t in reversed(range(len(rewards))): 
            is_terminal_transition = dones[t].item() > 0.5 
            delta = rewards[t] + self.gamma * next_values[t] * (1.0 - dones[t]) - values[t]
            advantages[t] = last_advantage = delta + self.gamma * self.gae_lambda * (1.0 - dones[t]) * last_advantage
        returns = advantages + values
        return advantages, returns

    def update(self, trajectories_data):
        states, actions, log_probs_old, rewards, next_states, dones, _ = trajectories_data
        
        if states.nelement() == 0: 
            return 0.0, 0.0

        with torch.no_grad():
            values = self.critic(states)          
            next_values = self.critic(next_states) 

        advantages, returns = self._compute_gae(rewards, values, next_values, dones)
        
        if advantages.nelement() == 0: 
             return 0.0, 0.0

        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        actor_total_loss_epoch = 0
        critic_total_loss_epoch = 0

        for _ in range(self.ppo_epochs):
            current_pi_mean, current_pi_std = self.actor(states)
            dist_new = Normal(current_pi_mean, current_pi_std)
            log_probs_new = dist_new.log_prob(actions).sum(dim=-1, keepdim=True)
            entropy = dist_new.entropy().mean()

            ratios = torch.exp(log_probs_new - log_probs_old.detach()) 
            
            surr1 = ratios * advantages.detach() 
            surr2 = torch.clamp(ratios, 1.0 - self.epsilon, 1.0 + self.epsilon) * advantages.detach()
            actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coeff * entropy

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.actor_optimizer.step()
            actor_total_loss_epoch += actor_loss.item()

            values_pred = self.critic(states) 
            critic_loss = (returns.detach() - values_pred).pow(2).mean()

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
            self.critic_optimizer.step()
            critic_total_loss_epoch += critic_loss.item()
        
        avg_actor_loss = actor_total_loss_epoch / self.ppo_epochs
        avg_critic_loss = critic_total_loss_epoch / self.ppo_epochs
        return avg_actor_loss, avg_critic_loss


    def train(self, num_iterations, output_path_base="ppo_training_output"):
        import os
        os.makedirs(output_path_base, exist_ok=True)
        
        all_iter_avg_rewards = []
        best_avg_reward = float('-inf')  # Initialize to negative infinity
        best_actor_path = os.path.join(output_path_base, "best_actor.pth")
        best_critic_path = os.path.join(output_path_base, "best_critic.pth")
        
        print(f"Starting PPO training for {num_iterations} iterations (serial execution).") 
        print(f"State dim: {self.state_dim}, Action dim: {self.action_dim}, Latent (z) dim: {self.latent_dim}")
        print(f"Num episodes per update: {self.num_episodes_per_update}, Horizon T: {self.T_horizon}")
        print(f"p0 initialized with N(0, {self.p0_init_std**2}) and clamped to [{self.min_x_bounds}, {self.max_x_bounds}]")


        for iteration in range(num_iterations):
            trajectories_data = self._collect_trajectories()
            avg_episode_reward = trajectories_data[-1] 
            
            if trajectories_data[0].nelement() == 0 and self.num_episodes_per_update > 0:
                print(f"Iter {iteration:04d}: No trajectories collected. Skipping update. Avg Ep Reward: {avg_episode_reward:.4f}")
                all_iter_avg_rewards.append(avg_episode_reward) 
                continue 

            actor_loss, critic_loss = self.update(trajectories_data)
            all_iter_avg_rewards.append(avg_episode_reward)
        
            print(f"Iter {iteration:04d}: Avg Ep Reward: {avg_episode_reward:.4f}, "
                    f"Actor Loss: {actor_loss:.4f}, Critic Loss: {critic_loss:.4f}")
            
             # Save the best model if the current average reward is the highest
            if avg_episode_reward > best_avg_reward:
                best_avg_reward = avg_episode_reward
                torch.save(self.actor.state_dict(), best_actor_path)
                torch.save(self.critic.state_dict(), best_critic_path)
                print(f"Iter {iteration:04d}: New best model saved with Avg Ep Reward: {best_avg_reward:.4f}")

            if iteration % 50 == 0 or iteration == num_iterations -1 : 
                actor_path = os.path.join(output_path_base, f"actor_iter_{iteration}.pth")
                critic_path = os.path.join(output_path_base, f"critic_iter_{iteration}.pth")
                torch.save(self.actor.state_dict(), actor_path)
                torch.save(self.critic.state_dict(), critic_path)
        
        print("Training finished.")
        return all_iter_avg_rewards


# Train PPO agent

In [82]:
#Parse arguments from configfile
configs = ConfigParser()
configs.read('configfile.ini')

n_samples = int(configs['MLP']['n_samples']) # Used by MLP for its internal sampling if any, and for p0 generation.

lnminkm = float(configs['CONSTRAINTS']['min_km'])
lnmaxkm = float(configs['CONSTRAINTS']['max_km'])

repeats = int(configs['EVOSTRAT']['repeats'])
generations = int(configs['EVOSTRAT']['generations']) # Will be used as num_iterations for PPO
ss_idx = int(configs['EVOSTRAT']['ss_idx'])
# n_threads = int(configs['EVOSTRAT']['n_threads']) # PPO collection is currently single-threaded

output_path = configs['PATHS']['output_path']
met_model = configs['PATHS']['met_model']
names_km_config = hp.load_pkl(f'models/{met_model}/parameter_names_km_fdp1.pkl') # Full list of param names

# Parameters needed directly by PPORefinement
param_dim_config = int(configs['MLP']['no_kms'])
latent_dim_config = int(configs['MLP']['latent_dim']) # For z vector in state


# Call solvers from SKimPy (Used only for initial messages now)
chk_jcbn = check_jacobian()

In [83]:
# Integrate data
print('---- Load kinetic and thermodynamic data')
chk_jcbn._load_ktmodels(met_model, 'fdp1')           ## Load kinetic and thermodynamic data
print('---- Load steady state data')
chk_jcbn._load_ssprofile(met_model, 'fdp1', ss_idx)  ## Integrate steady state information

---- Load kinetic and thermodynamic data


2025-05-13 17:02:30,900 - thermomodel_new - INFO - # Model initialized with units kcal/mol and temperature 298.15 K


---- Load steady state data


In [84]:
print('--- Begin PPO refinement strategy')
for rep in range(repeats):
    this_savepath = f'{output_path}/ppo_repeat_{rep}/' 
    os.makedirs(this_savepath, exist_ok=True)

    # Instantiate PPORefinement agent with direct parameters for serial execution
    ppo_agent = PPORefinement(
        param_dim=param_dim_config,
        latent_dim=latent_dim_config,
        min_x_bounds=lnminkm,
        max_x_bounds=lnmaxkm,
        names_km_full=names_km_config, # Pass the full list of names
        chk_jcbn=chk_jcbn,             # Pass the jacobian checker instance
        p0_init_std=1, # Default is 0.01
        ppo_epochs=30,
        k_reward_steepness=1.0,
        reward_flag=1,
    )
    
    print(f"Repeat {rep}: Starting PPO training for {generations} iterations (serial execution).")
    ppo_iteration_rewards = ppo_agent.train(
        num_iterations=generations, # Use 'generations' from config as PPO iterations
        output_path_base=this_savepath
    )
    
    hp.save_pkl(f'{this_savepath}/ppo_iteration_rewards.pkl', ppo_iteration_rewards)
    print(f"Repeat {rep}: PPO training finished. Rewards log saved to {this_savepath}")


--- Begin PPO refinement strategy
Repeat 0: Starting PPO training for 25 iterations (serial execution).
Starting PPO training for 25 iterations (serial execution).
State dim: 485, Action dim: 384, Latent (z) dim: 99
Num episodes per update: 32, Horizon T: 5
p0 initialized with N(0, 1) and clamped to [-25.0, 3.0]
Collecting episode data 1/32...
Collecting episode data 2/32...
Collecting episode data 3/32...
Collecting episode data 4/32...
Collecting episode data 5/32...
Collecting episode data 6/32...
Collecting episode data 7/32...
Collecting episode data 8/32...
Collecting episode data 9/32...
Collecting episode data 10/32...
Collecting episode data 11/32...
Collecting episode data 12/32...
Collecting episode data 13/32...
Collecting episode data 14/32...
Collecting episode data 15/32...
Collecting episode data 16/32...
Collecting episode data 17/32...
Collecting episode data 18/32...
Collecting episode data 19/32...
Collecting episode data 20/32...
Collecting episode data 21/32...
Co

In [87]:
def evaluate_policy_incidence(ppo_instance, actor_path, num_trials=1):
    """
    Evaluate the policy incidence using a pre-trained actor model.

    Args:
        ppo_instance: The PPORefinement instance.
        actor_path: Path to the pre-trained actor model (best_actor.pth).
        num_trials: Number of trials to evaluate the policy incidence.

    Returns:
        incidence_rate: The rate of valid models.
        all_final_params: List of final parameters for valid models.
    """
    # Load the pre-trained actor model
    ppo_instance.actor.load_state_dict(torch.load(actor_path))
    ppo_instance.actor.eval()  # Set to evaluation mode

    valid_count = 0
    all_final_params = []

    for i in range(num_trials):
        with torch.no_grad():
            # Sample initial p0
            p0_np = np.random.normal(0, ppo_instance.p0_init_std, size=ppo_instance.param_dim)
            p0_np = (p0_np - p0_np.min()) / (p0_np.max() - p0_np.min())  # Normalize
            p0_np = p0_np * (ppo_instance.max_x_bounds - ppo_instance.min_x_bounds) + ppo_instance.min_x_bounds
            p_curr = torch.tensor(p0_np, dtype=torch.float32)

            # Sample latent z
            z = torch.tensor(np.random.normal(0, 1, size=ppo_instance.latent_dim), dtype=torch.float32)

            for t_s in range(ppo_instance.T_horizon):
                lambda_max_val = ppo_instance._get_lambda_max(p_curr)[0]

                state_torch = torch.cat((
                    p_curr,
                    z,
                    torch.tensor([lambda_max_val], dtype=torch.float32),
                    torch.tensor([t_s], dtype=torch.float32)
                )).unsqueeze(0)

                action_mean, action_std = ppo_instance.actor(state_torch)
                dist = Normal(action_mean, action_std)
                action = dist.sample()
                action_clipped = torch.clamp(action.squeeze(0), ppo_instance.action_clip_range[0], ppo_instance.action_clip_range[1])
                p_next = p_curr + action_clipped
                p_next = torch.clamp(p_next, ppo_instance.min_x_bounds, ppo_instance.max_x_bounds)
                p_curr = p_next

            # Final p_curr after T steps
            p_final_np = p_curr.detach().cpu().numpy()
            ppo_instance.chk_jcbn._prepare_parameters([p_final_np], ppo_instance.names_km_full)

            # max eigenvalue check
            max_eig_list = ppo_instance.chk_jcbn.calc_eigenvalues_recal_vmax()
            max_eig_val = max_eig_list[0]
            print(max_eig_val)
            is_valid = max_eig_val <= ppo_instance.eig_partition_final_reward

            if is_valid:
                valid_count += 1
                all_final_params.append(p_final_np)

    incidence_rate = valid_count / num_trials
    print(f"Incidence Rate (valid models): {incidence_rate:.4f} ({valid_count}/{num_trials})")
    return incidence_rate, all_final_params

In [88]:
incidence_rate, all_final_params = evaluate_policy_incidence(
    ppo_agent, 
    actor_path=os.getcwd() + "/output/ppo-refinement/ppo_repeat_0/best_actor.pth", 
    num_trials=50
)

-0.8366565126703308
-0.3800016951214321
-1.99830588230128
-0.4950441624175344
1.8939371083605685
-0.1287878808091703
-0.5420135140380844
2.811543767125017
89.79009964834373
1.1905271860053948
-0.7440562845678245
-0.26093326038648024
281.33711289760083
1.8959274821311063
21.076483413500537
-0.7848911422596059
-0.46239764871201394
-0.0030001931842974267
-0.7472626009699929
-0.25284471214816406
-0.14477922602576504
2894.0299066180614
28.143691154887954
-0.42334207193272766
-0.23080973333406413
-0.6674578873282064
-0.26084045480645185
0.03949169982462305
-0.0711469562871184
-1.2096061977805528
-1.989381638100718
-0.11895712013903316
-0.005023202219284921
-0.2697773826847881
20065.50846423974
-0.9702740914229454
-1.6513761774943059
36.01367575364909
6.342956395388841
-0.19269096471234842
1.9035265583539625
-0.08565397724548474
-0.048998832546863595
1.2558307326107339
358.45473654409363
33.51080792573376
-1.3218516227764643
-0.0695718884783448
-0.11163448234629156
-0.2409647641202906
Inciden