In [None]:
!pip install tensorflow tensorflow-quantum cirq
!pip install pennylane --upgrade



In [None]:
import tensorflow as tf
import tensorflow_quantum as tfq
import cirq
import sympy
import numpy as np
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

#downsample to 8x8, greyscale, normalized
def preprocess(x):
    x = tf.image.rgb_to_grayscale(x)
    x = tf.image.resize(x, (8, 8))
    x = tf.cast(x, tf.float32) / 255.0
    return x
#shape = (batch, 8, 8, 1)

x_train = preprocess(x_train)
x_test = preprocess(x_test)

#4 patches, 16 features
def extract_patches(images):
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, 4, 4, 1],
        strides=[1, 4, 4, 1],
        rates=[1, 1, 1, 1],
        padding="VALID"
    )
    return patches
#shape = (batch, 2, 2, 16)

#patch compression  from 16 to 4
def compress_patches(patches):
    patches = tf.reshape(patches, (-1, 2, 2, 4, 4))
    pooled = tf.reduce_mean(patches, axis=(3, 4))
    return pooled
#shape = (batch, 2, 2, 4)

#qnn

#qbits
n_qubits = 4
qubits = cirq.GridQubit.rect(1, n_qubits)

# symbols
input_symbols = sympy.symbols("x0:4")         # classical input
trainable_symbols = sympy.symbols("theta0:4") # trainable

#circuit template with symbolic inputs + trainable rotations
circuit_template = cirq.Circuit()
for i in range(n_qubits):
    circuit_template.append(cirq.ry(input_symbols[i])(qubits[i]))
for i in range(n_qubits):
    circuit_template.append(cirq.ry(trainable_symbols[i])(qubits[i]))
for i in range(n_qubits - 1):
    circuit_template.append(cirq.CNOT(qubits[i], qubits[i + 1]))

readout_ops = [cirq.Z(q) for q in qubits]

#layer
quantum_layer = tfq.layers.PQC(
    model_circuit=circuit_template,
    operators=readout_ops,
    symbol_names=trainable_symbols
)

# convert patches into shape (batch, n_features) for PQC input
def patches_to_inputs(patch_tensor):
    # flatten each patch: (batch, 2, 2, 4) -> (batch*2*2, 4)
    patches = tf.reshape(patch_tensor, (-1, 4))
    return patches

#qnn model
inputs = tf.keras.Input(shape=(8, 8, 1))
patches = extract_patches(inputs)
compressed = compress_patches(patches)
quantum_inputs = tf.keras.layers.Lambda(patches_to_inputs)(compressed)
quantum_features = quantum_layer(quantum_inputs)
quantum_features = tf.reshape(quantum_features, (-1, 2*2*n_qubits))
outputs = tf.keras.layers.Dense(10)(quantum_features)
qnn_model = tf.keras.Model(inputs=inputs, outputs=outputs)

#baseline pure cnn model
cnn_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(
        filters=4,
        kernel_size=4,
        strides=4,
        input_shape=(8, 8, 1)
    ),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10)
])

#compile models
for model in [cnn_model, qnn_model]:
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"]
    )

#training
cnn_history = cnn_model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=15,
    batch_size=32
)

qnn_history = qnn_model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=15,
    batch_size=32
)

#plotting
plt.figure(figsize=(10,6))
plt.plot(cnn_history.history['accuracy'], label='CNN Train Acc', linestyle='-')
plt.plot(qnn_history.history['accuracy'], label='QNN Train Acc', linestyle='-.')
plt.plot(cnn_history.history['val_accuracy'], label='CNN Val Acc', linestyle='--')
plt.plot(qnn_history.history['val_accuracy'], label='QNN Val Acc', linestyle=':')
plt.title("CNN vs QNN Accuracy on 8x8 CIFAR-10")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)
plt.show()

#training vs validation loss
plt.figure(figsize=(10,6))
plt.plot(cnn_history.history['loss'], label='CNN Train Loss', linestyle='-')
plt.plot(cnn_history.history['val_loss'], label='CNN Val Loss', linestyle='--')
plt.plot(qnn_history.history['loss'], label='QNN Train Loss', linestyle='-.')
plt.plot(qnn_history.history['val_loss'], label='QNN Val Loss', linestyle=':')
plt.title("CNN vs QNN Loss on 8x8 CIFAR-10")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()


ValueError: Unrecognized keyword arguments passed to PQC: {'symbol_names': (theta0, theta1, theta2, theta3)}