In [1]:
from __future__ import print_function 
import h5py 
import numpy as np 
from tqdm import tqdm 

import torch.utils.data 
from torch import nn, optim 
from torchinfo import summary 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from base_model import TinyClassifier 
from hawq_model import HawqTinyClassifier 
from brevitas_model import BrevitasTinyClassifier 

In [3]:
from brevitas.export.onnx.qonnx.manager import QONNXManager
from hawq.utils.export import ExportManager

### Setup QAT training parameters

In [9]:
framework = "hawq"  # options: hawq/brevitas

# Hyper-Parameters
torch.manual_seed(4)
epochs = 100
train = 9000#data.shape[0]*0.99
batch_size = 12800
learning_rate = 1e-4

device = 'cuda' if torch.cuda.is_available() else 'cpu'
kwargs = {'num_workers':0, 'pin_memory': True} 
plot_interval = 1

In [10]:
if framework.lower() == "hawq":
    base_model = TinyClassifier()
    # base_model.load_state_dict(torch.load("../../checkpoints/checkpoint_tiny_affine.pth"))
    model = HawqTinyClassifier(base_model)
elif framework.lower() == "brevitas":
    model = BrevitasTinyClassifier()

In [11]:
# Load the raw data
IF = -136.75/1e3
with h5py.File(r'../../../datasets/qubits/00002_IQ_plot_raw.h5', 'r') as f:
    adc_g_1 = np.array(f['adc_g_1'])[0]
    adc_g_2 = np.array(f['adc_g_2'])[0]
    adc_e_1 = np.array(f['adc_e_1'])[0]
    adc_e_2 = np.array(f['adc_e_2'])[0]

""" Select the range of time series data. Each data is 2000 element vector 
representing 2000ns readout signal"""
csr = range(500,1500)
sr = len(csr)

I_g = adc_g_1[:,csr]
Q_g = adc_g_2[:,csr] 
I_e = adc_e_1[:,csr] 
Q_e = adc_e_2[:,csr] 

# Dataset Creation
data = np.zeros((adc_g_1.shape[0]*2,sr,2))
data[0:adc_g_1.shape[0],:,0] = I_g
data[0:adc_g_1.shape[0],:,1] = Q_g
data[adc_g_1.shape[0]:adc_g_1.shape[0]*2,:,0] = I_e
data[adc_g_1.shape[0]:adc_g_1.shape[0]*2,:,1] = Q_e

labels = np.zeros(I_e.shape[0]*2)
labels[I_e.shape[0]:I_e.shape[0]*2] = 1

data = torch.from_numpy(data).float()
labels = torch.from_numpy(labels).float()

class Qubit_Readout_Dataset():
    
    def __init__(self):
        self.data = data
        self.labels = labels    
        self.data = self.data.reshape(len(data),sr*2)
       
    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


# Dataloader Prep
test = len(data)-train

dataset = Qubit_Readout_Dataset()
train_data, test_data = torch.utils.data.random_split(dataset, [int(train), int(test)])

num_workers = 0
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, 
                                            num_workers = num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, 
                                            num_workers = num_workers, shuffle=True)


In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

# Loss tracks
train_loss_track=np.array([])
test_loss_track=np.array([])
acc_track=np.array([])

In [13]:
# Training loop
for epoch in tqdm(range(epochs)):
    
    train_loss = 0
    model.train() 
    
    for data, labels in train_loader:
        optimizer.zero_grad()
        data, labels = data.to(device), labels.to(device)
        states = model(data)
        loss = criterion(states, labels.long())
        loss.backward()
        train_loss += loss.detach().cpu().numpy()
        optimizer.step()
        
    train_loss_track = np.append(train_loss_track,np.asarray(train_loss))
        
    test_loss = 0    
    model.eval()
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            print(data.shape)
            states = model(data)
            loss = criterion(states, labels.long())
            test_loss += loss.detach().cpu().numpy()
    
    test_loss_track = np.append(test_loss_track,np.asarray(test_loss))
    
    if epoch % plot_interval == 0:
        print('====> Epoch: {} Training loss: {:.6f}'.format(
                  epoch, train_loss ))
        print('====> Epoch: {} Test loss: {:.6f}'.format(
                  epoch, test_loss ))

        p1, p2 = 0, 800

    # Readout Fidelity
    model.eval() 
    cc = 0
    y_true = torch.tensor([]).to(device)
    y_pred = torch.tensor([]).to(device)
    
    with torch.no_grad():
        for data, target in test_loader:
            
            data=data.to(device)
            states = model(data)
            target = target.to(device) 
    
            val, ind = torch.max(states,1)
            y_pred = torch.cat((y_pred, ind), 0)
            y_true = torch.cat((y_true, target), 0)
    
    acc = y_true-y_pred
    accuracy = (len(y_true)-torch.count_nonzero(acc))/len(y_true)
    accuracy = accuracy.item()
    
    acc_track = np.append(acc_track,np.asarray(accuracy))
    print('Readout Fidelity: %', accuracy*100)

    if accuracy*100 > 90.:
        break


  1%|          | 1/100 [00:00<00:19,  5.21it/s]

torch.Size([1000, 2000])
====> Epoch: 0 Training loss: 0.693147
====> Epoch: 0 Test loss: 0.786075
Readout Fidelity: % 50.599998235702515


  2%|▏         | 2/100 [00:00<00:21,  4.53it/s]

torch.Size([1000, 2000])
====> Epoch: 1 Training loss: 0.693075
====> Epoch: 1 Test loss: 0.758394
Readout Fidelity: % 51.20000243186951
torch.Size([1000, 2000])
====> Epoch: 2 Training loss: 0.693566
====> Epoch: 2 Test loss: 0.721135


  4%|▍         | 4/100 [00:00<00:18,  5.08it/s]

Readout Fidelity: % 51.20000243186951
torch.Size([1000, 2000])
====> Epoch: 3 Training loss: 0.692677
====> Epoch: 3 Test loss: 0.717862
Readout Fidelity: % 51.099997758865356


  6%|▌         | 6/100 [00:01<00:18,  5.02it/s]

torch.Size([1000, 2000])
====> Epoch: 4 Training loss: 0.691816
====> Epoch: 4 Test loss: 0.713026
Readout Fidelity: % 51.20000243186951
torch.Size([1000, 2000])
====> Epoch: 5 Training loss: 0.692606
====> Epoch: 5 Test loss: 0.708284
Readout Fidelity: % 51.39999985694885


  7%|▋         | 7/100 [00:01<00:19,  4.69it/s]

torch.Size([1000, 2000])
====> Epoch: 6 Training loss: 0.691720
====> Epoch: 6 Test loss: 0.697002
Readout Fidelity: % 51.499998569488525
torch.Size([1000, 2000])
====> Epoch: 7 Training loss: 0.690439
====> Epoch: 7 Test loss: 0.696179


  8%|▊         | 8/100 [00:01<00:19,  4.84it/s]

Readout Fidelity: % 51.5999972820282


  9%|▉         | 9/100 [00:01<00:19,  4.59it/s]

torch.Size([1000, 2000])
====> Epoch: 8 Training loss: 0.690334
====> Epoch: 8 Test loss: 0.695806
Readout Fidelity: % 51.5999972820282
torch.Size([1000, 2000])
====> Epoch: 9 Training loss: 0.689602
====> Epoch: 9 Test loss: 0.686927


 11%|█         | 11/100 [00:02<00:17,  4.96it/s]

Readout Fidelity: % 52.399998903274536
torch.Size([1000, 2000])
====> Epoch: 10 Training loss: 0.688047
====> Epoch: 10 Test loss: 0.683300
Readout Fidelity: % 52.30000019073486


 12%|█▏        | 12/100 [00:02<00:18,  4.67it/s]

torch.Size([1000, 2000])
====> Epoch: 11 Training loss: 0.686064
====> Epoch: 11 Test loss: 0.682751
Readout Fidelity: % 53.200000524520874
torch.Size([1000, 2000])
====> Epoch: 12 Training loss: 0.684200
====> Epoch: 12 Test loss: 0.678329


 13%|█▎        | 13/100 [00:02<00:18,  4.82it/s]

Readout Fidelity: % 53.1000018119812


 14%|█▍        | 14/100 [00:02<00:18,  4.64it/s]

torch.Size([1000, 2000])
====> Epoch: 13 Training loss: 0.684312
====> Epoch: 13 Test loss: 0.670483
Readout Fidelity: % 54.19999957084656
torch.Size([1000, 2000])
====> Epoch: 14 Training loss: 0.682672
====> Epoch: 14 Test loss: 0.666486


 15%|█▌        | 15/100 [00:03<00:17,  4.83it/s]

Readout Fidelity: % 54.500001668930054
torch.Size([1000, 2000])
====> Epoch: 15 Training loss: 0.679000
====> Epoch: 15 Test loss: 0.661972


 17%|█▋        | 17/100 [00:03<00:17,  4.81it/s]

Readout Fidelity: % 54.90000247955322
torch.Size([1000, 2000])
====> Epoch: 16 Training loss: 0.677140
====> Epoch: 16 Test loss: 0.664092
Readout Fidelity: % 55.59999942779541


 18%|█▊        | 18/100 [00:03<00:16,  4.95it/s]

torch.Size([1000, 2000])
====> Epoch: 17 Training loss: 0.675673
====> Epoch: 17 Test loss: 0.658145
Readout Fidelity: % 56.00000023841858


 19%|█▉        | 19/100 [00:03<00:17,  4.72it/s]

torch.Size([1000, 2000])
====> Epoch: 18 Training loss: 0.675451
====> Epoch: 18 Test loss: 0.650819
Readout Fidelity: % 56.90000057220459
torch.Size([1000, 2000])
====> Epoch: 19 Training loss: 0.673156
====> Epoch: 19 Test loss: 0.647348


 20%|██        | 20/100 [00:04<00:16,  4.89it/s]

Readout Fidelity: % 56.99999928474426


 21%|██        | 21/100 [00:04<00:16,  4.66it/s]

torch.Size([1000, 2000])
====> Epoch: 20 Training loss: 0.670364
====> Epoch: 20 Test loss: 0.652934
Readout Fidelity: % 57.200002670288086
torch.Size([1000, 2000])
====> Epoch: 21 Training loss: 0.666726
====> Epoch: 21 Test loss: 0.649388


 23%|██▎       | 23/100 [00:04<00:15,  4.94it/s]

Readout Fidelity: % 58.49999785423279
torch.Size([1000, 2000])
====> Epoch: 22 Training loss: 0.664705
====> Epoch: 22 Test loss: 0.643638
Readout Fidelity: % 59.700000286102295


 24%|██▍       | 24/100 [00:05<00:16,  4.72it/s]

torch.Size([1000, 2000])
====> Epoch: 23 Training loss: 0.659802
====> Epoch: 23 Test loss: 0.632839
Readout Fidelity: % 60.100001096725464
torch.Size([1000, 2000])
====> Epoch: 24 Training loss: 0.655866
====> Epoch: 24 Test loss: 0.631387


 25%|██▌       | 25/100 [00:05<00:15,  4.88it/s]

Readout Fidelity: % 60.50000190734863


 26%|██▌       | 26/100 [00:05<00:15,  4.63it/s]

torch.Size([1000, 2000])
====> Epoch: 25 Training loss: 0.650889
====> Epoch: 25 Test loss: 0.623159
Readout Fidelity: % 62.199997901916504
torch.Size([1000, 2000])
====> Epoch: 26 Training loss: 0.646404
====> Epoch: 26 Test loss: 0.628787


 27%|██▋       | 27/100 [00:05<00:15,  4.78it/s]

Readout Fidelity: % 61.900001764297485


 28%|██▊       | 28/100 [00:05<00:15,  4.54it/s]

torch.Size([1000, 2000])
====> Epoch: 27 Training loss: 0.642336
====> Epoch: 27 Test loss: 0.624748
Readout Fidelity: % 62.40000128746033
torch.Size([1000, 2000])
====> Epoch: 28 Training loss: 0.639443
====> Epoch: 28 Test loss: 0.610805


 30%|███       | 30/100 [00:06<00:14,  4.87it/s]

Readout Fidelity: % 63.89999985694885
torch.Size([1000, 2000])
====> Epoch: 29 Training loss: 0.635496
====> Epoch: 29 Test loss: 0.612657
Readout Fidelity: % 64.3999993801117


 31%|███       | 31/100 [00:06<00:14,  4.60it/s]

torch.Size([1000, 2000])
====> Epoch: 30 Training loss: 0.630490
====> Epoch: 30 Test loss: 0.613406
Readout Fidelity: % 65.79999923706055
torch.Size([1000, 2000])
====> Epoch: 31 Training loss: 0.624767
====> Epoch: 31 Test loss: 0.610681


 32%|███▏      | 32/100 [00:06<00:14,  4.75it/s]

Readout Fidelity: % 65.49999713897705


 33%|███▎      | 33/100 [00:06<00:14,  4.51it/s]

torch.Size([1000, 2000])
====> Epoch: 32 Training loss: 0.620217
====> Epoch: 32 Test loss: 0.602844
Readout Fidelity: % 66.50000214576721
torch.Size([1000, 2000])
====> Epoch: 33 Training loss: 0.615964
====> Epoch: 33 Test loss: 0.608792


 34%|███▍      | 34/100 [00:07<00:13,  4.74it/s]

Readout Fidelity: % 67.00000166893005


 35%|███▌      | 35/100 [00:07<00:14,  4.53it/s]

torch.Size([1000, 2000])
====> Epoch: 34 Training loss: 0.609395
====> Epoch: 34 Test loss: 0.617131
Readout Fidelity: % 67.1999990940094


 36%|███▌      | 36/100 [00:07<00:14,  4.46it/s]

torch.Size([1000, 2000])
====> Epoch: 35 Training loss: 0.604919
====> Epoch: 35 Test loss: 0.613537
Readout Fidelity: % 67.69999861717224
torch.Size([1000, 2000])
====> Epoch: 36 Training loss: 0.603254
====> Epoch: 36 Test loss: 0.614490


 37%|███▋      | 37/100 [00:07<00:13,  4.68it/s]

Readout Fidelity: % 67.69999861717224


 38%|███▊      | 38/100 [00:08<00:13,  4.51it/s]

torch.Size([1000, 2000])
====> Epoch: 37 Training loss: 0.594692
====> Epoch: 37 Test loss: 0.621835
Readout Fidelity: % 67.5000011920929
torch.Size([1000, 2000])
====> Epoch: 38 Training loss: 0.590340
====> Epoch: 38 Test loss: 0.609625


 39%|███▉      | 39/100 [00:08<00:12,  4.72it/s]

Readout Fidelity: % 68.00000071525574


 40%|████      | 40/100 [00:08<00:13,  4.46it/s]

torch.Size([1000, 2000])
====> Epoch: 39 Training loss: 0.586290
====> Epoch: 39 Test loss: 0.623161
Readout Fidelity: % 68.00000071525574
torch.Size([1000, 2000])
====> Epoch: 40 Training loss: 0.583956
====> Epoch: 40 Test loss: 0.626856


 42%|████▏     | 42/100 [00:08<00:11,  4.84it/s]

Readout Fidelity: % 68.4000015258789
torch.Size([1000, 2000])
====> Epoch: 41 Training loss: 0.575143
====> Epoch: 41 Test loss: 0.617529
Readout Fidelity: % 68.99999976158142


 43%|████▎     | 43/100 [00:09<00:12,  4.63it/s]

torch.Size([1000, 2000])
====> Epoch: 42 Training loss: 0.572458
====> Epoch: 42 Test loss: 0.622727
Readout Fidelity: % 68.69999766349792
torch.Size([1000, 2000])
====> Epoch: 43 Training loss: 0.567638
====> Epoch: 43 Test loss: 0.607966


 44%|████▍     | 44/100 [00:09<00:11,  4.77it/s]

Readout Fidelity: % 69.0999984741211


 45%|████▌     | 45/100 [00:09<00:12,  4.56it/s]

torch.Size([1000, 2000])
====> Epoch: 44 Training loss: 0.564447
====> Epoch: 44 Test loss: 0.617284
Readout Fidelity: % 69.40000057220459
torch.Size([1000, 2000])


 46%|████▌     | 46/100 [00:09<00:11,  4.70it/s]

====> Epoch: 45 Training loss: 0.564424
====> Epoch: 45 Test loss: 0.621460
Readout Fidelity: % 68.90000104904175


 47%|████▋     | 47/100 [00:09<00:11,  4.49it/s]

torch.Size([1000, 2000])
====> Epoch: 46 Training loss: 0.557580
====> Epoch: 46 Test loss: 0.615825
Readout Fidelity: % 69.30000185966492
torch.Size([1000, 2000])
====> Epoch: 47 Training loss: 0.551972
====> Epoch: 47 Test loss: 0.611431


 49%|████▉     | 49/100 [00:10<00:10,  4.83it/s]

Readout Fidelity: % 69.70000267028809
torch.Size([1000, 2000])
====> Epoch: 48 Training loss: 0.548431
====> Epoch: 48 Test loss: 0.615599
Readout Fidelity: % 70.3000009059906


 51%|█████     | 51/100 [00:10<00:10,  4.82it/s]

torch.Size([1000, 2000])
====> Epoch: 49 Training loss: 0.548829
====> Epoch: 49 Test loss: 0.617353
Readout Fidelity: % 70.09999752044678
torch.Size([1000, 2000])
====> Epoch: 50 Training loss: 0.546483
====> Epoch: 50 Test loss: 0.620916
Readout Fidelity: % 70.20000219345093


 52%|█████▏    | 52/100 [00:11<00:10,  4.55it/s]

torch.Size([1000, 2000])
====> Epoch: 51 Training loss: 0.545075
====> Epoch: 51 Test loss: 0.605044
Readout Fidelity: % 70.20000219345093
torch.Size([1000, 2000])
====> Epoch: 52 Training loss: 0.541948
====> Epoch: 52 Test loss: 0.600252


 53%|█████▎    | 53/100 [00:11<00:09,  4.73it/s]

Readout Fidelity: % 70.59999704360962


 54%|█████▍    | 54/100 [00:11<00:10,  4.54it/s]

torch.Size([1000, 2000])
====> Epoch: 53 Training loss: 0.532577
====> Epoch: 53 Test loss: 0.604166
Readout Fidelity: % 71.29999995231628
torch.Size([1000, 2000])


 55%|█████▌    | 55/100 [00:11<00:09,  4.70it/s]

====> Epoch: 54 Training loss: 0.528924
====> Epoch: 54 Test loss: 0.605084
Readout Fidelity: % 70.70000171661377
torch.Size([1000, 2000])
====> Epoch: 55 Training loss: 0.524742
====> Epoch: 55 Test loss: 0.596527


 56%|█████▌    | 56/100 [00:11<00:09,  4.82it/s]

Readout Fidelity: % 71.20000123977661


 57%|█████▋    | 57/100 [00:12<00:09,  4.60it/s]

torch.Size([1000, 2000])
====> Epoch: 56 Training loss: 0.519231
====> Epoch: 56 Test loss: 0.600195
Readout Fidelity: % 70.99999785423279
torch.Size([1000, 2000])
====> Epoch: 57 Training loss: 0.518727
====> Epoch: 57 Test loss: 0.597944


 58%|█████▊    | 58/100 [00:12<00:08,  4.79it/s]

Readout Fidelity: % 71.39999866485596


 59%|█████▉    | 59/100 [00:12<00:09,  4.54it/s]

torch.Size([1000, 2000])
====> Epoch: 58 Training loss: 0.515642
====> Epoch: 58 Test loss: 0.594377
Readout Fidelity: % 71.20000123977661
torch.Size([1000, 2000])


 60%|██████    | 60/100 [00:12<00:08,  4.69it/s]

====> Epoch: 59 Training loss: 0.514705
====> Epoch: 59 Test loss: 0.589449
Readout Fidelity: % 71.29999995231628
torch.Size([1000, 2000])
====> Epoch: 60 Training loss: 0.510471
====> Epoch: 60 Test loss: 0.590097


 62%|██████▏   | 62/100 [00:13<00:08,  4.68it/s]

Readout Fidelity: % 71.39999866485596
torch.Size([1000, 2000])
====> Epoch: 61 Training loss: 0.511862
====> Epoch: 61 Test loss: 0.583693
Readout Fidelity: % 70.99999785423279


 63%|██████▎   | 63/100 [00:13<00:07,  4.82it/s]

torch.Size([1000, 2000])
====> Epoch: 62 Training loss: 0.509804
====> Epoch: 62 Test loss: 0.588588
Readout Fidelity: % 70.80000042915344


 64%|██████▍   | 64/100 [00:13<00:07,  4.58it/s]

torch.Size([1000, 2000])
====> Epoch: 63 Training loss: 0.510533
====> Epoch: 63 Test loss: 0.597982
Readout Fidelity: % 70.59999704360962
torch.Size([1000, 2000])
====> Epoch: 64 Training loss: 0.504520
====> Epoch: 64 Test loss: 0.599955


 65%|██████▌   | 65/100 [00:13<00:07,  4.73it/s]

Readout Fidelity: % 71.60000205039978


 66%|██████▌   | 66/100 [00:14<00:07,  4.50it/s]

torch.Size([1000, 2000])
====> Epoch: 65 Training loss: 0.506691
====> Epoch: 65 Test loss: 0.595754
Readout Fidelity: % 71.39999866485596
torch.Size([1000, 2000])
====> Epoch: 66 Training loss: 0.503329
====> Epoch: 66 Test loss: 0.592337


 68%|██████▊   | 68/100 [00:14<00:06,  4.83it/s]

Readout Fidelity: % 71.29999995231628
torch.Size([1000, 2000])
====> Epoch: 67 Training loss: 0.497278
====> Epoch: 67 Test loss: 0.595782
Readout Fidelity: % 70.59999704360962


 69%|██████▉   | 69/100 [00:14<00:06,  4.59it/s]

torch.Size([1000, 2000])
====> Epoch: 68 Training loss: 0.500188
====> Epoch: 68 Test loss: 0.607505
Readout Fidelity: % 71.60000205039978
torch.Size([1000, 2000])


 70%|███████   | 70/100 [00:14<00:06,  4.70it/s]

====> Epoch: 69 Training loss: 0.502008
====> Epoch: 69 Test loss: 0.596718
Readout Fidelity: % 71.39999866485596


 71%|███████   | 71/100 [00:15<00:06,  4.49it/s]

torch.Size([1000, 2000])
====> Epoch: 70 Training loss: 0.503621
====> Epoch: 70 Test loss: 0.592096
Readout Fidelity: % 71.20000123977661
torch.Size([1000, 2000])


 72%|███████▏  | 72/100 [00:15<00:06,  4.63it/s]

====> Epoch: 71 Training loss: 0.506085
====> Epoch: 71 Test loss: 0.597515
Readout Fidelity: % 71.39999866485596


 73%|███████▎  | 73/100 [00:15<00:06,  4.44it/s]

torch.Size([1000, 2000])
====> Epoch: 72 Training loss: 0.504987
====> Epoch: 72 Test loss: 0.617229
Readout Fidelity: % 70.89999914169312
torch.Size([1000, 2000])
====> Epoch: 73 Training loss: 0.498659
====> Epoch: 73 Test loss: 0.592201


 75%|███████▌  | 75/100 [00:15<00:05,  4.78it/s]

Readout Fidelity: % 70.59999704360962
torch.Size([1000, 2000])
====> Epoch: 74 Training loss: 0.491950
====> Epoch: 74 Test loss: 0.591008
Readout Fidelity: % 71.79999947547913


 76%|███████▌  | 76/100 [00:16<00:05,  4.54it/s]

torch.Size([1000, 2000])
====> Epoch: 75 Training loss: 0.488398
====> Epoch: 75 Test loss: 0.598566
Readout Fidelity: % 72.00000286102295
torch.Size([1000, 2000])
====> Epoch: 76 Training loss: 0.484758
====> Epoch: 76 Test loss: 0.599707


 77%|███████▋  | 77/100 [00:16<00:04,  4.74it/s]

Readout Fidelity: % 72.60000109672546


 79%|███████▉  | 79/100 [00:16<00:04,  4.77it/s]

torch.Size([1000, 2000])
====> Epoch: 77 Training loss: 0.479408
====> Epoch: 77 Test loss: 0.588624
Readout Fidelity: % 72.2000002861023
torch.Size([1000, 2000])
====> Epoch: 78 Training loss: 0.476448
====> Epoch: 78 Test loss: 0.603593
Readout Fidelity: % 72.00000286102295


 80%|████████  | 80/100 [00:17<00:04,  4.57it/s]

torch.Size([1000, 2000])
====> Epoch: 79 Training loss: 0.475086
====> Epoch: 79 Test loss: 0.604122
Readout Fidelity: % 70.99999785423279
torch.Size([1000, 2000])
====> Epoch: 80 Training loss: 0.476076
====> Epoch: 80 Test loss: 0.603094


 82%|████████▏ | 82/100 [00:17<00:03,  4.82it/s]

Readout Fidelity: % 69.90000009536743
torch.Size([1000, 2000])
====> Epoch: 81 Training loss: 0.479656
====> Epoch: 81 Test loss: 0.576955
Readout Fidelity: % 70.3000009059906


 83%|████████▎ | 83/100 [00:17<00:03,  4.56it/s]

torch.Size([1000, 2000])
====> Epoch: 82 Training loss: 0.474433
====> Epoch: 82 Test loss: 0.599745
Readout Fidelity: % 70.49999833106995
torch.Size([1000, 2000])
====> Epoch: 83 Training loss: 0.472044
====> Epoch: 83 Test loss: 0.601784


 84%|████████▍ | 84/100 [00:17<00:03,  4.74it/s]

Readout Fidelity: % 69.49999928474426


 85%|████████▌ | 85/100 [00:18<00:03,  4.55it/s]

torch.Size([1000, 2000])
====> Epoch: 84 Training loss: 0.472583
====> Epoch: 84 Test loss: 0.602270
Readout Fidelity: % 70.59999704360962
torch.Size([1000, 2000])
====> Epoch: 85 Training loss: 0.473469
====> Epoch: 85 Test loss: 0.595614


 87%|████████▋ | 87/100 [00:18<00:02,  4.85it/s]

Readout Fidelity: % 70.09999752044678
torch.Size([1000, 2000])
====> Epoch: 86 Training loss: 0.466836
====> Epoch: 86 Test loss: 0.608186
Readout Fidelity: % 70.49999833106995


 88%|████████▊ | 88/100 [00:18<00:02,  4.60it/s]

torch.Size([1000, 2000])
====> Epoch: 87 Training loss: 0.460413
====> Epoch: 87 Test loss: 0.618493
Readout Fidelity: % 70.3000009059906
torch.Size([1000, 2000])
====> Epoch: 88 Training loss: 0.458652
====> Epoch: 88 Test loss: 0.601627


 89%|████████▉ | 89/100 [00:18<00:02,  4.77it/s]

Readout Fidelity: % 71.49999737739563


 90%|█████████ | 90/100 [00:19<00:02,  4.57it/s]

torch.Size([1000, 2000])
====> Epoch: 89 Training loss: 0.450966
====> Epoch: 89 Test loss: 0.594449
Readout Fidelity: % 71.49999737739563
torch.Size([1000, 2000])
====> Epoch: 90 Training loss: 0.450296
====> Epoch: 90 Test loss: 0.606179


 91%|█████████ | 91/100 [00:19<00:01,  4.74it/s]

Readout Fidelity: % 71.20000123977661


 92%|█████████▏| 92/100 [00:19<00:01,  4.54it/s]

torch.Size([1000, 2000])
====> Epoch: 91 Training loss: 0.450508
====> Epoch: 91 Test loss: 0.607974
Readout Fidelity: % 72.69999980926514
torch.Size([1000, 2000])
====> Epoch: 92 Training loss: 0.438752
====> Epoch: 92 Test loss: 0.618964


 94%|█████████▍| 94/100 [00:20<00:01,  4.89it/s]

Readout Fidelity: % 72.29999899864197
torch.Size([1000, 2000])
====> Epoch: 93 Training loss: 0.437869
====> Epoch: 93 Test loss: 0.610165
Readout Fidelity: % 72.00000286102295


 95%|█████████▌| 95/100 [00:20<00:01,  4.68it/s]

torch.Size([1000, 2000])
====> Epoch: 94 Training loss: 0.438353
====> Epoch: 94 Test loss: 0.619009
Readout Fidelity: % 71.10000252723694
torch.Size([1000, 2000])


 96%|█████████▌| 96/100 [00:20<00:00,  4.78it/s]

====> Epoch: 95 Training loss: 0.433903
====> Epoch: 95 Test loss: 0.614676
Readout Fidelity: % 70.99999785423279


 97%|█████████▋| 97/100 [00:20<00:00,  4.56it/s]

torch.Size([1000, 2000])
====> Epoch: 96 Training loss: 0.428104
====> Epoch: 96 Test loss: 0.602814
Readout Fidelity: % 70.99999785423279
torch.Size([1000, 2000])
====> Epoch: 97 Training loss: 0.425982
====> Epoch: 97 Test loss: 0.607206


 98%|█████████▊| 98/100 [00:20<00:00,  4.77it/s]

Readout Fidelity: % 70.89999914169312


 99%|█████████▉| 99/100 [00:21<00:00,  4.56it/s]

torch.Size([1000, 2000])
====> Epoch: 98 Training loss: 0.428136
====> Epoch: 98 Test loss: 0.615201
Readout Fidelity: % 70.20000219345093
torch.Size([1000, 2000])


100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

====> Epoch: 99 Training loss: 0.428824
====> Epoch: 99 Test loss: 0.625968
Readout Fidelity: % 69.90000009536743





In [15]:
torch.save(model.state_dict(), f"../../checkpoints/{framework}_ckp_a8.pth") 

In [16]:
if framework.lower() == "hawq": 
    manager = ExportManager(model)
    manager.export(torch.randn(1, 2000), "../../checkpoints/hawq_a8.onnx")
elif framework.lower() == "brevitas":
    QONNXManager.export(model, input_shape=(1, 2000), export_path='checkpoints/brevitas.onnx')

Exporting model...
Optimizing...
Model saved to: ../../checkpoints/hawq_a8.onnx


