In [1]:
warnings.filterwarnings("ignore")

In [2]:
def balance(e, q=None):
    try:
        # Try to recursively balance a vector or list of elements
        p = parent(e).change_ring(ZZ)
        return p([balance(e_, q) for e_ in e])
    except (TypeError, AttributeError):
        # If e is a scalar, balance it
        if q is None:
            try:
                q = parent(e).order()
            except AttributeError:
                q = parent(e).base_ring().order()
        e_z = ZZ(e)
        return e_z - q if e_z > q // 2 else e_z

def encode(e, q=None):
    try:
        # Try to recursively balance a vector or list of elements
        p = parent(e).change_ring(ZZ)
        return p([encode(e_, q) for e_ in e])
    except (TypeError, AttributeError):
        # If e is a scalar, balance it
        if q is None:
            try:
                q = parent(e).order()
            except AttributeError:
                q = parent(e).base_ring().order()
        e_z = ZZ(e)
        return round(q/2)*e_z

def decode(e, q=None):
    try:
        # Try to recursively balance a vector or list of elements
        p = parent(e).change_ring(ZZ)
        return p([decode(e_, q) for e_ in e])
    except (TypeError, AttributeError):
        # If e is a scalar, balance it
        if q is None:
            try:
                q = parent(e).order()
            except AttributeError:
                q = parent(e).base_ring().order()
        e_z = Zmod(q)(e).lift_centered()
        return 0 if -round(q/4) <= e_z and e_z <round(q/4)  else 1

from sage.stats.distributions.discrete_gaussian_integer import  DiscreteGaussianDistributionIntegerSampler as DG


In [3]:
from hashlib import sha256


q = 7681
n = 256
t = floor(q/4)

se=ss = sqrt(1.2)
class SDSIG:
    def __init__(self, q = q, n = n,  kie = se,kis=ss,DG=DG):
        self.q = q
        self.d = n
        self.kie = kie
        self.kis = kis
        self.Dr = DG(self.kis)
        self.De = DG(self.kie)
        Zq, Y = PolynomialRing(GF(q), 'y').objgen()
        R, X = Zq.quotient_ring(Y**n + 1, 'x').objgen()
        self.R = R
        self.X = X
        self._keygen()
        
    def _sample_element(self,q=q):
        coeffs = [ randint(0,q-1) for _ in range(self.d)]
        return self.R(coeffs)
        
    def _sample_short_secret(self):
        coeffs = [ self.Dr() for i in range(self.d)]
        return self.R(coeffs)

    def _sample_small_error(self):
        coeffs = [ self.De() for i in range(self.d)]
        return self.R(coeffs)

    def _keygen(self):
        a  = self._sample_element()
        r1  = self._sample_short_secret()
        r2  = self._sample_short_secret()
        p = r1 - a*r2
        self.sk = r2
        self.pk = (p, a)

    def sign(self, m):
        p,a = self.pk
        r2 = self.sk
        mbar = encode(m,q)  # assume that m is already hashed 
        e1, e2, e3, e4 = [self._sample_small_error() for _ in range(4)]
        C1 = balance(p * e1     + e2)
        C2 = balance(p * a * e1 + e3 + mbar)
        C3 = balance(a*e2 + e4)
        h = sha256((str(m) + str(decode(a*r2*e1,q))).encode() ).digest()
        return (C1,C2,C3,h)
    
    def verify(self, c,m):
        C1,C2,C3,h = c
        p,a = self.pk
        mbar = encode(m,q)
        h_recomputed = sha256((str(m) + str(decode(-C1.lift(),q))).encode() ).digest()
        
        return h_recomputed == h and decode((C2-a*C1+C3).lift(),q) == m
        #return [0 if abs(wi) < self.q//4 else 1 for wi in coeffs]

        

In [4]:
from tqdm import trange
se=ss = sqrt(1.2)
n=256
sig = SDSIG(kie = se,kis=ss,n=n)
r2 = (sig.sk.lift())
p,a = [_.lift() for _ in sig.pk]
r1 = p + a*r2

suc = 0
NTEST = 1000

for i in trange(NTEST):
    #choosing random message and assume it is hashed and has the length of the degree of ring
    m = sig.R([randint(0,1) for _ in sig.d*" "]).lift()
    mbar = encode(m,q)
    cond =  0
    while not(cond):
        e1, e2, e3, e4 = [sig._sample_small_error() for _ in range(4)]
        c1 = p*e1+e2
        c2 = p*a*e1+e3 + mbar
        c3 = a*e2+e4
        h = sha256( (str(m) + str(decode(-c1.lift(),q))).encode() ).digest()
        # verifying the condition for the signature to be valid without giving the signing keys 
        cond =  decode((c2-a*c1+c3).lift(),q) == m
                      
    if sig.verify((c1,c2,c3,h),m):
         suc += 1
print("Success Rate: ", suc/NTEST*100.,"%")

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:07<00:00, 133.73it/s]

Success Rate:  100.000000000000 %



