In [2]:
import pennylane as qml
from pennylane import numpy as np

In [34]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchsummary import summary

In [29]:
from MNISTData import MNISTData
from AutoEncoder import AutoEncoder

In [7]:
ENCODING_SIZE = 16
NUM_QUBITS = ENCODING_SIZE // 2

In [8]:
dev = qml.device('default.qubit', wires=NUM_QUBITS)

In [19]:
from pennylane.ops import RX, RY, CNOT

# x will be a length 16 vector that represents
# the encoding of a MNIST image
@qml.qnode(dev, interface='torch')
def circuit(x, thetas):
    for i in range(NUM_QUBITS):
        RX(x[i], wires=i)
    for i in range(NUM_QUBITS, ENCODING_SIZE):
        RY(x[i], wires=(i - NUM_QUBITS))
    for i in range(NUM_QUBITS - 1):
        CNOT(wires=[i, i+1])
    for i in range(NUM_QUBITS):
        RX(thetas[i], wires=i)
    for i in range(NUM_QUBITS, ENCODING_SIZE):
        RY(thetas[i], wires=(i - NUM_QUBITS))
    return tuple(qml.expval.PauliZ(wires=i) for i in range(NUM_QUBITS))

In [24]:
# example input
print(circuit([np.pi/5]*16, [np.pi]*16))

tensor([0.6545, 0.4284, 0.2804, 0.1835, 0.1201, 0.0786, 0.0515, 0.0337],
       dtype=torch.float64)


In [27]:
loss = nn.CrossEntropyLoss()
def cost(x, thetas, W, actual_labels):
    y = circuit(x, thetas)
    class_predicts = torch.mm(W, y)
    return loss(class_predicts, actual_labels)

At this point, we have to get our (encoded) images so that we may actually start training

In [30]:
data = MNISTData()
train_loader = data.get_train_loader()
test_loader = data.get_test_loader()

In [31]:
load_from = "./autoencoder_models/1558553790/ae.pt"

In [32]:
ae = AutoEncoder()
ae.load_state_dict(torch.load(load_from))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [35]:
summary(ae, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 256]         200,960
              ReLU-2                  [-1, 256]               0
            Linear-3                   [-1, 64]          16,448
              ReLU-4                   [-1, 64]               0
            Linear-5                   [-1, 16]           1,040
              ReLU-6                   [-1, 16]               0
            Linear-7                   [-1, 64]           1,088
              ReLU-8                   [-1, 64]               0
            Linear-9                  [-1, 256]          16,640
             ReLU-10                  [-1, 256]               0
           Linear-11                  [-1, 784]         201,488
          Sigmoid-12                  [-1, 784]               0
Total params: 437,664
Trainable params: 437,664
Non-trainable params: 0
-------------------------------