In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
os.chdir("../..")


In [54]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import json
from collections import namedtuple


In [123]:
from qml.model.gate import Gateset
from qml.model.unit import UnitManager, Unit
from qml.tools.random import XRandomGenerator


# Configs

In [8]:
dpo_db_filename = "dpo_databsae.txt"


## Loader

In [53]:
class DPOBatch:

    def __init__(self, wseries, best_unit, other_units):
        self.wseries = wseries
        self.best_unit = best_unit
        self.other_units = other_units


In [174]:
class DPODataset:

    def __init__(self, db_filename: str, num_qubits: int, dim_wavelet: int = 4):
        self.db_filename = db_filename
        self.num_qubits = num_qubits
        self.dim_wavelet = dim_wavelet

        self.gateset = Gateset.set_num_qubits(num_qubits)

        self.db = self.load_db_file()

    def __getitem__(self, index):
        djson = self.db[index]
        return DPODataDecoder.from_json(djson, self.gateset, self.dim_wavelet)
    
    def load_db_file(self):
        with open(self.db_filename) as fp:
            djsons = fp.readlines()        
        return djsons

    @property
    def size(self):
        return len(self.db)


dataset = DPODataset(dpo_db_filename, 2, 2)
dataset[0]


DPOData(wseries=array([-0.26506404,  0.03155852, -0.4379742 ]), gate_indices=[[0, 1, 0], [3, 2, 1], [0, 0, 2], [3, 2, 1], [0, 2, 3]], qubits=[[1, 0, 1], [1, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1]], losses=array([0.05769153, 0.37874371, 0.27544483, 0.07399594, 0.35288897]))

In [263]:
DPOData = namedtuple("DPOData", ["wseries", "gindices", "qubits", "losses"])

class DPODataDecoder:

    KEY_WSERIES = "wseries"
    KEY_UNITS = "units"
    KEY_UNITS_GATES = "gates"
    KEY_UNITS_QUBITS = "qubits"
    KEY_LOSSES = "losses"

    @classmethod
    def from_json(cls, djson: str, gateset: Gateset, dim_wavelet: int = 4):
        gdict = {gate_name: idx for idx, gate_name in enumerate(gateset.keys())}
        ddict = json.loads(djson)

        # wseries
        wseries = np.asarray(ddict[cls.KEY_WSERIES])
        len_wseries = 2 ** dim_wavelet - 1
        dwseries = wseries[:len_wseries]

        # units
        udicts = ddict[cls.KEY_UNITS]
        # units/gate_indices
        dginfices = [
            [
                gdict[ugate.upper()]
                for ugate in udict[cls.KEY_UNITS_GATES]
            ] for udict in udicts
        ]
        dqubits = [
            udict[cls.KEY_UNITS_QUBITS]
            for udict in udicts
        ]

        # losses
        dlosses = np.asarray(ddict[cls.KEY_LOSSES])

        return DPOData(dwseries, dginfices, dqubits, dlosses)
    
    @classmethod
    def divide_best_and_others(cls, data: DPOData):
        bgindices = data.gindices
        bqubits = data.qubits
        blosses = data.losses

        bbest_indices = np.argmin(blosses, axis=1)

        # collect best candidates
        best_gindices = [
            gindices[best_index]
            for gindices, best_index in zip(bgindices, bbest_indices)
        ]

        best_qubits = [
            qubits[best_index]
            for qubits, best_index in zip(bqubits, bbest_indices)
        ]

        best_losses = np.min(blosses, axis=1)

        best_data = DPOData(data.wseries, best_gindices, best_qubits, best_losses)

        # collect others
        others_ginfices = [
            [gindex for idx, gindex in enumerate(gindices) if idx != best_index]
            for gindices, best_index in zip(bgindices, bbest_indices)
        ]

        others_qubits = [
            [qubit for idx, qubit in enumerate(qubits) if idx != best_index]
            for qubits, best_index in zip(bqubits, bbest_indices)
        ]

        others_losses = [
            [loss.item() for idx, loss in enumerate(losses) if idx != best_index]
            for losses, best_index in zip(blosses, bbest_indices)
        ]

        others_data = DPOData(data.wseries, others_ginfices, others_qubits, others_losses)
        
        return best_data, others_data




In [297]:
class DPODataBatchDivided:

    def __init__(self, batch_data: DPOData, nq: int, ngc: int):
        self.data = batch_data
        self.nq = nq
        self.ngc = ngc
    
    @staticmethod
    def as_onehot(indices, num_classes):
        onehot = nn.functional.one_hot(indices, num_classes)
        return onehot

    @property
    def np_gindices(self):
        return np.asarray(self.data.gindices).astype(int)
    
    @property
    def gindices(self):
        return torch.from_numpy(self.np_gindices).float()
    
    @property
    def onehot_gindices(self):
        return self.as_onehot(self.gindices.long(), self.ngc)
    
    @property
    def np_qubits(self):
        return np.asarray(self.data.qubits).astype(int)
    
    @property
    def qubits(self):
        return torch.from_numpy(self.np_qubits).float()
    
    @property
    def onehot_qubits(self):
        return self.as_onehot(self.qubits.long(), self.nq)
    
    @property
    def np_losses(self):
        return np.asarray(self.data.losses).astype(float)
    
    @property
    def losses(self):
        return torch.from_numpy(self.np_losses).float()


class DPODataBatch:

    def __init__(self, batch_data: DPOData, num_qubits: int, num_gate_classes: int = None):
        if num_gate_classes is None:
            num_gate_classes = Gateset.set_num_qubits(num_qubits).size
        self.num_qubits = num_qubits
        self.num_gate_classes = num_gate_classes
        self.data = data = batch_data
        self.best_data = None
        self.others_data = None
        self._wseries = np.vstack(data.wseries)

        self.encode(data)
    
    def encode(self, data):
        best_data, others_data = DPODataDecoder.divide_best_and_others(data)
        self.best_data = DPODataBatchDivided(best_data, self.num_qubits, self.num_gate_classes)
        self.others_data = DPODataBatchDivided(others_data, self.num_qubits, self.num_gate_classes)
    
    @property
    def size(self):
        pass

    @property
    def wserieses(self):
        return torch.from_numpy(self._wseries).float()
    
    @property
    def np_wserieses(self):
        return self._wseries.copy()
    
    @property
    def best(self):
        return self.best_data
    
    @property
    def others(self):
        return self.others_data


In [296]:
class DPODataLoaderIter:

    def __init__(
            self,
            dataset: DPODataset,
            batched_indices: list[list[int]],
            num_qubits: int,
            num_gate_classes: int,
    ):
        self.db = dataset
        self.indices = batched_indices

        self.nq = num_qubits
        self.ngc = num_gate_classes

        self.indices_iter = iter(batched_indices)
    
    def __next__(self):
        idxs = next(self.indices_iter)
        bdata = [self.db[idx] for idx in idxs]
        bdata = DPOData(
            [data.wseries for data in bdata],
            [data.gindices for  data in bdata],
            [data.qubits for  data in bdata],
            [data.losses for  data in bdata],
        )
        return DPODataBatch(bdata, self.nq, self.ngc)

class DPODataLoader:
    
    def __init__(self, dataset: DPODataset, num_qubits: int, batch_size: int, max_wavelet_dim: int = 4, seed: int = None):
        self.rng = XRandomGenerator(seed)

        self.dataset = dataset
        self.num_qubits = num_qubits
        self.num_gate_classes = Gateset.set_num_qubits(num_qubits).size
        self.max_wavelet_dim = max_wavelet_dim
        self.batch_size = batch_size
    
    @property
    def size(self):
        return int(np.ceil(self.dataset.size / self.batch_size))
    
    def __iter__(self):
        indices = np.arange(self.dataset.size).astype(int)
        indices = self.rng.permutation(indices)
        batched_indices = indices.reshape((self.size, self.batch_size))
        return DPODataLoaderIter(self.dataset, batched_indices, self.num_qubits, self.num_gate_classes)



loader = DPODataLoader(dataset, 2, 10)
for batch in loader:
    print(batch.others.onehot_qubits.shape)
    print(batch.others.onehot_qubits)
    break


torch.Size([10, 4, 3, 2])
tensor([[[[0, 1],
          [1, 0],
          [1, 0]],

         [[0, 1],
          [1, 0],
          [1, 0]],

         [[1, 0],
          [0, 1],
          [1, 0]],

         [[1, 0],
          [1, 0],
          [0, 1]]],


        [[[1, 0],
          [0, 1],
          [0, 1]],

         [[0, 1],
          [0, 1],
          [0, 1]],

         [[0, 1],
          [0, 1],
          [0, 1]],

         [[0, 1],
          [0, 1],
          [1, 0]]],


        [[[0, 1],
          [0, 1],
          [0, 1]],

         [[0, 1],
          [1, 0],
          [1, 0]],

         [[0, 1],
          [0, 1],
          [0, 1]],

         [[0, 1],
          [1, 0],
          [0, 1]]],


        [[[1, 0],
          [0, 1],
          [1, 0]],

         [[0, 1],
          [0, 1],
          [0, 1]],

         [[1, 0],
          [0, 1],
          [1, 0]],

         [[1, 0],
          [0, 1],
          [1, 0]]],


        [[[0, 1],
          [1, 0],
          [0, 1]],

         [[1, 