Imports:

In [1]:
import sys
sys.path.append("../")
sys.path.append("spsa/")

import numpy as np
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import qiskit 
from qiskit_algorithms.optimizers import SPSA
import models
from models import CNN_Simple as CNN_Simple
import torchvision
import torchvision.transforms as transforms

Global Variables:

In [2]:
classes = ("T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
        "Sandal", "Shirt", "Sneaker", "Bag", "Ankle Boot")

transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

training_set = torchvision.datasets.FashionMNIST("./data", train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST("./data", train=False, transform=transform, download=True)
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

data = training_set.data.type(torch.double)[:, None, :, :]
target = training_set.targets

Code:

In [3]:
def load_params(model, param_tensor):
    index = 0
    for param in model.parameters():
        param_length = param.numel()
        param.data = param_tensor[index:index + param_length].view(param.size())
        index += param_length

In [4]:
criterion = torch.nn.CrossEntropyLoss()
def batch_loss(batch_input, batch_output):
  def loss_f(weights):
    model = CNN_Simple()
    load_params(model, torch.tensor(weights))
    prediction = model.forward(batch_input)
    return criterion(prediction, batch_output.type(torch.long)).item()
  
  return loss_f

In [5]:
class ShowSteps:
  def __init__(self):
    self.losses = []

  def callback(self, n, weights, loss_f, step_size, accepted):
    self.losses.append(loss_f)

In [6]:
steps = ShowSteps()
model = CNN_Simple()
param_list = np.random.randn(sum(p.numel() for p in model.parameters())).astype(np.float32)
param_tensor = torch.tensor(param_list)
losses = []

for batch_input, batch_output in training_loader:
  our_spsa = SPSA(maxiter=1, learning_rate=0.001, perturbation=0.1, callback=steps.callback)
  result = our_spsa.minimize(batch_loss(batch_input.type(torch.double), batch_output.type(torch.double)), param_tensor)
  param_tensor = result.x
  losses.append(result.fun)
  print(losses[-1])
  #print("Loss =", param_tensor)

6175159.572779831
54968703633107.58
1.039641634880008e+36
1.3788189349041631e+101
1.6592746537978438e+101
1.3443412180249783e+101
1.4149889579264777e+101
1.4723744270580092e+101
1.5215864173537083e+101
1.4453600533222943e+101
1.5183521590835096e+101
1.1173321447006756e+101
1.0909325015864397e+101
1.4840454027766293e+101
1.7170641281526396e+101
1.980291702983017e+101
2.131301054225362e+101
1.2914075031975265e+101
2.0727274990444485e+101
7.721365492158521e+100
1.24447389335083e+101
2.1381449145968622e+101
1.4148036665553523e+101
1.4369615006725747e+101
1.3492669460517938e+101
1.23378661167178e+101
1.473722349818859e+101
6.085273708489832e+100
1.769173161014073e+101
1.4916423551269378e+101
9.323548155039694e+100
6.884420621869938e+100
1.5378268354933235e+101
1.4101888823024757e+101
1.0521866883445724e+101
9.359381829088553e+100
1.610348980681016e+101
1.8390916747678518e+101
1.7178386807635783e+101
1.5605540584001416e+101
1.983568106197484e+101


KeyboardInterrupt: 