In [27]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
import os
cwd = os.getcwd().replace("/note/experiments", "")
os.chdir(cwd)


In [39]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import xtools as xt
import json
import xsim
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm
from collections import namedtuple


In [30]:
from qml.db import dpo as xdpo
from qml.db import target as xtarget
from qml.db.ml import MLDatasetGenerator
from qml.model.gate import Gateset
from qml.model.unit import UnitManager, Unit
from qml.model.encoding import EncodingUnitManager
from qml.model.model import Model
from qml.optimizer import dpo as xdpopt
from qml.optimizer import evaluator as xeval
from qml.tools.random import XRandomGenerator
from qml.tools.sampler import CandidateSampler
from qml import optimizer as xoptim
from qml.tools.config import Config
# from qml.tools.validation import validate
from qml.tools.experiment_tools import train_once, hard_copy, prepare_policy, plot_results, validate

In [31]:
cf = Config("note/experiments/config.yaml")

In [32]:
cf.shots

30

In [33]:
def plot_results_on_notebook(logger, vlogger, cf):
    ret = xsim.Retriever(logger)
    res_dict = dict(
        epoch=ret.epoch(),
        loss=ret.loss(),
        loss_dpo=ret.loss_dpo(),
    )
    plot_labels = ["loss_dpo"]
    if cf.dpo.training.cpo:
        res_dict["loss_llh"] = ret.loss_llh()
        plot_labels.append("loss_llh")
    plot_labels.append("loss")
    res = pd.DataFrame(res_dict)

    vret = xsim.Retriever(vlogger)
    vres = pd.DataFrame(dict(
        epoch=vret.epoch(),
        loss=vret.loss(),
        time=vret.time(),
    ))

    clear_output()
    fig, axes = plt.subplots(nrows=3, figsize=(8, 10), sharex=True)
    res.plot(x="epoch", y=plot_labels, ax=axes[0])
    vres.plot(x="epoch", y=["loss"], ax=axes[1])
    vres.plot(x="epoch", y=["time"], ax=axes[2])
    plt.show()

    return res, vres


In [34]:
# training dataset
dataset = xdpo.DPODataset(cf.dpo.training.db.filename, cf.nq, cf.ocg.dim_wavelet)
loader = xdpo.DPODataLoader(dataset, cf.nq, cf.dpo.training.db.batch_size, cf.ocg.dim_wavelet)


In [35]:
# validation datasets
tfun = xtarget.PolynominalTargetFunctionGenerator(cf.qml.db.dim_polynomial)
tgen = MLDatasetGenerator(tfun)
validation_datasets = [
    tgen.generate(cf.dpo.validation.dbsize)
    for _ in range(cf.dpo.validation.num_db)
]

In [42]:
save_dir = xt.make_dirs_current_time("note/experiments/results")

In [44]:
save_dir = xt.make_dirs_current_time("note/experiments/results")

policy, reference, optimizer, sampler = prepare_policy(cf)
logger = xsim.Logger()
vlogger = xsim.Logger()


torch.save(policy.state_dict(), xt.join(save_dir, f"results_0000.pth"))
vresults = validate(sampler, validation_datasets, cf)
vlogger.store(
    epoch=0,
    **vresults,
).flush()

for epoch in range(1, cf.dpo.training.num_epochs+1):
    epoch_losses = []
    # hard_copy(policy, reference)
    for batch in loader:
        losses = train_once(policy, reference, optimizer, batch, sampler, cf)
        
        # logging
        epoch_losses.append(losses)

    loss = np.mean([loss.total for loss in epoch_losses])
    loss_dpo = np.mean([loss.dpo for loss in epoch_losses])
    logger.store(
        epoch=epoch,
        loss=loss,
        loss_dpo=loss_dpo,
    )
    status = f"epoch {epoch:>3}  loss: {loss:.4f}  dpo: {loss_dpo:.4f}"

    if cf.dpo.training.cpo:
        loss_llh = np.mean([loss.llh for loss in epoch_losses])
        logger.store(loss_llh=loss_llh)
        status +=f"  llh: {loss_llh:.4f}"

    logger.flush()
    print(status)

    if epoch % cf.dpo.validation.interval == 0:
        torch.save(policy.state_dict(), xt.join(save_dir, f"results_{epoch:04d}.pth"))
        vresults = validate(sampler, validation_datasets, cf)
        vlogger.store(
            epoch=epoch,
            **vresults,
        ).flush()

        plot_results_on_notebook(logger, vlogger, cf)


res, vres = plot_results_on_notebook(logger, vlogger, cf)
res.to_csv(save_dir + "results.csv")
vres.to_csv(save_dir + "vresults.csv")


30


100%|██████████| 3/3 [01:33<00:00, 31.23s/it]


epoch   1  loss: 0.4882  dpo: 0.4846  llh: 0.0037
epoch   2  loss: 0.4051  dpo: 0.4017  llh: 0.0035
epoch   3  loss: 0.3942  dpo: 0.3903  llh: 0.0039
epoch   4  loss: 0.3755  dpo: 0.3722  llh: 0.0033
epoch   5  loss: 0.3572  dpo: 0.3538  llh: 0.0033
epoch   6  loss: 0.3439  dpo: 0.3405  llh: 0.0034
epoch   7  loss: 0.3421  dpo: 0.3392  llh: 0.0029
epoch   8  loss: 0.3316  dpo: 0.3288  llh: 0.0028
epoch   9  loss: 0.3135  dpo: 0.3105  llh: 0.0030
epoch  10  loss: 0.2935  dpo: 0.2905  llh: 0.0030
30


  0%|          | 0/3 [00:03<?, ?it/s]


KeyboardInterrupt: 