This repository has been archived by the owner on Feb 8, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantum.py
58 lines (51 loc) · 1.92 KB
/
quantum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# SPDX-License-Identifier: MIT
# Copyright : JP Morgan Chase & Co and QC Ware
from typing import Callable, List, Tuple
import numpy as np
from jax import numpy as jnp
from jax import scipy as jsp
def make_ortho_fn(
rbs_idxs: List[List[Tuple[int, int]]],
num_qubits: int,
) -> Callable:
"""
Args:
rbs_idxs: List of RBS indices.
num_qubits: The total number of qubits.
Returns:
A pure function that maps a set of parameters to an orthogonal matrix (compound 1).
"""
rbs_idxs = [list(map(list, rbs_idx)) for rbs_idx in rbs_idxs]
len_idxs = np.cumsum([0] + list(map(len, rbs_idxs)))
def get_rbs_matrix(theta):
"""Returns the matrix for the RBS gate."""
cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
matrix = jnp.array(
[
[cos_theta, sin_theta],
[-sin_theta, cos_theta],
]
)
matrix = matrix.transpose(*[*range(2, matrix.ndim), 0, 1])
return matrix
def orthogonal_fn(thetas):
"""Returns the orthogonal matrix for the given parameters."""
matrices = []
# Compute the matrix for each layer
for i, idxs in enumerate(rbs_idxs):
idxs = sum(idxs, [])
sub_thetas = thetas[len_idxs[i] : len_idxs[i + 1]]
rbs_blocks = get_rbs_matrix(sub_thetas)
eye_block = jnp.eye(num_qubits - len(idxs), dtype=thetas.dtype)
permutation = idxs + [i for i in range(num_qubits) if i not in idxs]
permutation = np.argsort(permutation)
matrix = jsp.linalg.block_diag(*rbs_blocks, eye_block)
matrix = matrix[permutation][:, permutation]
matrices.append(matrix)
matrices = jnp.stack(matrices)
if len(matrices) > 1:
matrix = jnp.linalg.multi_dot(matrices[::-1])
else:
matrix = matrices[0]
return matrix
return orthogonal_fn