In [None]:
import torch
from torch import cat, no_grad, manual_seed
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
from torch.nn import (
    Module,
    Conv2d,
    Linear,
    Dropout2d,
    NLLLoss,
    MaxPool2d,
    Flatten,
    Sequential,
    ReLU,
)
import torch.nn.functional as F

In [None]:
# Necessary imports

import numpy as np
import matplotlib.pyplot as plt

from torch import Tensor
from torch.nn import Linear, CrossEntropyLoss, MSELoss
from torch.optim import LBFGS
import torch.nn as nn
from qiskit import QuantumCircuit
from qiskit.circuit import Parameter
from qiskit.circuit.library import RealAmplitudes, ZZFeatureMap , EfficientSU2
from qiskit_algorithms.utils import algorithm_globals
from qiskit_machine_learning.neural_networks import SamplerQNN, EstimatorQNN
from qiskit_machine_learning.connectors import TorchConnector
from qiskit.quantum_info import SparsePauliOp

import pennylane as qml

In [None]:
device = 'cpu'
use_cuda = True
if torch.cuda.is_available() and use_cuda:  
    device = 'cuda:0'
print(f'Using {device} for training.')

In [None]:
#　Hyperparameter
batch_size = 64

# data
test_input = torch.randn((1,3,32,32)).to(device)
class_n = 2
train_samples = -1
test_samples = -1

# model config
qubitn = 3
q_depth = 1
out_type = 'pool'

In [None]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import numpy as np


manual_seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)


def get_class_indices(dataset, classes):
    indices = []
    for i in range(len(dataset)):
        if dataset.targets[i] in classes:
            indices.append(i)
    return indices


classes_of_interest = [3, 88]
train_indices = get_class_indices(trainset, classes_of_interest)


X_train = Subset(trainset, train_indices)


batch_size = 32  


train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True)


len(train_loader)

In [None]:
def filter_and_relabel(dataset, classes):
    indices = []
    new_labels = []
    for i in range(len(dataset)):
        if dataset.targets[i] == classes[0]:
            indices.append(i)
            new_labels.append(0)  # Remap class 3 to 0
        elif dataset.targets[i] == classes[1]:
            indices.append(i)
            new_labels.append(1)  # Remap class 88 to 1
    return indices, new_labels


train_indices, train_new_labels = filter_and_relabel(trainset, classes_of_interest)


train_subset = Subset(trainset, train_indices)
for i, idx in enumerate(train_indices):
    train_subset.dataset.targets[idx] = train_new_labels[i]


train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

len(train_loader)


In [None]:
def filter_and_relabel(dataset, classes):
    indices = []
    new_labels = []
    for i in range(len(dataset)):
        if dataset.targets[i] == classes[0]:
            indices.append(i)
            new_labels.append(0)  # Remap class 3 to 0
        elif dataset.targets[i] == classes[1]:
            indices.append(i)
            new_labels.append(1)  # Remap class 88 to 1
    return indices, new_labels


manual_seed(5)


testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)


test_indices, test_new_labels = filter_and_relabel(testset, classes_of_interest)


test_samples = 50  
test_indices = test_indices[:test_samples]
test_new_labels = test_new_labels[:test_samples]


test_subset = Subset(testset, test_indices)
for i, idx in enumerate(test_indices):
    test_subset.dataset.targets[idx] = test_new_labels[i]


batch_size = 32  


test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=True)


len(test_loader)

In [None]:
class CNet(Module):
    def __init__(self,class_n):
        super().__init__()
        self.conv1 = Conv2d(1, 9, kernel_size=5)
        self.conv2 = Conv2d(9, 16, kernel_size=5)
        self.dropout = Dropout2d()
        self.fc1 = Linear(256, 64)
        self.fc2 = Linear(64, class_n)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
SU2 = [
    qml.PauliX ,
    qml.PauliY ,
    qml.PauliZ 
]
R = [
    qml.RX ,
    qml.RY ,
    qml.RZ
]
def XYZ(x , qid) :
    for r , _x in zip(R,x) :
        r(_x, wires = qid)

def entangle_layer(targets) :
    for tg in targets :
        qml.CNOT(wires = tg)

In [None]:
dev = qml.device('lightning.qubit', wires=qubitn)
#dev = qml.device('default.qubit', wires=qubitn)

@qml.qnode(dev , interface = 'torch')
def circuit(inputs , weight) :
    #　x with shape (qubit , encode input)
    #  weights with shape (depth, qubit , weights,2)
    for idx ,_x in enumerate(inputs.view(3,3)) :
        XYZ(_x , idx)

    #  entangle structure = {i , i+1}
    entangle = [[i ,i+1] for i in range(qubitn-1)]
    for idd , single_layer in enumerate(weight) :
        for qid , weights in enumerate(single_layer) :
            XYZ(weights[0] ,qid)
        entangle_layer(entangle)
        for qid , weights in enumerate(single_layer) :
            XYZ(weights[1] ,qid)
    rs = []
    if out_type == "pool" :
        for pauli in SU2 :
            ob = pauli(0)
            for i in range(1 ,qubitn) :
                ob = ob @ pauli(i)
            rs.append(qml.expval(ob))
    else :
        for i in range(qubitn) :
            rs.append(qml.expval(qml.PauliZ(i)))
    return  rs

In [None]:
class QNN(nn.Module) :
    def __init__(self , qubitn , q_depth , out) -> None:
        super(QNN , self).__init__()
        #self.weights = nn.Parameter(torch.randn((q_depth ,qubitn , 2,3)))
        self.qnn =  qml.qnn.TorchLayer(circuit , {"weight" : ( q_depth,qubitn , 2,3)})
        self.out = out
        self.qubitn = qubitn
        self.q_depth  = q_depth

    def forward(self, input):
        b,c  = input.shape

        # Batched forward
        out = self.qnn(input[0])
        for idx , i in  enumerate(input[1:] ):
            out = torch.cat((out , self.qnn(input[idx])))
        return torch.reshape(out , (b,3))

class QSELayer(nn.Module):
    def __init__(self,channel, qubitn , q_depth , out):
        super(QSELayer,self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        fin = qubitn
        if out == 'pool' :
            fin = 3
        self.fc = nn.Sequential(
            nn.Linear(fin ,channel),
            nn.Sigmoid()
        )
        self.qnn = QNN(qubitn=qubitn , q_depth= q_depth , out = out)
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.qnn(y)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [None]:
# define Hybrid Net with Quantum Pooling
class QSENet(Module):
    def __init__(self, qubitn, class_n, q_depth, out):
        super().__init__()
        chn = qubitn * 3
        self.conv1 = Conv2d(3, chn, kernel_size=5)  # Modified to accept 3-channel input
        self.qse = QSELayer(qubitn=qubitn, channel=chn, q_depth=q_depth, out=out)
        self.conv2 = Conv2d(chn, 16, kernel_size=5)
        self.dropout = Dropout2d()
        self.fc1 = Linear(400, 64)
        self.fc2 = Linear(64, class_n)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = self.qse(x)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


QPSE = QSENet(qubitn = qubitn ,class_n = class_n ,q_depth=q_depth, out=out_type)
QPSE.to(device)
QPSE(test_input)

In [None]:
model = QPSE
model.to(device)

In [None]:
# Define model, optimizer, and loss function

model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_func = CrossEntropyLoss()

# Start training
epochs = 20  # Set number of epochs
loss_list = []  # Store loss history
model.train()  # Set model to training mode

for epoch in range(epochs):
    total_loss = []
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad(set_to_none=True)  # Initialize gradient
        data = data.to(device)
        target = target.to(device)
        output = model(data)  # Forward pass
        loss = loss_func(output, target)  # Calculate loss
        #print(batch_idx)
        loss.backward()  # Backward pass
        optimizer.step()  # Optimize weights
        total_loss.append(loss.item())  # Store loss
        print("\rDone {:.3f} %" .format(100*batch_idx/len(train_loader)), end='')
    loss_list.append(sum(total_loss) / len(total_loss))
    print("\nTraining [{:.0f}%]\tLoss: {:.4f}".format(100.0 * (epoch + 1) / epochs, loss_list[-1]))

In [None]:
# Plot loss convergence
plt.plot(loss_list)
plt.title("Hybrid NN Training Convergence")
plt.xlabel("Training Iterations")
plt.ylabel("Neg. Log Likelihood Loss")
plt.show()

In [None]:
save_name = "Qmodel4.pt"
torch.save(model.state_dict(), save_name)

In [None]:
load_name = "Qmodel4.pt"
modelt = model
modelt.load_state_dict(torch.load(save_name))

In [None]:
modelt.eval()  # set model to evaluation mode
target_loader = test_loader
modelt.to(device)
with no_grad():

    correct = 0
    for batch_idx, (data, target) in enumerate(target_loader):
        data = data.to(device)
        target = target.to(device)
        output = modelt(data)
        if len(output.shape) == 1:
            output = output.reshape(1, *output.shape)

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

        loss = loss_func(output, target)
        total_loss.append(loss.item())

    print(
        "Performance on test data:\n\tLoss: {:.4f}\n\tAccuracy: {:.1f}%".format(
            sum(total_loss) / len(total_loss), correct / len(target_loader) / batch_size * 100
        )
    )