In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
os.chdir("../..")


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import json


In [9]:
from qml.model.gate import Gateset
from qml.db import dpo as xdpo


## Parameters

In [81]:
# circuit
num_qubits = 2
num_gates = 3

# dataset
db_filename = "dpo_databsae.txt"
batch_size = 10
dim_wavelet = 4

# model
dim_hiddens = [32, 32]


## Model

In [96]:
class Policy(nn.Module):

    def __init__(
            self,
            dim_wavelet: int,
            num_qubits: int,
            num_gates: int,
            dim_hiddens: list[int]
    ):
        super().__init__()
        self.nq = num_qubits
        self.ng = num_gates
        self.gset = gset = Gateset.set_num_qubits(num_qubits)

        self.dim_gindices = gset.size * self.ng
        self.dim_qubits = self.ng * self.nq
        self.dim_output = dim_output = self.dim_gindices + self.dim_qubits
        self.dim_input = dim_input = 2 ** dim_wavelet - 1

        dim_units = dim_hiddens.copy()
        dim_units.append(dim_output)
        dim_units.insert(0, dim_input)

        self.net = nn.Sequential(
            *sum([
                self.build_layer(din, dout, activation=(l < len(dim_hiddens)))
                for l, din, dout
                in zip(range(len(dim_units)+1), dim_units[:-1], dim_units[1:])
            ], [])
        )
    
    @staticmethod
    def build_layer(din, dout, activation=True):
        layer = [nn.Linear(din, dout)]
        if activation:
            layer.append(nn.ReLU())
        return layer
    
    def forward(self, x):
        feat = self.net(x)
        return feat
    
    def as_logps(self, logits):
        print(logits.shape)
        logits_gate = logits[..., :self.dim_gindices]
        logits_qbit = logits[..., self.dim_gindices:]
        logps_gate = nn.functional.log_softmax(logits_gate)
        logps_qbit = nn.functional.log_softmax(logits_qbit)
        return logps_gate, logps_qbit


## Data

In [83]:
dataset = xdpo.DPODataset(db_filename, num_qubits, dim_wavelet)
loader = xdpo.DPODataLoader(dataset, num_qubits, batch_size, dim_wavelet)


## Training

In [160]:
policy = Policy(dim_wavelet, num_qubits, num_gates, dim_hiddens)
policy_ref = Policy(dim_wavelet, num_qubits, num_gates, dim_hiddens)
policy_ref.load_state_dict(policy.state_dict())


<All keys matched successfully>

In [161]:
def selected_logps(logits, onehot):
    logps = nn.functional.log_softmax(logits, dim=-1)
    selected_logps = (logps * onehot).sum(dim=-1)
    return selected_logps

def calc_logps(policy, batch):
    logits = policy(batch.wserieses)

    logits_gate = logits[..., :policy.dim_gindices].view(batch.best.onehot_gindices.shape)
    logits_qbit = logits[..., policy.dim_gindices:].view(batch.best.onehot_qubits.shape)

    logps_best = selected_logps(logits_gate, batch.best.onehot_gindices) + selected_logps(logits_qbit, batch.best.onehot_qubits)
    logps_best = logps_best.sum(dim=-1)
    logps_others = selected_logps(logits_gate, batch.others.onehot_gindices) + selected_logps(logits_qbit, batch.others.onehot_qubits)
    logps_others = logps_others.sum(dim=-1)
    
    return logps_best, logps_others


In [164]:
def calc_loss_dpo(logrp, logrp_ref, beta=0.5):
    return -1 * nn.functional.logsigmoid(beta * (logrp - logrp_ref)).mean()

def calc_loss_llh(logp_best):
    return torch.exp(logp_best).mean()


In [167]:
for batch in loader:
    logps_best, logps_others = calc_logps(policy, batch)
    logps_ref_best, logps_ref_others = calc_logps(policy_ref, batch)

    loss_dpo = calc_loss_dpo(
        logps_best - logps_others,
        logps_ref_best - logps_ref_others,
    )
    loss_llh = calc_loss_llh(logps_best)
    print(loss_dpo, loss_llh)
    break


tensor(0.6931, grad_fn=<MulBackward0>) tensor(0.0022, grad_fn=<MeanBackward0>)
