In [1]:
!pip install pennylane pennylane-lightning torch torchvision matplotlib
!pip install jax==0.4.28 jaxlib==0.4.28

Collecting pennylane
  Downloading PennyLane-0.41.1-py3-none-any.whl.metadata (10 kB)
Collecting pennylane-lightning
  Downloading pennylane_lightning-0.41.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (12 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting tomlkit (from pennylane)
  Downloading tomlkit-0.13.3-py3-none-any.whl.metadata (2.8 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning)
  Downloading scipy_openblas32-0.3.29.265.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (56 kB)
[2K     [90m━━

In [2]:
import warnings, os
import numpy as np
warnings.filterwarnings("ignore")
np.seterr(all='ignore')
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [3]:
print('********************************************************************************************************************')
print('Single QBIT Encoding for MNIST dataset')
print('https://ieeexplore.ieee.org/abstract/document/9798852')
import pennylane as qml
from pennylane import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
n_epochs = 6
batch_size = 1
learning_rate = 1e-4
img_size = 12
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor()
])
train_raw = datasets.MNIST(root=".", train=True, download=True, transform=transform)
test_raw = datasets.MNIST(root=".", train=False, download=True, transform=transform)
train_data = [(x, y) for x, y in train_raw if y in [0, 1]]
test_data = [(x, y) for x, y in test_raw if y in [0, 1]]
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1)
dev = qml.device("default.qubit", wires=1)
def single_qubit_encode(x, theta, phi):
    for i in range(0, len(x), 3):
        x_pad = x[i:i+3] + [0]*(3 - len(x[i:i+3]))
        beta  = theta[0] + x_pad[0]*phi[0]
        gamma = theta[1] + x_pad[1]*phi[1]
        delta = theta[2] + x_pad[2]*phi[2]
        qml.RZ(beta, wires=0)
        qml.RY(gamma, wires=0)
        qml.RZ(delta, wires=0)
@qml.qnode(dev, interface="torch")
def quantum_circuit(inputs, theta, phi):
    qml.Hadamard(wires=0)
    single_qubit_encode(inputs, theta, phi)
    return qml.expval(qml.PauliZ(wires=0))
class QuantumNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.theta = nn.Parameter(torch.randn(3) * 0.1)
        self.phi   = nn.Parameter(torch.randn(3) * 0.1)
    def forward(self, x):
      x = x.view(-1).to(torch.float32).tolist()
      result = quantum_circuit(x, self.theta, self.phi)
      prob_0 = (1 + result) / 2
      return prob_0.unsqueeze(0).to(torch.float32)
model = QuantumNet()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()
for epoch in range(n_epochs):
    total_loss, correct = 0, 0
    for img, label in train_loader:
        target = torch.tensor([1.0], dtype=torch.float32) if label.item() == 0 else torch.tensor([0.0], dtype=torch.float32)
        optimizer.zero_grad()
        output = model(img)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = 0 if output >= 0.5 else 1
        if pred == label.item():
            correct += 1
    acc = correct / len(train_loader)
    print(f"Epoch {epoch+1:2d} | Loss: {total_loss:.4f} | Accuracy: {acc*100:.2f}%")
correct = 0
with torch.no_grad():
    for img, label in test_loader:
        output = model(img)
        pred = 0 if output >= 0.5 else 1
        if pred == label.item():
            correct += 1
print(f"\nTest Accuracy: {correct / len(test_loader) * 100:.2f}%")
print('********************************************************************************************************************')

********************************************************************************************************************
Single QBIT Encoding for MNIST dataset
https://ieeexplore.ieee.org/abstract/document/9798852


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 488kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.42MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.57MB/s]


Epoch  1 | Loss: 1074.9817 | Accuracy: 92.93%
Epoch  2 | Loss: 831.5095 | Accuracy: 92.81%
Epoch  3 | Loss: 817.5536 | Accuracy: 92.84%
Epoch  4 | Loss: 807.9657 | Accuracy: 92.88%
Epoch  5 | Loss: 804.5721 | Accuracy: 92.95%
Epoch  6 | Loss: 801.5699 | Accuracy: 92.86%

Test Accuracy: 91.30%
********************************************************************************************************************


In [4]:
print('Simple CNN')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
n_epochs = 6
batch_size = 1
learning_rate = 1e-4
img_size = 12
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_raw = datasets.MNIST(root=".", train=True, download=True, transform=transform)
test_raw = datasets.MNIST(root=".", train=False, download=True, transform=transform)
train_data = [(x, y) for x, y in train_raw if y in [0, 1]]
test_data = [(x, y) for x, y in test_raw if y in [0, 1]]
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1)
class SampleCNN(nn.Module):
    def __init__(self):
        super(SampleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1_input_features = 32 * (img_size // 4) * (img_size // 4)
        self.fc1 = nn.Linear(self.fc1_input_features, 2)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, self.fc1_input_features)
        return self.fc1(x)
model_dl = SampleCNN()
optimizer_dl = torch.optim.Adam(model_dl.parameters(), lr=learning_rate)
loss_fn_dl = nn.CrossEntropyLoss()
for epoch in range(n_epochs):
    total_loss, correct = 0, 0
    for img, label in train_loader:
        target_dl = label.to(torch.long)
        optimizer_dl.zero_grad()
        output_dl = model_dl(img)
        loss_dl = loss_fn_dl(output_dl, target_dl)
        loss_dl.backward()
        optimizer_dl.step()
        total_loss += loss_dl.item()
        pred_dl = output_dl.argmax(dim=1, keepdim=True)
        correct += pred_dl.eq(target_dl.view_as(pred_dl)).sum().item()
    acc_dl = correct / len(train_loader)
    print(f"Epoch {epoch+1:2d} | Loss: {total_loss:.4f} | Accuracy: {acc_dl*100:.2f}%")
correct = 0
with torch.no_grad():
    for img, label in test_loader:
        output_dl = model_dl(img)
        pred_dl = output_dl.argmax(dim=1, keepdim=True)
        if pred_dl.item() == label.item():
            correct += 1
print(f"\nTest Accuracy: {correct / len(test_loader) * 100:.2f}%")
print('********************************************************************************************************************')

Simple CNN
Epoch  1 | Loss: 376.9544 | Accuracy: 99.46%
Epoch  2 | Loss: 37.9077 | Accuracy: 99.88%
Epoch  3 | Loss: 25.7753 | Accuracy: 99.94%
Epoch  4 | Loss: 15.5955 | Accuracy: 99.95%
Epoch  5 | Loss: 12.2532 | Accuracy: 99.97%
Epoch  6 | Loss: 14.4317 | Accuracy: 99.96%

Test Accuracy: 99.91%
********************************************************************************************************************
