# Classifying MNIST with a simple model and quantum embeddings

Inspired by:  https://www.kaggle.com/code/geekysaint/solving-mnist-using-pytorch

Useful imports

In [1]:
# for the Boson Sampler
import perceval as pcvl
#import perceval.providers.scaleway as scw

import torch
from math import comb

from typing import Iterable

from functools import lru_cache

# for the machine learning model
import torch
import torchvision ## Contains some utilities for working with the image data
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
import time

## Definition of the Boson Sampler

In [2]:
class BosonSampler:
    
    def __init__(self, m: int, n: int, postselect: int = None, session = None):
        self.m = m
        self.n = n
        assert n <= m, "Got more photons than modes, can only input 1 photon per mode"
        self.postselect = postselect or n
        self.session = session
    
    @property
    def nb_parameters(self):
        return self.m * (self.m - 1) - (self.m // 2)  # Doesn't count the last layer of PS as it doesn't change anything
    
    @property
    def nb_parameters_needed(self):
        return self.m * (self.m - 1)
    
    def create_circuit(self, parameters: Iterable[float] = None):
        if parameters is None:
            parameters = [p for i in range(self.m * (self.m - 1) // 2)
                            for p in [pcvl.P(f"phi_{2 * i}"), pcvl.P(f"phi_{2 * i + 1}")]]
        return pcvl.GenericInterferometer(self.m, lambda i: (pcvl.BS()
                                                             .add(0, pcvl.PS(parameters[2 * i]))
                                                             .add(0, pcvl.BS())
                                                             .add(0, pcvl.PS(parameters[2 * i + 1]))
                                                             )
                                          )
        
    def embed(self, t: torch.tensor, n_sample: int):
        """t is supposed to be normalized"""
        t = t.reshape(-1)  # TODO: check if this is a good way to do this
        if len(t) > self.nb_parameters:
            raise ValueError(f"Got too many parameters (got {len(t)}, maximum {self.nb_parameters})")
        z = torch.zeros(self.nb_parameters_needed - len(t))
        if len(z):
            t = torch.cat((t, z), 0)
            
        t = t * 2 * torch.pi
        
        res = self.run(t, n_sample)
        
        return self.translate_results(res)
        
    @property
    def embedding_size(self) -> int:
        # For thresholded output, this is the number of binary numbers having at least self.postselect 1s
        s = 0
        for k in range(self.postselect, self.n + 1):
            s += comb(self.m, k)
        return s
        
    def translate_results(self, res: pcvl.BSDistribution) -> torch.tensor:
        t = torch.zeros(self.embedding_size)
        
        # First, we generate a list of all possible output states
        state_list = self.generate_state_list()
        
        # Then we take the probabilities from the BSD in the order of the list
        for i, state in enumerate(state_list):
            t[i] = res[state]
            
        return t
        
    @lru_cache  # Always the same, no need to compute it each time
    def generate_state_list(self):
        res = []
        for k in range(self.postselect, self.n + 1):
            res += self._generate_state_list_k(k)
        
        return res
    
    def _generate_state_list_k(self, k):
        # generates all binary states of size self.m having k 1s
        return list(map(pcvl.BasicState, pcvl.utils.qmath.distinct_permutations(k * [1] + (self.m - k) * [0])))
        
        
    def prepare_processor(self, processor, parameters: Iterable[float]):
        processor.set_circuit(self.create_circuit(parameters))
        processor.min_detected_photons_filter(self.postselect)
        processor.thresholded_output(True)
        
        # Evenly spaces the photons
        input_state = self.m * [0]
        places = torch.linspace(0, self.m - 1, self.n)
        for photon in places:
            input_state[int(photon)] = 1
        input_state = pcvl.BasicState(input_state)
        
        processor.with_input(input_state)
        
    def run(self, parameters: Iterable[float], samples: int) -> pcvl.BSDistribution:
        """Samples and return the raw results"""
        if self.session is not None:
            proc = self.session.build_remote_processor()

        else:
            # Local simulation
            proc = pcvl.Processor("SLOS", self.m)

        self.prepare_processor(proc, parameters)

        sampler = pcvl.algorithm.Sampler(proc, max_shots_per_call=samples)
        res = sampler.probs(samples)
            
        return res["results"]

## Dataset

In [None]:
# if you need to download MNIST
dataset = MNIST(root = '/home/jupyter-pemeriau/scaleway_demo/mnist-data/', download = True)
print(f"Total length of dataset = {len(dataset)}")

# to load the useful dataset
mnist_dataset = MNIST(root = '/home/jupyter-pemeriau/scaleway_demo/mnist-data/', train = True, transform = transforms.ToTensor())
len_dataset = len(mnist_dataset)

# TODO: here, you can chose the proportions of the dataset to use for training and validation
train_split, val_split = 0.005, 0.0001
train_size, val_size = int(train_split*len_dataset), int(val_split*len_dataset)
not_used = len_dataset - train_size - val_size
train_dataset, val_dataset, not_used_dataset = random_split(mnist_dataset, [train_size, val_size ,not_used] )
print("length of Train Datasets: ", len(train_dataset))
print("length of Validation Datasets: ", len(val_dataset))

# definition of the dataloader, to process the data in the model
# here, we need a batch size of 1 to use the boson sampler
batch_size = 1
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size, shuffle = False)

## Definition of the model

In [4]:
# compute the accuracy of the model
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim = 1)
    return(torch.tensor(torch.sum(preds == labels).item()/ len(preds)))

In [5]:
class MnistModel(nn.Module):
    def __init__(self, device = 'cpu', use_quantum = False):
        super().__init__()
        input_size = 28 * 28
        num_classes = 10
        self.device = device
        self.use_quantum = use_quantum
        if self.use_quantum:
            input_size = 28 * 28 + 435 #considering 30 photons and 2 modes
        self.linear = nn.Linear(input_size, num_classes)
    
    def forward(self, xb, emb = None):
        xb = xb.reshape(-1, 784)
        if self.use_quantum and emb is not None:
            # concatenation of the embeddings and the input images
            xb = torch.cat((xb,emb),dim=1)
        out = self.linear(xb)
        return(out)
    
    def training_step(self, batch, emb = None):
        images, labels = batch
        images, labels = images.to(self.device), labels.to(self.device)
        if self.use_quantum:
            out = self(images, emb.to(self.device)) ## Generate predictions
        else:
            out = self(images) ## Generate predictions
        loss = F.cross_entropy(out, labels) ## Calculate the loss
        acc = accuracy(out, labels)
        return loss, acc
    
    def validation_step(self, batch, emb =None):
        images, labels = batch
        images, labels = images.to(self.device), labels.to(self.device)
        if self.use_quantum:
            out = self(images, emb.to(self.device)) ## Generate predictions
        else:
            out = self(images) ## Generate predictions
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)
        return({'val_loss':loss, 'val_acc': acc})
    
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        return({'val_loss': epoch_loss.item(), 'val_acc' : epoch_acc.item()})
    
    def epoch_end(self, epoch,result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
        return result['val_loss'], result['val_acc']


In [6]:
# evaluation of the model
def evaluate(model, val_loader):
    if model.use_quantum:
        outputs = []
        for step, batch in enumerate(tqdm(val_loader)):
            # embedding in the BS
            images, labs = batch
            images = images.squeeze(0).squeeze(0)
            t_s = time.time()
            embs = bs.embed(images,1000)
            outputs.append(model.validation_step(batch, emb=embs.unsqueeze(0)))
    else:
        outputs = [model.validation_step(batch) for batch in val_loader]
    #val_loss, val_acc = model.validation_epoch_end(outputs)
    return(model.validation_epoch_end(outputs))

In [7]:
def plot_training_metrics(train_acc,val_acc,train_loss,val_loss):
    fig, axes = plt.subplots(1,2,figsize = (15,5))
    X = [i for i in range(len(train_acc))]
    axes[0].plot(X,train_acc,label = 'training')
    axes[0].plot(X,val_acc,label = 'validation')
    axes[0].set_xlabel("Epochs")
    axes[0].set_ylabel("ACC")
    axes[0].set_title("Training and validation accuracies")
    axes[0].grid(visible = True)
    axes[0].legend()
    axes[1].plot(X,train_loss,label = 'training')
    axes[1].plot(X,val_loss,label = 'validation')
    axes[1].set_xlabel("Epochs")
    axes[1].set_ylabel("Loss")
    axes[1].set_title("Training and validation losses")
    axes[1].grid(visible = True)
    axes[1].legend()
    fig.savefig("training_curves.png")

In [8]:
# training loop
def fit(epochs, lr, model, train_loader, val_loader, opt_func = torch.optim.SGD):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    train_loss, train_acc, val_loss, val_acc = [], [], [], []
    for epoch in range(epochs):
        training_losses, training_accs = 0, 0
        ## Training Phas
        for step, batch in enumerate(tqdm(train_loader)):
            # embedding in the BS
            if model.use_quantum:
                images, labs = batch
                images = images.squeeze(0).squeeze(0)
                t_s = time.time()
                embs = bs.embed(images,1000)
                loss,acc = model.training_step(batch,emb = embs.unsqueeze(0))
            #print(f"Null elements {torch.isnan(embs).any()}")
            #print(f"embs of shape = {embs.shape} in {time.time()-t_s}")
            else:
                loss,acc = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            training_losses+=int(loss.detach())
            training_accs+=int(acc.detach())
            if step%100==0:
                print(f"STEP {step}, Training-acc = {training_accs/(step+1)}, Training-losses = {training_losses/(step+1)}")
        
        ## Validation phase
        result = evaluate(model, val_loader)
        validation_loss, validation_acc = result['val_loss'], result['val_acc']
        model.epoch_end(epoch, result)
        history.append(result)

        ## summing up all the training and validation metrics
        training_loss = training_losses/len(train_loader)
        training_accs = training_accs/len(train_loader)
        train_loss.append(training_loss)
        train_acc.append(training_accs)
        val_loss.append(validation_loss)
        val_acc.append(validation_acc)

        # plot training curves
        plot_training_metrics(train_acc,val_acc,train_loss,val_loss)
    return(history)


## Training

In [None]:
# definition of the BosonSampler
# here, we use 30 photons and 2 modes

session = None
# to run a remote session on Scaleway, uncomment the following and fill project_id and token
#session = scw.Session(
#                    platform="sim:sampling:p100",  # or sim:sampling:h100
#                    project_id=""  # Your project id,
#                    token=""  # Your personal API key
#                    )

# start session
if session:
    session.start()

In [None]:
bs = BosonSampler(30, 2, postselect = 2, session = None)
print(f"Boson sampler defined with number of parameters = {bs.nb_parameters}")

In [None]:
# define device to run the model
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'DEVICE = {device}')

In [12]:
# define the model and send it to the appropriate device
# set use_quantum = True if you want to use the boson sampler in input of the model
model = MnistModel(device= device, use_quantum = True)
model = model.to(device)

In [None]:
# train the model with the chosen parameters
experiment = fit(epochs = 2, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader)

In [None]:
# end session if needed
if session:
    session.stop()