In [2]:
from typing import Optional
import numpy as np

INT32_MIN = np.iinfo(np.int32).min
INT32_MAX = np.iinfo(np.int32).max


def uniform_sample_int32(size: int) -> np.ndarray:
    return np.random.randint(
        low=INT32_MIN,
        high=INT32_MAX + 1,
        size=size,
        dtype=np.int32,
    )


def gaussian_sample_int32(std: float, size: Optional[float]) -> np.ndarray:
    return np.int32(INT32_MAX * np.random.normal(loc=0.0, scale=std, size=size))

print(INT32_MIN)
print(INT32_MAX)
print(uniform_sample_int32(1))
print(gaussian_sample_int32(0.01, 1))

-2147483648
2147483647
[1281071357]
[18756883]


In [3]:
import dataclasses
import numpy as np

@dataclasses.dataclass
class LweConfig:
    # Size of the LWE encryption key.
    dimension: int

    # Standard deviation of the encryption noise.
    noise_std: float

@dataclasses.dataclass
class LwePlaintext:
    message: np.int32


@dataclasses.dataclass
class LweCiphertext:
    config: LweConfig
    a: np.ndarray  # An int32 array of size config.dimension
    b: np.int32


@dataclasses.dataclass
class LweEncryptionKey:
    config: LweConfig
    key: np.ndarray  # An int32 array of size config.dimension


def generate_lwe_key(config: LweConfig) -> LweEncryptionKey:
    return LweEncryptionKey(
        config=config,
        key=np.random.randint(
            low=0, high=2, size=(config.dimension,), dtype=np.int32
        ),
    )


def lwe_encrypt(
    plaintext: LwePlaintext, key: LweEncryptionKey
) -> LweCiphertext:
    a = uniform_sample_int32(size=key.config.dimension)
    noise = gaussian_sample_int32(std=key.config.noise_std, size=None)

    # b = (a, key) + message + noise
    b = np.add(np.dot(a, key.key), plaintext.message, dtype=np.int32)
    b = np.add(b, noise, dtype=np.int32)

    return LweCiphertext(config=key.config, a=a, b=b)


def lwe_decrypt(
    ciphertext: LweCiphertext, key: LweEncryptionKey
) -> LwePlaintext:
    return LwePlaintext(
        np.subtract(ciphertext.b, np.dot(ciphertext.a, key.key), dtype=np.int32)
    )

In [4]:
# Lattice Estimator https://github.com/malb/lattice-estimator

LWE_CONFIG = LweConfig(dimension=1024, noise_std=2**(-24))

# Generate an LWE key.
key = generate_lwe_key(LWE_CONFIG)

# This is the plaintext that we will encrypt.
plaintext = LwePlaintext(2**29)

# Encrypt the plaintext 1000 times and store the error of each ciphertext.
errors = []
for _ in range(1000):
    ciphertext = lwe_encrypt(plaintext, key)
    errors.append(lwe_decrypt(ciphertext, key).message - plaintext.message)

In [5]:
def encode(i: int) -> np.int32:
    """Encode an integer in [-4, 4) as an int32"""
    return np.multiply(i, 1 << 29, dtype=np.int32)


def decode(i: np.int32) -> int:
    """Decode an int32 to an integer in the range [-4, 4) mod 8"""
    d = int(np.rint(i / (1 << 29)))
    return ((d + 4) % 8) - 4

def lwe_encode(i: int) -> LwePlaintext:
    """Encode an integer in [-4,4) as an LWE plaintext."""
    return LwePlaintext(encode(i))


def lwe_decode(plaintext: LwePlaintext) -> int:
    """Decode an LWE plaintext to an integer in [-4,4) mod 8."""
    return decode(plaintext.message)

In [6]:
def lwe_add(
    ciphertext_left: LweCiphertext,
    ciphertext_right: LweCiphertext) -> LweCiphertext:
    """Homomorphic addition evaluation.

       If ciphertext_left is an encryption of m_left and ciphertext_right is
       an encryption of m_right then return an encryption of
       m_left + m_right.
    """
    return LweCiphertext(
        ciphertext_left.config,
        np.add(ciphertext_left.a, ciphertext_right.a, dtype=np.int32),
        np.add(ciphertext_left.b, ciphertext_right.b, dtype=np.int32))

def lwe_subtract(
    ciphertext_left: LweCiphertext,
    ciphertext_right: LweCiphertext) -> LweCiphertext:
    """Homomorphic subtraction evaluation.

       If ciphertext_left is an encryption of m_left and ciphertext_right is
       an encryption of m_right then return an encryption of
       m_left - m_right.
    """
    return LweCiphertext(
        ciphertext_left.config,
        np.subtract(ciphertext_left.a, ciphertext_right.a, dtype=np.int32),
        np.subtract(ciphertext_left.b, ciphertext_right.b, dtype=np.int32))

In [13]:
LWE_CONFIG = LweConfig(dimension=1024, noise_std=2**(-24)) 

key = generate_lwe_key(LWE_CONFIG)

m_1 = lwe_encode(1)
m_2 = lwe_encode(2)

c_1 = lwe_encrypt(m_1, key)
c_2 = lwe_encrypt(m_2, key)

c_3 = lwe_add(c_1, c_2)
m_3 = lwe_decrypt(c_3, key)

print(lwe_decode(m_1) + lwe_decode(m_2))
print(lwe_decode(m_3))

3
3
