In [None]:
import json
from pathlib import Path
import numpy as np
import pandas as pd
from loguru import logger
import pygmo as pg

from solarmed_optimization import (EnvironmentVariables,
                                   ProblemParameters,
                                   ProblemSamples)
from solarmed_optimization.utils import (evaluate_model,
                                         decision_vector_to_decision_variables,
                                         validate_dec_var_updates)
from solarmed_optimization.utils.initialization import problem_initialization
from solarmed_optimization.problems import BaseMinlpProblem, get_bounds, evaluate_fitness

logger.disable("phd_visualizations")

#%% Constants
# Paths definition
output_path: Path = Path("./results")
data_path: Path = Path("./data")
fsm_data_path: Path = Path("./results")
date_str: str = "20230703" # "20230707_20230710" # '20230630' '20230703'
pop_size: int = 10


### Initialize problem

In [None]:
# Either load parameters from json or create a new instance
with open(output_path / "problem_params.json") as f:
    problem_params = ProblemParameters(**json.load(f))
# problem_params = ProblemParameters()

problem_data = problem_initialization(problem_params=problem_params,
                                      date_str=date_str,
                                      data_path=data_path)

ps: ProblemSamples = problem_data.problem_samples
pp: ProblemParameters = problem_data.problem_params
df: pd.DataFrame = problem_data.df
model = problem_data.model

# df_mods: list[pd.DataFrame] = []
df_hors: list[pd.DataFrame] = []
df_sim: pd.DataFrame = None

# Fill missing data
# df['med_vac_state'] = 2

opt_step_idx: int = 0
# for opt_step_idx in range(0, max_opt_steps):
idx_mod = pp.idx_start
# for opt_step_idx in range(0, max_opt_steps):
hor_span = (idx_mod+1, idx_mod+1+ps.n_evals_mod_in_hor_window)

# Optimization step `opt_step_idx`

# 1. Initialize the problem instance
## Intialize model instance
ds = df.iloc[idx_mod]

print("")
print(f"Optimization step {opt_step_idx+1}/{ps.max_opt_steps}")

## Environment variables predictions
ds = df.iloc[hor_span[0]:hor_span[1]]
env_vars: EnvironmentVariables = EnvironmentVariables(
    I=ds['I'].values,
    Tamb=ds['Tamb'].values,
    Tmed_c_in=ds['Tmed_c_in'].values,
    cost_w=np.ones((ps.n_evals_mod_in_hor_window, )) * pp.env_params.cost_w,
    cost_e=np.ones((ps.n_evals_mod_in_hor_window, )) * pp.env_params.cost_e,
)

## Initialize problem
base_problem = BaseMinlpProblem(
    model=model, 
    sample_time_opt=pp.sample_time_opt,
    optim_window_time=pp.optim_window_time,
    env_vars=env_vars,
    dec_var_updates=pp.dec_var_updates,
    fsm_valid_sequences=pp.fsm_valid_sequences,
    fsm_data_path=fsm_data_path
)


## EvoX specific code

In [None]:
from evox import Problem, State
# from jax import random

class MinlpProblem(Problem):
    problem_instance: BaseMinlpProblem
    
    def __init__(self, problem_instance: BaseMinlpProblem) -> None:
        
        super().__init__()
        self.problem_instance = problem_instance 
        
    # def setup(self, key: random.PRNGKey) -> State:
    #     return State()
        

    def evaluate(self, state: State, x: np.ndarray) -> tuple[float, State]:
        # bitstrings has shape (pop_size, num_bits)
        # so sum along the axis 1.
        
        return evaluate_fitness(self.problem_instance, x), state


In [None]:
from evox import algorithms, workflows, monitors, use_state
import jax.numpy as jnp
from jax import random

lb, ub = get_bounds(base_problem)

# Algorithm
algo = algorithms.PSO(
    lb=jnp.array(lb),
    ub=jnp.array(ub),
    pop_size=10,
)

# Problem
problem = MinlpProblem(problem_instance=base_problem)

# Monitor
monitor = monitors.EvalMonitor()

workflow = workflows.StdWorkflow(
    algorithm=algo, 
    problem=problem,
    monitors=[monitor],
)

state = workflow.init(random.PRNGKey(42))


In [None]:
for i in range(10):
    state = workflow.step(state)
    print(f"Generation {i}")
    print(f"Solution space: \n{use_state(monitor.get_latest_solution)(state)[0]}")
    print(f"Fitness: \n{use_state(monitor.get_latest_fitness)(state)[0]}")


In [None]:
best_fitness, state = use_state(monitor.get_best_fitness)(state)
print(f"Best fitness so far: {best_fitness}")

best_solution, state = use_state(monitor.get_best_solution)(state)
print(f"Best solution so far: {best_solution}")

monitor.plot()
