In [2]:
import torch
import time
import math
import gc #force garbage collection if necessary

In [59]:
class Gate:
    
    def __init__(self, device, custom=None):
        self.device = device
        self.custom = custom

        self.zero = torch.tensor([[1., 0.],
                               [0., 0.]], device=self.device, dtype=torch.cfloat)
        
        self.one = torch.tensor([[0., 0.],
                               [0., 1.]], device=self.device, dtype=torch.cfloat)       
        
        self.H = torch.tensor([[1., 1.],
                               [1., -1.]], device=self.device, dtype=torch.cfloat) / torch.sqrt(torch.tensor(2))
    
        self.I = torch.tensor([[1.+0.j, 0.+0.j],
                               [0.+0.j, 1.+0.j]], device=self.device, dtype=torch.cfloat)

        self.S = torch.tensor([[1., 0.],
                               [0., 1.j]], device=self.device, dtype=torch.cfloat)

        #square root of X gate
        self.SX= torch.tensor([[1.+1.j, 1.-1.j],
                               [1.-1.j, 1.+1.j]], device=self.device, dtype=torch.cfloat) / 2

        #Phase gate T uses class phase shift method (fourth root of Z)
        self.T = self.P(torch.pi/4)
        
        self.X = torch.tensor([[0., 1.],
                               [1., 0.]], device=self.device, dtype=torch.cfloat)

        self.Y = torch.tensor([[0., -1.j],
                               [1.j, 0.]], device=self.device, dtype=torch.cfloat)

        self.Z = torch.tensor([[1., 0.],
                               [0., -1.]], device=self.device, dtype=torch.cfloat)

    #---Single qubit matrix operator---#

    def apply(self, gate, state, target, controls=None):
        """
        High-performance strided single-qubit gate application.
        Works on CPU, CUDA, and MPS.
        Takes (optional) list of control qubits for gate
        Avoids high-rank tensors (MPS limit)
        """
    
        N = state.numel()
        if target < 0:
            raise ValueError("Target cannot have negative index")
        if (target + 1) > math.log2(N):
            raise ValueError("Target is out of range")
        if N & (N - 1):
            raise ValueError("State vector length must be power of two")
        
        # bit mask for the target qubit
        t_bit = 1 << target
        
        # generate all base indices where target bit = 0
        idx0 = torch.arange(N, device=state.device)
        idx0 = idx0[(idx0 & t_bit) == 0]
        c_mask = 0
        if controls:
            c_mask = sum(1<<c for c in controls)
            idx0 = idx0[(idx0 & c_mask)!=0] #if there are controls, match only |1> control amplitudes
        idx1 = idx0 | t_bit       #matches the |1> amplitudes of targets
        # extract paired amplitude vectors
        v0 = state[idx0]
        v1 = state[idx1]
        # apply the 2x2 operator
        out0 = gate[0,0] * v0 + gate[0,1] * v1
        out1 = gate[1,0] * v0 + gate[1,1] * v1
        # scatter back
        state[idx0] = out0
        state[idx1] = out1
        
    #---Phase transforms---#
    
    '''
    Ph (Global phase transform) takes a real-valued angle argument
    representing a complex phase rotation in the Bloch sphere
    The phase transform rotates the qubit state without changing its
    probability of collapse to |0) or |1)
    '''
    def Ph(self, phase):
        phase = phase % (2*torch.pi) #modulus 2*pi of input
        Ph_angle = torch.tensor([0. + 1j*(phase)], device=self.device, dtype=torch.cfloat)
        Ph_angle = torch.exp(Ph_angle)
        return Ph_angle * self.I
    '''
    Bloch sphere rotation around the x axis
    '''
    def Rx(self, angle):
        angle = torch.tensor(angle, device=self.device)
        #form sin and cos components
        cos_comp = torch.cos((angle/2))*self.I
        sin_comp = 1j*torch.sin((angle/2))*self.X
        #form complete matrix
        Rx_rot = cos_comp - sin_comp
        return Rx_rot
    '''
    Bloch sphere rotation around the y axis
    '''
    def Ry(self, angle):
        angle = torch.tensor(angle, device=self.device)
        #form sin and cos components
        cos_comp = torch.cos((angle/2))*self.I
        sin_comp = 1j*torch.sin((angle/2))*self.Y
        #form complete matrix
        Ry_rot = cos_comp - sin_comp
        return Ry_rot
    '''
    Bloch sphere rotation around the z axis
    '''
    def Rz(self, angle):
        angle = torch.flatten(torch.tensor(angle, device=self.device))
        #form sin and cos components
        cos_comp = torch.cos(angle/2)*self.I
        sin_comp = 1j*torch.sin(angle/2)*self.Z
        #form complete matrix
        Rz_rot = cos_comp - sin_comp
        return Rz_rot
    '''
    Universal Bloch sphere rotation
    U(theta, phi, lambda)
    '''
    def U(self, angles):
        theta_t, phi_t, lam_t = torch.tensor(angles, device=self.device)
        theta_t /= 2 #divide theta by two for closed matrix form

        #form tensor using closed-form definition
        U_rot = torch.tensor(
            [[  torch.cos(theta_t),
                -1*torch.exp(1.j*lam_t)*torch.sin(theta_t)
             ],
             [  torch.exp(1.j*phi_t)*torch.sin(theta_t),
                torch.exp(1.j*(lam_t+phi_t))*torch.cos(theta_t)
             ]], device=self.device)
        return U_rot

    #---Controlled gates---#
    '''
    These sets of operators are syntactically different than
    2x2 and rotation operators as they must extend between
    a control qubit and a target qubit
    For a n-qubit system with target 4 and control 0, the
    operator must extend across 5 qubits! Furthermore, the
    matrix created for an n-qubit system holds 2^n x 2^n values
    Since for all practical purposes this is vastly inefficient,
    these routines use classical bitwise analogs instead of matrices
    '''
    def CNOT(self, states, control, target):
        '''
        *args = [states, control, target]
        '''
        
        C = 1 << control
        T = 1 << target
        N = len(states)
        if (control < 0) or (target < 0):
            raise ValueError("Indices cannot be negative")
        elif (2**(control+1) > N) or (2**(target+1) > N):
            raise ValueError("Control and target qubits must be in range")
        
        #create indices tensor
        indices = torch.arange(N)
        #perform XOR on target and state indices
        i = indices[(indices & C)!=0]
        j = i^T
        #swap states
        tmp = states[i].clone()
        states[i] = states[j]
        states[j] = tmp
        return states

    #---Non-clifford gates---#
    '''
    These are the set of non-universal quantum gates
    Since these are 2x2, they can still be applied using the apply() function
    '''

    def P(self, phi):
        '''
        This a single amplitude phase shift routine
        '''
        #convert angle to torch tensor and exponentiate
        phi_t = torch.tensor(phi, device=self.device)
        phi_t = torch.exp(1j*phi_t)
        #form matrix as tensor
        P_rot = torch.tensor([
            [1, 0],
            [0, phi_t]
        ], dtype=torch.cfloat, device=self.device)
        return P_rot
        
    #---Measurement---#
    def MCM(self, states, target):
        """
        Mid-circuit measurement on a statevector.
        Collapses amplitudes of |0> and |1> subspaces for the target qubit,
        then renormalizes the surviving branch.
        """
    
        # basic consistency check: avoid floating-point roulette
        N = states.numel()
        num_qubits = int(math.log2(N))
        if target < 0 or target >= num_qubits:
            raise ValueError("Target qubit index out of range.")
    
        # bit mask selecting the |1> subspace
        bit = 1 << target
        indices = torch.arange(N, device=states.device)
        mask = (indices & bit) != 0
    
        # probability of collapsing to |1>
        p1 = states[mask].abs().pow(2).sum()
        p1 = p1.item()  # extract proper float to avoid device shenanigans
    
        # sample measurement result
        measure_one = (torch.rand(()) < p1).item() if p1 > 0 else False
        # small courtesy: if p1 == 0 or 1, sampling becomes deterministic
    
        if measure_one:
            # collapse to |1>
            # zero everything in |0> subspace
            states[~mask] = 0.0
            if p1 > 0:
                states /= math.sqrt(p1)
        else:
            # collapse to |0>
            # zero everything in |1> subspace
            states[mask] = 0.0
            p0 = 1.0 - p1
            if p0 > 0:
                states /= math.sqrt(p0)
    
        return int(measure_one)

In [60]:
class Circuit:
    
    def __init__(self, size, device=None, threads=8):

        self.size = size
        self.device = self.get_user_device(device, threads)
        self.first_mcm = None #None if no mcm, otherwise first index
        self.Gate = Gate(device=self.device) #create a reference gate object using the circuit's device
        #create a gate reference dict for easy application
        self.gates = {
                        '1':    self.Gate.one,
                        '0':    self.Gate.zero,
                        'NOT':  self.Gate.X,
                        'H':    self.Gate.H,
                        'I':    self.Gate.I,
                        'MCM':  self.Gate.MCM,
                        'P':    self.Gate.P,
                        'PH':   self.Gate.Ph,
                        'RX':   self.Gate.Rx,
                        'RY':   self.Gate.Ry,
                        'Rz':   self.Gate.Rz,
                        'S':    self.Gate.S,
                        'SX':   self.Gate.SX,
                        'T':    self.Gate.T,
                        'U':    self.Gate.U,
                        'X':    self.Gate.X,
                        'Y':    self.Gate.Y,
                        'Z':    self.Gate.Z,
                      }
        self.states = None #states array for circuit amplitudes
        #create a circuit array to store ordered operations on the qubits
        #stores each as a gate, and parameters: {gate_matrix : {metadata}}
        self.circuit = []
        self.measurements = None
    
    def add(self, gate, target, angles=None, controls=None):
        '''
        Add a gate to the circuit construction dict
        Gate is input as a string i.e. 'H', 'NOT'
        Gate is added with metadata as a 2x2 or 4x4 matrix
        Data structure is a dict with gate name, target, control, and phases
        i.e. {'name':'H', 'target':0}
             {'name':'Ph', 'target':1, 'phases'=[pi]}
             {'name':'NOT', 'controls':[0], 'target':1}
        '''
        #finds necessary kwargs for gate contruction
        kwargs = {}
        meta = {
                'target': target,
                'name':gate
               } #gate metadata
        if controls is not None:
            kwargs['controls'] = controls
            meta['controls'] = controls
        if angles is not None:
            kwargs['angles'] = angles
            meta['angles'] = angles
        
        self.circuit.append(meta)
        
    def execute(self, shots, cache=False):
        '''
        Applies gates in circuit from start to measurement
         - Runs circuit n times for mid-circuit measurements
         - Samples n shots from wavefunction probability distribution otherwise
         - Exports counts from measurements to an amplitude dictionary
         - Optionally caches results up to first mcm
        '''
        self.measurements = torch.zeros(2**self.size, dtype=torch.int, device=self.device)
        mcm_exist = self.next_mcm_idx()
        if mcm_exist:
            if cache:
                self.run_circuit(stop_idx = mcm_exist)
                cached_states = self.states.clone()
                for _ in range(shots):
                    self.run_circuit(states = cached_states, start_idx = mcm_exist)
                    self.measurements += torch.bincount(self.measure(self.states, shots=1), minlength=2**self.size)
            else:
                for _ in range(shots):
                    self.run_circuit()
                    self.measurements += torch.bincount(self.measure(self.states, shots=1), minlength=2**self.size)
        else:
            self.run_circuit()
            self.measurements += torch.bincount(self.measure(self.states, shots=shots), minlength=2**self.size)
        
    def next_mcm_idx(self, start_idx=0):
        '''
        Returns index of next mcm operator in the circuit list
        '''
        index = next(
            (i for i, d in enumerate(self.circuit[start_idx:])
                if 'MCM' in d.values()),
            None
        )       
        return index
        
    def measure(self, states, shots):
        '''
        Measurement routine to sample probability state vector
        shots times.
        '''
        #get probability vector
        probs = self.probs(states)
        #form cumulative probability distribution vector
        cdf = probs.cumsum(dim=0)
        #generate n random numbers in (0,1)
        r = torch.rand(shots, device=states.device)
        #batch binary search cdf for indices of cumulative probability = r
        measurements = torch.searchsorted(cdf, r)
        return measurements
    
    def probs(self, states):
        return states.abs().pow(2)

    def run_circuit(self, states=None, start_idx=None, stop_idx=None):
        '''
        Run through the circuit gates one at a time to obtain state amplitudes
        Applies each gate in self.circuit sequentially
        Setting start_idx evaluates the circuit starting at given index
        Setting stop_idx stops evaluating the circuit at the given index
        By default will start at 0 and end at end of circuit
        '''
        
        start_idx = 0 if start_idx is None else start_idx
        stop_idx = len(self.circuit) if stop_idx is None else stop_idx
        
        if start_idx>len(self.circuit) or stop_idx>len(self.circuit) or start_idx<0 or stop_idx<0:
            raise ValueError("Start/Stop indices are out of bounds")
            
        if states==None:
            self.states = torch.zeros(2**self.size, dtype=torch.cfloat, device=self.device)
            self.states[0] = 1 #initialize qubits all to zero state (|0) amplitude is 1)
        else: 
            self.states = states.clone()

        for gate in self.circuit[start_idx:stop_idx]:
            #construct 2x2 gate operator
            if "angles" in gate:
                gate_matrix = self.gates[gate['name']](gate['angles'])
            else:
                gate_matrix = self.gates[gate['name']]
                
            #dispatch an application routine per gate
            if "controls" in gate:
                self.Gate.apply(state=self.states, gate=gate_matrix, target=gate['target'], controls=gate['controls'])
            elif gate['name'] == 'MCM':
                self.Gate.MCM(self.states, target=gate['target'])
            else:
                self.Gate.apply(state=self.states, gate=gate_matrix, target=gate['target'])
    
    def set_size(self, size):
        '''
        Change circuit qubit count
        '''
        self.size = size

    def print_circuit(self):
        
        for gate in self.circuit:
            print(f"Gate: {gate['name']} | Target: {gate['target']} | {f"Control: {gate['control']}" if 'control' in gate else ""}")
    
    #---configure user device---#
    
    def get_user_device(self, device_type, threads):
        #find GPU device
        if str(device_type).lower() == 'gpu':    
            if torch.cuda.is_available():
                #if Nvidia GPU available, use it
                device = torch.device('cuda')
                print("Program using cuda")
            elif torch.backends.mps.is_available():
                #if apple silicon available
                device = torch.device('mps')
                print("Program using mps")
        elif str(device_type).lower() == 'cpu':
            #else just go with CPU
            device = torch.device('cpu')
            #set cpu threads
            torch.set_num_threads(threads)
            print(f"Program using cpu with {threads} threads")
        else:
            device = device_type
        return device

    def change_device(self, new_device):
        self.device = new_device
        #update gate references to new device
        self.Gate = Gate(device=new_device)

In [61]:
def gpu_mem_usage():
    if torch.cuda.is_available():
        mem_used = torch.cuda.memory_allocated()
    elif torch.backends.mps.is_available():
        mem_used = torch.mps.current_allocated_memory()
    else:
        print("Cannot show memory usage for CPU")
        return
    mem_used /= 1024**2
    return mem_used

In [62]:
'''
Save torch tensor directly to binary file
''' 
def save_states(states, filename):
    states.cpu().numpy().tofile(filename)
    print(f"Saved wavefunction states to {filename}")

In [63]:
'''
Read raw binary directly into torch tensor on device
'''
def import_states(filename, dtype=torch.cfloat, device='cpu'):
    with open(filename, "rb") as f:
        buf = f.read()

    tensor = torch.frombuffer(buf, dtype=dtype)
    return tensor.to(device)

In [75]:
n=22
q = Circuit(n, 'gpu')
qcpu = Circuit(n, device='cpu')

for i in range(n):
    q.add('H', target=i)
    q.add('X', target=i)
    q.add('Y', target=i)
    q.add('Z', target=i)
    q.add('T', target=i)
    qcpu.add('H', target=i)
    qcpu.add('X', target=i)
    qcpu.add('Y', target=i)
    qcpu.add('Z', target=i)
    qcpu.add('T', target=i)
print(q.circuit)
t0 = time.time()
q.execute(shots=100000)
t1 = time.time()
qcpu.execute(shots=100000)
t2 = time.time()
print(f"GPU {t1-t0}")
print(f"CPU {t2-t1}")

q.measurements

Program using mps
Program using cpu with 8 threads
[{'target': 0, 'name': 'H'}, {'target': 0, 'name': 'X'}, {'target': 0, 'name': 'Y'}, {'target': 0, 'name': 'Z'}, {'target': 0, 'name': 'T'}, {'target': 1, 'name': 'H'}, {'target': 1, 'name': 'X'}, {'target': 1, 'name': 'Y'}, {'target': 1, 'name': 'Z'}, {'target': 1, 'name': 'T'}, {'target': 2, 'name': 'H'}, {'target': 2, 'name': 'X'}, {'target': 2, 'name': 'Y'}, {'target': 2, 'name': 'Z'}, {'target': 2, 'name': 'T'}, {'target': 3, 'name': 'H'}, {'target': 3, 'name': 'X'}, {'target': 3, 'name': 'Y'}, {'target': 3, 'name': 'Z'}, {'target': 3, 'name': 'T'}, {'target': 4, 'name': 'H'}, {'target': 4, 'name': 'X'}, {'target': 4, 'name': 'Y'}, {'target': 4, 'name': 'Z'}, {'target': 4, 'name': 'T'}, {'target': 5, 'name': 'H'}, {'target': 5, 'name': 'X'}, {'target': 5, 'name': 'Y'}, {'target': 5, 'name': 'Z'}, {'target': 5, 'name': 'T'}, {'target': 6, 'name': 'H'}, {'target': 6, 'name': 'X'}, {'target': 6, 'name': 'Y'}, {'target': 6, 'name': 'Z

tensor([0, 0, 0,  ..., 0, 0, 0], device='mps:0', dtype=torch.int32)

In [45]:
#save_states(x, "saved_states/states.bin")
#print(import_states("saved_states/states.bin", device=device)[0:10])
print(gpu_mem_usage())

660.0146484375
