# Trainable Quanvolutional NN

In [1]:
import torch
from torch import nn

import torchvision

import pennylane as qml
from pennylane import numpy as np
from pennylane.templates import RandomLayers

from sklearn.metrics import accuracy_score

import numpy             as np
import matplotlib.pyplot as plt
from math import floor

import keras
import tensorflow as tf
import datetime

from keras.models   import Sequential
from keras.layers   import Dense, Dropout, Flatten
from keras.layers   import Conv1D, MaxPooling1D, AveragePooling1D, GlobalMaxPooling1D
from keras          import regularizers, initializers
from keras.datasets import mnist
import pennylane as qml

In [2]:
colors = [
    "#7eb0d5",
    "#fd7f6f",
    "#b2e061",
    "#bd7ebe",
    "#ffb55a",
    "#8bd3c7"
]

In [6]:
(trainX, trainY), (testX, testY) = mnist.load_data()

# grey-scale ==> 1 channel
trainX = trainX.reshape(trainX.shape[0], trainX.shape[1], trainX.shape[2], 1)
testX  = testX.reshape(testX.shape[0], testX.shape[1], testX.shape[2], 1)

# pixel normalization
trainX = trainX.astype("float32")
testX  = testX.astype("float32")
trainX = tf.convert_to_tensor(trainX / 255)
testX  = tf.convert_to_tensor(testX / 255)

# label one-hot encoding
trainY = tf.keras.utils.to_categorical(trainY)
testY  = tf.keras.utils.to_categorical(testY)

In [3]:
class QConv2D(tf.keras.layers.Layer):
	def __init__(self, filters, kernel_size, strides=(1, 1), device="default.qubit", out_channels=1):
		super(QConv2D, self).__init__()

		# init device
		self.kernel_size = kernel_size
		self.wires = kernel_size[0]*kernel_size[1]
		self.dev = qml.device(device, wires=self.wires)
		
		self.strides = strides
		self.out_channels = min(out_channels, self.wires)
		'''
		self.w = self.add_weight(
			name='w',
			shape=(self.wires, 3),
			initializer='random_normal',
			trainable=True
		)

		self.b = self.add_weight(
			name='b',
			shape=(self.wires,),
			initializer='zeros',
			trainable=True
		)
		'''
		# random circuit
		@qml.qnode(device=self.dev, interface='tf')
		def circuit(inputs, weights):

			#print(qml.math.shape(inputs), '\n', self.wires)
			for j in range(self.wires):
				qml.RY(tf.math.atan(inputs[j]), wires=j)
			
			qml.BasicEntanglerLayers(weights=tf.math.atan(tf.reshape(inputs, (1, -1))), wires=range(self.wires), rotation=qml.RZ)

			for j in range(self.wires):
				qml.Rot(weights[j, 0], weights[j, 1], weights[j, 2], wires=j)
			
			# Measurement producing 4 classical output values
			return [qml.expval(qml.PauliZ(j)) for j in range(self.out_channels)]
		
		weight_shapes = {"weights": [self.wires, 3]}
		self.circuit = qml.qnn.KerasLayer(circuit, weight_shapes=weight_shapes, output_dim=self.out_channels)
    
    
	def draw(self):
		# build circuit by sending dummy data through it
		#_ = self.circuit()
		print(qml.draw(self.circuit.qnode)(inputs=np.zeros(self.wires), weights=np.zeros((self.wires,3))))
		#self.circuit.zero_grad()
	
	
	def call(self, img):
		bs, h, w, ch = img.shape
		if ch > 1:
			img = img.mean(axis=-1).reshape(bs, h, w, 1)
						        
		h_out = (h-self.kernel_size[0]) // self.strides[0] + 1
		w_out = (w-self.kernel_size[1]) // self.strides[1] + 1
		
		
		out = np.zeros((bs, h_out, w_out, self.out_channels))
		# Loop over the coordinates of the top-left pixel of 2X2 squares
		for b in range(bs):
			for j in range(0, h_out, self.strides[0]):
				for k in range(0, w_out, self.strides[1]):
					# Process a squared 2x2 region of the image with a quantum circuit
					q_results = self.circuit(
						inputs=tf.convert_to_tensor([
							img[b, j, k, 0],
							img[b, j, k + 1, 0],
							img[b, j + 1, k, 0],
							img[b, j + 1, k + 1, 0]
						])
					)
					# Assign expectation values to different channels of the output pixel (j/2, k/2)
					for c in range(self.out_channels):
						out[b, j // self.kernel_size[0], k // self.kernel_size[1], c] = q_results[c]
						
				
		return tf.convert_to_tensor(out)

	#def build(self, input_shape):

In [5]:
qlayer = QConv2D(filters=2, kernel_size=(2,2), out_channels=1)
qlayer.draw()

0: ──RY(0.00)─╭BasicEntanglerLayers(M0)──Rot(0.00,0.00,0.00)─┤  <Z>
1: ──RY(0.00)─├BasicEntanglerLayers(M0)──Rot(0.00,0.00,0.00)─┤     
2: ──RY(0.00)─├BasicEntanglerLayers(M0)──Rot(0.00,0.00,0.00)─┤     
3: ──RY(0.00)─╰BasicEntanglerLayers(M0)──Rot(0.00,0.00,0.00)─┤     


In [15]:
# constants
EPOCHS     = 10
FONTSIZE   = 18


model = tf.keras.Sequential([
    QConv2D(filters=2, kernel_size=(2,2), out_channels=1),
    # tf.keras.layers.Conv2D(32, (3, 3), activation="relu"),
    tf.keras.layers.Flatten(),
    # tf.keras.layers.Dense(, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")
])
opt     = tf.keras.optimizers.Adam(learning_rate=0.01)
loss_fn = tf.keras.losses.CategoricalCrossentropy()

model.compile(optimizer=opt, loss=loss_fn, metrics=["accuracy"])

# tensorboard callback
log_dir = "qcnn_logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [16]:
NDATA = 10

history   = model.fit(trainX[:NDATA], trainY[:NDATA], epochs=EPOCHS, batch_size=32, verbose=2)
loss, acc = model.evaluate(testX[:NDATA], testY[:NDATA], verbose=0)


train_loss_history = np.array(history.history["loss"])
train_acc_history  = np.array(history.history["accuracy"])

valid_loss_history = np.array(history.history["val_loss"])
valid_acc_history  = np.array(history.history["val_accuracy"])

Epoch 1/10


In [None]:
epochs_grid = np.arange(0, EPOCHS, 1)

fig, (ax1, ax2) = plt.subplots(figsize=(14, 8), ncols=2, sharex=True, constrained_layout=True)

ax1.set_axisbelow(True)
ax1.grid(axis="y", color="black", alpha=0.2)
ax2.set_axisbelow(True)
ax2.grid(axis="y", color="black", alpha=0.2)

ax1.set_title("Loss history",     fontsize=FONTSIZE+4)
ax1.set_xlabel("epochs",          fontsize=FONTSIZE)
ax1.set_ylabel("loss",            fontsize=FONTSIZE)
ax1.tick_params(axis="both", which="major", labelsize=FONTSIZE, length=5)
ax2.set_title("Accuracy history", fontsize=FONTSIZE+4)
ax2.set_xlabel("epochs",          fontsize=FONTSIZE)
ax2.set_ylabel("1-accuracy",      fontsize=FONTSIZE)
ax2.tick_params(axis="both", which="major", labelsize=FONTSIZE, length=5)

ax1.plot(epochs_grid, train_loss_history,  linewidth=3, color=colors[0], label="train")
ax1.plot(epochs_grid, valid_loss_history,  linewidth=3, color=colors[1], label="validation")
ax2.plot(epochs_grid, 1-train_acc_history, linewidth=3, color=colors[0], label="train")
ax2.plot(epochs_grid, 1-valid_acc_history, linewidth=3, color=colors[1], label="validation")

ax1.text(0.4, 0.7, f"Test loss = {loss:.4f}",    fontsize=FONTSIZE, transform=ax1.transAxes)
ax2.text(0.4, 0.7, f"Test accuracy = {acc:.4f}", fontsize=FONTSIZE, transform=ax2.transAxes)

ax1.set_yscale("linear")
ax2.set_yscale("linear")

ax1.legend(fontsize=FONTSIZE)
ax2.legend(fontsize=FONTSIZE)

plt.show()