In [1]:
import numpy as np
import torch
import torchquantum as tq
import torchquantum.functional as tqf
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from dataset import QuantumSensingDataset
from qnn import QuantumSensing, QuantumML0, QuantumML1

In [2]:
# data
root_dir = 'qml-data/toy/train'
train_dataset = QuantumSensingDataset(root_dir)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
root_dir = 'qml-data/toy/test'
test_dataset = QuantumSensingDataset(root_dir)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=4)
print(train_dataset.__len__())
print(test_dataset.__len__())
print(train_dataset[0])

300
100
{'phase': array([4.0597625, 3.9561498, 2.978269 , 2.9022985], dtype=float32), 'label': array(0)}


In [11]:
# the model and training related variables
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = QuantumML0(n_wires=4, n_locations=4).to(device)
n_epochs = 25
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

for e in range(n_epochs):
    print(f'epoch={e}')
    model.train()
    for t, sample in enumerate(train_dataloader):
        thetas = sample['phase']
        targets = sample['label'].to(device)
        # preparing sensing data
        bsz = thetas.shape[0]
        n_qubits = thetas.shape[1]
        qsensing = QuantumSensing(n_qubits=n_qubits, list_of_thetas=thetas, device=device)
        qstate = tq.QuantumState(n_wires=n_qubits, bsz=bsz)
        qsensing(qstate)
        q_device = tq.QuantumDevice(n_wires=n_qubits)
        q_device.reset_states(bsz=bsz)
        # the model
        outputs = model(q_device, qstate.states)
        # compute loss, gradient, optimize, etc...
        loss = F.nll_loss(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if t % 10 == 0:
            print(f'loss={loss.item()}')
    
    model.eval()
    target_all = []
    output_all = []
    with torch.no_grad():
        for t, sample in enumerate(test_dataloader):
            thetas = sample['phase']
            targets = sample['label'].to(device)
            bsz = thetas.shape[0]
            n_qubits = thetas.shape[1]
            qsensing = QuantumSensing(n_qubits=n_qubits, list_of_thetas=thetas, device=device)
            qstate = tq.QuantumState(n_wires=n_qubits, bsz=bsz)
            qsensing(qstate)
            q_device = tq.QuantumDevice(n_wires=n_qubits)
            q_device.reset_states(bsz=bsz)
            # the model
            outputs = model(q_device, qstate.states)
            target_all.append(targets)
            output_all.append(outputs)
        target_all = torch.cat(target_all)
        output_all = torch.cat(output_all)
        
#     print(f'target_all = {target_all}')
#     print(f'output_all = {output_all}')
    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    print(f'accuracy={accuracy}')

epoch=0
loss=1.5274028778076172
loss=1.3250788450241089
accuracy=0.25
epoch=1
loss=1.2785444259643555
loss=1.2095234394073486
accuracy=0.96
epoch=2
loss=1.136954426765442
loss=0.9982602000236511
accuracy=1.0
epoch=3
loss=1.0841344594955444
loss=0.8038815259933472
accuracy=1.0
epoch=4
loss=0.7527194619178772
loss=0.6946508288383484
accuracy=1.0
epoch=5
loss=0.6216016411781311
loss=0.5849497318267822
accuracy=1.0
epoch=6
loss=0.5228888988494873
loss=0.5026627779006958
accuracy=1.0
epoch=7
loss=0.46952539682388306
loss=0.4169218838214874
accuracy=1.0
epoch=8
loss=0.40230226516723633
loss=0.34280291199684143
accuracy=1.0
epoch=9
loss=0.32562991976737976
loss=0.2877598702907562
accuracy=1.0
epoch=10
loss=0.29273465275764465
loss=0.21009264886379242
accuracy=1.0
epoch=11
loss=0.2354767918586731
loss=0.2052772492170334
accuracy=1.0
epoch=12
loss=0.20772768557071686
loss=0.20318329334259033
accuracy=1.0
epoch=13
loss=0.19715078175067902
loss=0.16233962774276733
accuracy=1.0
epoch=14
loss=0.146

In [12]:
# the model and training related variables
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = QuantumML1(n_wires=4, n_locations=4).to(device)
n_epochs = 25
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

for e in range(n_epochs):
    print(f'epoch={e}')
    model.train()
    for t, sample in enumerate(train_dataloader):
        thetas = sample['phase']
        targets = sample['label'].to(device)
        # preparing sensing data
        bsz = thetas.shape[0]
        n_qubits = thetas.shape[1]
        qsensing = QuantumSensing(n_qubits=n_qubits, list_of_thetas=thetas, device=device)
        qstate = tq.QuantumState(n_wires=n_qubits, bsz=bsz)
        qsensing(qstate)
        q_device = tq.QuantumDevice(n_wires=n_qubits)
        q_device.reset_states(bsz=bsz)
        # the model
        outputs = model(q_device, qstate.states)
        # compute loss, gradient, optimize, etc...
        loss = F.nll_loss(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if t % 10 == 0:
            print(f'loss={loss.item()}')
    
    model.eval()
    target_all = []
    output_all = []
    with torch.no_grad():
        for t, sample in enumerate(test_dataloader):
            thetas = sample['phase']
            targets = sample['label'].to(device)
            bsz = thetas.shape[0]
            n_qubits = thetas.shape[1]
            qsensing = QuantumSensing(n_qubits=n_qubits, list_of_thetas=thetas, device=device)
            qstate = tq.QuantumState(n_wires=n_qubits, bsz=bsz)
            qsensing(qstate)
            q_device = tq.QuantumDevice(n_wires=n_qubits)
            q_device.reset_states(bsz=bsz)
            # the model
            outputs = model(q_device, qstate.states)
            target_all.append(targets)
            output_all.append(outputs)
        target_all = torch.cat(target_all)
        output_all = torch.cat(output_all)
        
#     print(f'target_all = {target_all}')
#     print(f'output_all = {output_all}')
    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    print(f'accuracy={accuracy}')

epoch=0
loss=1.4952408075332642
loss=1.4221601486206055
accuracy=0.25
epoch=1
loss=1.3084717988967896
loss=1.3591673374176025
accuracy=0.68
epoch=2
loss=1.25327467918396
loss=1.1748526096343994
accuracy=1.0
epoch=3
loss=1.11481773853302
loss=1.0419483184814453
accuracy=1.0
epoch=4
loss=0.9796124696731567
loss=0.9416905045509338
accuracy=1.0
epoch=5
loss=0.9174097776412964
loss=0.8544785976409912
accuracy=1.0
epoch=6
loss=0.8041238188743591
loss=0.7262616753578186
accuracy=1.0
epoch=7
loss=0.7357550859451294
loss=0.7362036108970642
accuracy=1.0
epoch=8
loss=0.6418341994285583
loss=0.6573255658149719
accuracy=1.0
epoch=9
loss=0.599551796913147
loss=0.4605522155761719
accuracy=1.0
epoch=10
loss=0.48134204745292664
loss=0.46387407183647156
accuracy=1.0
epoch=11
loss=0.46777525544166565
loss=0.41137099266052246
accuracy=1.0
epoch=12
loss=0.4301094114780426
loss=0.3954290747642517
accuracy=1.0
epoch=13
loss=0.402992308139801
loss=0.3316196799278259
accuracy=1.0
epoch=14
loss=0.35607999563217