# CS295/395: Secure Distributed Computation
## Homework 8

## Useful Definitions for FV FHE Scheme

Reference: [Somewhat Practical Fully Homomorphic Encryption](https://eprint.iacr.org/2012/144.pdf) (2012).

In [None]:
import random
import numpy as np
from collections import defaultdict
from collections import namedtuple

#q = 2**13 - 1
q = 2**32
p = q**3

t = 2

def noise():
    return random.randint(-5, 5) % q

def keygen():
    s = random.randint(0, t-1)
    a = random.randint(1, q-1)
    e = noise()
    pk = (-(a*s+e)%q, a)
    return s,pk

def eval_keygen(sk):
    s = sk
    a = random.randint(1, p*q-1)
    e = noise()
    rlk = (-(a*s + e) + p * s**2) % (p*q)
    return (rlk, a)

def encrypt(pk, m):
    p0, p1 = pk
    u = random.randint(0, t-1)
    e1 = noise()
    e2 = noise()
    Delta = int(q / t)
    ct1 = (p0*u + e1 + Delta*m) % q
    ct2 = (p1 * u + e2) % q
    return (ct1, ct2)

def decrypt(sk, ct):
    s = sk
    c0, c1 = ct
    #print('Decrypt, before rounding:', t * ((c0 + c1 * s) % q) / q)
    m = round(t * ((c0 + c1 * s) % q) / q) % t
    return m

def e_add(ct1, ct2):
    o1 = (ct1[0] + ct2[0]) % q
    o2 = (ct1[1] + ct2[1]) % q
    return (o1, o2)

def e_mul(ct1, ct2, rlk):
    # multiplication
    c0 = round((t*(ct1[0] * ct2[0]))/q) % q
    c1 = round((t*(ct1[0]*ct2[1] + ct1[1] * ct2[0])) / q) % q
    c2 = round((t*(ct1[1]*ct2[1])) / q) % q
    
    # degree reduction
    c20 = round((c2*rlk[0])/p) % q
    c21 = round((c2*rlk[1])/p) % q
    
    return ((c0 + c20) % q, (c1 + c21) % q)

AddGate = namedtuple('AddGate', ['in1', 'in2'])
MultGate = namedtuple('MultGate', ['in1', 'in2'])

largest_wire = 0
def new_wire():
    global largest_wire
    largest_wire += 1
    return f'w{largest_wire}'

## Question 1 (20 points)

Implement a function `fv_eval_circuit` that evaluates a circuit given ciphertexts for its inputs.

In [None]:
def fv_eval_circuit(circuit, eval_order, wire_values, rlk):
    # YOUR CODE HERE
    raise NotImplementedError()
            
    return wire_values

In [None]:
# TEST CASE: add x and y

# Generate keys
sk, pk = keygen()
rlk = eval_keygen(sk)

# Encrypt inputs
inputs = {'x': encrypt(pk, 0),
          'y': encrypt(pk, 1)}

# Build circuit
circ = {'w1': AddGate('x', 'y')}
eval_order = ['w1']

# Run the circuit
wire_values = fv_eval_circuit(circ, eval_order, inputs, rlk)
output = wire_values['w1']

print(decrypt(sk, output))

In [None]:
# TEST CASE: Multiply x by itself 50 times

# Generate keys
sk, pk = keygen()
rlk = eval_keygen(sk)

# Encrypt inputs
inputs = {'x': encrypt(pk, 1)}

# Build circuit
circ = {}
last_wire = 'x'
for i in range(50):
    next_wire = new_wire()
    circ[next_wire] = MultGate(last_wire, 'x')
    last_wire = next_wire
eval_order = list(circ.keys())
output_wire = last_wire

# Run the circuit
wire_values = fv_eval_circuit(circ, eval_order, inputs, rlk)
output = wire_values[output_wire]

print(decrypt(sk, output))

## Question 2 (20 points)

Implement a function that adds two *integers* using the FV SHE scheme. Your solution should take two integers and a bitwidth ($n$), and:

1. Convert the two integers into bit-strings (lists of bits) using `convert_to_bitstring`
2. Generate a keypair (public, secret, and relinearization keys)
3. Encrypt each bit of the two bit-strings
4. Construct a bitstring adder circuit (reference the exercise from 9/30/2020)
5. Construct an `inputs` dictionary containing the encrypted values from (3) and a mapping from 'zero' to the result of encrypting the bit 0
6. Evaluate the circuit on the inputs
7. Construct an encrypted bit-string for the result from the output wire values
8. Decrypt each bit of the result from (7)
9. Convert the decrypted bit-string from (8) back into an integer and return it

In [None]:
def convert_to_bitstring(i, num_bits):
    s = f'{i:0{num_bits}b}'
    return [int(s) for s in s]

def convert_to_int(bitstring):
    return int("".join([str(v) for v in bitstring]), 2)

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

def add_ints(x, y, bitwidth):
    # YOUR CODE HERE
    raise NotImplementedError()
    
# Examples
print('Adding 5 and 7, 8-bit:', add_ints(5, 2, 8))
print('Adding 5 and 7, 16-bit:', add_ints(5, 2, 16))

In [None]:
# TEST CASE
assert add_ints(5, 2, 8) == 7
assert add_ints(5, 2, 16) == 7
#assert add_ints(5, 2, 64) == 7  ## Too much noise for this one!

## Question 3 (20 points)

In 1-2 sentences each, answer the following:

- How many gates are required for the 8-bit addition? How many are required for the 16-bit addition?
- Try performing 64-bit or 128-bit addition (or even higher). At what bitwidth do you start to see wrong answers?
- Why do you sometimes get the wrong answer for larger bitwidths?
- What parameter would you change to make these errors go away? How would you change it, and why?
- Is the addition circuit "wide" or "deep"? In other words, how large is the longest path (in terms of gates) from an input to the output, relative to the total number of gates?
- Would you expect noise to be worse in the case of a "wide" circuit, or a "deep" circuit? Why?

YOUR ANSWER HERE