<a href="https://colab.research.google.com/github/joshtburdick/misc/blob/master/plog/Factoring_using_loopy_belief_propagation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Factoring using loopy belief propagation

Or, "Factoring using factor graphs" `:-)`

Here, we use Fermat's difference-of-squares method.
Let $N$ be the number to factor.

First, we'll find solutions to $N = a^2 - b^2 \mod m$, where $m > N$ is the product of smaller primes $p_i$. (Since this is$\mod m$, and we're squaring $a$ and $b$, this doesn't directly give a factor.)

Having found $a^2$ and $b^2$, we then search through possible values of $a$ and $b$, looking for one such that GCD($a+b$, $N$) is not 1 or N. That GCD (somewhat as in the quadratic sieve) should then divide $N$.


We'll use
- [PGMax](https://github.com/google-deepmind/PGMax) library for loopy belief propagation
- the [modulo](https://github.com/lapets/modulo) library, for calculations related to the Chinese Remainder Theorem


In [1]:
!pip install --quiet pgmax modulo

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/77.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.5/77.5 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/221.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m221.0/221.0 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.0/434.0 kB[0m [31m27.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m76.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m75.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.8/251.8 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

In [2]:
import itertools
import pdb
import math

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pgmax
from pgmax import fgraph, fgroup, infer, vgroup

from modulo import modulo

## Solving for the "difference of squares"

We'll need a standard numbering of the quadratic residues mod a prime
$p$.

In [3]:
def number_quadratic_residues(p):
    """
    Given a prime p, number the quadratic residues mod p.

    Args:
        p: The prime number.

    Returns:
        A dictionary mapping the quadratic residues (non-zero) to their 0-based index.
    """
    quadratic_residues = set()
    for i in range(1, p):
        quadratic_residues.add((i * i) % p)

    # Sort the quadratic residues and create a mapping to 0-based indices
    sorted_residues = sorted(list(quadratic_residues))
    residue_to_index = {residue: i for i, residue in enumerate(sorted_residues)}

    return residue_to_index

In [4]:
# Quick check of this.
number_quadratic_residues(11)

{1: 0, 3: 1, 4: 2, 5: 3, 9: 4}

In [5]:
# Since they're sorted, this is a convenient way to get the i'th one.
list(number_quadratic_residues(11).values())

[0, 1, 2, 3, 4]

The modulus $m$ is the product of many (smallish) primes.
For each pair of primes $p_i$ and $p_j$, we'll need to find the possible values of $a$ and $b$, $\mod p_i p_j$.

In [6]:
def quadratic_residue_table(N, p1, p2, qr_numbering):
    """
    Finds pairs (sa, sb) of quadratic residues modulo p1*p2 such that N = sa - sb (mod p1*p2).

    Args:
        N: The integer to factor.
        p1: The first prime.
        p2: The second prime.
        qr_numbering: A dictionary mapping primes to dictionaries,
           each of which maps quadratic residues to their 0-based index
           for that prime.

    Returns:
        A list of unique 4-tuples of the *indices* (in `qr_numbering`)
        of `(sa % p1, sa % p2, sb % p1, sb % p2)` satisfying the
        conditions.
    """
    qr_1 = qr_numbering[p1]
    qr_2 = qr_numbering[p2]
    m = p1 * p2
    solutions = set()
    for a in range(m):
        sa = (a * a) % m
        for b in range(m):
            sb = (b * b) % m
            if (sa - sb) % m == N % m:
                try:
                    solutions.add(
                        (qr_1[ sa % p1 ], qr_2[ sa % p2 ],
                         qr_1[ sb % p1 ], qr_2[ sb % p2 ]))
                except KeyError:
                    # A KeyError here presumably means that sa or sb
                    # wasn't relatively prime to p1 or p2.
                    pass
    return list(solutions)

In [7]:
quadratic_residue_table(3*5, 11, 17,
  {p: number_quadratic_residues(p) for p in [11,17]})

[(4, 5, 3, 6),
 (3, 1, 0, 2),
 (3, 5, 0, 6),
 (3, 7, 0, 0),
 (4, 7, 3, 0),
 (4, 1, 3, 2)]

In a previous version of this, I was having trouble defining variables with different numbers of possible values using PGMax. Thus, here we just assume the small primes are about the same size, and use the maximum. (The factors will be sparse, anyway.)

In [8]:
### this is all deprecated, as it wasn't working.
# Trying to build the factors.
# First, create the variables.
primes = [7,11,13]
var_names = [f"a_{p}" for p in primes] + [f"b_{p}" for p in primes]
variable_group = vgroup.VarDict(
    num_states = np.array(primes + primes),
    variable_names = var_names
)
print(np.array(primes+primes))
print(var_names)
# Initialize the factor graph.
fg = fgraph.FactorGraph(variable_groups=variable_group)

# Try to add one of the factors.
print(variable_group["a_7"])
print("printed")

factor_group = pgmax.fgroup.EnumFactorGroup(
  variables_for_factors=[variable_group[s]
    for s in ["a_7", "a_11", "b_7", "b_11"]],
  factor_configs=np.eye(4),
)

fg.add_factors(factor_group)
fg

[ 7 11 13  7 11 13]
['a_7', 'a_11', 'a_13', 'b_7', 'b_11', 'b_13']
(np.int64(957625600000000000), np.int64(7))
printed


FactorGraph(variable_groups=[VarDict(num_states=array([ 7, 11, 13,  7, 11, 13]), _hash=957625600000000000, variable_names=['a_7', 'a_11', 'a_13', 'b_7', 'b_11', 'b_13'])])

In [9]:
qr = {p: number_quadratic_residues(p) for p in primes}
[p for p in qr]

[7, 11, 13]

In [10]:
def build_factor_graph(N, primes):
    """
    Builds a factor graph for finding pairs of quadratic residues (sa, sb)
    modulo prod(primes) such that N = sa - sb (mod pi*pj),
    for all possible combinations of primes (pi, pj).

    Args:
        N: The integer to factor.
        primes: A list of small primes.

    Returns a tuple (qr_numbering, variables, fg):
        qr_numbering: A dictionary mapping primes to dictionaries,
           each of which maps quadratic residues to their 0-based index
           for that prime.
        variables: A variable group
        fg: A PGMax factor graph.
    """
    # Standardize list of primes. (FIXME check that they're prime?)
    primes = sorted(list(set(primes)))
    print(f"in build_factor_graph: primes = {primes}")
    # Standard numbering of quadratic residues.
    qr_numbering = {p: number_quadratic_residues(p) for p in primes}
    # Create variables for with one row for a and b, and
    # one column for each prime.
    num_residues = max([len(qr_numbering[p]) for p in primes])
    variables = vgroup.NDVarArray(
        num_states=num_residues, shape=(2, len(primes)))
    # Initialize the factor graph.
    fg = fgraph.FactorGraph(variable_groups=variables)
    # Add a factor for each pair of primes.
    for (i, j) in itertools.combinations(range(len(primes)), 2):
        p_i = primes[i]
        p_j = primes[j]
        matching_qr = quadratic_residue_table(
            N, p_i, p_j, qr_numbering)
        # print((p_i, p_j))
        # print(np.array(matching_qr))
        vars_for_factor = [(0,i), (0,j), (1,i), (1,j)]
        factor = pgmax.factor.EnumFactor(
            variables=[variables[s] for s in vars_for_factor],
            factor_configs=np.array(matching_qr),
            log_potentials=np.zeros(len(matching_qr))
        )
        fg.add_factors(factor)
    return qr_numbering, variables, fg

## Searching for $a$ and $b$ with a common factor

So far, we've solved for $a^2$ and $b^2$; but squaring has lost the information about what $a$ and $b$ are.

In the original [Quadratic sieve](https://en.wikipedia.org/wiki/Quadratic_sieve) method, we find $a$ and $b$, then find the greatest common divisor of $N$ with $a+b$ and $a-b$.
In this case, we know what $a^2$ is ($\mod p_i$), so we need to consider $\pm a \mod p_i$. We need to consider these for all $p_i$ (and similarly for $b$).

Empirically, amongst all of these, there's often been a nontrivial GCD with $N$. (FIXME prove this?) So, it's a cost which is exponential in the number of primes $p_i$; but it is at least deterministic.

??? is it enough to just check whether for some $A$ and $B$, $A+B$ divides $N$?

In [11]:
def solve_mod_primes(x_mod, primes):
    """Given what x is (mod some primes), solve for x.

    x_mod: an array of small integers, such that x % primes[i] == x_mod[i]
    primes: an array of primes

    Returns: x, in the range 1 <= x <= product(primes),
        satisfying x % primes[i] == xmod[i].
    """
    x = modulo(x_mod[0], primes[0])
    for i in range(1, len(primes)):
        x &= modulo(x_mod[i], primes[i])
    return int(x)


In [12]:
# a quick check of this
print(primes)
x = solve_mod_primes(primes, [2,3,4]) % np.prod(primes)
print(x)
print(x % primes)

[7, 11, 13]
5
[5 5 5]


In [13]:
# some itertools practice
a = [[1,-1], [4,-4], [5,-5]]
list(itertools.product(*a))

[(1, 4, 5),
 (1, 4, -5),
 (1, -4, 5),
 (1, -4, -5),
 (-1, 4, 5),
 (-1, 4, -5),
 (-1, -4, 5),
 (-1, -4, -5)]

In [14]:
def search_for_common_factors(qr_numbering, N, ab):
    """Searches for a factor of N, based on the values of a^2 and b^2.

    This finds all A = +/- a mod p_i (and similarly for B),
    mod all the p_i, computes GCD(N, A+B), and checks whether it's
    nontrivial (that is, not 1 or N.)

    qr_numbering: the numbering of quadratic residues
    N: the number to factor
    ab: an array of shape (2, len(primes)), containing the
      indices of a^2 and b^2
    """
    primes = list(qr_numbering.keys())
    print(f"in search_for_common_factors: primes={primes}")
    print(f"ab = {ab}")
    def square_root_mod(ix2, p):
        """Given the index of x^2, computes +/-x (mod p).

        Here, "ix2" refers to the index of x2 in the
        numbering of quadratic residues."""
        x2 = list(qr_numbering[p].keys())[ix2]
        x = [i for i in range(1, p) if (i * i) % p == x2]
        print(f"p={p}, ix2={ix2}, x2={x2}, x={x}")
        return x
    # Get +/- a and b
    # pdb.set_trace()
    a = [square_root_mod(i, p) for (i,p) in zip(ab[0], primes)]
    b = [square_root_mod(i, p) for (i,p) in zip(ab[1], primes)]
    print(a)
    print(b)
    # get all possible a +/- b ("generalized", for however
    # many prime factors)
    a_mod_m = [solve_mod_primes(a1, primes)
        for a1 in itertools.product(*a)]
    b_mod_m = [solve_mod_primes(b1, primes)
        for b1 in itertools.product(*b)]
    # check GCD of each of these
    f = None
    for (a1,b1) in itertools.product(a_mod_m, b_mod_m):
        f = math.gcd(N, a1+b1)
        if f != 1 and f != N:
            print(f"f = {f}")
            break
    return {"f": f,
            "a": a,
            "b": b}

# Task
Create a Python function that takes an integer `N` and a list of small primes `primes` as input. The function should build a factor graph using the PGMax library to solve the congruence $N \equiv a^2 - b^2 \pmod{m}$, where $m$ is the product of the primes in the input list. The function should then use inference on the factor graph to find values for $a$ and $b$ modulo $m$, and finally use these values to find factors of $N$ by computing $\text{GCD}(a+b, N)$. The function should return the found factors of $N$.

## Define the main function

### Subtask:
Create a function that accepts the number to factor `N` and a list of small primes `primes`.


**Reasoning**:
Define the `factor_with_pgmax` function with the specified arguments and docstring.



In [15]:
def factor_with_pgmax(N, primes, num_iters=100, seed=0):
    """
    Factors an integer N using the PGMax library by solving the congruence
    N = a^2 - b^2 (mod m), where m is the product of the given primes.

    Args:
        N: The integer to factor.
        primes: A list of small primes whose product forms the modulus m.
        num_iters: The number of iterations to run the inference algorithm.
        seed: The seed for the random number generator.

    Returns:
        A list of factors of N found using GCD(a+b, N), or an empty list if no
        factors are found.
    """
    m = np.prod(primes)
    qr_numbering, variables, fg = build_factor_graph(N, primes)
    bp = infer.build_inferer(fg.bp_state, backend="bp")

    # Initialize PRNG key
    key = jax.random.PRNGKey(seed)
    # This is basically in the Ising example.
    num_states = fg.variable_groups[0].num_states[0,0]
    evidence_updates={variables: jax.random.gumbel(
        key,
        shape=list(fg.variable_groups[0].shape) + [num_states]
    )}
    # Run MAP inference
    inferer_arrays = bp.init(evidence_updates=evidence_updates)
    # inferer_arrays = bp.init()
    inferer_arrays, msgs_deltas = bp.run_with_diffs(inferer_arrays, num_iters=3000, temperature=0)
    # Compute the beliefs
    beliefs = bp.get_beliefs(inferer_arrays)
    print("computed beliefs")
    # Get the MAP states
    map_states = infer.decode_map_states(beliefs)
    print("decoded map states")
    # Compute the energy of the decoding
    # FIXME I think this is throwing a FutureWarning; I'm not sure why.
    decoding_energy = (
        infer.compute_energy(fg.bp_state, inferer_arrays, map_states)[0]
    )
    # return map_states[variables], decoding_energy
    ab = np.array(map_states[variables])
    print(f"map_states[variables] = {ab}")
    print(f"type(ab) = {type(ab)}")
    # search for a factor, based on this
    return search_for_common_factors(qr_numbering, N, ab)


In [16]:
# quick test of this
factor_with_pgmax(19*23, [7,11,13,17])

in build_factor_graph: primes = [7, 11, 13, 17]
computed beliefs
decoded map states




map_states[variables] = [[2 0 5 2]
 [0 2 2 4]]
type(ab) = <class 'numpy.ndarray'>
in search_for_common_factors: primes=[7, 11, 13, 17]
ab = [[2 0 5 2]
 [0 2 2 4]]
p=7, ix2=2, x2=4, x=[2, 5]
p=11, ix2=0, x2=1, x=[1, 10]
p=13, ix2=5, x2=12, x=[5, 8]
p=17, ix2=2, x2=4, x=[2, 15]
p=7, ix2=0, x2=1, x=[1, 6]
p=11, ix2=2, x2=4, x=[2, 9]
p=13, ix2=2, x2=4, x=[2, 11]
p=17, ix2=4, x2=9, x=[3, 14]
[[2, 5], [1, 10], [5, 8], [2, 15]]
[[1, 6], [2, 9], [2, 11], [3, 14]]
f = 23


{'f': 23,
 'a': [[2, 5], [1, 10], [5, 8], [2, 15]],
 'b': [[1, 6], [2, 9], [2, 11], [3, 14]]}

## Summary
This found a factor of 19*23, at least when I ran it.

There isn't any good reason to think that the first-phase SAT problem is any easier than, say, using a SAT solver on factoring encoded using a multiplication circuit. Nonetheless, it might be interesting to see how LBP does on it (empirically).