In [32]:
import numpy as np
import torch
import torchquantum as tq
import torchquantum.functional as tqf
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from dataset import QuantumSensingDataset
from qnn import QuantumSensing, QuantumML

In [17]:
# data
root_dir = 'qml-data/toy/train'
train_dataset = QuantumSensingDataset(root_dir)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
root_dir = 'qml-data/toy/test'
test_dataset = QuantumSensingDataset(root_dir)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=4)
print(train_dataset.__len__())
print(test_dataset.__len__())
print(train_dataset[0])

300
100
{'phase': array([4.0597625, 3.9561498, 2.978269 , 2.9022985], dtype=float32), 'label': array(0)}


In [39]:
# the model and training related variables
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = QuantumML(n_wires=4, n_locations=4).to(device)
n_epochs = 25
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

for e in range(n_epochs):
    print(f'epoch={e}')
    model.train()
    for t, sample in enumerate(train_dataloader):
        thetas = sample['phase']
        targets = sample['label'].to(device)
        # preparing sensing data
        bsz = X.shape[0]
        n_qubits = X.shape[1]
        qsensing = QuantumSensing(n_qubits=n_qubits, list_of_thetas=thetas, device=device)
        qstate = tq.QuantumState(n_wires=n_qubits, bsz=bsz)
        qsensing(qstate)
        q_device = tq.QuantumDevice(n_wires=n_qubits)
        q_device.reset_states(bsz=bsz)
        # the model
        outputs = model(q_device, qstate.states)
        # compute loss, gradient, optimize, etc...
        loss = F.nll_loss(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if t % 10 == 0:
            print(f'loss={loss.item()}')
    
    model.eval()
    target_all = []
    output_all = []
    with torch.no_grad():
        for t, sample in enumerate(test_dataloader):
            thetas = sample['phase']
            targets = sample['label'].to(device)
            bsz = X.shape[0]
            n_qubits = X.shape[1]
            qsensing = QuantumSensing(n_qubits=n_qubits, list_of_thetas=thetas, device=device)
            qstate = tq.QuantumState(n_wires=n_qubits, bsz=bsz)
            qsensing(qstate)
            q_device = tq.QuantumDevice(n_wires=n_qubits)
            q_device.reset_states(bsz=bsz)
            # the model
            outputs = model(q_device, qstate.states)
            target_all.append(targets)
            output_all.append(outputs)
        target_all = torch.cat(target_all)
        output_all = torch.cat(output_all)
        
#     print(f'target_all = {target_all}')
#     print(f'output_all = {output_all}')
    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    print(f'accuracy={accuracy}')

epoch=0
loss=1.4783298969268799
loss=1.341200351715088
accuracy=0.48
epoch=1
loss=1.1503210067749023
loss=1.1464542150497437
accuracy=0.84
epoch=2
loss=1.1050156354904175
loss=1.0077183246612549
accuracy=1.0
epoch=3
loss=0.9447122812271118
loss=0.8660270571708679
accuracy=1.0
epoch=4
loss=0.7749543190002441
loss=0.756308913230896
accuracy=1.0
epoch=5
loss=0.6835474371910095
loss=0.6683788895606995
accuracy=1.0
epoch=6
loss=0.5853619575500488
loss=0.5427542328834534
accuracy=1.0
epoch=7
loss=0.4951237738132477
loss=0.47508564591407776
accuracy=1.0
epoch=8
loss=0.45358145236968994
loss=0.3992227613925934
accuracy=1.0
epoch=9
loss=0.3552781045436859
loss=0.3695710599422455
accuracy=1.0
epoch=10
loss=0.315591424703598
loss=0.36177298426628113
accuracy=1.0
epoch=11
loss=0.2917434573173523
loss=0.27257809042930603
accuracy=1.0
epoch=12
loss=0.25115787982940674
loss=0.2463836371898651
accuracy=1.0
epoch=13
loss=0.24202734231948853
loss=0.21259015798568726
accuracy=1.0
epoch=14
loss=0.21787908