# Simulating EPA-RE using points of low-order

In [None]:
import itertools
import gc
import glob
import pickle
import random
import re
import hashlib

import warnings
warnings.filterwarnings(
    "ignore",
    message="pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html.",
    category=UserWarning
)

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from collections import Counter
from pathlib import Path
from random import randint, randbytes
from typing import Type, Any

from tqdm.auto import tqdm, trange

from pyecsca.ec.params import DomainParameters, get_params
from pyecsca.ec.mult import *
from pyecsca.ec.mod import mod
from pyecsca.sca.re.rpa import multiple_graph
from pyecsca.sca.re.epa import graph_to_check_inputs, evaluate_checks
from pyecsca.misc.utils import TaskExecutor

from common import *

## Initialize

In [None]:
def silence():
    import warnings
    warnings.filterwarnings(
        "ignore",
        message="pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html.",
        category=UserWarning
    )
silence()

In [None]:
nmults = len(all_mults)
nmults_ctr = len(all_mults_with_ctr)
nerror_models = len(all_error_models)
ncfgs = nmults_ctr * nerror_models

print(f"Scalar multipliers considered:  {nmults}")
print(f"Scalar multipliers (with a single countermeasure) considered:  {nmults_ctr}")
print(f"Error models considered:  {nerror_models}")
print(f"Total configurations considered:  {ncfgs}")

## Prepare

In [None]:
category = "secg"
curve = "secp256r1"
kind = "precomp+necessary"
use_init = True
use_multiply = True
params = get_params(category, curve, "projective")
num_workers = 20
bits = params.order.bit_length()
samples = 100
selected_mults = all_mults

In [None]:
def simulate_multiples(mult: MultIdent,
                       params: DomainParameters,
                       bits: int,
                       samples: int = 100,
                       use_init: bool = True,
                       use_multiply: bool = True,
                       seed: bytes | None = None) -> MultResults:
    results = []
    if seed is not None:
        random.seed(seed)

    # If no countermeasure is used, we have fully random scalars.
    # Otherwise, fix one per chunk.
    if mult.countermeasure is None:
        scalars = [random.randint(1, 2**bits) for _ in range(samples)]
    else:
        one = random.randint(1, 2**bits)
        scalars = [one for _ in range(samples)]

    for scalar in scalars:
        results.append(multiple_graph(scalar, params, mult.klass, mult.partial, use_init, use_multiply))
    return MultResults(results, samples)

In [None]:
def evaluate_multiples(mult: MultIdent, res: MultResults, divisors: set[int]):
    errors = {divisor: 0 for divisor in divisors}
    samples = len(res)
    divisors_hash = hashlib.blake2b(str(sorted(divisors)).encode(), digest_size=8).digest()
    for ctx, out in res:
        check_inputs = graph_to_check_inputs(ctx, out,
                                             check_condition=mult.error_model.check_condition,
                                             precomp_to_affine=mult.error_model.precomp_to_affine)
        for q in divisors:
            error = evaluate_checks(check_funcs={"add": mult.error_model.check_add(q), "affine": mult.error_model.check_affine(q)},
                                    check_inputs=check_inputs)
            errors[q] += error
    # Make probmaps smaller. Do not store zero probabilities.
    probs = {}
    for q, error in errors.items():
        if error != 0:
            probs[q] = error / samples
    return ProbMap(probs, divisors_hash, samples)

In [None]:
def evaluate_multiples_direct(mult: MultIdent, fname: str, offset: int, divisors: set[int]):
    with open(fname, "rb") as f:
        f.seek(offset)
        _, res = pickle.load(f)
    errors = {divisor: 0 for divisor in divisors}
    samples = len(res)
    divisors_hash = hashlib.blake2b(str(sorted(divisors)).encode(), digest_size=8).digest()
    for ctx, out in res:
        check_inputs = graph_to_check_inputs(ctx, out,
                                             check_condition=mult.error_model.check_condition,
                                             precomp_to_affine=mult.error_model.precomp_to_affine)
        for q in divisors:
            error = evaluate_checks(check_funcs={"add": mult.error_model.check_add(q), "affine": mult.error_model.check_affine(q)},
                                    check_inputs=check_inputs)
            errors[q] += error
    # Make probmaps smaller. Do not store zero probabilities.
    probs = {}
    for q, error in errors.items():
        if error != 0:
            probs[q] = error / samples
    return ProbMap(probs, divisors_hash, samples)

## Run
Run this cell as many times as you want. It will write chunks into files.

In [None]:
chunk_id = randbytes(4).hex()
with TaskExecutor(max_workers=num_workers, initializer=silence) as pool:
    for mult in all_mults_with_ctr:
        pool.submit_task(mult,
                         simulate_multiples,
                         mult, params, bits, samples, seed=chunk_id, use_init=use_init, use_multiply=use_multiply)
    with open(f"multiples_{bits}_{'init' if use_init else 'noinit'}_{'mult' if use_multiply else 'nomult'}_chunk{chunk_id}.pickle","wb") as h:
        for mult, future in tqdm(pool.as_completed(), desc="Computing multiple graphs.", total=len(pool.tasks)):
            print(f"Got {mult}.")
            if error := future.exception():
                print("Error!", error)
                continue
            res = future.result()
            pickle.dump((mult, res), h)
            gc.collect()

## Process

In [None]:
with TaskExecutor(max_workers=num_workers, initializer=silence) as pool:
    for in_fname in tqdm(glob.glob(f"multiples_{bits}_{'init' if use_init else 'noinit'}_{'mult' if use_multiply else 'nomult'}_chunk*.pickle"), desc="Processing chunks", smoothing=0):
        
        match = re.match("multiples_(?P<bits>[0-9]+)_(?P<init>(?:no)?init)_(?P<mult>(?:no)?mult)_chunk(?P<id>[0-9a-f]+).pickle", in_fname)
        bits = match.group("bits")
        use_init = match.group("init")
        use_multiply = match.group("mult")
        chunk_id = match.group("id")
        out_fname = f"probs_{bits}_{use_init}_{use_multiply}_chunk{chunk_id}.pickle"


        in_file = Path(in_fname)
        out_file = Path(out_fname)

        cfgs_todo = set()
        for mult in all_mults_with_ctr:
            for error_model in all_error_models:
                cfgs_todo.add(mult.with_error_model(error_model))

        if out_file.exists():
            print(f"Processing chunk {chunk_id}, some(or all) probmaps found.")
            with out_file.open("r+b") as f:
                while True:
                    try:
                        full, _ = pickle.load(f)
                        cfgs_todo.remove(full)
                        last_end = f.tell()
                    except EOFError:
                        break
                    except pickle.UnpicklingError:
                        f.truncate(last_end)
            if not cfgs_todo:
                print(f"Chunk complete. Continuing...")
            else:
                print(f"Chunk missing {len(cfgs_todo)} probmaps, computing...")
        else:
            print(f"Processing chunk {chunk_id}, no probmaps found.")
        
        with in_file.open("rb") as f, out_file.open("ab") as h:
            loading_bar = tqdm(total=nmults_ctr, desc=f"Loading chunk {chunk_id}.", smoothing=0)
            processing_bar = tqdm(total=len(cfgs_todo), desc=f"Processing {chunk_id}.", smoothing=0)
            while True:
                try:
                    start = f.tell()
                    mult, vals = pickle.load(f)
                    loading_bar.update(1)
                    for error_model in all_error_models:
                        full = mult.with_error_model(error_model)
                        if full in cfgs_todo:
                            # Pass the file name and offset to speed up computation start.
                            pool.submit_task(full,
                                             evaluate_multiples_direct,
                                             full, in_fname, start, divisor_map["all"])
                    gc.collect()
                    if len(pool.tasks) > 1000:
                        for full, future in pool.as_completed():
                            processing_bar.update(1)
                            if error := future.exception():
                                print("Error!", full, error)
                                continue
                            res = future.result()
                            pickle.dump((full, res), h)
                except EOFError:
                    break
                except pickle.UnpicklingError:
                    print("Bad unpickling, the multiples file is likely truncated.")
                    break
            for full, future in pool.as_completed():
                processing_bar.update(1)
                if error := future.exception():
                    print("Error!", full, error)
                    continue
                res = future.result()
                pickle.dump((full, res), h)
        print("Chunk completed.")


## Misc

In [None]:
from pyinstrument import Profiler as PyProfiler
mult = next(iter(multiples_mults))
res = multiples_mults[mult]


for checks in powerset(checks_add):
    for precomp_to_affine in (True, False):
        for check_condition in ("all", "necessary"):
            error_model = ErrorModel(checks, check_condition=check_condition, precomp_to_affine=precomp_to_affine)
            full = mult.with_error_model(error_model)
            print(full)
            #with PyProfiler() as prof:
            probmap = evaluate_multiples(full, res, divisor_map["all"])
            #print(prof.output_text(unicode=True, color=True))
            #print(probmap)

In [None]:
multiples_mults = {}
for fname in glob.glob(f"multiples_{bits}_{'init' if use_init else 'noinit'}_{'mult' if use_multiply else 'nomult'}_chunk*.pickle"):
    with open(fname, "rb") as f:
        while True:
            try:
                mult, vals = pickle.load(f)
                if mult not in multiples_mults:
                    multiples_mults[mult] = vals
                else:
                    multiples_mults[mult].merge(vals)
            except EOFError:
                break