In [None]:
import warnings
import os
import sys
from math import comb
import torch
warnings.filterwarnings('ignore')

project_root = os.path.abspath("..")  
if project_root not in sys.path:
    sys.path.insert(0, project_root)
from src.rl_environment import defining_environments
from src.rl_policy import UnifiedPolicy
from src.cortical_estimator import CORTICAL
from src.ba_estimator import MI_ESTIMATOR
from src.utils import set_seed, train_cases

In [None]:
set_seed(42)

dimension = 2
num_thresholds = 4
alphabet_size = int(sum(comb(num_thresholds, k) for k in range(dimension + 1)))
box_param = 1.5
mi_estimator = 'BA'

thrsh = None
if dimension==1.0 and num_thresholds==1: thrsh = torch.tensor([1.0,0.0])
elif dimension==1.0 and num_thresholds==2: thrsh = torch.tensor([1.0,-0.5,1.0,0.5])
elif dimension==1.0 and num_thresholds==3: thrsh = torch.tensor([1.0,-0.5,1.0,0.0,1.0,0.5])
elif dimension==2.0 and num_thresholds==2: thrsh = torch.tensor([1.0,0.0,0.0,0.0,1.0,0.0])
elif dimension==2.0 and num_thresholds==3: thrsh = torch.tensor([0.0,1.0,0.0,1.0,0.0,-0.25,1.0,0.0,0.25])
elif dimension==2.0 and num_thresholds==4: thrsh = torch.tensor([1.0,0.0,-0.25,1.0,0.0,0.25,0.0,1.0,0.25,0.0,1.0,-0.25])

num_envs, max_steps, patience, num_episodes = 10, 2000, 100, 5
norm_patience = 2000, 2000
kl_coeff = 0.3
lr = 1e-3

if mi_estimator == 'BA':
    mi_est = MI_ESTIMATOR(dimension,box_param,'identity-csi',10000)
elif mi_estimator == 'CORTICAL':
    mi_est = CORTICAL((512,dimension,alphabet_size,num_thresholds,num_envs),alphas=[1.0, 0.5, 0.1, 0.01],
                    lambda_entropy=0.3,cost_coef=10.0,box_param=box_param)
    mi_est.load_models('./models/cortical_models/')
    
policy = UnifiedPolicy(dimension,alphabet_size,num_thresholds,box_param,kl_coeff,mi_est,policy_scale=3)

run_id = 'run0'
envs,thrsh,qtpts = defining_environments(dimension, num_thresholds, alphabet_size, box_param, 
                                         (-10.0,40.0), num_envs, max_steps, patience, mi_est, 
                                         norm_patience,True, False,thrsh)
fig = envs[0].render()
fig.show()
policy = train_cases(dimension,num_thresholds,alphabet_size,box_param,num_envs,max_steps,patience,num_episodes,norm_patience,lr,mi_est,policy,run_id)