In [None]:
import wandb

import torch
import torch.nn as nn
from gymnasium import spaces
from qiskit.transpiler import CouplingMap
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

from ai_linear_function_synthesis.env import *

In [None]:
NUM_QUBITS = 5
COUPLING_MAP = CouplingMap.from_line(NUM_QUBITS)
BATCH_SIZE = 100
SUCCESS_RATE_THRESHOLD = 0.8

model_class = PPO
NUM_TIME_STEPS = 5 * 10**5

config = {
    "num_qubits": NUM_QUBITS,
    "coupling_map": COUPLING_MAP,
    "eval_batch_size": BATCH_SIZE,
    "success_rate_threshold": SUCCESS_RATE_THRESHOLD,
    "model_class": model_class.__name__,
    "num_time_steps": NUM_TIME_STEPS,
}

device_name = "cpu"
if torch.cuda.is_available():
    device_name = "cuda"
if torch.backends.mps.is_available():
    device_name = "mps"
print(f"device_name: {device_name}")

In [None]:
class CustomCNN(BaseFeaturesExtractor):
    """
    ref: stable-baselines3 documentation & https://github.com/greentfrapp/snake
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: spaces.Box, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU(),
            nn.Linear(features_dim, features_dim),
            nn.ReLU(),
        )

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

## Curriculum learning + sparse reward

In [None]:
wandb.init(
    project="ai-linear-function-synthesis",
    name="run_curriculum",
    config=config,
)


env = AILinearFunctionSynthesis(
    coupling_map=COUPLING_MAP,
    eval_batch_size=BATCH_SIZE,
    success_rate_threshold=SUCCESS_RATE_THRESHOLD,
    wandb_log=True,
)

# Initialize the model
model = model_class(
    policy="CnnPolicy",
    env=env,
    verbose=1,
    policy_kwargs=dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=64),
    ),
    device=device_name,
)

# Train the model
model.learn(
    total_timesteps=NUM_TIME_STEPS,
    log_interval=10**5,
    progress_bar=True,
)

model.save("saved_models/lin_func_curriculum")

wandb.finish()

## No curriculum learning + sparse reward

In [None]:
wandb.init(
    project="ai-linear-function-synthesis",
    name="run_no_curriculum",
    config=config,
)


env = AILinearFunctionSynthesisNoCurriculumLearning(
    coupling_map=COUPLING_MAP,
    eval_batch_size=BATCH_SIZE,
    wandb_log=True,
)

# Initialize the model
model = model_class(
    policy="CnnPolicy",
    env=env,
    verbose=1,
    policy_kwargs=dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=64),
    ),
    device=device_name,
)

# Train the model
model.learn(
    total_timesteps=NUM_TIME_STEPS,
    log_interval=10**5,
    progress_bar=True,
)

model.save("saved_models/lin_func_no_curriculum")

wandb.finish()

## Curriculum learning + dense reward

In [None]:
wandb.init(
    project="ai-linear-function-synthesis",
    name="run_dense",
    config=config,
)


env = AILinearFunctionSynthesisDenseReward(
    coupling_map=COUPLING_MAP,
    eval_batch_size=BATCH_SIZE,
    success_rate_threshold=SUCCESS_RATE_THRESHOLD,
    wandb_log=True,
)

# Initialize the model
model = model_class(
    policy="CnnPolicy",
    env=env,
    verbose=1,
    policy_kwargs=dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=64),
    ),
    device=device_name,
)

# Train the model
model.learn(
    total_timesteps=NUM_TIME_STEPS,
    log_interval=10**5,
    progress_bar=True,
)

model.save("saved_models/lin_func_dense")

wandb.finish()