In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
os.chdir("../..")


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import json
import xsim
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output


In [51]:
from qml.db import dpo as xdpo
from qml.model.gate import Gateset
from qml.model.unit import UnitManager, Unit
from qml.optimizer import dpo as xdpopt
from qml.tools.random import XRandomGenerator


# Configs

In [5]:
# circuit
num_qubits = 2
num_gates = 3

# dataset
db_filename = "dpo_databsae.txt"
batch_size = 10
dim_wavelet = 4

# model
dim_hiddens = [32, 32]

# training
lr = 1e-3
max_epoch = 1000
cpo = True


## Data

In [17]:
dataset = xdpo.DPODataset(db_filename, num_qubits, dim_wavelet)
loader = xdpo.DPODataLoader(dataset, num_qubits, 1, dim_wavelet)


## Training

In [59]:
class Policy(nn.Module):

    def __init__(
            self,
            dim_wavelet: int,
            num_qubits: int,
            num_gates: int,
            dim_hiddens: list[int],
    ):
        super().__init__()
        self.nq = num_qubits
        self.ng = num_gates
        self.gset = gset = Gateset.set_num_qubits(num_qubits)

        self.dim_gindices = gset.size * self.ng
        self.dim_qubits = self.ng * self.nq
        self.dim_output = dim_output = self.dim_gindices + self.dim_qubits
        self.dim_input = dim_input = 2 ** dim_wavelet - 1

        dim_units = dim_hiddens.copy()
        dim_units.append(dim_output)
        dim_units.insert(0, dim_input)

        self.net = nn.Sequential(
            *sum([
                self.build_layer(din, dout, activation=(l < len(dim_hiddens)))
                for l, din, dout
                in zip(range(len(dim_units)+1), dim_units[:-1], dim_units[1:])
            ], [])
        )
    
    @staticmethod
    def build_layer(din, dout, activation=True):
        layer = [nn.Linear(din, dout)]
        if activation:
            layer.append(nn.ReLU())
        return layer
    
    def forward(self, x):
        feat = self.net(x)
        return feat


class CandidateSampler:

    def __init__(self, policy: Policy, seed: int = None):
        self.rng = XRandomGenerator(seed)
        self.policy = policy

        self.nq = policy.nq
        self.ng = policy.ng
        self.gset = policy.gset
        self.dim_gindices =  policy.dim_gindices

        self.uman = UnitManager(self.nq, self.ng, self.rng.new_seed())

    
    def divide_gate_and_qubit(self, logits):
        logits_gate = logits[..., :self.dim_gindices]
        logits_qbit = logits[..., self.dim_gindices:]
        return logits_gate, logits_qbit
    
    def as_logps(self, logits):
        logits_gate, logits_qbit = self.divide_gate_and_qubit(logits)
        logps_gate = nn.functional.log_softmax(logits_gate, dim=-1)
        logps_qbit = nn.functional.log_softmax(logits_qbit, dim=-1)
        return logps_gate, logps_qbit
    
    def as_probs(self, logits, as_numpy=False):
        logits_gate, logits_qbit = self.divide_gate_and_qubit(logits)
        probs_gate = torch.softmax(logits_gate.view(self.ng, -1), dim=-1)
        probs_qbit = torch.softmax(logits_qbit.view(self.ng, -1), dim=-1)
        if not as_numpy:
            return probs_gate, probs_qbit
        return (
            probs_gate.detach().numpy(),
            probs_qbit.detach().numpy(),
        )
    
    def sample_from_probs(self, probs, clist):
        return [
            self.rng.choice(clist, p=prob)
            for prob in probs
        ]
    
    def sample(self, x, num_candidates: int = 1):
        if not isinstance(x, torch.Tensor):
            x = torch.from_numpy(x).float()
        if x.ndim < 2:
            x = x.unsqueeze()
        logits = self.policy.forward(x)
        probs_gate, probs_qbit = self.as_probs(logits, as_numpy=True)

        candidates = [
            self.uman.from_info_and_qubits(
                self.sample_from_probs(probs_gate, list(self.gset.values())),    # gate infos
                self.sample_from_probs(probs_qbit, [i for i in range(self.nq)]), # qubits
            )
            for _ in range(num_candidates)
        ]
        return candidates
    

In [60]:
gateset = Gateset.set_num_qubits(num_qubits)
glist = list(gateset.values())
qlist = [i for i in range(num_qubits)]


In [61]:
policy = Policy(dim_wavelet, num_qubits, num_gates, dim_hiddens)
print(policy)
sampler = CandidateSampler(policy)


Policy(
  (net): Sequential(
    (0): Linear(in_features=15, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=18, bias=True)
  )
)


In [73]:
for batch in loader:
    candidates = sampler.sample(batch.wserieses, 5)
    for candidate in candidates:
        for gate in candidate.gates:
            print(gate.gate)
        print("- "*20)
    break


Instruction(name='rz', num_qubits=1, num_clbits=0, params=[Parameter(unit_51_param_0)])
Instruction(name='rx', num_qubits=1, num_clbits=0, params=[Parameter(unit_51_param_1)])
Instruction(name='rx', num_qubits=1, num_clbits=0, params=[Parameter(unit_51_param_2)])
- - - - - - - - - - - - - - - - - - - - 
Instruction(name='ry', num_qubits=1, num_clbits=0, params=[Parameter(unit_52_param_0)])
Instruction(name='rz', num_qubits=1, num_clbits=0, params=[Parameter(unit_52_param_1)])
Instruction(name='cz', num_qubits=2, num_clbits=0, params=[])
- - - - - - - - - - - - - - - - - - - - 
Instruction(name='ry', num_qubits=1, num_clbits=0, params=[Parameter(unit_53_param_0)])
Instruction(name='rz', num_qubits=1, num_clbits=0, params=[Parameter(unit_53_param_1)])
Instruction(name='cz', num_qubits=2, num_clbits=0, params=[])
- - - - - - - - - - - - - - - - - - - - 
Instruction(name='rz', num_qubits=1, num_clbits=0, params=[Parameter(unit_54_param_0)])
Instruction(name='ry', num_qubits=1, num_clbits=0