In [None]:
reset()
import numpy as np
import networkx as nx
from numpy import linalg as LA
import random
from fpylll import IntegerMatrix,LLL,GSO

# if for some reason the target is not in the span, sagemath will raise an arithmetic error
def find_linear_combination(basis, target):
    dim = len(basis)  # we want the len of the superbasis
    if len(list(target))!=dim:
        return "target vector does not have the right dimension"
    basis_sage = []
    basis_sage.append([vector(x) for x in basis])
    basis_sage = basis_sage[0]  # we ignore the superbasis vector!!!
    K = QQ  # K is the rational field
    VS = (K ** (dim)).span_of_basis(basis_sage)  # we define the subspace generated by v[0],v[1],..,v[n-2]

    try:
        z = VS.coordinate_vector(vector(target))
        z = vector(list(z) + [0])
        return z
    except ArithmeticError:
        print("Oops! vector is not in the span.")
        return

def create_weighted_graph_networkx(dimension, S, Q):
    """
    source and sink are two positive integers, source = 0 and sink = n+2
    S=(s_i) is given by s_i=-2sum_{j} q_{ij}p_j
    Graph has n+3 nodes: the 1st node is the source = 0, the next n+1 nodes represent each vector 
    of the superbasis and thelast n+3 th node is the sink
    """
    graph = nx.Graph()
    source = 0
    sink = dimension + 2
    graph.add_nodes_from(range(source, sink))
    graph.add_edge(source, sink, capacity=0.0)        
    for i in range(1, dimension+2):
        for j in range(i+1, dimension+2):
            node1 = i
            node2 = j
            graph.add_edge(node1, node2, capacity=-Q[i-1][j-1])
    
    for i in range(0, dimension+1):
        node = i+1
        if S[i] < 0:
            graph.add_edge(source, node, capacity=-S[i])
            graph.add_edge(node, sink, capacity=0.0)
        elif S[i] >= 0:
            graph.add_edge(source, node, capacity=0.0)
            graph.add_edge(node, sink, capacity=S[i])
    return graph

def get_qij(superbasis):
    n = len(superbasis)
    Q = [superbasis[i].dot_product(superbasis[j]) for i in range(n) for j in range(n)]
    Q = matrix(n, n, Q)
    return Q

def gen_ntru_vfk(N,P,q,k):
    H = matrix(N)
    I=identity_matrix(N)
    Zero_Matrix=matrix(N)
    B_1=block_matrix([[I,-k*I]])        
    B_2=block_matrix([[P*I,(q-P*k)*I]])
    M_NTRU=block_matrix([[B_1],[B_2]])
    return M_NTRU

def get_superbasis(basis): # basis is a list
    n = len(basis)
    basis = [vector(x) for x in basis]
    superbasis_vector = -sum(basis[i] for i in [0..n-1])
    return [vector(basis[i]) for i in [0..n-1]] + [vector(superbasis_vector)]


def calculate_s(z, u, Q):
    s = []
    p = np.subtract(z, u)
    rows,columns=Q.dimensions()
    for i in range(rows):
        elements = []
        for j in range(columns):
            elements.append(Q[i][j] * p[j])
        s.append(-2 * np.sum(elements))
    return s


def calculate_s_2(z, u, Q):
    s = []
    p = np.subtract(z, u)
    n = len(p)
    #print(n)
    s =[-2 * Q[i].dot_product(vector(p)) for i in [0..n-1]]
    return s

def length(v):
    n = len(v[0])
    x=0
    print(n)
    for i in range(n):
        x=x+v[0][i]**2
    return x.expand()

def t_coeffs(dimension, set_c, source):
    set_c.remove(source)
    t = np.zeros(shape=dimension)
    for i in set_c:
        t[i-1] = 1
    return t

def do_mincut(n,S,Q):
    G = create_weighted_graph_networkx(n,S,Q) # S depends on the target vector.
    source=0
    sink=n+2
    Gsage=Graph(G)
    #for i in range(n+3):
    #    for j in range(i+1,n+3):
    #        print(i,"-->",j,":",Gsage.edge_label(i,j))
    cut_value, partition = nx.minimum_cut(G, 0, sink)
    reachable, non_reachable = partition
    cut = reachable if source in reachable else non_reachable
    t_c = t_coeffs(n+1,cut,0)
    return t_c

def cvp_sage(target_vector,basis):   
    from fpylll import IntegerMatrix, CVP, LLL
    from sage.modules.free_module_integer import IntegerLattice
    L = IntegerLattice(matrix(basis))
    L.LLL()
    #LLL.reduction(L, delta =0.99 )
    v = L.closest_vector(target_vector)
    print("CVP(L,target) result from sagemath:", v)
    distance = LA.norm(np.subtract(target_vector, v))
    print("distance=",distance)
    return 

def create_weighted_graph_networkx(dimension, S, Q):
    """
    source and sink are two positive integers, source = 0 and sink = n+2
    S=(s_i) is given by s_i=-2sum_{j} q_{ij}p_j
    Graph has n+3 nodes: the 1st node is the source = 0, the next n+1 nodes represent each vector 
    of the superbasis and thelast n+3 th node is the sink
    """
    graph = nx.Graph()
    source = 0
    sink = dimension + 2
    graph.add_nodes_from(range(source, sink))
    graph.add_edge(source, sink, capacity=0.0)
    
        
    for i in range(1, dimension+2):
        for j in range(i+1, dimension+2):
            node1 = i
            node2 = j
            graph.add_edge(node1, node2, capacity=-Q[i-1][j-1])
    
    for i in range(0, dimension+1):
        node = i+1
        if S[i] < 0:
            graph.add_edge(source, node, capacity=-S[i])
            graph.add_edge(node, sink, capacity=0.0)
        elif S[i] >= 0:
            graph.add_edge(source, node, capacity=0.0)
            graph.add_edge(node, sink, capacity=S[i])
    # print("EDGES: ", graph.edges.data())
    # print(laplacian_matrix(graph))
    return graph

def is_vfk(Q,N):
    '''
    Q is the Gram-Schimdt matrix of the superbasis
    N is the parameter given when we build ntru_vfk_matrix
    '''
    lst=[]
    x=np.array([Q])
    lst=x[np.nonzero(x)]
    lst=list(lst)
    for i in range(2*N+1):
        if Q[i][i]!=0:
            lst.remove(Q[i][i])
    return all(val <= 0 for val in lst)


def is_vfk_extended(A,N):
    '''
    A is the matrix to be checked
    N is the parameter given when we build ntru_vfk_matrix
    
    Remark :  This is not a generic function. It assumes that the rows of A provide the basis
    which can be extended to an (obtuse) superbasis
    '''
    
    basis = A.rows()
    super_basis=get_superbasis(basis)
    Q=get_qij(super_basis)
    lst=[]
    x=np.array([Q])
    lst=x[np.nonzero(x)]
    lst=list(lst)
    for i in range(2*N+1):
        if Q[i][i]!=0:
            lst.remove(Q[i][i])
    return all(val <= 0 for val in lst)

def length(v):
    n = len(v[0])
    x=0
    print(n)
    for i in range(n):
        x=x+v[0][i]**2
    return x.expand()

def fptosage(A):
    n = A.nrows
    C = matrix(n)
    for i in range(n):
        C[i] = list(A[i])
    return C

def babai_fpylll(M,target_vector,N):
    
    M_fpylll=IntegerMatrix.from_matrix(M)
    LLL.reduction(M_fpylll, delta =0.99 )
    M_GSO = GSO.Mat(M_fpylll)
    M_GSO.update_gso()
    
    b = M_GSO.babai(target_vector)
    M_sage = fptosage(M_fpylll)
    w = sum(b[i]*M_sage[i] for i in range(2*N)).list()
    return w,(vector(w)-vector(target_vector)).norm().n()

def cvp_vfk(N,Q,basis,super_basis,target_vector):
    z = find_linear_combination(basis, target_vector)
    u_point = np.floor(np.array(z)).astype(int)
    S = calculate_s_2(z,u_point,Q)
    vec = np.dot(u_point, super_basis).astype(int)
    print("distance before the loop:",LA.norm(np.subtract(target_vector, vec)))
    for i in range(2*N):  # the iter should converge at n iters, not n+1
        # TODO: add if y (target) is inside the voronoi cell of u then stop the procedure.\
        distance_old = LA.norm(np.subtract(target_vector, vec))
        print("++++++++++++++++++++++++++++++++++++++++++++++++++")
        print("Iter: ", i)
        t = do_mincut(2*N,S,Q) # you must update S
        #print(u_point)
        u_point = np.add(u_point, t).astype(int)
        #print(u_point)
        vec = np.dot(u_point, super_basis).astype(int)
        #print("x=",vec)
        distance = LA.norm(np.subtract(target_vector, vec))
        if distance>=distance_old:
            #print("vector=",vec)
            print("distance=",distance)
            print("+++++++")
            return vector(vec)
        S = calculate_s_2(z,u_point,Q) # why S reamains the same???