# Trainable Quanvolutional NN in Tensorflow

In [35]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist


import pennylane as qml
from pennylane import numpy as np
from pennylane.templates import RandomLayers

from sklearn.metrics import accuracy_score

import datetime

In [47]:
class QLayer(tf.keras.layers.Layer):
    
    def __init__(
        self, 
        stride         = 2,
        device         = "default.qubit", #qulacs.simulator
        wires          = 4,
        circuit_layers = 4,
        n_rotations    = 8,
        out_channels   = 4,
        seed           = None
    ):
        
        super().__init__()
        
        # init device
        self.wires = wires
        self.dev   = qml.device(device, wires=self.wires)
                
        self.stride       = stride
        self.out_channels = min(out_channels, wires)
                
        if seed is None:
            seed = np.random.randint(low=0, high=10e6)
            
        # random circuits
        @qml.qnode(device=self.dev)
        def circuit(inputs, weights):
            n_inputs=4
            # Encoding of 4 classical input values
            for j in range(n_inputs):
                qml.RY(inputs[j], wires=j)
            # Random quantum circuit
            RandomLayers(weights, wires=list(range(self.wires)), seed=seed)
            
            # Measurement producing 4 classical output values
            return [qml.expval(qml.PauliZ(j)) for j in range(self.out_channels)]

        weight_shapes = {"weights": [circuit_layers, n_rotations]}
        self.circuit = qml.qnn.KerasLayer(circuit, weight_shapes=weight_shapes, output_dim=1)
        
    def draw(self):
        # build circuit by sending dummy data through it
        _ = self.circuit(inputs=tf.convert_to_tensor(np.zeros(4)))
        print(self.circuit.qnode.draw())
        self.circuit.zero_grad()
        
    def call(self, img):
        bs, h, w, ch = img.size()
        if ch > 1:
            img = img.mean(axis=-1).reshape(bs, h, w, 1)
                        
        kernel_size = 2        
        h_out = (h-kernel_size) // self.stride + 1
        w_out = (w-kernel_size) // self.stride + 1
        
        
        out = tf.zeros((bs, h_out, w_out, self.out_channels))
        
        # Loop over the coordinates of the top-left pixel of 2X2 squares
        for b in range(bs):
            for j in range(0, h_out, self.stride):
                for k in range(0, w_out, self.stride):
                    # Process a squared 2x2 region of the image with a quantum circuit
                    q_results = self.circuit(
                        inputs=tf.Tensor([
                            img[b, j, k, 0],
                            img[b, j, k + 1, 0],
                            img[b, j + 1, k, 0],
                            img[b, j + 1, k + 1, 0]
                        ])
                    )
                    # Assign expectation values to different channels of the output pixel (j/2, k/2)
                    for c in range(self.out_channels):
                        out[b, j // kernel_size, k // kernel_size, c] = q_results[c]
                        
                
        return out
    

In [48]:
qlayer = QLayer(
        stride         = 2,
        device         = "default.qubit", #qulacs.simulator
        wires          = 4,
        circuit_layers = 4,
        n_rotations    = 8,
        out_channels   = 4,
        seed           = None
)

In [49]:
def transform(x):
    x = np.array(x)
    x = x/255.0
    
    return tf.convert_to_tensor(x).float()

In [50]:
(trainX, trainY), (testX, testY) = mnist.load_data()

# grey-scale ==> 1 channel
trainX = trainX.reshape(trainX.shape[0], trainX.shape[1], trainX.shape[2], 1)
testX  = testX.reshape(testX.shape[0], testX.shape[1], testX.shape[2], 1)

# pixel normalization
trainX = trainX.astype("float32")
testX  = testX.astype("float32")
trainX = tf.convert_to_tensor(trainX / 255)
testX  = tf.convert_to_tensor(testX / 255)

# label one-hot encoding
trainY = tf.keras.utils.to_categorical(trainY)
testY  = tf.keras.utils.to_categorical(testY)

In [51]:
model = tf.keras.Sequential(
    qlayer(stride=2, circuit_layers=2, n_rotations=4, out_channels=4),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(14*14*4, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")
)
opt     = tf.keras.optimizers.Adam(learning_rate=0.01)
loss_fn = tf.keras.losses.CategoricalCrossentropy()

model.compile(optimizer=opt, loss=loss_fn, metrics=["accuracy"])

log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model.summary()

ValueError: The first argument to `Layer.call` must always be passed.