In [None]:
%load_ext autoreload
%autoreload 2


In [1]:
import os
cwd = os.getcwd().replace("/note/experiments", "")
os.chdir(cwd)


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import xtools as xt
import json
import xsim
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm
from collections import namedtuple


In [3]:
from qml.db import dpo as xdpo
from qml.db import target as xtarget
from qml.db.ml import MLDatasetGenerator
from qml.db.dpo.decoder import DPODataDecoder, DPOData
from qml.db.dpo.loader import DPODataBatch
from qml.model.gate import Gateset
from qml.model.unit import UnitManager, Unit
from qml.model.encoding import EncodingUnitManager
from qml.model.model import Model
from qml.optimizer import dpo as xdpopt
from qml.optimizer import evaluator as xeval
from qml.tools.random import XRandomGenerator
from qml.tools.sampler import CandidateSampler
from qml import optimizer as xoptim
from qml.tools.config import Config
# from qml.tools.validation import validate
from qml.tools.validation import get_base_qc
from qml.tools.experiment_tools import train_once, hard_copy, prepare_policy, plot_results, validate, generate_batch, soft_update

In [13]:
cf = Config("note/experiments/config.yaml")
seed = None

In [14]:
online_dpo_weight_filepath = "/Users/loutrebleu/workspace/qml/note/experiments/results/2025.04.14.195932/results_00730.pth"
online_cpo_weight_filepath = "/Users/loutrebleu/workspace/qml/note/experiments/results/2025.04.14.195932/results_00730.pth"

In [18]:
# build policies
policy_dpo, _, _, sampler_dpo = prepare_policy(cf)
policy_dpo.load_state_dict(torch.load(online_dpo_weight_filepath))
policy_cpo, _, _, sampler_cpo = prepare_policy(cf)
policy_cpo.load_state_dict(torch.load(online_cpo_weight_filepath))


<All keys matched successfully>

In [24]:
rng = XRandomGenerator(seed)
gset = Gateset.set_num_qubits(cf.nq)

# target datasets
tfun = xtarget.PolynominalTargetFunctionGenerator(cf.qml.db.dim_polynomial, seed=rng.new_seed())
tgen = MLDatasetGenerator(tfun, seed=rng.new_seed())
Dqml = tgen.generate(cf.qml.db.size)
# Dqml = tgen.generate(5)

# prepare
base_model = get_base_qc(cf.nq, cf.qml.db.dim_input, cf.qml.db.dim_output, cf.shots)
weval = xeval.WaveletEvaluator(cf.wavelet, Dqml, wavelet_dim=cf.ocg.dim_wavelet)

policy = policy_dpo
sampler = sampler_dpo
model = Model(
    base_model.nq, base_model.nc,
    base_model.input_units,
    base_model.fixed_units,
    base_model.trainable_units,
)

# for round in range(1, cf.dpo.validation.num_rounds+1+2):
for round in tqdm(range(1)):
    # 1. measure the wavelet series
    wresult = weval(model.trainable_parameters, model)

    # 2. estimate the candidate unit
    wseries = wresult.powers
    candidate = sampler.sample(wseries)
    model.fix_trainable_units()

    # 3. updata model
    model = Model(
        base_model.nq, base_model.nc,
        base_model.input_units,
        base_model.fixed_units,
        candidate,
    )

    # 4. train the model
    optimizer = xoptim.LocalSearchOptimizer(Dqml)
    tresult = optimizer.optimize(model, cf.qml.num_train, verbose=False)

print(tresult)


  0%|          | 0/1 [00:00<?, ?it/s]

[[-0.52068123 -0.28068123 -0.40068123 -0.44068123 -0.40068123 -0.44068123
  -0.44068123 -0.12068123 -0.52068123 -0.40068123]
 [-0.64270338 -0.40270338 -0.52270338 -0.56270338 -0.52270338 -0.56270338
  -0.56270338 -0.24270338 -0.64270338 -0.52270338]
 [-0.60466196 -0.36466196 -0.48466196 -0.52466196 -0.48466196 -0.52466196
  -0.52466196 -0.20466196 -0.60466196 -0.48466196]
 [-0.80508968 -0.56508968 -0.68508968 -0.72508968 -0.68508968 -0.72508968
  -0.72508968 -0.40508968 -0.80508968 -0.68508968]
 [-0.77585951 -0.53585951 -0.65585951 -0.69585951 -0.65585951 -0.69585951
  -0.69585951 -0.37585951 -0.77585951 -0.65585951]
 [-0.50296094 -0.26296094 -0.38296094 -0.42296094 -0.38296094 -0.42296094
  -0.42296094 -0.10296094 -0.50296094 -0.38296094]
 [-0.51007931 -0.27007931 -0.39007931 -0.43007931 -0.39007931 -0.43007931
  -0.43007931 -0.11007931 -0.51007931 -0.39007931]
 [-0.15305292  0.08694708 -0.03305292 -0.07305292 -0.03305292 -0.07305292
  -0.07305292  0.24694708 -0.15305292 -0.03305292]


100%|██████████| 1/1 [00:10<00:00, 10.70s/it]

<qml.tools.logger.Logger object at 0x174768d10>



