# Algorithms 5.1 and 5.2 - Solving the Quaternion Embedding Problem

The paper presents algorithms solving the *Quaternion Embedding Problem*. This file contains implementations of Algorithms 5.1 and 5.2 for finding embeddings of quadratic orders within quaternion orders.

**Algorithm 5.1** can return one embedding, return all embeddings, or determine if no such embeddings exist. It should be used when values of size roughly $N(I)^2 d$ are easy to factor (the paper assumes a factorization oracle). Here $N(I)$ is the norm of the smallest connecting ideal to $\mathcal{O}_0$, which for a random order is about $\sqrt{p}$, and $d$ is the norm of the generator of the quadratic order.  

**Algorithm 5.2** can return one embedding or determine if none exist. In worst cases it will not terminate in a reasonable amount of time. On the plus side, the rerandomization trick means some hard factorizations can be skipped, so it can work with much larger parameters. For quadratic orders with discriminants less than $O(p)$ it is very fast on average, even for cryptographically sized $p$. Examples are included.  

For only finding primitive embeddings see the file: `Experiments - Primitive solutions to Quaternion Embedding Problem.ipynb`

We include our own implementation of Cornacchia's algorithm in `cornacchia.py` for finding all solutions to the equation $x^2 + qy^2 = v$ where $gcd(q,v)=1$.

## Algorithm 5.1

In [5]:
from sage.all import ZZ, GF, lcm, sqrt, floor
from cornacchia import all_cornacchia
from numpy import argmax
from hnf import lower_hnf_basis
from ideals import small_equivalent_ideal

def Fp_to_int(n):
    """
    Returns element of GF(p) as integer in range (-p/2 ... +p/2]
    """
    if ZZ(n) > n.parent().order() / 2: return ZZ(n) - n.parent().order()
    return ZZ(n)

def find_element_defining_embedding(O, d, t, all_slns=False, filter_func=None):
    """
    Finds an element in quaternion order 'O' with trace 't' and norm 'd'. Set 'all_slns=True' to get all solutions.
        A function can be provided as 'filter_func' which is called when a solution is found to see if it should be counted or not. We use this for filtering primitive solutions.
    """
    slns = [] if all_slns else None
    # Compute the connecting ideal, and find smaller equivalent ideal, to give right order with lower N
    B = O.quaternion_algebra()
    I = B.maximal_order() * O
    J, y = small_equivalent_ideal(I)
    O_new = J.right_order()
    # Put basis in HNF
    basis_hnf = lower_hnf_basis(B, O_new.basis())
    e00, e01, e02, e03 = basis_hnf[0]
    _,   e11, e12, e13 = basis_hnf[1]
    _,   _,   e22, e23 = basis_hnf[2]
    _,   _,   _,   e33 = basis_hnf[3]
    if (e00 == 0) or (e11 == 0) or (e22 == 0) or (e33 == 0):
        return slns
    # Find alpha_0
    alpha_0 = t / (2 * e00)
    if (alpha_0 not in ZZ) or (d not in ZZ):
        return slns
    # Compute a, b, N
    q, p = [ZZ(abs(l)) for l in B.invariants()]
    N = lcm([e.denominator() for e in [e00,e01,e02,e03,e11,e12,e13,e22,e23,e33]])
    N2 = N**2
    # Find residues of alpha_1 mod p
    Fp = GF(p)
    sq_mod_p = Fp(d - (alpha_0 * e00)**2) / Fp(q)
    rt1 = sqrt(sq_mod_p)
    if rt1 not in Fp:
        return slns
    rt2 = -rt1
    residues = [Fp_to_int((rt1 - Fp(alpha_0 * e01)) / Fp(e11)), Fp_to_int((rt2 - Fp(alpha_0 * e01)) / Fp(e11))]
    # compute maximum value of k - for each residue
    temp1 = d - (alpha_0**2)*(e00**2)
    temp1_scaled = N2 * temp1
    temp2 = sqrt(temp1 / q) - alpha_0*e01
    ks = [floor((temp2 - ZZ(r)*e11)/(p*e11)) for r in residues]
    # loop over k decreasing, for each residue
    max_iter = sum([k + 1 for k in ks if k >= 0])
    while max(ks) >= 0:
        k_index = argmax(ks)
        k = ks[k_index]
        r = residues[k_index]
        ks[k_index] = ks[k_index] - 1
        # Compute u and v (v = RHS for Cornacchia)
        alpha_1 = ZZ(r) + k*p
        gamma_1 = alpha_0*e01 + alpha_1*e11
        u = q * N2 * gamma_1**2
        v = ZZ((temp1_scaled - u) / p)
        # find all solutions to Cornacchia's
        betas = all_cornacchia(q, v)
        for beta_pair in betas:
            # Check if this gives a solution with integral alpha_2 and alpha_3
            alpha_2 = (beta_pair[0] - N*alpha_1*e12 - N*alpha_0*e02) / (N*e22)
            alpha_3 = (beta_pair[1] - N*alpha_1*e13 - N*alpha_2*e23 - N*alpha_0*e03) / (N*e33)
            if (alpha_2 in ZZ) and (alpha_3 in ZZ):
                alpha = alpha_0*basis_hnf[0] + alpha_1*basis_hnf[1] + alpha_2*basis_hnf[2] + alpha_3*basis_hnf[3]
                alpha_in_O = y * alpha * y**(-1) # map alpha back in to original order
                valid_sln = True
                if filter_func != None:
                    valid_sln = filter_func(alpha_in_O, k)
                if valid_sln:
                    if all_slns: slns.append(alpha_in_O)
                    if not all_slns: return alpha_in_O
    return slns

### Example 1:

In [6]:
p = 1000003
B.<i,j,k> = QuaternionAlgebra(-1, -p)
O = B.quaternion_order([1/2 + 13/2*j + 19*k, 1/858*i + 1714/429*j + 641/66*k, 13*j + 5*k, 33*k])

d = 21174601658
t = 21
find_element_defining_embedding(O, d, t, all_slns=True)

[21/2 + 76000229/858*i + 94447/858*j - 2297/66*k,
 21/2 + 24000073/858*i + 28355/858*j - 9169/66*k,
 21/2 + 4000013/858*i + 114475/858*j - 3821/66*k,
 21/2 + 1/858*i - 124843/858*j + 113/66*k,
 21/2 - 1/858*i + 124843/858*j - 113/66*k]

### Example 2: No solutions

In [7]:
d = 21174601658
t = 0
print(find_element_defining_embedding(O, d, t))

None


## Algorithm 5.2 with examples $disc(\mathbb{Z}[\omega]) \sim O(p)$

In [8]:
from sage.all import matrix, diagonal_matrix, vector, is_pseudoprime, QQ, ZZ, GF, lcm, sqrt, log, ceil, floor, denominator, round, random_prime, next_prime, QuaternionAlgebra, kronecker, ceil, DiagonalQuadraticForm
import itertools
from numpy import argmax
from cornacchia import all_cornacchia
from hnf import lower_hnf_basis
from algorithm_5_1 import Fp_to_int
from random import randint

def factors_easily(n, B=2**20):
    """
    Given a number n, checks if a n is "Cornacchia Friendly" (= easily factorable)
    """
    n = ZZ(n)
    if n < 0: return False
    if n < 2**160: return True
    l,_ = n.factor(limit=B)[-1]
    return l < 2**160 or is_pseudoprime(l)

def quat_algs(p):
    """
    Generate 3 isomorphic quaternion algebras ramified at p \neq 2 and infity, with abs(i^2) small
    """
    Bs = []
    mod = 4
    if p % 4 == 3:
        Bs.append(QuaternionAlgebra(-1, -p))
    q = 1
    while len(Bs) < 3:
        q = next_prime(q)
        if (-q) % mod == 1 and kronecker(-q, p) == -1:
            Bs.append(QuaternionAlgebra(-q, -p))
    assert all([B.ramified_primes() == [p] for B in Bs ])
    return Bs

def isomorphism_gamma(B_old, B):
    """
    Defines an isomorphism of quaternion algebras, See Lemma 10 [EPSV23]
    """
    if B_old == B:
        return B(1), B(1)
    i_old, j_old, k_old = B_old.gens()
    q_old = -ZZ(i_old**2)
    i, j, k = B.gens()
    q = -ZZ(i**2) 
    p = -ZZ(j**2)
    x, y = DiagonalQuadraticForm(QQ, [1,p]).solve(q_old/q)
    return x + j*y, (x + j_old*y)**(-1)

def eval_isomorphism(alpha, B, gamma):
    """
    Evaluates a quaternion in an isomorphism of quaternion algebras
    """
    i, j, k = B.gens()
    return sum([coeff*b for coeff, b in zip(alpha.coefficient_tuple(), [1, i*gamma, j, k*gamma])])

def find_element_defining_embedding_randomized(O, d, t, filter_func=None):
    """
        Continuously randomizes basis for the order O, until an element of trace t norm d can be found without doing any hard factorizations.
            Maps the element back into the starting order and returns it.
        
        Returns two values:
        - The element of trace t norm d.
        - A "confidence" boolean.
            If no element found, True if we're certain it's not possible an element exists. False if it still might be possible as we may have skipped it.
    """
    def find_element_defining_embedding_with_skips(O, d, t):
        """
        Attempts to find an element in quaternion order 'O' with trace 't' and norm 'd', but may miss solutions.
            In solving x^2+|a|y^2=v with Coracchias, skips if v if it is hard to factor.
            Returns solution, and 'confidence' boolean that is True if no solutions have been skipped.
        """
        # Put basis in HNF
        basis_hnf = lower_hnf_basis(B, O.basis())
        e00, e01, e02, e03 = basis_hnf[0]
        _,   e11, e12, e13 = basis_hnf[1]
        _,   _,   e22, e23 = basis_hnf[2]
        _,   _,   _,   e33 = basis_hnf[3]
        if (e00 == 0) or (e11 == 0) or (e22 == 0) or (e33 == 0):
            return None, True
        # Find alpha_0
        alpha_0 = t / (2 * e00)
        if (alpha_0 not in ZZ) or (d not in ZZ):
            # If none works, we're confident no solution exists in any conjugate order
            return None, True
        # Compute N
        N = lcm([e.denominator() for e in [e00,e01,e02,e03,e11,e12,e13,e22,e23,e33]])
        N2 = N**2
        # Find residues of alpha_1 mod p
        Fp = GF(p)
        sq_mod_p = Fp(d - (alpha_0 * e00)**2) / Fp(q)
        rt1 = sqrt(sq_mod_p)
        if rt1 not in Fp:
            return None, True
        rt2 = -rt1
        residues = [Fp_to_int((rt1 - Fp(alpha_0 * e01)) / Fp(e11)), Fp_to_int((rt2 - Fp(alpha_0 * e01)) / Fp(e11))]
        # compute maximum value of k - for each residue
        temp1 = d - (alpha_0**2)*(e00**2)
        temp1_scaled = N2 * temp1
        temp2 = sqrt(temp1 / q) - alpha_0*e01
        ks = [floor((temp2 - ZZ(r)*e11)/(p*e11)) for r in residues]
        # loop over k decreasing, for each residue
        max_iter = sum([k + 1 for k in ks if k >= 0])
        skipped_v = False
        while max(ks) >= 0:
            k_index = argmax(ks)
            k = ks[k_index]
            r = residues[k_index]
            ks[k_index] = ks[k_index] - 1
            # Compute u and v (v = RHS for Cornacchia)
            alpha_1 = ZZ(r) + k*p
            gamma_1 = alpha_0*e01 + alpha_1*e11
            u = q * N2 * gamma_1**2
            v = ZZ((temp1_scaled - u) / p)
            if factors_easily(v):
                # find all solutions to Cornacchia's
                betas = all_cornacchia(q, v)
                for beta_pair in betas:
                    # Check if this gives a solution with integral alpha_2 and alpha_3
                    alpha_2 = (beta_pair[0] - N*alpha_1*e12 - N*alpha_0*e02) / (N*e22)
                    alpha_3 = (beta_pair[1] - N*alpha_1*e13 - N*alpha_2*e23 - N*alpha_0*e03) / (N*e33)
                    if (alpha_2 in ZZ) and (alpha_3 in ZZ):
                        alpha = alpha_0*basis_hnf[0] + alpha_1*basis_hnf[1] + alpha_2*basis_hnf[2] + alpha_3*basis_hnf[3]
                        valid_sln = True
                        if filter_func != None:
                            valid_sln = filter_func(alpha, k)
                        if valid_sln:
                            return alpha, (not skipped_v)
            else:
                skipped_v = True
        # If we didn't skip any v's we know no solution exists
        return None, (not skipped_v)

    p = -ZZ(O.quaternion_algebra().gens()[1]**2)
    Bs = quat_algs(p)
    for B in Bs:
        # The maximal order with small denominator O_0
        O0 = B.maximal_order()
        # Compute isomorphism between the quat algebras
        gamma, gamma_inv = isomorphism_gamma(O.quaternion_algebra(), B)
        # Transfer the maximal order to new quaternion algebra
        O_in_new_quatalg = B.quaternion_order([eval_isomorphism(alpha, B, gamma) for alpha in O.gens()])
        print(f"\nFinding solution in {O_in_new_quatalg}")
        q, p = [ZZ(abs(l)) for l in B.invariants()]
        # Find connecting ideal
        I = O0 * O_in_new_quatalg
        I = I * denominator(I.norm())
        # Reduced basis to find other small equivalent ideals, which gives suitable isomorphisms of O
        basis_hnf = lower_hnf_basis(B, I.basis())
        M = matrix(QQ, [ai.coefficient_tuple() for ai in basis_hnf])
        S = 2**ceil(log(p*q, 2))
        D = diagonal_matrix(round(S * sqrt(g.reduced_norm())) for g in B.basis())
        reduced_basis = (M * D).LLL() * ~D
        # Define constants for conjugating order
        used = []
        max_size = round(p**(1/1.8),5) + 10
        bound = max(round(log(p,2)/10), 10)
        # Try a bunch of small connecting ideals
        for (a1,a2,a3,a4) in itertools.product(range(0,bound+1), range(-bound,bound+1), range(-bound,bound+1), range(-bound,bound+1)):
            coeffvec = vector(QQ, [a1,a2,a3,a4])
            y = coeffvec * reduced_basis * vector(B.basis())
            Jnorm = y.reduced_norm() / I.norm()
            if y in used or Jnorm > max_size:
                continue
            used.append(y)
            y = y.conjugate() / I.norm()
            J = I * y
            beta, confidence = find_element_defining_embedding_with_skips(J.right_order(), d, t)
            if beta:
                beta_new =  y * beta * y**(-1)
                return eval_isomorphism(beta_new, O.quaternion_algebra(), gamma_inv)
            else:
                if confidence: return None, True
    return None, False

All examples we give below are for cryptographically sized $p$ with the discriminant of the quadratic order approximately the size of $p$.

### Example 1:

We can sample a trace zero element as follows

In [9]:
def small_trace0(O):
    """
    Generate a random short element of O with trace 0 (this is only for used for testing)
    """
    M_O = Matrix(QQ, [ai.coefficient_tuple() for ai in O.gens()])
    M_t0 = Matrix([row[1:] for row in M_O[1:]])
    i,j,k = O.quaternion_algebra().gens()
    p = -ZZ(j^2)
    q = -ZZ(i^2)        
    S = 2**ceil(log(p*q, 2))
    D = diagonal_matrix(round(S * sqrt(g.reduced_norm())) for g in [i,j,k])
    shortM_t0 = (M_t0 * D).LLL() * ~D
    bound = floor(p^(1/7.5)) # Should make elements around size p (or slightly smaller than)
    coeffvec = vector(QQ, [randint(-bound, bound) for _ in range(3)])
    omega = coeffvec*shortM_t0*vector([i,j,k])
    return omega

For $250$ bit prime $p$, and a random order $\mathcal{O}$ we sample a random element giving discriminant close to $p$, and use our algorithm to recover it:

In [10]:
p = next_prime(2^250)
while p%4 != 3:
    p = next_prime(p)
B = QuaternionAlgebra(-1, -p)
O0 = B.maximal_order()
i,j,k = B.gens()
O = B.quaternion_order((1/2 + 1/2*j + 5312503554041246563877374554369270532*k, 1/41596753742804676578982348015068861470*i + 5101459939357330973899243766421925768/20798376871402338289491174007534430735*j + 454343285072474849939624336299261232984629048508402945760071971088381943711/41596753742804676578982348015068861470*k, j + 10625007108082493127754749108738541064*k, 20798376871402338289491174007534430735*k))

omega = small_trace0(O)
d = omega.reduced_norm()
t = omega.reduced_trace()

In [11]:
rec = find_element_defining_embedding_randomized(O, d, t)
rec


Finding solution in Order of Quaternion Algebra (-1, -1809251394333065553493296640760748560207343510400633813116524750123642651047) with base ring Rational Field with basis (1/2 + 1/2*j + 5312503554041246563877374554369270532*k, 1/41596753742804676578982348015068861470*i + 5101459939357330973899243766421925768/20798376871402338289491174007534430735*j + 454343285072474849939624336299261232984629048508402945760071971088381943711/41596753742804676578982348015068861470*k, j + 10625007108082493127754749108738541064*k, 20798376871402338289491174007534430735*k)


1048106827405728162616907686812226712536483474650420676505816475483609578/20798376871402338289491174007534430735*i + 41549232588444294597371750840757188/20798376871402338289491174007534430735*j + 26410973976346054086050907419311328/20798376871402338289491174007534430735*k

### Example 2: No solutions - determined immediatley

In some cases we can be confident it has no solutions. E.g. If setting $\alpha_0$ for the trace doesn't give an $\alpha_0 \in \mathbb{Z}$.

In [12]:
d = 13385751525286023724122722290538230793986387607736069527954467367683793
rec = find_element_defining_embedding_randomized(O, d, 0)
rec


Finding solution in Order of Quaternion Algebra (-1, -1809251394333065553493296640760748560207343510400633813116524750123642651047) with base ring Rational Field with basis (1/2 + 1/2*j + 5312503554041246563877374554369270532*k, 1/41596753742804676578982348015068861470*i + 5101459939357330973899243766421925768/20798376871402338289491174007534430735*j + 454343285072474849939624336299261232984629048508402945760071971088381943711/41596753742804676578982348015068861470*k, j + 10625007108082493127754749108738541064*k, 20798376871402338289491174007534430735*k)


(None, True)

### Example 3: No solutions - takes a few randomizations

Another way we can be confident no solutions exist, is if for one rerandomization, we do not skip any of the $v$'s and still do not find a solution. This happens with the following example (takes about $40$ seconds):

In [13]:
d = 13385751525286023724122722290538230793986387607736069527954467367683794
rec = find_element_defining_embedding_randomized(O, d, 0)
rec


Finding solution in Order of Quaternion Algebra (-1, -1809251394333065553493296640760748560207343510400633813116524750123642651047) with base ring Rational Field with basis (1/2 + 1/2*j + 5312503554041246563877374554369270532*k, 1/41596753742804676578982348015068861470*i + 5101459939357330973899243766421925768/20798376871402338289491174007534430735*j + 454343285072474849939624336299261232984629048508402945760071971088381943711/41596753742804676578982348015068861470*k, j + 10625007108082493127754749108738541064*k, 20798376871402338289491174007534430735*k)


(None, True)

Note that when it doesn't terminate, it means either means it is unlikely there is a solution, or the parameters are too big that too many values cannot be factorized and are being skipped.