# SPAKE2+ (P256-SHA256-HKDF-HMAC/CMAC)

SageMath implementation of SPAKE2+ as described by draft-bar-cfrg-spake2plus-00.

In [25]:
# sage -pip install pycryptodome
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import HKDF
from Crypto.Hash import SHA256, HMAC, CMAC

from struct import *

# P-256 constants and helper functions
px = 0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296
py = 0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5

p256 = 2^256 - 2^224 + 2^192 + 2^96 - 1
b256 = 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b

FF = GF(p256)
EC = EllipticCurve([FF(p256 - 3), FF(b256)])
P = EC(FF(px), FF(py))

# seed: 1.2.840.10045.3.1.7 point generation seed (M)
mx = 0x886e2f97ace46e55ba9dd7242579f2993b64e16ef3dcab95afd497333d8fa12f
my = 0x5ff355163e43ce224e0b0e65ff02ac8e5c7be09419c785e0ca547d55a12e2d20

# seed: 1.2.840.10045.3.1.7 point generation seed (N)
nx = 0xd8bbd6c639c62937b04d997f38c3770719c629d7014d49a24b4f98baa1292b49
ny = 0x07d60aa6bfade45008a636337f5168c64d9bd36034808cd564490b1e656edbe7

M = EC(FF(mx), FF(my))
N = EC(FF(nx), FF(ny))

def wrap_print(arg, *args):
    line_length = 68
    string = arg + " " + " ".join(args)
    for hunk in (string[0+i:line_length+i] for i in range(0, len(string), line_length)):
        if hunk and len(hunk.strip()) > 0:
            print(hunk)

def print_integer(name, x):
    wrap_print(name + ' = 0x' + format(x, 'x').zfill(64))

def encode_point(point):
    return '04' + format(int(point[0]), 'x').zfill(64) + format(int(point[1]), 'x').zfill(64)

def print_point(name, point):
    wrap_print(name + ' = 0x' + encode_point(point))

def pack_point(point):
    return pack_len(bytes.fromhex(encode_point(point)))

def pack_len(bytes):
    return pack('<Q', len(bytes)) + bytes

def pack_string(s):
    return pack_len(s) if s and len(s) > 0 else b''

def hkdf(ikm, info):
    return HKDF(ikm, 32, None, SHA256, 1, context=info)

def hmac(k, m):
    h = HMAC.new(k, digestmod=SHA256)
    h.update(m)
    return h.hexdigest()

def cmac(k, m):
    c = CMAC.new(k, ciphermod=AES)
    c.update(m)
    return c.hexdigest()

def derive_keys(TT):
    # Ka || Ke = Hash(TT)
    sk = SHA256.new(data=TT).digest()
    Ka = sk[:16]
    Ke = sk[16:]
    wrap_print('Ka = 0x' + Ka.hex())
    wrap_print('Ke = 0x' + Ke.hex())

    # KDF(nil, Ka, "ConfirmationKeys") = KcA || KcB
    ck = hkdf(Ka, b'ConfirmationKeys')
    KcA = ck[:16]
    KcB = ck[16:]
    wrap_print('KcA = 0x' + KcA.hex())
    wrap_print('KcB = 0x' + KcB.hex())
    
    return Ke, KcA, KcB

def spake2plus(A, B):
    # Print w0 and w1
    print_integer('w0', w0)
    print_integer('w1', w1)
    
    # B generates L
    L = w1 * P
    print_point('L', L)

    # A generates key share X
    x = int(FF.random_element())
    print_integer('x', x)
    X = int(x) * P + w0 * M
    print_point('X', X)

    # B generates key share Y
    y = int(FF.random_element())
    print_integer('y', y)
    Y = int(y) * P + w0 * N
    print_point('Y', Y)

    # A computes shared keys Z, V
    Z = x * (Y - w0 * N)
    V = w1 * (Y - w0 * N)
    print_point('Z', Z)
    print_point('V', V)

    # B computes shared keys Z, V
    assert Z == y * (X - w0 * M)
    assert V == y * L

    # TT = len(Context) || Context ||
    #      len(A) || A || len(B) || B ||
    #      len(M) || M || len(N) || N ||
    #      len(X) || X || len(Y) || Y ||
    #      len(Z) || Z || len(V) || V ||
    #      len(w0) || w0
    TT = pack_string(Context)
    TT += pack_string(A)
    TT += pack_string(B)
    TT += pack_point(M)
    TT += pack_point(N)
    TT += pack_point(X)
    TT += pack_point(Y)
    TT += pack_point(Z)
    TT += pack_point(V)
    TT += pack_len(bytes.fromhex(format(w0, 'x')))
    wrap_print('TT = 0x' + TT.hex())
    
    # Derive key schedule
    Ke, KcA, KcB = derive_keys(TT)

    # MAC = HMAC(KcA/KcB, Y/X)
    wrap_print('HMAC(KcA, Y) = 0x' + hmac(KcA, bytes.fromhex(encode_point(Y))))
    wrap_print('HMAC(KcB, X) = 0x' + hmac(KcB, bytes.fromhex(encode_point(X))))
    
    # MAC = CMAC(KcA/KcB, Y/X)
    wrap_print('CMAC(KcA, Y) = 0x' + cmac(KcA, bytes.fromhex(encode_point(Y))))
    wrap_print('CMAC(KcB, X) = 0x' + cmac(KcB, bytes.fromhex(encode_point(X))))

In [26]:
# Context for domain separation.
Context = b'SPAKE2+-P256-SHA256-HKDF-HMAC draft-01'

# w0s || w1s = PBKDF(len(pw) || pw || len(A) || A || len(B) || B)
# w0 = w0s (mod p) and w1 = w1s (mod p)
w0 = int(FF.random_element())
w1 = int(FF.random_element())

# Set A and B to None if identities are implicit.
A_identities = [b'client', b'']
B_identities = [b'server', b'']

for A in A_identities:
    for B in B_identities:
        wrap_print('\n[Context=%s]' % (Context))
        wrap_print('[A=%s]' % (A))
        wrap_print('[B=%s]' % (B))
        spake2plus(A, B)


[Context=b'SPAKE2+-P256-SHA256-HKDF-HMAC draft-01'] 
[A=b'client'] 
[B=b'server'] 
w0 = 0x1d122d5b59da10c389f4951b41abc18ed1919a24c04ede960bcf88dbc4c69
946 
w1 = 0xf327601d0c6cc3071b449555591dc01531528db7b887264bb4515630a6430
d08 
L = 0x0413c6c51ae6fcf717e626f520dd6d60135062220241516c2e522589c08775
a264a6548cb85b9a9d1369517829b8978d0ca5d11059c7d0beeb22c490bdcce4a83d
x = 0x5f6b46fb2ea1910d22faf099d77e1d32b7794d38f69933c55075e50e9158a2
5f 
X = 0x043d0aedfe82808e4ef731cab5f4db9db427d95692bb3c5be5698071765c11
3836b81a7f85c6eed46a073a9fc5049e413b0e75d895d0e622aafa4c0614b3094b45
y = 0x10c2a67d006d5b44d9841f878dc049abdec1b324fc7c15b58af45726c15a59
05 
Y = 0x040a1e796a0fff35a17a1c5ca8c8efe27143f2046727ec5ec763c83ac557be
04ab05d9f86e3aea08c1718eb26153fb3302ed67b1d65e7fbda8a40a0db2998399ba
Z = 0x04dcda70ec5a997386fd8303c38c94760033a8f4de515534792d1b9cefc10f
7aaa5af8ee2212cec16fc6c391b95659ad13c4f0b529a40ccee7cddd3d8568c76b8f
V = 0x04225ec3195304e09fb49ee8fa0c366cacda2fb2518510f7d51c9dbcd2fb87
f