In [1]:
import torch
import numpy as np
from datatypes import torchcomplex, show
from gate_implementation import apply_gate, apply_sparse_gate, device
import time

In [2]:
def test_correctness():
    psi0 = torch.arange(8, device=device) + 1
    X = np.array([[0, 1], [1, 0]])
    Z = np.array([[1, 0], [0, -1]])
    XZ = torch.tensor(np.kron(X, Z)).to(torch.float64)
    XZ = torchcomplex(XZ).to(device)
    psi1 = apply_gate([1, 2], XZ, psi0)
    np.testing.assert_allclose(
        psi1.cpu(), torch.tensor([3, -4, 1, -2, 7, -8, 5, -6]).to(torch.float64).cpu(), atol=0.00001
    )
    print("Correctness test passed for dense")

    XZ_sparse=(
        [(0,2),(1,3),(2,0),(3,1)],
        torchcomplex(torch.tensor([1.0,-1.0,1.0,-1.0]))
    )
    psi2 = apply_sparse_gate([1, 2], XZ_sparse, psi0)
    np.testing.assert_allclose(
        psi2.cpu(), torch.tensor([3, -4, 1, -2, 7, -8, 5, -6]).to(torch.float64).cpu(), atol=0.00001
    )
    print("Correctness test passed for sparse")

In [3]:
class PerfTest:
    @staticmethod
    def time(n, positions, sparse=False):

        if device.type == "cuda":
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

        if isinstance(positions, int):
            positions = list(range(positions))

        k = len(positions)
        if sparse:
            matrix = PerfTest.get_k_qubit_sparse_Xk(k)
        else:
            matrix = PerfTest.get_k_qubit_dense_Hadamard(k)

        psi0 = PerfTest.get_state(n)
        if device.type == "cuda":
            start.record()
            PerfTest.apply_gate_to_qubits(positions, matrix, psi0, sparse=sparse)
            end.record()
            """https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964/10"""
            torch.cuda.synchronize()

        else:
            t0 = time.time()
            PerfTest.apply_gate_to_qubits(positions, matrix, psi0, sparse=sparse)
            t1 = time.time()

        if device.type == "cuda":
            s = start.elapsed_time(end) / 1000
        else:
            s = t1 - t0
            
        if sparse:
            print(f"{s:.4f} s for sparse {k} qubit X^k gate on {n} qubit state")
        else:
            print(f"{s:.4f} s for dense {k} qubit Hadamard gate on {n} qubit state")

        return s

    @staticmethod
    def apply_gate_to_qubits(positions, matrix, psi0, sparse=False):
        if sparse:
            return apply_sparse_gate(positions, matrix, psi0)
        else:
            return apply_gate(positions, matrix, psi0)
        
    @staticmethod
    def get_state(n):
        return torch.arange(2**n, device=device)
    
    @staticmethod
    def get_k_qubit_dense_Hadamard(k):
        X = torchcomplex(torch.tensor([[1, 1], [1, -1]]) / torch.sqrt(2)).to(device)
        X_ = torchcomplex(torch.tensor([[1]])).to(device)
        for _ in range(k):
            X_ = torch.kron(X_, X)
        return X_
    
    @staticmethod
    def get_k_qubit_sparse_Xk(k):
        K = 2**k
        I = torch.arange(K, device=device)
        J = K - I
        V = torch.ones(K, device=device)
        matrix_indices = torch.stack([I, J], dim=1)
        return (matrix_indices, V)

In [4]:
test_correctness()

Correctness test passed for dense
Correctness test passed for sparse


In [5]:
densetimes=dict()
sparsetimes=dict()

for k in range(1,12):
    densetimes[k]=PerfTest.time(20, k)

for k in range(1,19):
    sparsetimes[k]=PerfTest.time(20, k, sparse=True)

AttributeError: 'numpy.ndarray' object has no attribute 'is_complex'

In [None]:
densetimes[11]=PerfTest.time(20,11)

In [None]:
import matplotlib.pyplot as plt


plt.xlabel("k")
plt.ylabel("Time (s)")
plt.yscale("log")
plt.axhline(60,lw=.5,ls=':',color='gray',label="1 min")
plt.axhline(1,lw=.5,color='gray',label="1 second")
plt.axvline(10,lw=.5,ls=':',color='gray')
plt.xticks(list(range(1,21)))

plt.plot(densetimes.keys(), densetimes.values(),'-o', label="dense Hadamard gate H^k")
plt.plot(sparsetimes.keys(), sparsetimes.values(),'--s', label="sparse X^k gate")

plt.legend()

plt.title("Time to apply k-qubit gate to 20-qubit state")

In [None]:
#plt.savefig("example_outputs/many_qubit_gate_perftest.pdf")