In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torch.autograd import Function
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

# from torchsummary import summary

import qiskit
from qiskit.visualization import *
from qiskit.circuit.random import random_circuit

from itertools import combinations

### Resources:
 - https://github.com/PlanQK/variational-quanvolutional-neural-networks
 - https://qiskit.org/textbook/ch-machine-learning/machine-learning-qiskit-pytorch.html#quantumlayer
 - https://github.com/yh08037/quantum-neural-network/blob/master/model2-conv/quanv.ipynb
 - https://pytorch.org/docs/stable/notes/extending.html
 


### Create the quantum circuit

In [None]:
# @Julia
def get_quanv_fn(kernel_size=2, backend=None, shots=1024, threshold=127, ansatz=''):  # note: threshold and ansatz are just suggestions
    # Instantiate quantum circuit
    # create param vector 
    # create input param vector
    # apply appropriate gates
    
    
    def execute_qc(input_data, params):
        # bind data to circuit
        # execute
        # extract ouput expectations
        
        return results
    
    return execute_qc


### Create the Quanvolution Class with PyTorch

In [None]:
class QuanvFunction(Function):
    """ Variational Quanvolution function definition """
    
    @staticmethod
    def forward(ctx, input_data, params, quantum_circuit_fn):
        # forward pass of the quanvolutional function
        
        ctx.save_for_backwards(input_data, params)
        ctx.qc_fn = quantum_circuit_fn
        
        expectations = quantum_circuit_fn(input_data, params)
        result = torch.tensor([expectations])
        return result
        
    @staticmethod
    def backward(ctx, grad_output):
        #backwards pass of the quanvolutional function
        
        input_data, params = ctx.saved_tensors
        qc_fn = ctx.qc_fn
        
        # Gradients w.r.t each inputs to the function
        grad_input = grad_params = grad_qc = None
        
        # Compute gradients
        # @Tristan
        
        return grad_input, grad_params, grad_qc


class QuanvLayer(nn.Module):
    """ Quanvolution(Quantum convolution) layer definition """
    
    def __init__(self, in_channels, out_channels=4, kernel_size=2, stride=1,
                 backend=qiskit.Aer.get_backend('qasm_simulator'), shots=100):
        
        super(QuanvLayer, self).__init__()
        
        self.qc_fn = get_quanv_fn(kernel_size=kernel_size, backend=backend, shots=shots) # TODO: multiple circuits? analogue to multiple kernel CNN
                
        self.in_channels = in_channels
        self.out_channels = out_channels  # TODO: what do we do with out_channels - look at CNN
        self.kernel_size = kernel_size
        self.stride = stride
        
        self.parameters = nn.Parameter(torch.empty(self.get_parameter_shape()))
        nn.init.uniform_(self.parameters, -0.1, 0.1)

    def _get_parameter_shape(self):
        """Computes the number of trainable parameters required by the quantum circuit function"""
        
        # TODO: implement based on some ansatz specification (see get_quanv_fn) that is either provided
        # to the object or a global default
        
        return (4,)
        

    def _get_out_dim(self, img):
        bs, h, w, ch = img.size()
        h_out = (int(h) - self.kernel_size) // self.stride + 1
        w_out = (int(w) - self.kernel_size) // self.stride + 1
        return bs, h_out, w_out, self.out_channels


    def convolve(self, imgs):
        """Get input to circuit following a convolution pattern"""
        
        # @Robbie
    
        yield data, batch_idx, row, col


    def forward(self, imgs):
        """Apply variational quanvolution layer to image
        
        Parameters
        ----------
        imgs : np.ndarray
            A vector of input images. Should have shape [batch_size, height, width, n_channels].
        """
        
        out = torch.empty(self._get_out_dim(img), dtype=torch.float32)
        
        # @Robbie - iterate over image and apply function. Ex:
        for data, _, _, _ in self.convolve(img):
            res = QuanvFunction.apply(data, self.parameters, self.qc_fn)
            
            out[batch_idx, row // self.stride, col // self.stride] = res

        

### Define the Network Class

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quanv = Quanv(in_channels=1, out_channels=4, kernel_size=2)
        self.conv = nn.Conv2d(6, 16, kernel_size=5)
        self.dropout = nn.Dropout2d()
        self.fc1 = nn.Linear(256, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        #this is where we build our entire network
        #whatever layers of quanvolution, pooling,
        #convolution, dropout, flattening,
        #fully connectecd layers, go here
        return 0

### Build and Train the Model

In [None]:
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

epochs = 20
loss_list = []

model.train()
for epoch in range(epochs):
    total_loss = []
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Calculating loss
        loss = loss_func(output, target)
        
        # Backward pass
        loss.backward()
        
        # Optimize the weights
        optimizer.step()
        
        total_loss.append(loss.item())
    loss_list.append(sum(total_loss)/len(total_loss))
    print('Training [{:.0f}%]\tLoss: {:.4f}'.format(
        100. * (epoch + 1) / epochs, loss_list[-1]))