In [4]:
import torch
import time
import math
import gc #force garbage collection if necessary

In [5]:
#find GPU device
if torch.cuda.is_available():
    #if Nvidia GPU available, use it
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    #if apple silicon available
    device = torch.device('mps')
else:
    #else just go with CPU
    device = torch.device('cpu')

print(f"Program using {device}")

Program using mps


In [6]:
class Gate:
    
    def __init__(self, device, custom=None):
        self.device = device
        self.custom = custom
        self.H = torch.tensor([[1., 1.],
                               [1., -1.]], device=self.device, dtype=torch.cfloat) / torch.sqrt(torch.tensor(2))
    
        self.I = torch.tensor([[1., 0.],
                               [0., 1.]], device=self.device, dtype=torch.cfloat)

        self.S = torch.tensor([[1., 0.],
                               [0., 1.j]], device=self.device, dtype=torch.cfloat)

        #square root of X gate
        self.SX= torch.tensor([[1.+1.j, 1.-1.j],
                               [1.-1.j, 1.+1.j]], device=self.device, dtype=torch.cfloat) / 2

        #Phase gate T uses class phase shift method (fourth root of Z)
        self.T = self.P(torch.pi/4)
        
        self.X = torch.tensor([[0., 1.],
                               [1., 0.]], device=self.device, dtype=torch.cfloat)

        self.Y = torch.tensor([[0., -1.j],
                               [1.j, 0.]], device=self.device, dtype=torch.cfloat)

        self.Z = torch.tensor([[1., 0.],
                               [0., -1.]], device=self.device, dtype=torch.cfloat)

    #---Single qubit matrix operator---#
    '''
    Routine to apply given single-qubit 2x2 unitary matrix operator on a target 
    qubit index in some given states vector (representing the wavefunction) if the 
    number of qubits is greater than 16
    As MPS cannot currently support tensors of rank >16, this is required to provide
    a small speedup using MPS over CPU
    '''

    def apply_mps(self, gate, states, target):
        N = states.numel()
        if target < 0:
            raise ValueError("Target cannot have negative index!")
        if N == 0:
            raise ValueError("State vector cannot be empty!")
        if (N & (N - 1)) != 0:
            raise ValueError("State vector length must be a power of two")
        n = int(math.log2(N))
        if 2**(target + 1) > N:
            raise ValueError("Target index cannot exceed circuit!")
        if gate.shape != torch.Size([2, 2]):
            raise ValueError("Provided gate is not 2x2 unitary!")
            
        #create mask to identify target qubit
        mask = 1 << target
        #range indices of wavefunction amplitudes
        indices = torch.arange(states.numel())
        #identify indices adjacent to target qubit index
        i = indices[(indices & mask)==0]
        j = i ^ mask
        #apply operator
        a0, a1 = states[i], states[j]
        states[i] = gate[0,0] * a0 + gate[0,1] * a1
        states[j] = gate[1,0] * a0 + gate[1,1] * a1
        return states

    '''
    General apply routine for 2x2 unitary matrix operator on target qubit
        '''
    def apply(self, gate, state, target):
        """
        High-performance strided single-qubit gate application.
        Works on CPU, CUDA, and MPS.
        Avoids high-rank tensors (MPS limit), avoids masked indexing (CUDA poison).
        """
    
        N = state.numel()
        if target < 0:
            raise ValueError("Target cannot have negative index")
        if N & (N - 1):
            raise ValueError("State vector length must be power of two")
    
        # stride between alternating |0> and |1> amplitudes of this target qubit
        stride = 1 << target
        block = stride << 1  # = 2 * stride
    
        # number of blocks
        num_blocks = N // block
    
        # reshape into (num_blocks, stride) but flattened because we want 1D layout
        # we will manually create views using strides
        # indexes:
        #   base + offset gives |0>
        #   base + offset + stride gives |1>
    
        # we create views with appropriate strides â€” no memory copies
        s = state  # alias
    
        # a0: all |0> amplitudes for this target qubit
        a0 = s.as_strided(
            size=(num_blocks, stride),
            stride=(block, 1)
        )
    
        # a1: all |1> amplitudes for this target qubit
        a1 = s.as_strided(
            size=(num_blocks, stride),
            stride=(block, 1),
            storage_offset=stride
        )
    
        # Stack into (num_blocks * stride, 2) column vectors:
        # But we want shape (2, M) to matmul with (2,2)
        # So we do the transpose
        M = num_blocks * stride
    
        v0 = a0.reshape(M)
        v1 = a1.reshape(M)
    
        # Build a stacked (2, M) matrix WITHOUT allocating huge tensors
        V = torch.stack((v0, v1), dim=0)  # shape (2, M)
    
        # Apply the 2x2 gate: (2,2) @ (2,M) -> (2,M)
        out = gate @ V
    
        # Write results back in-place
        v0.copy_(out[0])
        v1.copy_(out[1])
    
        return state
    #---Phase transforms---#
    
    '''
    Ph (Global phase transform) takes a real-valued angle argument
    representing a complex phase rotation in the Block sphere
    The phase transform rotates the qubit state without changing its
    probability of collapse to |0) or |1)
    '''
    def Ph(self, phase):
        phase = phase % (2*torch.pi) #modulus 2*pi of input
        Ph_angle = torch.tensor([0. + 1j*(phase)], device=self.device, dtype=torch.cfloat)
        Ph_angle = torch.exp(Ph_angle)
        return Ph_angle * self.I
    '''
    Bloch sphere rotation around the x axis
    '''
    def Rx(self, angle):
        angle = torch.tensor(angle, device=self.device)
        #form sin and cos components
        cos_comp = torch.cos((angle/2))*self.I
        sin_comp = 1j*torch.sin((angle/2))*self.X
        #form complete matrix
        Rx_rot = cos_comp - sin_comp
        return Rx_rot
    '''
    Bloch sphere rotation around the y axis
    '''
    def Ry(self, angle):
        angle = torch.tensor(angle, device=self.device)
        #form sin and cos components
        cos_comp = torch.cos((angle/2))*self.I
        sin_comp = 1j*torch.sin((angle/2))*self.Y
        #form complete matrix
        Ry_rot = cos_comp - sin_comp
        return Ry_rot
    '''
    Bloch sphere rotation around the z axis
    '''
    def Rz(self, angle):
        angle = torch.tensor(angle, device=self.device)
        #form sin and cos components
        cos_comp = torch.cos(angle/2)*self.I
        sin_comp = 1j*torch.sin(angle/2)*self.Z
        #form complete matrix
        Rz_rot = cos_comp - sin_comp
        return Rz_rot
    '''
    Universal Bloch sphere rotation
    U(theta, phi, lambda)
    '''
    def U(self, theta, phi, lam):
        #convert angles to tensors
        theta_t = torch.tensor(theta/2, device=self.device)
        phi_t = torch.tensor(phi, device=self.device)
        lam_t = torch.tensor(lam, device=self.device)
        #form tensor using closed-form definition
        U_rot = torch.tensor(
            [[  torch.cos(theta_t),
                -1*torch.exp(1.j*lam_t)*torch.sin(theta_t)
             ],
             [  torch.exp(1.j*phi_t)*torch.sin(theta_t),
                torch.exp(1.j*(lam_t+phi_t))*torch.cos(theta_t)
             ]], device=self.device)
        return U_rot

    #---Controlled gates---#
    '''
    These sets of operators are syntactically different than
    2x2 and rotation operators as they must extend between
    a control qubit and a target qubit
    For a n-qubit system with target 4 and control 0, the
    operator must extend across 5 qubits! Furthermore, the
    matrix created for an n-qubit system holds 2^n x 2^n values
    Since for all practical purposes this is vastly inefficient,
    these routines use classical bitwise analogs instead of matrices
    '''
    def CNOT(self, states, control, target):
        C = 1 << control
        T = 1 << target
        N = len(states)
        if (control < 0) or (target < 0):
            raise ValueError("Indices cannot be negative")
        elif (2**(control+1) > N) or (2**(target+1) > N):
            raise ValueError("Control and target qubits must be in range")
        
        #create indices tensor
        indices = torch.arange(N)
        #perform XOR on target and state indices
        i = indices[(indices & C)!=0]
        j = i^T
        #swap states
        tmp = states[i].clone()
        states[i] = states[j]
        states[j] = tmp
        return states

    #---Non-clifford gates---#
    '''
    These are the set of non-universal quantum gates
    Since these are 2x2, they can still be applied using the apply() function
    '''
    
    '''
    This a single amplitude phase shift routine
    '''
    def P(self, phi):
        #convert angle to torch tensor and exponentiate
        phi_t = torch.tensor(phi, device=self.device)
        phi_t = torch.exp(1j*phi_t)
        #form matrix as tensor
        P_rot = torch.tensor([
            [1, 0],
            [0, phi_t]
        ], dtype=torch.cfloat, device=self.device)
        return P_rot
    

In [7]:
class circuit:
    def __init__(self, size):
        
        #constructor takes size parameter describing number of qubits in circuit
        self.size = size
        

In [14]:
#test cases
gate = Gate(device=device) #create gate object on mps
gate_cpu = Gate(device=torch.device('cpu'))
n = 26 #test with n qubits

t0_gpu = time.time()
x = torch.zeros(2**n, dtype=torch.cfloat, device=device)
x[0] = 1
for i in range(n):
    x = gate.apply(gate.H, x, i)
    x = gate.apply(gate.Y, x, i)
    x = gate.apply(gate.Rx(torch.pi/2), x, i)
t1_gpu = time.time()
print(f"MPS accelerated took {t1_gpu - t0_gpu} s")

t0_cpu = time.time()
y = torch.zeros(2**n, dtype=torch.cfloat, device=torch.device('cpu'))
y[0] = 1
for i in range(n):
    y = gate_cpu.apply(gate_cpu.H, y, i)
    y = gate_cpu.apply(gate_cpu.Y, y, i)
    y = gate_cpu.apply(gate_cpu.Rx(torch.pi/2), y, i)
t1_cpu = time.time()
print(f"CPU took {t1_cpu - t0_cpu} s")
#print(x)
save_states(x, "saved_states/states.bin")
print(import_states("saved_states/states.bin", device=device)[0:10])
print(gpu_mem_usage())

MPS accelerated took 12.089386940002441 s
CPU took 40.67214608192444 s
Saved wavefunction states to saved_states/states.bin
tensor([0.-0.5000j, 0.+0.5000j, 0.+0.0000j, 0.+0.0000j, 0.+0.0000j, 0.+0.0000j,
        0.+0.0000j, 0.+0.0000j, 0.+0.0000j, 0.+0.0000j], device='mps:0')
512.001953125


In [8]:
def gpu_mem_usage():
    if torch.cuda.is_available():
        mem_used = torch.cuda.memory_allocated()
    elif torch.backends.mps.is_available():
        mem_used = torch.mps.current_allocated_memory()
    else:
        print("Cannot show memory usage for CPU")
        return
    mem_used /= 1024**2
    return mem_used

In [9]:
'''
Save torch tensor directly to binary file
'''
def save_states(states, filename):
    states.cpu().numpy().tofile(filename)
    print(f"Saved wavefunction states to {filename}")

In [10]:
'''
Read raw binary directly into torch tensor on device
'''
def import_states(filename, dtype=torch.cfloat, device='cpu'):
    with open(filename, "rb") as f:
        buf = f.read()

    tensor = torch.frombuffer(buf, dtype=dtype)
    return tensor.to(device)