In [1]:
import torch
import time
import math
import gc #force garbage collection if necessary

In [2]:
#find GPU device
if torch.cuda.is_available():
    #if Nvidia GPU available, use it
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    #if apple silicon available
    device = torch.device('mps')
else:
    #else just go with CPU
    device = torch.device('cpu')
print(f"Program using {device}")

torch.set_num_threads(8)
torch.set_num_interop_threads(8)

Program using mps


In [3]:
class Gate:
    
    def __init__(self, device, custom=None):
        self.device = device
        self.custom = custom

        self.zero = torch.tensor([[1., 0.],
                               [0., 0.]], device=self.device, dtype=torch.cfloat)
        
        self.one = torch.tensor([[0., 0.],
                               [0., 1.]], device=self.device, dtype=torch.cfloat)       
        
        self.H = torch.tensor([[1., 1.],
                               [1., -1.]], device=self.device, dtype=torch.cfloat) / torch.sqrt(torch.tensor(2))
    
        self.I = torch.tensor([[1., 0.],
                               [0., 1.]], device=self.device, dtype=torch.cfloat)

        self.S = torch.tensor([[1., 0.],
                               [0., 1.j]], device=self.device, dtype=torch.cfloat)

        #square root of X gate
        self.SX= torch.tensor([[1.+1.j, 1.-1.j],
                               [1.-1.j, 1.+1.j]], device=self.device, dtype=torch.cfloat) / 2

        #Phase gate T uses class phase shift method (fourth root of Z)
        self.T = self.P(torch.pi/4)
        
        self.X = torch.tensor([[0., 1.],
                               [1., 0.]], device=self.device, dtype=torch.cfloat)

        self.Y = torch.tensor([[0., -1.j],
                               [1.j, 0.]], device=self.device, dtype=torch.cfloat)

        self.Z = torch.tensor([[1., 0.],
                               [0., -1.]], device=self.device, dtype=torch.cfloat)

    #---Single qubit matrix operator---#

    def apply(self, gate, state, target):
        """
        High-performance strided single-qubit gate application.
        Works on CPU, CUDA, and MPS.
        Avoids high-rank tensors (MPS limit)
        """
    
        N = state.numel()
        if target < 0:
            raise ValueError("Target cannot have negative index")
        if (target + 1) > math.log2(N):
            raise ValueError("Target is out of range")
        if N & (N - 1):
            raise ValueError("State vector length must be power of two")
    
           # bit mask for the target qubit
        bit = 1 << target
    
        # generate all base indices where target bit = 0
        idx0 = torch.arange(N, device=state.device)
        idx0 = idx0[(idx0 & bit) == 0]
        idx1 = idx0 | bit       # matches the |1> amplitudes
    
        # extract paired amplitude vectors
        v0 = state[idx0]
        v1 = state[idx1]
    
        # apply the 2x2 operator
        out0 = gate[0,0] * v0 + gate[0,1] * v1
        out1 = gate[1,0] * v0 + gate[1,1] * v1
    
        # scatter back
        state[idx0] = out0
        state[idx1] = out1
        
    #---Phase transforms---#
    
    '''
    Ph (Global phase transform) takes a real-valued angle argument
    representing a complex phase rotation in the Bloch sphere
    The phase transform rotates the qubit state without changing its
    probability of collapse to |0) or |1)
    '''
    def Ph(self, phase):
        phase = phase % (2*torch.pi) #modulus 2*pi of input
        Ph_angle = torch.tensor([0. + 1j*(phase)], device=self.device, dtype=torch.cfloat)
        Ph_angle = torch.exp(Ph_angle)
        return Ph_angle * self.I
    '''
    Bloch sphere rotation around the x axis
    '''
    def Rx(self, *angle):
        angle = torch.tensor(angle, device=self.device)
        #form sin and cos components
        cos_comp = torch.cos((angle/2))*self.I
        sin_comp = 1j*torch.sin((angle/2))*self.X
        #form complete matrix
        Rx_rot = cos_comp - sin_comp
        return Rx_rot
    '''
    Bloch sphere rotation around the y axis
    '''
    def Ry(self, *angle):
        angle = torch.tensor(angle, device=self.device)
        #form sin and cos components
        cos_comp = torch.cos((angle/2))*self.I
        sin_comp = 1j*torch.sin((angle/2))*self.Y
        #form complete matrix
        Ry_rot = cos_comp - sin_comp
        return Ry_rot
    '''
    Bloch sphere rotation around the z axis
    '''
    def Rz(self, *angle):
        angle = torch.flatten(torch.tensor(angle, device=self.device))
        #form sin and cos components
        cos_comp = torch.cos(angle/2)*self.I
        sin_comp = 1j*torch.sin(angle/2)*self.Z
        #form complete matrix
        Rz_rot = cos_comp - sin_comp
        return Rz_rot
    '''
    Universal Bloch sphere rotation
    U(theta, phi, lambda)
    '''
    def U(self, *angles):
        theta_t, phi_t, lam_t = torch.tensor(angles, device=self.device)
        theta_t /= 2 #divide theta by two for closed matrix form

        #form tensor using closed-form definition
        U_rot = torch.tensor(
            [[  torch.cos(theta_t),
                -1*torch.exp(1.j*lam_t)*torch.sin(theta_t)
             ],
             [  torch.exp(1.j*phi_t)*torch.sin(theta_t),
                torch.exp(1.j*(lam_t+phi_t))*torch.cos(theta_t)
             ]], device=self.device)
        return U_rot

    #---Controlled gates---#
    '''
    These sets of operators are syntactically different than
    2x2 and rotation operators as they must extend between
    a control qubit and a target qubit
    For a n-qubit system with target 4 and control 0, the
    operator must extend across 5 qubits! Furthermore, the
    matrix created for an n-qubit system holds 2^n x 2^n values
    Since for all practical purposes this is vastly inefficient,
    these routines use classical bitwise analogs instead of matrices
    '''
    def CNOT(self, states, control, target):
        '''
        *args = [states, control, target]
        '''
        
        C = 1 << control
        T = 1 << target
        N = len(states)
        if (control < 0) or (target < 0):
            raise ValueError("Indices cannot be negative")
        elif (2**(control+1) > N) or (2**(target+1) > N):
            raise ValueError("Control and target qubits must be in range")
        
        #create indices tensor
        indices = torch.arange(N)
        #perform XOR on target and state indices
        i = indices[(indices & C)!=0]
        j = i^T
        #swap states
        tmp = states[i].clone()
        states[i] = states[j]
        states[j] = tmp
        return states

    #---Non-clifford gates---#
    '''
    These are the set of non-universal quantum gates
    Since these are 2x2, they can still be applied using the apply() function
    '''

    def P(self, phi):
        '''
        This a single amplitude phase shift routine
        '''
        #convert angle to torch tensor and exponentiate
        phi_t = torch.tensor(phi, device=self.device)
        phi_t = torch.exp(1j*phi_t)
        #form matrix as tensor
        P_rot = torch.tensor([
            [1, 0],
            [0, phi_t]
        ], dtype=torch.cfloat, device=self.device)
        return P_rot
        
    #---Measurement---#
    def MCM(self, states, target):
        '''
        Perform a mid-circuit measurement
        Probabilistically samples state |0) or |1) for
        the target qubit, collapsing corresponding amplitudes
        '''
        N = states.numel()
        if (target+1) > math.log2(N):
            raise ValueError("Target is out of range")
        bit_1 = 1 << target
        
        #create indices vector to browse
        indices = torch.arange(N, device=states.device)
        mask = ((indices & bit_1)!=0)
        
        #find probability of measuring 1 for target qubit
        p1 = states[mask].abs().pow(2).sum()

        #define sampling probability and 
        #measure 1 if rand < p1 else measure 0
        if torch.rand(()) < p1:
            self.apply(self.one, states, target)
            states /= torch.sqrt(p1)
        else:
            self.apply(self.zero, states, target)
            states /= torch.sqrt(1-p1)

In [4]:
class Circuit:
    
    def __init__(self, size, device=None):
        
        #constructor takes size parameter describing number of qubits in circuit
        self.size = size
        self.device = device if device else self.get_user_device()
        self.first_mcm = None #None if no mcm, otherwise first index
        self.Gate = Gate(device=device) #create a reference gate object using the circuit's device
        #create a gate reference dict for easy application
        self.gates = {
                        '1':    self.Gate.one,
                        '0':    self.Gate.zero,
                        'CNOT': self.Gate.CNOT,
                        'H':    self.Gate.H,
                        'I':    self.Gate.I,
                        'MCM':  self.Gate.MCM,
                        'P':    self.Gate.P,
                        'PH':   self.Gate.Ph,
                        'RX':   self.Gate.Rx,
                        'RY':   self.Gate.Ry,
                        'Rz':   self.Gate.Rz,
                        'S':    self.Gate.S,
                        'SX':   self.Gate.SX,
                        'T':    self.Gate.T,
                        'U':    self.Gate.U,
                        'X':    self.Gate.X,
                        'Y':    self.Gate.Y,
                        'Z':    self.Gate.Z,
                      }
        self.states = [] #states array for circuit amplitudes
        #create a circuit array to store ordered operations on the qubits
        #stores each as a gate, and parameters: {gate_matrix : {metadata}}
        self.circuit = []
        self.measurements = {}
    
    def add(self, gate, target, *angles, control=None):
        '''
        Add a gate to the circuit construction dict
        Gate is added as a 2x2 or 4x4 matrix with metadata
        Metadata is a dict with gate, target, control, phases,
        i.e. {'H':{'gate':2x2, 'target':0}} 
             {'Ph':{'gate':2x2, 'target':1, 'phases'=[pi]}}
             {'CNOT':{'gate':4x4, 'control':0, 'target':1}}
        Where 2x2 or 4x4 is a tensor matrix
        '''
        #finds necessary kwargs for gate contruction
        kwargs = {}
        
        if control is not None:
            kwargs['control'] = control
            meta['control'] = control
        if angles is not None:
            kwargs['angles'] = angles
            meta['control'] = angles
            
        added_gate = self.gates[gate](self.states, target, **kwargs)
        meta = {
                'gate':added_gate,
                'target': target
               } #gate metadata
        
        self.circuit.append({gate:metadata})
        
    def execute(self, states, shots):
        '''
        Applies gates in circuit from start to measurement
         - Runs circuit n times for mid-circuit measurements
         - Samples n shots from wavefunction probability distribution otherwise
         - Exports counts from measurements to an amplitude dictionary
        '''
        if self.first_mcm:
            
    def first_mcm_idx(self):
        '''
        Returns first index of an mcm operator in the circuit list
        '''
        index = next(
            i for i, d in enumerate(x)
            for v in d.values()
            if v.get('MCM') == 1
        )       
        self.first_mcm = index
        
    def measure(self, states, shots):
        '''
        Measurement routine to sample probability state vector
        shots times.
        '''
        #get probability vector
        probs = self.probs(states)
        #form cumulative probability distribution vector
        cdf = probs.cumsum(dim=0)
        #generate n random numbers in (0,1)
        r = torch.rand(shots, device=states.device)
        #batch binary search cdf for indices of cumulative probability = r
        measurements = torch.searchsorted(cdf, r)
        return measurements
    
    def probs(self, states):
        return states.abs().pow(2)

    def run_circuit(self):
        '''
        Run through the circuit one time to obtain final state amplitudes
        Applies each gate in self.circuit sequentially
        '''
        

        
    def set_size(self, size):
        '''
        Change circuit qubit count
        '''
        self.size = size
    
    #---configure user device---#
    
    def get_user_device(threads):
        #find GPU device
        if torch.cuda.is_available():
            #if Nvidia GPU available, use it
            device = torch.device('cuda')
        elif torch.backends.mps.is_available():
            #if apple silicon available
            device = torch.device('mps')
        else:
            #else just go with CPU
            device = torch.device('cpu')
        print(f"Program using {device}")
        return device

    def change_device(self, new_device):
        self.device = new_device
        #update gate references to new device
        self.Gate = Gate(device=new_device)
    
circuit = Circuit(2)

IndentationError: expected an indented block after 'if' statement on line 72 (2351536367.py, line 74)

In [5]:
def gpu_mem_usage():
    if torch.cuda.is_available():
        mem_used = torch.cuda.memory_allocated()
    elif torch.backends.mps.is_available():
        mem_used = torch.mps.current_allocated_memory()
    else:
        print("Cannot show memory usage for CPU")
        return
    mem_used /= 1024**2
    return mem_used

In [6]:
'''
Save torch tensor directly to binary file
'''
def save_states(states, filename):
    states.cpu().numpy().tofile(filename)
    print(f"Saved wavefunction states to {filename}")

In [14]:
'''
Read raw binary directly into torch tensor on device
'''
def import_states(filename, dtype=torch.cfloat, device='cpu'):
    with open(filename, "rb") as f:
        buf = f.read()

    tensor = torch.frombuffer(buf, dtype=dtype)
    return tensor.to(device)

In [15]:
#test cases
gate = Gate(device=device) #create gate object on mps
gate_cpu = Gate(device=torch.device('cpu'))
n = 26 #test with n qubits

t0_gpu = time.time()
x = torch.zeros(2**n, dtype=torch.cfloat, device=device)
x[0] = 1
for i in range(n):
    gate.apply(gate.H, x, i)
    gate.MCM(x, i)

t1_gpu = time.time()
print(f"GPU accelerated took {t1_gpu - t0_gpu} s")
t0_cpu = time.time()
y = torch.zeros(2**n, dtype=torch.cfloat, device=torch.device('cpu'))
y[0] = 1
for i in range(n):
    gate_cpu.apply(gate_cpu.H, y, i)
    gate_cpu.MCM(y,i)

t1_cpu = time.time()
print(f"CPU took {t1_cpu - t0_cpu} s")
#save_states(x, "saved_states/states.bin")
#print(import_states("saved_states/states.bin", device=device)[0:10])
print(gpu_mem_usage())

GPU accelerated took 31.557900190353394 s
CPU took 45.65886068344116 s
512.0048828125
