In [6]:
import pennylane as qml
import numpy as np
import random
import matplotlib.pyplot as plt
%matplotlib inline
import csv
import pandas as pd
import argparse
import os
import math
import datetime
import time

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch

from solver import Solver
from data_loader import get_loader
from torch.backends import cudnn
from utils import *
from models import Generator, Discriminator
from data.sparse_molecular_dataset import SparseMolecularDataset
from rdkit import Chem

def str2bool(v):
    return v.lower() in ('true')

In [4]:
parser = argparse.ArgumentParser()

# Model configuration.
parser.add_argument('--z_dim', type=int, default=10, help='dimension of domain labels')
parser.add_argument('--g_conv_dim', default=[128,512], help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=[[128, 64], 128, [128, 64]], help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
parser.add_argument('--post_method', type=str, default='softmax', choices=['softmax', 'soft_gumbel', 'hard_gumbel'])

# Training configuration.
parser.add_argument('--batch_size', type=int, default=128, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=10000, help='number of total iterations for training D')
parser.add_argument('--num_iters_decay', type=int, default=5000, help='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')

# Test configuration.
parser.add_argument('--test_iters', type=int, default=10000, help='test model from this step')

# Miscellaneous.
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--mode', type=str, default='test', choices=['train', 'test'])
parser.add_argument('--use_tensorboard', type=str2bool, default=False)

# Directories.
parser.add_argument('--mol_data_dir', type=str, default='data/gdb9_9nodes.sparsedataset')
parser.add_argument('--log_dir', type=str, default='molgan/logs')
parser.add_argument('--model_save_dir', type=str, default='molgan/models')
parser.add_argument('--sample_dir', type=str, default='molgan/samples')
parser.add_argument('--result_dir', type=str, default='molgan/results')

# Step size.
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=1000)
parser.add_argument('--model_save_step', type=int, default=10000)
parser.add_argument('--lr_update_step', type=int, default=1000)

config = parser.parse_known_args()[0]
print(config)

# For fast training.
cudnn.benchmark = True

Namespace(batch_size=128, beta1=0.5, beta2=0.999, d_conv_dim=[[128, 64], 128, [128, 64]], d_lr=0.0001, d_repeat_num=6, dropout=0.0, g_conv_dim=[128, 512], g_lr=0.0001, g_repeat_num=6, lambda_cls=1, lambda_gp=10, lambda_rec=10, log_dir='molgan/logs', log_step=10, lr_update_step=1000, mode='test', model_save_dir='molgan/models', model_save_step=10000, mol_data_dir='data/gdb9_9nodes.sparsedataset', n_critic=5, num_iters=10000, num_iters_decay=5000, num_workers=1, post_method='softmax', result_dir='molgan/results', resume_iters=None, sample_dir='molgan/samples', sample_step=1000, test_iters=10000, use_tensorboard=False, z_dim=10)


In [14]:
cuda = True if torch.cuda.is_available() else False

dev = qml.device('default.qubit', wires=23)
@qml.qnode(dev, interface='torch')
def gen_circuit(w):
    # random noise as generator input
    z = random.uniform(-1, 1)
    layers = 1
    qubits = 23
    cnt = 0
    
    # construct generator circuit for both atom vector and node matrix
    for i in range(qubits):
        qml.RY(np.arcsin(z), wires=i)
        qml.RZ(np.arccos(z), wires=i)
    for l in range(layers):
        for i in range(qubits):
            qml.RY(w[i], wires=i)
            qml.Hadamard(wires=i)
            cnt += 1
        for i in range(qubits-1):
            qml.CNOT(wires=[i, i+1])
            qml.RZ(w[i+qubits], wires=i+1)
            qml.CNOT(wires=[i, i+1])
            cnt += 1
    print(cnt)
    for i in range(qubits):
        qml.Hadamard(wires=i)
    return [qml.expval(qml.PauliZ(i)) for i in range(qubits)]

In [15]:
w = torch.tensor(list(np.random.rand(45)*2-1), requires_grad=True)
t0 = time.time()
gen_circuit(w)
time.time()-t0

45


23.476433515548706

In [35]:
class SparseMoleCular(data.Dataset):
    """Dataset class for the CelebA dataset."""

    def __init__(self, data_dir):
        """Initialize and preprocess the CelebA dataset."""
        self.data = SparseMolecularDataset()
        self.data.load(data_dir)

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""

        return index, self.data.data[index], self.data.smiles[index],\
               self.data.data_S[index], self.data.data_A[index],\
               self.data.data_X[index], self.data.data_D[index],\
               self.data.data_F[index], self.data.data_Le[index],\
               self.data.data_Lv[index]

    def __len__(self):
        """Return the number of images."""
        return len(self.data)
 
filename = 'data_bimodal.csv'
dataset = SparseMoleCular(config.mol_data_dir)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=False)

In [36]:
self = Solver(config)

Generator(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.0, inplace=True)
    (3): Linear(in_features=128, out_features=512, bias=True)
    (4): Tanh()
    (5): Dropout(p=0.0, inplace=True)
  )
  (edges_layer): Linear(in_features=512, out_features=405, bias=True)
  (nodes_layer): Linear(in_features=512, out_features=45, bias=True)
  (dropoout): Dropout(p=0.0, inplace=False)
)
G
The number of parameters: 298306
Discriminator(
  (gcn_layer): GraphConvolution(
    (linear1): Linear(in_features=5, out_features=128, bias=True)
    (linear2): Linear(in_features=128, out_features=64, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (agg_layer): GraphAggregation(
    (sigmoid_linear): Sequential(
      (0): Linear(in_features=69, out_features=128, bias=True)
      (1): Sigmoid()
    )
    (tanh_linear): Sequential(
      (0): Linear(in_features=69, out_features=128, bias=True)
      (1): Tanh()
    )
    (

In [67]:
[gen_weights]+list(self.G.parameters())+list(self.V.parameters())

[tensor([ 0.7303,  0.3373,  0.8634, -0.5128,  0.8183, -0.0618,  0.0061, -0.4003,
         -0.6139, -0.4951,  0.8553,  0.4394,  0.8351, -0.3772, -0.6404, -0.2580,
         -0.6218, -0.6328, -0.0432], dtype=torch.float64, requires_grad=True),
 Parameter containing:
 tensor([[ 0.1867, -0.2513,  0.1971,  ..., -0.0338, -0.1414, -0.1840],
         [-0.2948,  0.1381, -0.2206,  ..., -0.2816,  0.2555,  0.2047],
         [-0.1871,  0.2188,  0.2683,  ..., -0.2734,  0.3090, -0.2552],
         ...,
         [-0.0926, -0.1622,  0.2693,  ...,  0.1350, -0.1668,  0.0959],
         [ 0.2663, -0.2907,  0.1255,  ..., -0.0267, -0.2872, -0.0840],
         [ 0.2453, -0.1061, -0.0354,  ..., -0.0036,  0.2868,  0.2432]],
        requires_grad=True),
 Parameter containing:
 tensor([-0.1827,  0.2323,  0.0640,  0.1698,  0.0781,  0.1025, -0.2414, -0.0134,
         -0.0814, -0.3163,  0.0013, -0.3057,  0.2794,  0.1091, -0.1999, -0.2088,
         -0.1248,  0.1491, -0.1770,  0.2785,  0.2196, -0.0844, -0.1173, -0.1949,


In [68]:
list(self.V.parameters())

[Parameter containing:
 tensor([[-3.7475e-01, -4.0833e-01,  3.2221e-01,  3.8066e-01,  3.7779e-01],
         [-2.7292e-01,  2.4038e-01, -4.0791e-01, -2.6894e-01, -6.7997e-02],
         [ 1.4437e-01,  2.4562e-01, -1.4657e-02,  2.1876e-01,  1.6600e-01],
         [-4.2427e-01,  2.5682e-01,  1.0205e-01,  6.7404e-03, -1.1130e-01],
         [-3.6588e-01,  1.6119e-01,  9.1851e-02, -3.9853e-02,  6.1906e-02],
         [-1.5517e-02, -3.6955e-01, -3.0877e-01, -3.4947e-01, -2.7227e-01],
         [ 3.7120e-01, -3.5904e-01,  1.1242e-02, -1.3608e-01, -2.8722e-01],
         [ 3.1365e-01, -1.5799e-01, -4.0197e-01, -1.8634e-01, -4.3466e-01],
         [-1.7525e-01,  6.8738e-02,  2.4370e-02,  1.8306e-01, -1.8165e-01],
         [-3.5573e-01, -1.7537e-02,  2.6428e-01,  3.0377e-01,  7.8550e-02],
         [-2.7918e-01,  3.7055e-01, -1.6469e-01,  2.8410e-02,  1.8967e-01],
         [ 1.5390e-01,  2.7701e-01, -4.4218e-01, -2.9641e-01, -2.6583e-01],
         [-1.2951e-01, -7.2168e-02, -3.6892e-01,  4.2199e-01,  4.

In [54]:
w = torch.tensor(list(np.random.rand(19)*2-1), requires_grad=True)

In [69]:
# Learning rate cache for decaying.
g_lr = self.g_lr
d_lr = self.d_lr
gen_weights = torch.tensor(list(np.random.rand(19)*2-1), requires_grad=True)
self.g_optimizer = torch.optim.Adam(list(self.G.parameters())+list(self.V.parameters())+[gen_weights],
                                    self.g_lr, [self.beta1, self.beta2])

# Start training from scratch or resume training.
start_iters = 0
self.resume_iters = 0
if self.resume_iters:
    start_iters = self.resume_iters
    self.restore_model(self.resume_iters)
    
# Start training.
print('Start training...')
start_time = time.time()
cnt = 0
for i in range(self.num_iters):
    print(gen_weights)
    cnt += 1
    if (i+1) % self.log_step == 0:
        mols, _, _, a, x, _, _, _, _ = self.data.next_validation_batch()
        z = self.sample_z(a.shape[0])
        print('[Valid]', '')
    else:
        mols, _, _, a, x, _, _, _, _ = self.data.next_train_batch(self.batch_size)
        

        # =================================================================================== #
        #                             1. Preprocess input data                                #
        # =================================================================================== #

        a = torch.from_numpy(a).to(self.device).long()            # Adjacency.
        x = torch.from_numpy(x).to(self.device).long()            # Nodes.
        a_tensor = self.label2onehot(a, self.b_dim)
        x_tensor = self.label2onehot(x, self.m_dim)
        sample_list = [gen_circuit(gen_weights.detach()) for i in range(self.batch_size)]
        z = torch.stack(tuple(sample_list)).to(self.device).float()

        # =================================================================================== #
        #                             2. Train the discriminator                              #
        # =================================================================================== #

        # Compute loss with real images.
        logits_real, features_real = self.D(a_tensor, None, x_tensor)
        d_loss_real = - torch.mean(logits_real)

        # Compute loss with fake images.
        edges_logits, nodes_logits = self.G(z)
        # Postprocess with Gumbel softmax
        (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method)
        logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
        d_loss_fake = torch.mean(logits_fake)

        # Compute loss for gradient penalty.
        eps = torch.rand(logits_real.size(0),1,1,1).to(self.device)
        x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True)
        x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True)
        grad0, grad1 = self.D(x_int0, None, x_int1)
        d_loss_gp = self.gradient_penalty(grad0, x_int0) + self.gradient_penalty(grad1, x_int1)


        # Backward and optimize.
        d_loss = d_loss_fake + d_loss_real + self.lambda_gp * d_loss_gp
        self.reset_grad()
        d_loss.backward()
        self.d_optimizer.step()

        # =================================================================================== #
        #                               3. Train the generator                                #
        # =================================================================================== #

        if (i+1) % self.n_critic == 0 or True:
            # Z-to-target
            edges_logits, nodes_logits = self.G(z)
            # Postprocess with Gumbel softmax
            (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method)
            logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
            g_loss_fake = - torch.mean(logits_fake)

            # Real Reward
            rewardR = torch.from_numpy(self.reward(mols)).to(self.device)
            # Fake Reward
            (edges_hard, nodes_hard) = self.postprocess((edges_logits, nodes_logits), 'hard_gumbel')
            edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1]
            mols = [self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
                    for e_, n_ in zip(edges_hard, nodes_hard)]
            rewardF = torch.from_numpy(self.reward(mols)).to(self.device)

            # Value loss
            value_logit_real,_ = self.V(a_tensor, None, x_tensor, torch.sigmoid)
            value_logit_fake,_ = self.V(edges_hat, None, nodes_hat, torch.sigmoid)
            g_loss_value = torch.mean((value_logit_real - rewardR) ** 2 + (
                                       value_logit_fake - rewardF) ** 2)
            
            # Backward and optimize.
            g_loss = g_loss_fake + g_loss_value
            self.reset_grad()
            g_loss.backward()
            self.g_optimizer.step()

        if cnt == 5:
            break

Start training...
tensor([-0.6913, -0.9595,  0.3173,  0.2617,  0.6782,  0.5515,  0.1620,  0.2770,
        -0.5500, -0.6735, -0.3621,  0.5646,  0.6952, -0.6556,  0.8752, -0.4428,
        -0.8155,  0.1247, -0.9049], dtype=torch.float64, requires_grad=True)


RDKit ERROR: [16:47:11] Explicit valence for atom # 0 N, 10, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 0 O, 11, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 0 F, 11, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 0 O, 9, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 1 C, 7, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 0 O, 11, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 0 O, 9, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 1 O, 16, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 0 F, 12, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 1 O, 16, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for atom # 0 N, 13, is greater than permitted
RDKit ERROR: [16:47:11] Explicit valence for a

tensor([-0.6913, -0.9595,  0.3173,  0.2617,  0.6782,  0.5515,  0.1620,  0.2770,
        -0.5500, -0.6735, -0.3621,  0.5646,  0.6952, -0.6556,  0.8752, -0.4428,
        -0.8155,  0.1247, -0.9049], dtype=torch.float64, requires_grad=True)


RDKit ERROR: [16:47:13] Explicit valence for atom # 0 N, 13, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 1 F, 17, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 1 O, 12, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 0 O, 13, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 1 O, 14, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 0 F, 13, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 0 C, 18, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 0 O, 11, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 0 C, 13, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 1 N, 13, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence for atom # 0 O, 12, is greater than permitted
RDKit ERROR: [16:47:13] Explicit valence fo

tensor([-0.6913, -0.9595,  0.3173,  0.2617,  0.6782,  0.5515,  0.1620,  0.2770,
        -0.5500, -0.6735, -0.3621,  0.5646,  0.6952, -0.6556,  0.8752, -0.4428,
        -0.8155,  0.1247, -0.9049], dtype=torch.float64, requires_grad=True)


RDKit ERROR: [16:47:15] Explicit valence for atom # 1 C, 13, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 2 N, 12, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 1 O, 13, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 0 C, 13, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 0 O, 13, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 1 F, 12, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 0 F, 7, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 0 N, 14, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 0 C, 15, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 0 F, 12, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for atom # 0 N, 8, is greater than permitted
RDKit ERROR: [16:47:15] Explicit valence for 

tensor([-0.6913, -0.9595,  0.3173,  0.2617,  0.6782,  0.5515,  0.1620,  0.2770,
        -0.5500, -0.6735, -0.3621,  0.5646,  0.6952, -0.6556,  0.8752, -0.4428,
        -0.8155,  0.1247, -0.9049], dtype=torch.float64, requires_grad=True)


RDKit ERROR: [16:47:17] Explicit valence for atom # 1 C, 9, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 N, 11, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 N, 10, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 N, 10, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 N, 14, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 C, 11, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 1 N, 10, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 N, 8, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 N, 14, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 N, 13, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for atom # 0 F, 7, is greater than permitted
RDKit ERROR: [16:47:17] Explicit valence for a

tensor([-0.6913, -0.9595,  0.3173,  0.2617,  0.6782,  0.5515,  0.1620,  0.2770,
        -0.5500, -0.6735, -0.3621,  0.5646,  0.6952, -0.6556,  0.8752, -0.4428,
        -0.8155,  0.1247, -0.9049], dtype=torch.float64, requires_grad=True)


RDKit ERROR: [16:47:20] Explicit valence for atom # 0 N, 6, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 0 C, 20, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 0 O, 13, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 3 C, 10, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 1 N, 15, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 1 N, 14, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 0 F, 8, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 0 C, 9, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 0 N, 12, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 1 N, 14, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for atom # 1 N, 9, is greater than permitted
RDKit ERROR: [16:47:20] Explicit valence for at

In [None]:
# Adversarial ground truths
valid = Variable(Tensor(config.batch_size).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(config.batch_size).fill_(0.0), requires_grad=False)
init_weights = list(np.random.rand(3))
print('Initial generator weights:', init_weights)
print()
gen_weights = torch.tensor(list(np.random.rand(3)), requires_grad=True)
# optimizer_G = torch.optim.Adam([gen_weights], lr = config.lr)
# optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=config.lr)
optimizer_G = torch.optim.SGD([gen_weights], lr=config.lr, momentum=0.9)
optimizer_D = torch.optim.SGD(discriminator.parameters(), lr=config.lr, momentum=0.9)
best_g_loss = np.inf

# Start training.
print('Start training...')

for epoch in range(config.n_epochs):    
    # learning rate decay
    if (epoch+1) % 10 == 0:
        config.lr = config.lr * .6
        print()
        print('Training with learning rate:', config.lr)
        
    for i, samples in enumerate(dataloader):
        # Configure input
        real_samples = Variable(samples.type(Tensor))
        
        # -----------------
        #  Train Generator
        # -----------------
#         if (i%5 == 0):
        optimizer_G.zero_grad()

        # Generate a batch of samples
        sample_list = [gen_circuit(gen_weights) for i in range(config.batch_size)]
        gen_samples = torch.stack(tuple(sample_list))
        gen_samples = gen_samples.to(device)
#         loss = adversarial_loss(discriminator(torch.unsqueeze(gen_samples.float(), 1)), valid)
        loss = adversarial_loss(discriminator(gen_samples.float()), valid)
        g_loss = loss.mean()
        
        # keep track of best G loss and generator parameters
        if g_loss < best_g_loss:
            best_g_loss = g_loss
            best_g_weights = gen_weights            

        g_loss.backward()
        optimizer_G.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        for j in range(config.k):
            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_samples), valid)
            fake_loss = adversarial_loss(discriminator(gen_samples.detach().float()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward(retain_graph=True)
            optimizer_D.step()
        
        if i % config.sample_interval == 0:
            print(
                "%s\t[Epoch %d/%d  Batch %d/%d]\t[D loss: %f]\t[G loss: %f]"
                % (datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), epoch+1, config.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )
            print(gen_weights.detach())

In [None]:
gen_weights.detach().numpy()

In [None]:
samples = []
num_samples = 1000
w = [1.8302, 2.1733, 4.2969] #[-23.1983,  -4.4826, -10.3598]
for i in range(num_samples):
    samples.append(gen_circuit(w))

In [None]:
# the histogram of the data
plt.hist(samples, 30, facecolor='g', alpha=0.75)

plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram of bimodal distribution')
plt.xlim(-1, 1)
plt.ylim(0, 140)
plt.grid(True)
plt.show()

In [None]:
dev = qml.device("default.qubit", wires=2, shots=4)

@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    qml.Hadamard(wires=1)
    qml.CNOT(wires=[0, 1])
    qml.RY(x, wires=1)
    return qml.sample(qml.PauliY(0),qml.PauliY(1))

circuit(0.5)

In [None]:
qml.PauliY(0)