# WOTS (Winternitz One-Time Signature) Implementation
This notebook implements the WOTS signature scheme in Python


In [9]:
import hashlib
import numpy as np
from typing import List

# Parameters
SPX_N = 16  # Hash output length in bytes
SPX_WOTS_W = 16  # Winternitz parameter
SPX_WOTS_LOGW = 4  # log2(SPX_WOTS_W)
SPX_WOTS_LEN1 = 64  # Length of message part
SPX_WOTS_LEN2 = 3   # Length of checksum part
SPX_WOTS_LEN = SPX_WOTS_LEN1 + SPX_WOTS_LEN2  # Total length
SPX_SHA256_OUTPUT_BYTES = 32  # Length of hash output

In [10]:
def prf_addr(seed: bytes, addr: List[int]) -> bytes:
    """PRF function using SHA256"""
    addr_bytes = b''.join(i.to_bytes(4, 'big') for i in addr)
    return hashlib.sha256(seed + addr_bytes).digest()[:SPX_SHA256_OUTPUT_BYTES]

def thash(in_data: bytes, pub_seed: bytes, addr: List[int]) -> bytes:
    """T-hash function using SHA256"""
    addr_bytes = b''.join(i.to_bytes(4, 'big') for i in addr)
    return hashlib.sha256(pub_seed + addr_bytes + in_data).digest()[:SPX_N]

In [11]:
def gen_chain(input_data: bytes, start: int, steps: int, pub_seed: bytes, addr: List[int]) -> bytes:
    """Compute the chaining function"""
    out = input_data[:]
    
    for i in range(start, min(start + steps, SPX_WOTS_W)):
        addr[6] = i  # Set hash address
        out = thash(out, pub_seed, addr)
    
    return out

def base_w(input_bytes: bytes, out_len: int) -> List[int]:
    """base_w algorithm for converting bytes to base w integers"""
    output = []
    total = 0
    bits = 0
    
    for consumed in range(out_len):
        if bits == 0:
            total = input_bytes[len(output) // (8 // SPX_WOTS_LOGW)]
            bits = 8
        bits -= SPX_WOTS_LOGW
        output.append((total >> bits) & (SPX_WOTS_W - 1))
    
    return output

# test base_w
assert base_w(b'\x00', 1) == [0]
assert base_w(b'\x01', 2) == [0, 1]

In [12]:
def wots_gen_pk(seed: bytes, pub_seed: bytes, addr: List[int]) -> bytes:
    """Generate WOTS public key"""
    pk = bytearray()
    
    for i in range(SPX_WOTS_LEN):
        addr[5] = i  # Set chain address
        # Generate private key element
        sk = prf_addr(seed, addr)
        # Compute chain
        pk_element = gen_chain(sk, 0, SPX_WOTS_W - 1, pub_seed, addr)
        pk.extend(pk_element)
    
    return bytes(pk)

def wots_sign(msg: bytes, seed: bytes, pub_seed: bytes, addr: List[int]) -> bytes:
    """Generate WOTS signature"""
    # Convert message to base w
    msg_base_w = base_w(msg, SPX_WOTS_LEN1)
    
    # Compute checksum
    csum = sum(SPX_WOTS_W - 1 - x for x in msg_base_w)
    csum_bytes = csum.to_bytes((SPX_WOTS_LEN2 * SPX_WOTS_LOGW + 7) // 8, 'big')
    csum_base_w = base_w(csum_bytes, SPX_WOTS_LEN2)
    
    lengths = msg_base_w + csum_base_w
    sig = bytearray()
    
    for i in range(SPX_WOTS_LEN):
        addr[5] = i  # Set chain address
        sk = prf_addr(seed, addr)
        sig_element = gen_chain(sk, 0, lengths[i], pub_seed, addr)
        sig.extend(sig_element)
    
    return bytes(sig)
    

In [13]:
def wots_pk_from_sig(sig: bytes, msg: bytes, pub_seed: bytes, addr: List[int]) -> bytes:
    """Compute public key from signature"""
    # Convert message to base w
    msg_base_w = base_w(msg, SPX_WOTS_LEN1)
    
    # Compute checksum
    csum = sum(SPX_WOTS_W - 1 - x for x in msg_base_w)
    csum_bytes = csum.to_bytes((SPX_WOTS_LEN2 * SPX_WOTS_LOGW + 7) // 8, 'big')
    csum_base_w = base_w(csum_bytes, SPX_WOTS_LEN2)
    
    lengths = msg_base_w + csum_base_w
    pk = bytearray()
    
    for i in range(SPX_WOTS_LEN):
        addr[5] = i  # Set chain address
        sig_element = sig[i*SPX_N:(i+1)*SPX_N]
        pk_element = gen_chain(sig_element, lengths[i], 
                              SPX_WOTS_W - 1 - lengths[i], 
                              pub_seed, addr)
        pk.extend(pk_element)
    
    return bytes(pk)

In [15]:
# Example usage
if __name__ == "__main__":
    # Test vectors
    seed = bytes([i % 256 for i in range(32)])
    pub_seed = bytes([(i + 128) % 256 for i in range(32)])
    addr = [0] * 8
    msg = b"Hello, WOTS!"
    # msg extended to 32 bytes
    msg = msg + b'\x00' * (32 - len(msg))
    
    # Generate public key
    pk = wots_gen_pk(seed, pub_seed, addr)
    
    print(f"Public key: {pk.hex()}")
    
    # Sign message
    sig = wots_sign(msg, seed, pub_seed, addr)
    print(f"Signature: {sig.hex()}")
    
    # Verify by reconstructing public key
    pk2 = wots_pk_from_sig(sig, msg, pub_seed, addr)
    print(f"Verified: {pk == pk2}")

Public key: 78353cf2967d5587e4ff59d3ae03c095b3355ed99390c2d2da8c90220d525c2640920c0601885fe8e232da6ba5ae540ca53e03527c09ee5546a820211c1307bfd71667516d2a0358d2cd9513e376c3f3ac2de1b886416d3dd618881b01af6387b1c6dfecfcaf8c854402e6006378c73fba8851691325df9446c0866730458f6f184789c00cefef168179570adacb186eff1d09faa649c12c64e589139589f403228312f3cb3edbb378160e31c807546157c6c31b106d01809e5aca1ae859f9debf6da8a5dbd14bf59409dc6d641aecdb0e9ecb68a5dbd3fd1778d5781a8b3da7b8a0c5ce5a894814d6a9c1b2e4ead00ae5e767ce0de3ce3f3ea57daabbeac3d70759542c4fffeded905614e1a7b9dcbb7124efbd07c3b930b2e6b197b1ad2089cc4d7614914cf126d72e4e60d0a1cc8857965e202fa798d258aca80d3fe799c42cb421c43b60b8525ad012d2f836163dc6652d0b974b29aedacbad33b82f4e809bf6288f736669f8554ef210387ca54c5ac49da5c50f155636b0349fa4e21d76032811276c1b64b5f4996a4745b2fbdf84aece4272b5707d82643b905aa78d40188a8bcaeb8f00c159a6c4fc2afaac702b59445d99c82949472ccd4507e5f5805c15f687f6ea318831db718a37e9c9b465f39dd0a83aa2206db9eba201000728d6ce27d50441865d06c34889e48c

- `wots_PKgen`: use the left node (security key) to create right nodes (public key) on hash chain, then hash all the right nodes to get the public key
- `wots_sign`: use the message to chose the node position on hash chain, concatenate all node to get the signature
- `wots_pkFromSig`: use the signature and message to get the public key