In [41]:
import torch
import math
import cmath

In [45]:
qubit_0 = torch.tensor([1.0, 1.0],dtype=torch.complex64)

H = [[(1/math.sqrt(2)), (1/math.sqrt(2))], [(1/math.sqrt(2)), -(1/math.sqrt(2))]]
theta_value = 2*math.pi

hadamard_gate = torch.tensor(H, dtype=torch.complex64)

rx_gate = torch.tensor([[math.cos(theta_value/2), -1j*math.sin(theta_value/2)],
                        [-1j*math.sin(theta_value/2), math.cos(theta_value/2)]], dtype=torch.complex64)

ry_gate = torch.tensor([[math.cos(theta_value/2), -math.sin(theta_value/2)], [math.sin(theta_value/2), math.cos(theta_value/2)]], dtype=torch.complex64)

rz_gate = torch.tensor([[-cmath.exp(-1.0j*theta_value/2), 0.0], [0.0, cmath.exp(-1.0j*theta_value/2)]], dtype=torch.complex64)
 
pauli_x_gate = torch.tensor([[0.0, 1.0],[1.0 ,0.0]], dtype=torch.complex64)

pauli_y_gate = torch.tensor([[0.0, -1.0j], [1.0j, 0.0]],dtype=torch.complex64)
pauli_z_gate = torch.tensor([[1.0, 0.0], [0.0,-1.0]], dtype=torch.complex64)

phase_gate = torch.tensor([[1.0, 0.0], [0.0, 1.0j]], dtype=torch.complex64)
t_gate = torch.tensor([[1.0, 0.0], [0.0, cmath.exp(math.pi*1.0j/4)]], dtype=torch.complex64)

In [46]:
gate_operation = {
    'pauli_x': pauli_x_gate,
    'pauli_y': pauli_y_gate,
    'pauli_z': pauli_z_gate,
    'rx_gate': rx_gate,
    'ry_gate': ry_gate,
    'rz_gate': rz_gate,
    'phase': phase_gate,
    't': t_gate,
    'hadamard': hadamard_gate
}

In [49]:
run_circuit(['rx_gate', 'pauli_x'], qubit_0)

tensor([-1.-1.2246e-16j, -1.-1.2246e-16j])

In [None]:
def run_circuit(tensor_expression: list, initial_state: torch.Tensor) -> torch.Tensor:
    next_state = initial_state
    for gate_name in tensor_expression:
        gate = gate_operation[gate_name]
        next_state = torch.kron(gate, next_state)
    
    return next_state

In [56]:
no_qubits = 2
initial_state_vector = torch.eye(no_qubits, dtype=torch.complex64)
quantum_processing = run_circuit(['hadamard', 'rx_gate', 'pauli_x', 't'], initial_state_vector)
quantum_processing

tensor([[-0.7071-8.6596e-17j,  0.7071-8.6596e-17j],
        [-0.5000-5.0000e-01j, -0.5000-5.0000e-01j]])