In [7]:
%load_ext autoreload
%autoreload 2


In [8]:
import os
os.chdir("../..")


## Considerer la format de sortie de l'apprentissage par renforcement

Input of the RL system is a wavelet series of the regidual error comme the state vector


Here the original quantum circuit is written as the folloing:
$$
\Sigma_\ell = \underset{j \in J}{\otimes} \Sigma_{\ell j} \longleftarrow a_\ell = \left< \: g_\ell, \; q_\ell \: \right> \longleftarrow \pi\left( s_\ell ; \; \phi \right)
$$

State consists the partial wavelet series of the residual error vector:
$$
s_\ell = F_w \left( \varepsilon_\ell; \; \omega \right) \\

\varepsilon_\ell = \left[ \: t_k - \left< \: B \rho_\ell \left( x_k; \; \theta_L \right) \: \right> | \; \forall \left( x_k, t_k \right) \in \mathcal{D} \: \right]
$$

In [574]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import json


In [199]:
from qml.model.gate import Gateset
from qml.model.unit import UnitManager, Unit


In [562]:
num_qubits = 2
num_gates = 3
dim_wavelet = 4

batch_size = 4

beta = 0.5


In [141]:
class Policy(nn.Module):

    def __init__(self, dim_wavelet: int, num_qubits: int, num_gates: int, dim_hiddens: list[int] = [32, 32]):
        super().__init__()

        self.nq = num_qubits
        self.ng = num_gates
        gset = Gateset.set_num_qubits(num_qubits)

        self.dim_gateinfo = gset.size * self.ng
        self.dim_qubitinfo = self.ng * self.nq
        self.dim_output = dim_output = self.dim_gateinfo + self.dim_qubitinfo
        self.dim_input = dim_input = 2 ** dim_wavelet - 1

        dim_units = dim_hiddens.copy()
        dim_units.append(dim_output)
        dim_units.insert(0, dim_input)
        self.net = nn.Sequential(
            *sum([
                self.build_layer(din, dout, activation=(l < len(dim_hiddens)))
                for l, din, dout in zip(range(len(dim_units)+1), dim_units[:-1], dim_units[1:]) 
            ], [])
        )
    
    @staticmethod
    def build_layer(din, dout, activation=True):
        layer = [nn.Linear(din, dout)]
        if activation:
            layer.append(nn.ReLU())
        return layer
    
    def forward(self, x):
        feat = self.net(x)
        return feat


actor = Policy(dim_wavelet, num_qubits, num_gates)
xs = torch.from_numpy(np.random.rand(batch_size, 2**dim_wavelet-1)).float()
act_features = actor(xs)
act_features.shape
# act_features = act_features.view(-1, 2, num_gates)
# actions = torch.argmax(act_features, dim=-1)
# print(act_features)
# print(actions)


torch.Size([4, 18])

In [307]:
def calc_prob_table(logits: torch.Tensor, num_gates: int):
    logits = logits.reshape(num_gates, -1)
    probs = torch.softmax(logits, 1)
    return probs

def select_idx_at_prob(probs: torch.Tensor):
    np_probs = probs.detach().numpy()
    size = np_probs.shape[1]
    selectables = np.arange(size)
    idxs = np.array([
        np.random.choice(selectables, replace=True, p=prob)
        for prob in np_probs
    ])
    return idxs

def get_onehot(idx, num_classes):
    onehot = nn.functional.one_hot(torch.from_numpy(idx), num_classes)
    return onehot

def calc_selected_probs(probs, onehot):
    filtered = probs * onehot
    probs = filtered.sum(1)
    return probs


In [166]:
num_candidates = 5

batch_size = 4
len_wavelet_series = 2 ** dim_wavelet - 1
gset = Gateset.set_num_qubits(num_qubits)

dim_gateinfo = gset.size * num_gates

uman = UnitManager(num_qubits, num_gates)


In [148]:
# prepare models
policy = Policy(dim_wavelet, num_qubits, num_gates)
policy_ref = Policy(dim_wavelet, num_qubits, num_gates)
policy_ref.load_state_dict(policy.state_dict())


<All keys matched successfully>

Training with buffer
- wavelet seriese
- list[unit json]
- list[loss]

1. predict prob table
2. calc probs for units
3. calc loss of DPO
3. training step!

In [180]:
gdict = {key: i for i, key in enumerate(gset.keys())}
gdict


{'RX': 0, 'RY': 1, 'RZ': 2, 'CZ': 3}

In [170]:
# wavelet series
states = np.random.rand(2 ** dim_wavelet - 1)

# unit json
units = [uman.generate_random_unit() for _ in range(num_candidates)]
ujsons = [unit.to_json() for unit in units]

# losses for each unit
losses = np.random.rand(num_candidates)
print(losses)


[0.13829247 0.69440971 0.12206243 0.09769475 0.42599988]


## NEW data gen

In [228]:
def generate_random_data():
    states = np.random.rand(2 ** dim_wavelet - 1)
    udicts = [unit.to_dict() for unit in units]
    losses = np.random.rand(num_candidates)
    return states, udicts, losses


In [235]:
def make_db_json(pwseriese, udicts, losses):
    if isinstance(pwseriese, np.ndarray):
        pwseriese = pwseriese.tolist()
    if isinstance(udicts[0], Unit):
        udicts = [unit.to_dict() for unit in units]
    if isinstance(losses, np.ndarray):
        losses = losses.tolist()
    data = dict(
        pwseries=pwseriese,
        units=udicts,
        losses=losses,
    )
    djson = json.dumps(data)
    return djson


In [236]:
djson = make_db_json(states, units, losses)
djson


'{"pwseries": [0.5610219497431205, 0.4248068396481297, 0.32671685122822425, 0.005159735984395897, 0.5024070739362936, 0.10545909912556395, 0.17577469853905225, 0.49217431117780874, 0.053339417202042605, 0.2658375751933949, 0.35820264999165874, 0.3539492500895648, 0.8988563619370747, 0.024339032191310617, 0.973204539282582], "units": [{"name": "unit_15", "gates": ["cz", "ry", "cz"], "qubits": [1, 0, 0], "params": [0.0]}, {"name": "unit_16", "gates": ["rx", "rx", "rz"], "qubits": [1, 0, 0], "params": [0.0, 0.0, 0.0]}, {"name": "unit_17", "gates": ["rz", "rz", "cz"], "qubits": [1, 0, 1], "params": [0.0, 0.0]}, {"name": "unit_18", "gates": ["rx", "rx", "ry"], "qubits": [1, 0, 0], "params": [0.0, 0.0, 0.0]}, {"name": "unit_19", "gates": ["rx", "rz", "cz"], "qubits": [0, 0, 1], "params": [0.0, 0.0]}], "losses": [0.1648565772331949, 0.027398155840465788, 0.7185280222624457, 0.12626717648477725, 0.7594827934668941]}'

In [237]:
rdata = json.loads(djson)
rdata


{'pwseries': [0.5610219497431205,
  0.4248068396481297,
  0.32671685122822425,
  0.005159735984395897,
  0.5024070739362936,
  0.10545909912556395,
  0.17577469853905225,
  0.49217431117780874,
  0.053339417202042605,
  0.2658375751933949,
  0.35820264999165874,
  0.3539492500895648,
  0.8988563619370747,
  0.024339032191310617,
  0.973204539282582],
 'units': [{'name': 'unit_15',
   'gates': ['cz', 'ry', 'cz'],
   'qubits': [1, 0, 0],
   'params': [0.0]},
  {'name': 'unit_16',
   'gates': ['rx', 'rx', 'rz'],
   'qubits': [1, 0, 0],
   'params': [0.0, 0.0, 0.0]},
  {'name': 'unit_17',
   'gates': ['rz', 'rz', 'cz'],
   'qubits': [1, 0, 1],
   'params': [0.0, 0.0]},
  {'name': 'unit_18',
   'gates': ['rx', 'rx', 'ry'],
   'qubits': [1, 0, 0],
   'params': [0.0, 0.0, 0.0]},
  {'name': 'unit_19',
   'gates': ['rx', 'rz', 'cz'],
   'qubits': [0, 0, 1],
   'params': [0.0, 0.0]}],
 'losses': [0.1648565772331949,
  0.027398155840465788,
  0.7185280222624457,
  0.12626717648477725,
  0.7594827

In [247]:
def decode_gate_indices(udicts):
    # convert json to indices
    # 1 gates
    sgates = [uinfo["gates"] for uinfo in udicts]
    igates = [
        [gdict[sg.upper()] for sg in sgate]
        for sgate in sgates
    ]
    return igates

def decode_qubits(udicts):
    # 2 qubits
    iqubits = [uinfo["qubits"] for uinfo in udicts]
    
    return iqubits


In [239]:
decode_udicts(rdata["units"])


([[3, 1, 3], [0, 0, 2], [2, 2, 3], [0, 0, 1], [0, 2, 3]],
 [[1, 0, 0], [1, 0, 0], [1, 0, 1], [1, 0, 0], [0, 0, 1]])

In [241]:
x = torch.tensor(rdata["pwseries"])


### Consider le batch

In [383]:
class BatchData:

    def __init__(self, pwserieses, gate_indices, qubits, losses, num_gate_classes, num_qubits):
        self._pwserieses = np.asarray(pwserieses)
        self._gate_indices = np.asarray(gate_indices)
        self._qubits = np.asarray(qubits)
        self._losses = np.asarray(losses)
        self.size = len(qubits)
        self.num_gate_classes = num_gate_classes
        self.num_qubits = num_qubits
        self.num_cadicates = self.losses.shape[-1]
    
    @property
    def pwserieses(self):
        return torch.from_numpy(self._pwserieses).float()
    @property
    def states(self):
        return self.pwserieses
    @property
    def np_pwserieses(self):
        return self._pwserieses.copy()
    
    @property
    def gate_indices(self):
        return torch.from_numpy(self._gate_indices).float()
    @property
    def igates(self):
        return self.gate_indices
    @property
    def np_gate_indices(self):
        return self._gate_indices.copy()
    @property
    def onehot_igates(self):
        return get_onehot(self._gate_indices, self.num_gate_classes)
    
    @property
    def qubits(self):
        return torch.from_numpy(self._qubits).float()
    @property
    def np_qubits(self):
        return self._qubits.copy()
    @property
    def onehot_qubits(self):
        return get_onehot(self._qubits, self.num_qubits)
    
    @property
    def losses(self):
        return torch.from_numpy(self._losses).float()
    @property
    def np_losses(self):
        return self._losses.copy()
    @property
    def best_indices(self):
        return torch.argmin(self.losses, dim=-1)
    @property
    def onehot_ibests(self):
        return get_onehot(np.argmin(self.np_losses, axis=-1), self.num_cadicates)


In [384]:
batch_data = [
    make_db_json(*generate_random_data())
    for _ in range(batch_size)
]
batch_data


['{"pwseries": [0.5660569355850306, 0.5801335132441083, 0.41876590896698174, 0.7757465754614421, 0.17146400948803642, 0.8249182098833542, 0.7029926606449767, 0.02850381134829305, 0.8216821751237374, 0.7469196735665115, 0.20055910903906726, 0.8503942063551014, 0.9675239400755534, 0.38901468536903605, 0.36128104384650805], "units": [{"name": "unit_15", "gates": ["cz", "ry", "cz"], "qubits": [1, 0, 0], "params": [0.0]}, {"name": "unit_16", "gates": ["rx", "rx", "rz"], "qubits": [1, 0, 0], "params": [0.0, 0.0, 0.0]}, {"name": "unit_17", "gates": ["rz", "rz", "cz"], "qubits": [1, 0, 1], "params": [0.0, 0.0]}, {"name": "unit_18", "gates": ["rx", "rx", "ry"], "qubits": [1, 0, 0], "params": [0.0, 0.0, 0.0]}, {"name": "unit_19", "gates": ["rx", "rz", "cz"], "qubits": [0, 0, 1], "params": [0.0, 0.0]}], "losses": [0.46741752904476186, 0.17349056878167646, 0.5966442677772882, 0.7614166474717902, 0.7162669422931702]}',
 '{"pwseries": [0.5519988502746666, 0.4644204501949948, 0.2041280407023356, 0.69

In [386]:
batch = BatchData(
    np.vstack([bd["pwseries"] for bd in batch_dicts]),
    [decode_gate_indices(bd["units"]) for bd in batch_dicts],
    [decode_qubits(bd["units"]) for bd in batch_dicts],
    np.vstack([bd["losses"] for bd in batch_dicts]),
    gset.size, num_qubits
)


In [398]:
def calc_selected_probs(logits, onehots, num_gates):
    probs = calc_prob_table(logits, num_gates)
    probs = probs.view(1, *probs.shape)
    filtered_probs = probs * onehots
    selected_probs = filtered_probs.sum(dim=-1)
    return selected_probs

def _calc_prob(gate_logits, qubit_logits, gate_onehot, qubit_onehot):
    gate_selected_probs = calc_selected_probs(gate_logits, gate_onehot, policy.ng)
    qubit_selected_probs = calc_selected_probs(qubit_logits, qubit_onehot, policy.ng)
    selected_probs = gate_selected_probs * qubit_selected_probs
    probs = selected_probs.prod(dim=1)
    return probs

def calc_probs(model, batch):
    batch_logits = model(batch.states)
    batch_gate_logits  = batch_logits[..., :policy.dim_gateinfo]
    batch_qubit_logits = batch_logits[..., policy.dim_gateinfo:]

    probs = torch.vstack([
        _calc_prob(gate_logits, qubit_logits, gate_onehot, qubit_onehot)
        for gate_logits, qubit_logits, gate_onehot, qubit_onehot
        in zip(batch_gate_logits, batch_qubit_logits, batch.onehot_igates, batch.onehot_qubits)
    ])
    return probs

probs = calc_probs(policy, batch)
probs_ref = calc_probs(policy_ref, batch).detach()
print(probs)
print(probs_ref)


tensor([[0.0027, 0.0028, 0.0013, 0.0031, 0.0014],
        [0.0025, 0.0027, 0.0013, 0.0032, 0.0015],
        [0.0025, 0.0026, 0.0013, 0.0029, 0.0014],
        [0.0025, 0.0027, 0.0013, 0.0029, 0.0013]], grad_fn=<CatBackward0>)
tensor([[0.0027, 0.0028, 0.0013, 0.0031, 0.0014],
        [0.0025, 0.0027, 0.0013, 0.0032, 0.0015],
        [0.0025, 0.0026, 0.0013, 0.0029, 0.0014],
        [0.0025, 0.0027, 0.0013, 0.0029, 0.0013]])


In [495]:
print(batch.onehot_igates.shape)
print(batch.best_indices.shape)
# print(batch.best_indices.reshape(1, num_candidates, 1))
# best_gate_indices = torch.gather(batch.onehot_igates, -1, )
# best_gate_indices


torch.Size([4, 5, 3, 4])
torch.Size([4])


In [551]:
logits = policy(batch.states)
gate_logits = logits[..., :policy.dim_gateinfo]
qbit_logits = logits[..., policy.dim_gateinfo:]

rs_gate_logits = gate_logits.view(batch.size, policy.ng, -1)
logp_gate = nn.functional.log_softmax(rs_gate_logits, dim=-1)

rs_qbit_logits = qbit_logits.view(batch.size, policy.ng, -1)
logp_qbit = nn.functional.log_softmax(rs_qbit_logits, dim=-1)


In [552]:
rs_logp_gate = logp_gate.unsqueeze(1)
selected_logp_gate = (batch.onehot_igates * rs_logp_gate).sum(dim=-1)

rs_logp_qbit = logp_qbit.unsqueeze(1)
selected_logp_qbit = (batch.onehot_qubits * rs_logp_qbit).sum(dim=-1)

selected_logp = (selected_logp_gate + selected_logp_qbit).sum(dim=-1)
print(selected_logp_gate.shape)
print(selected_logp_qbit.shape)
print(selected_logp.shape)


torch.Size([4, 5, 3])
torch.Size([4, 5, 3])
torch.Size([4, 5])


In [531]:
best_logp = (batch.onehot_ibests * selected_logp).sum(dim=-1)
best_logp


tensor([-5.8862, -5.9920, -6.6633, -6.6659], grad_fn=<SumBackward1>)

In [532]:
other_logp = (selected_logp.sum(dim=-1) - best_logp) / (num_candidates - 1)
other_logp


tensor([-6.2424, -6.2138, -6.0962, -6.1011], grad_fn=<DivBackward0>)

In [607]:
def calc_logp(logits, num_gates):
    rs_logits = logits.view(batch.size, num_gates, -1)
    return nn.functional.log_softmax(rs_logits, dim=-1)

def calc_selected_logp(logp, onehot):
    rs_logp = logp.unsqueeze(1)
    return (onehot * rs_logp).sum(dim=-1)


def calc_logps(model, batch, num_candidates, num_gates, detach=False):
    logits = model(batch.states)
    logits_gate = logits[..., :model.dim_gateinfo]
    logits_qbit = logits[..., model.dim_gateinfo:]

    logp_gate = calc_logp(logits_gate, num_gates)
    logp_qbit = calc_logp(logits_qbit, num_gates)

    selected_logp_gate = calc_selected_logp(logp_gate, batch.onehot_igates)
    selected_logp_qbit = calc_selected_logp(logp_qbit, batch.onehot_qubits)

    selected_logp = (selected_logp_gate + selected_logp_qbit).sum(dim=-1)
    
    logp_best = (batch.onehot_ibests * selected_logp).sum(dim=-1)
    logp_others = (selected_logp.sum(dim=-1) - logp_best) / (num_candidates - 1)
    if detach:
        logp_best = logp_best.detach()
        logp_others = logp_others.detach()
    return logp_best, logp_others

calc_logps(policy, batch, num_candidates, num_gates)


(tensor([-5.8899, -6.4940, -6.3668, -6.2932], grad_fn=<SumBackward1>),
 tensor([-6.2314, -6.1534, -6.1864, -6.1731], grad_fn=<DivBackward0>))

In [608]:
def calc_loss_dpo(logrp, logrp_ref, beta=0.5):
    return -1 * nn.functional.logsigmoid(beta * (logrp - logrp_ref)).mean()

def calc_loss_llh(logp_best):
    return torch.exp(logp_best).mean()


In [581]:
logp_best, logp_others = calc_logps(policy, batch, num_candidates, num_gates)
logp_ref_best, logp_ref_others = calc_logps(policy_ref, batch, num_candidates, num_gates, detach=True)
logp_best, logp_ref_best


(tensor([-5.8862, -5.9920, -6.6633, -6.6659], grad_fn=<SumBackward1>),
 tensor([-5.8862, -5.9920, -6.6633, -6.6659]))

In [582]:
loss_dpo = calc_loss_dpo(logp_best - logp_others, logp_ref_best - logp_ref_others, beta=beta)
loss_llh = calc_loss_llh(logp_best)
loss_dpo, loss_llh


(tensor(0.6931, grad_fn=<MulBackward0>),
 tensor(0.0020, grad_fn=<MeanBackward0>))

In [None]:
logrp = logp_best - logp_others
logrp_ref = logp_ref_best - logp_ref_others


In [566]:
loss_dpo = -1 * nn.functional.logsigmoid(beta * (logp_best - logp_ref_best - logp_others + logp_ref_others))
loss_llh = torch.exp(logp_best)
print(loss_dpo, loss_llh)


tensor([0.6931, 0.6931, 0.6931, 0.6931], grad_fn=<MulBackward0>) tensor([0.0028, 0.0025, 0.0013, 0.0013], grad_fn=<ExpBackward0>)


In [643]:
policy = Policy(dim_wavelet, num_qubits, num_gates)
policy_ref = Policy(dim_wavelet, num_qubits, num_gates)
policy_ref.load_state_dict(policy.state_dict())


<All keys matched successfully>

In [644]:
optimizer = optim.Adam(policy.parameters(), lr=1e-3)


In [645]:
# before train step
logits = policy(batch.states)
logits_gate = logits[..., :policy.dim_gateinfo]
logits_qbit = logits[..., policy.dim_gateinfo:]

before = calc_logp(logits_gate, num_gates)[..., 0].detach().numpy()


# Calc loss here

In [646]:
for step in range(50):
    # prediction
    logp_best, logp_others = calc_logps(policy, batch, num_candidates, num_gates)
    logp_ref_best, logp_ref_others = calc_logps(policy_ref, batch, num_candidates, num_gates, detach=True)

    # losses
    loss_dpo = calc_loss_dpo(logp_best - logp_others, logp_ref_best - logp_ref_others, beta=beta)
    loss_llh = calc_loss_llh(logp_best)
    loss = loss_dpo + 0 * loss_llh
    print(f"step:{step+1:>3d} loss: {loss_dpo.item():6.3f}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


step:  1 loss:  0.693
step:  2 loss:  0.688
step:  3 loss:  0.682
step:  4 loss:  0.678
step:  5 loss:  0.673
step:  6 loss:  0.668
step:  7 loss:  0.663
step:  8 loss:  0.659
step:  9 loss:  0.654
step: 10 loss:  0.650
step: 11 loss:  0.646
step: 12 loss:  0.642
step: 13 loss:  0.637
step: 14 loss:  0.633
step: 15 loss:  0.629
step: 16 loss:  0.625
step: 17 loss:  0.621
step: 18 loss:  0.617
step: 19 loss:  0.613
step: 20 loss:  0.609
step: 21 loss:  0.605
step: 22 loss:  0.601
step: 23 loss:  0.597
step: 24 loss:  0.593
step: 25 loss:  0.589
step: 26 loss:  0.585
step: 27 loss:  0.580
step: 28 loss:  0.576
step: 29 loss:  0.571
step: 30 loss:  0.566
step: 31 loss:  0.562
step: 32 loss:  0.557
step: 33 loss:  0.551
step: 34 loss:  0.546
step: 35 loss:  0.541
step: 36 loss:  0.535
step: 37 loss:  0.529
step: 38 loss:  0.523
step: 39 loss:  0.517
step: 40 loss:  0.510
step: 41 loss:  0.504
step: 42 loss:  0.497
step: 43 loss:  0.491
step: 44 loss:  0.484
step: 45 loss:  0.477
step: 46 l

In [647]:
# after train step
logits = policy(batch.states)
logits_gate = logits[..., :policy.dim_gateinfo]
logits_qbit = logits[..., policy.dim_gateinfo:]

after = calc_logp(logits_gate, num_gates)[..., 0].detach().numpy()


In [648]:
after - before, batch.best_indices[0]


(array([[-0.4901234 , -0.46423495, -0.3027004 ],
        [-0.63642216, -0.5208156 , -0.26988602],
        [-0.6360792 , -0.5051565 , -0.28410172],
        [-0.73366046, -0.61706424, -0.39302683]], dtype=float32),
 tensor(1))

In [650]:
logits = policy(batch.states)
logits_gate = logits[..., :policy.dim_gateinfo]
logits_qbit = logits[..., policy.dim_gateinfo:]
torch.exp(calc_logp(logits_gate, num_gates)[0, ...])


tensor([[0.1430, 0.2282, 0.2980, 0.3308],
        [0.1494, 0.3352, 0.2723, 0.2431],
        [0.1406, 0.1072, 0.4759, 0.2763]], grad_fn=<ExpBackward0>)