In [9]:
import sparse
import torch
import datetime
import os
import json


from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.ppo_mask import MaskablePPO

import gymnasium as gym


In [10]:
import sys

sys.path.append('/home/laia/Documents/GNN-DRL-qbit-allocation/')  # Cambia esto por la ruta absoluta a src
from src.environment.env import GraphSeriesEnv
from src.models.ppo_policy import CustomPPOPolicy
from src.models.feature_extractor import GATv2FeatureExtractor
from data.circuit_generator import generate_circuit
from src.utils.callback import CustomTensorboardCallback


In [11]:
torch.cuda.is_available()

True

In [13]:
from src.utils.constants import N_CORES, OUT_FEATURES_GAT


run_id = datetime.datetime.now().strftime("%m_%d-%H_%M_%S")


gnn_config = dict(
                  in_features = 2*N_CORES,
                  edge_dim = 2,
                  hidden_features_gat = 64,
                  out_features_gat = OUT_FEATURES_GAT,
                  num_heads = 1,
                  hidden_layers_gat=4,
                  gat_dropout = 0.3
                  )


config = {
    'circuit_config': {'circuit': None , 'random_circuits': True, 'n_slices': 32, 'n_qubits': 8, 'gates_per_slice': 1},
    'weights_reward': {'nonlocal': 10, 'capacity': 30, 'intervention': 40, 'slice_idx': 50},
    'ppo_config': {'batch_size': 64, 'learning_rate': 0.0005, 'n_steps': 4096 , 'gae_lambda': 0.95},
    'policy_kwargs': {'net_arch': [64, 64, 64], 'features_extractor_kwargs': gnn_config},
    'total_timesteps': 3000000
    }


load_model = False
load_path = 'models/'
run_dir = f"runs/{run_id}/"


model_dir = f'models/{run_id}/'



os.makedirs(model_dir, exist_ok=True)

with open(model_dir + 'config.json', 'w') as f:
    json.dump(config, f, indent=4)


config['policy_kwargs']['features_extractor_class'] = GATv2FeatureExtractor


In [14]:
# environment

a = sparse.load_npz('../data/cuccaroadder_q8.npz')
#circuit_config = {'circuit': a}


def make_env():
    env = GraphSeriesEnv(config=config['circuit_config'], weights_reward=config['weights_reward'])
    env = ActionMasker(env, lambda e: e.qbit_mask())
    #env = Monitor(env)
    return env


env = DummyVecEnv([make_env])
print(env.envs[0])

<ActionMasker<GraphSeriesEnv instance>>


In [15]:
# model

if load_model:
    model = MaskablePPO.load(load_path, env=env, device="cuda" if torch.cuda.is_available() else "cpu")

else:
    model = MaskablePPO(policy=CustomPPOPolicy,
                policy_kwargs=config['policy_kwargs'],
                env = env,
                device="cuda" if torch.cuda.is_available() else "cpu",
                verbose = 1,
                seed = 42,
                tensorboard_log = 'runs',
                **config['ppo_config']
            )

print(model.device)


feature_extr = model.policy.features_extractor

Using cuda device
{}
cuda


In [16]:
# train
callback = CustomTensorboardCallback(save_path = model_dir, verbose=1)    

model.learn(
    total_timesteps=config['total_timesteps'],
    callback=callback,
    reset_num_timesteps = False,
    tb_log_name=run_id
)

'''
model.learn(
    total_timesteps=wandb_config["total_timesteps"],
    callback=WandbCallback(
        gradient_save_freq=100,
        #model_save_path=f"models/{run.id}",
        verbose=1,
    ),
)
'''


Logging to runs/01_10-16_22_27_0


-----------------------------------------------
| episode/                         |          |
|    direct_capacity_violation_sum | 0        |
|    final_reward                  | -2640.0  |
|    intervention_sum              | 34       |
|    nl_comm_sum                   | 128      |
| time/                            |          |
|    fps                           | 68       |
|    iterations                    | 1        |
|    time_elapsed                  | 59       |
|    total_timesteps               | 4096     |
| total/                           |          |
|    truncated                     | 0        |
-----------------------------------------------
--------------------------------------------------
| episode/                         |             |
|    direct_capacity_violation_sum | 0           |
|    final_reward                  | -2460.0     |
|    intervention_sum              | 30          |
|    nl_comm_sum                   | 126         |
| time/               

'\nmodel.learn(\n    total_timesteps=wandb_config["total_timesteps"],\n    callback=WandbCallback(\n        gradient_save_freq=100,\n        #model_save_path=f"models/{run.id}",\n        verbose=1,\n    ),\n)\n'