In [1]:
import torch
from torch import nn
import torchvision.transforms as T
import torch.utils.data as data

import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import glob

# Quantum inspired computer vision

reimagining 3 color channels as information for 8-state qudits

In [2]:
root_directory = '/home/shiva/data/StockImages/'
image_paths = glob.glob(root_directory + '*.jpg')

In [3]:
transforms = T.Compose([
    T.ToTensor(),
    T.Resize((256, 256), antialias=False)
])

In [4]:
class Dataset(data.Dataset):
    def __init__(self, root_directory, transforms=T.ToTensor(), nchannels=3):
        self.images = glob.glob(root_directory + '*.jpg')
        self.transforms = transforms
        self.nchannels = nchannels
        self.bases = self.__make_bases__()
    
    def __new_order_of_magnitude__(self, size):
        zeros = torch.zeros((size, 1))
        ones = torch.ones((size, 1))
        return torch.cat((zeros, ones), dim=0)
    
    def __make_bases__(self):
        nstates = 2
        bases = torch.arange(nstates).reshape((nstates, 1))
        while bases.shape[1] < self.nchannels:
            new_order = self.__new_order_of_magnitude__(bases.shape[0])
            bases = torch.cat((bases, bases), dim=0)
            bases = torch.cat((new_order, bases), dim=1)
        return bases.type(torch.int)
    
    def __zeros_ones__(self, img):
        return torch.stack((img, 1-img), dim=0)

    def quantum_tensors(self, img):
        X = self.__zeros_ones__(img)
        Xq = torch.stack([
            X[b[0], :, :, :] + X[b[1], :, :, :] + X[b[2], :, :, :]
            for b in self.bases], dim=0
        )
        return Xq
    
    def __getitem__(self, index):
        img = cv.imread(self.images[index])
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        img = cv.resize(img, (256, 256))
        img = self.transforms(img)
        # img = self.quantum_tensors(img)
        img = self.__zeros_ones__(img)
        return img
    
    def __len__(self):
        return len(self.images)

In [5]:
dataset = Dataset(root_directory)
dataloader = data.DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True)
X = next(iter(dataloader))
print(X.shape, X.min(), X.max())

torch.Size([4, 2, 3, 256, 256]) tensor(0.) tensor(1.)


Looking to entangle states

Engangle 3 qubits per pixel

3d analog to CX + H gate

entangle pixels together

In [6]:
b2, b4 = dataset.bases[3], dataset.bases[5]
Xq2 = torch.stack([X[:, b2[0], :, :, :], X[:, b2[1], :, :, :], X[:, b2[2], :, :, :]], dim=1)
Xq4 = torch.stack([X[:, b4[0], :, :, :], X[:, b4[1], :, :, :], X[:, b4[2], :, :, :]], dim=1)

print((Xq2==Xq4).all())
print((Xq2.sum(dim=1)==Xq4.sum(dim=1)).all())

tensor(False)
tensor(True)


In [7]:
print(dataset.bases[3], dataset.bases[5])

tensor([0, 1, 1], dtype=torch.int32) tensor([1, 0, 1], dtype=torch.int32)
