In [1]:

import torch
import torch.nn as nn
import random
from torch.utils.data.dataset import TensorDataset
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import os
import numpy as np
import pandas as pd
from os.path import join
from typing import Optional
from tqdm.notebook import tqdm
from scipy.stats import wasserstein_distance 
from matplotlib import pyplot as plt
import threading, concurrent
import glob
import numpy as np
import matplotlib.pyplot as plt
import sys
from qiskit import QuantumCircuit, QuantumRegister
from qiskit import Aer, execute

# Dataset

In [2]:

sys.path.append('../')

class DatasetImages(Dataset): #
    def __init__(self, data_type: str = 'train'):
        with open('../data/images.npy', 'rb') as f:
            images = np.load(f)
        with open('../data/labels.npy', 'rb') as f:
            labels = np.load(f)

        # Scale the images to be between 0 and 2pi TODO: Chcek
        images = (images - images.min()) / (images.max() - images.min()) * 2 * np.pi

        # Convert to numpy
        images = np.array(images, dtype=np.float32)
        labels = np.array(labels, dtype=np.float32)


        # images shape must be (n_samples, -1)
        images = images.reshape(images.shape[0], -1)


        # Split the data into train and validation
        n_train = int(0.8 * len(images))
        n_val = len(images) - n_train

        if data_type == 'train':
            images = images[:n_train]
            labels = labels[:n_train]
        elif data_type == 'val':
            images = images[n_train:]
            labels = labels[n_train:]

        self.n_samples = len(images)
        self.images = images
        self.labels = labels


        # Print shapes DEBUG
        print(f"Type: {data_type}")
        print(f"Images shape: {self.images.shape}")
        print(f"Labels shape: {self.labels.shape}")



    def __getitem__(self, index):
        return torch.from_numpy(self.images)[index], torch.from_numpy(self.labels)[index]

    def __len__(self):
        return self.n_samples
        
    def all_items(self):
        return torch.from_numpy(self.images), torch.from_numpy(self.labels)


## Encoder

In [3]:
def encoder(images):

        output = []
        for image in images:
            # Initialize QC
            q = QuantumRegister(16)
            circuit = QuantumCircuit(q)

            # Angle embedding
            # Each data point is mapped to a rotation
            # the rotation will alternate between the x, y and z axis
            # Essentially each qubit will have the data of 3 pixels

            total_features = 0 # Total features processed

            while total_features < len(image):
                # Lazy way to do it
                for i in range(16):
                    try:
                        circuit.rx(image[total_features], q[i])
                    except:
                        total_features += 3
                        break
                    
                    try:
                        circuit.ry(image[total_features+1], q[i])
                    except:
                        total_features += 3
                        break

                    try:
                        circuit.rz(image[total_features+2], q[i])
                    except:
                        total_features += 3
                        break
                    total_features += 3

            # Measure and return the state
            circuit.measure_all()

            backend = Aer.get_backend('qasm_simulator')
            job = execute(circuit, backend, shots=1000)
            result = job.result()
            counts = result.get_counts(circuit)

            # Get the state with the biggest probability
            state = max(counts, key=counts.get)
            output.append(torch.tensor([int(i) for i in state], dtype=torch.float32)) # binary state


        output = torch.stack(output)
        #print("QC Output", output.shape)
        return output

## Decoder

In [4]:
#-> Decoder
# Will receive an hidden state of 16 dimensions
# will output a 28x28 image
n_layers_encoder = 3
layers = []
hidden_size = 16

# Input layer
layers.append(nn.Linear(hidden_size, 16))
layers.append(nn.ReLU())

# Hidden layers
for i in range(n_layers_encoder):
    layers.append(nn.Linear(16, 16))
    layers.append(nn.ReLU())

# Output layer
layers.append(nn.Linear(16, 28*28))
layers.append(nn.Sigmoid())

decoder = nn.Sequential(*layers)

## Training stuff

In [5]:
def decode(x):
    # Pass through encoder
    return decoder(x)

def forward(x):
    # Pass through encoder
    image_embedding = encoder(x)
    
    # Pass through decoder
    output = decoder(image_embedding)

    return output

def training_step(batch, batch_idx):
    x, _ = batch
    
    # Pass
    output = forward(x)

    # Loss
    loss_criterion = nn.MSELoss()

    # Print x and output shapes
    #print("x", x.shape)
    #print("output", output.shape)

    loss = loss_criterion(x, output)

    if loss.isnan().any():
        raise KeyboardInterrupt

    return loss

In [6]:
# Define optimizer
optimizer = Adam(decoder.parameters(), lr=1e-3)

In [7]:
# Define dataset
dataset = DatasetImages('train')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Type: train
Images shape: (1600, 784)
Labels shape: (1600,)


In [8]:
val_dataset = DatasetImages('val').all_items()

Type: val
Images shape: (400, 784)
Labels shape: (400,)


In [9]:
# Train
best_loss = 1e10
for epoch in range(100):
    for batch_idx, batch in enumerate(dataloader):
        loss = training_step(batch, batch_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # benchmark on validation set
    with torch.no_grad():
        val_loss = training_step(val_dataset, 0)
        print(f'Epoch {epoch} - Val Loss: {val_loss.item()}')

        if val_loss.item() < best_loss:
            best_loss = val_loss.item()
            torch.save(decoder.state_dict(), 'best_model.pt')

    print(f'Epoch {epoch} - Loss: {loss.item()}')

Epoch 0 - Val Loss: 8.987215042114258
Epoch 0 - Loss: 7.314968109130859
