### [PSMT_Malicious] Project (ND&HYU Collaboration)
- Mathematical Construction of Approximated VAF and Analysis
- Seunghun Paik

In [1]:
# Libraries
import numpy as np
import matplotlib.pyplot as plt
import math
from tqdm.auto import tqdm

### Helper Tools

In [2]:
# Helper Function to Plot Functions
def plot_single(fn, lb, ub, **kwargs):
    npoint = min(int(ub-lb), 1001)
    npoint = max(npoint, 10001)
    base = np.linspace(lb, ub, npoint)
    ret = fn(base, **kwargs)
    plt.plot(base, ret)
    

### Core Functions

In [3]:
# Single Operation
def wDEPSingle(x, k):
    # 3\sqrt{3} / 2k\sqrt{k} * x * (k - x^2)
    return pow(27/4, 0.5) / pow(k, 1.5) * x * (k - x ** 2)

# Weak DEP
def wDEP(x, k, L, R, n):
    coeff = pow(L, n-1) * R
    x = x / coeff
    coeff = pow(27/4, 0.5) / pow(k, 1.5)
    for i in range(n-1, -1, -1):
        x = (k * x - x ** 3)
        if i > 0 :
            x = x * L * coeff
        else:
            x = x * R * coeff
    return x
    
# Transformation and Squaring
def polyVAF(x, n):
    # Start Function: (1 - 1.5x^2)^2
    x = (1 - 1.5 * x ** 2) ** 2
    for i in range(n):
        # x -> (1.5 * x - 0.5) ** 2
        x = (1.5 * x - 0.5) ** 2
        
    for i in range(4):
        x = x ** 2
    return x

# Naive Squaring
def squareVAF(x, n):
    x = (1 - 1.5 * x ** 2) ** 2
    for i in range(n):
        x = x ** 2
    return x

### Tools for Analysis

In [4]:
# Finds the epsilon parameter for VAFs
def compute_crit_pt(f, bd, n_iter = 1000, **kwargs):
    lower = 0 
    upper = 1e-12
    
    while (f(upper, **kwargs) > bd):
        upper *= 10
        
        # In this case, there is no point less than bd
        if upper >= 1:
            return 1

    for i in range(n_iter):
        mid = (lower + upper) / 2
        fmid = f(mid, **kwargs) - bd
        
        if fmid > 0 :
            lower = mid            
        else:
            upper = mid
    
    return mid

In [5]:
# This function computes the root of f1(X) = f2(X)
# For a constant function, it behaves as the root finding algorithm. (Bisection Search)
def solver(f1, f2, lo, hi, **kwargs):
    v1 = f1(lo, **kwargs) - f2(lo)
    v2 = f1(hi, **kwargs) - f2(hi)
    
    if v1 < 0 and v2 > 0:
        lo, hi = hi, lo
    
    for i in range(1000):
        mid = (lo + hi) / 2
        vmid = f1(mid, **kwargs) - f2(mid)
        if vmid > 0:
            lo = mid
        else:
            hi = mid
            
    return mid
        
# Inverse of the weak DEP functions for the range (-\sqrt{k/3}, \sqrt{k/3})
# Here, *local* inverse is well-defined on this domain
def inv_f(v, k):
    return solver(wDEPSingle, lambda x: v, 0, pow(k/3, 0.5), k = k)

# This function calculates the first peak of the wDEPs closest to the origin.
# That is, this function tells us the domain where wDEPs behaves monotonically.
def firstPeak(k, L, R, n):
    start = pow(k/3, 0.5)
    for i in range(1,n):
        start = start / L
        start = inv_f(start, k)        
    return start* pow(L, n-1) * R

In [6]:
# Depth Calculator
# Greedy-Type Algorithm
def greedy_depth(val):
    ret = 0
    
    while val > 0:
        if val >= 4:
            val -= 4
            ret += 5
            
        elif val >= 3:
            val -= 3
            ret += 4
        elif val >= 2:
            val -= 2
            ret += 3
        else:
            val -= 1
            ret += 2
    return ret

### Parameter Selector

In [7]:
# Parameter Selection
# k: params for weak DEP
# L: Expansion Rate
# M: Desirable domain (The resulting DEP will be defined over [-M, M])
def DEPSelector(k, L, M):
    # Avoid Invalid Parameters
    assert(L**2 < k)
    
    # Do a grid search with respect to R & n
    Rs = np.linspace(1, M, min(M, 1001))
    ns = np.ceil(np.log2(M/Rs) / math.log2(L)).astype(int)
    
    # Analysis on the tail value
    fL = wDEPSingle(L, k)
    retdict = dict()
    
    best_key = None
    best_depth = 10000
    best_range = 0
    
    # Do operations
    for (R, n) in tqdm(zip(Rs, ns), total = len(Rs)):
        # If we don't need to run DEPs?
        if n == 0:
            e_sep = 1/R
            ret = 1
            i_new = 0
            # Compute the epsilon for VAFs
            while True:
                ret = compute_crit_pt(polyVAF, pow(2, -20), n=i_new)
                if ret < e_sep:
                    break
                i_new += 1
            # Depth 7 comes from 4 squarings & 3 Ops to compute the first function
            depth_vaf = greedy_depth(i_new) + 7
            
            ret = 1
            i_sq = 0
            # Compute the epsilon for VAFs (from squaring)
            while True:
                ret = compute_crit_pt(squareVAF, pow(2, -20), n=i_sq)
                if ret < e_sep:
                    break
                i_sq += 1
            # Depth 3 comes from computing the first function
            depth_sq = i_sq+3
            
            retdict[(R,n)] = {
                "depth": min(depth_vaf, depth_sq) + 1,
                "depth_from_vaf": min(depth_vaf, depth_sq),
                "e_sep": 1/R,
                "n_fn": i_new if depth_vaf < depth_sq else i_sq,
                "new": depth_vaf < depth_sq,
                "range": R
            }   
            
            curr_depth = min(depth_vaf, depth_sq) + 1
            curr_range = R
            
            # Record & Update the minimum depth data
            if curr_depth < best_depth:
                best_depth = curr_depth
                best_range = curr_range
                best_key = (R,n)
            elif curr_depth == best_depth and curr_range > best_range:
                best_range = curr_range
                best_key = (R,n)
                
        else:
            # Evaluation Result of 1 is also imporant
            f1 = wDEP(1, k = k, L = L, R = R, n = n) / R
            
            # Check If the function monotonically increases from 0 to 1.
            critpt = firstPeak(k, L, R, n)
            
            # If not?
            if critpt < 1:
                # In this case, f(1) would not give a meaningful information
                e_sep = fL
            else:
                # Compute the separation parameter of wDEP
                e_sep = min(fL, f1)
            ret = 1
            i_new = 0
            while True:
                ret = compute_crit_pt(polyVAF, pow(2, -20), n=i_new)
                if ret < e_sep:
                    break
                i_new += 1
            # Compared to previous code, we can save 1 additional depth 
            # by merging the last scalar multiplication of DEP and scalar multiplication for computing 1.5x^2
            depth_vaf = greedy_depth(i_new) + 6
            
            ret = 1
            i_sq = 0
            while True:
                ret = compute_crit_pt(squareVAF, pow(2, -20), n=i_sq)
                if ret < e_sep:
                    break
                i_sq += 1
            depth_sq = i_sq + 2
            
            curr_range = n * math.log2(L) + math.log2(R)
            # Note that DEP requires 2n + 1 depths
            retdict[(R,n)] = {
                "depth": 2 * n + min(depth_vaf, depth_sq) + 1,
                "n_dep": n,
                "e_sep": e_sep,
                "depth_from_vaf": min(depth_vaf, depth_sq),
                "n_fn": i_new if depth_vaf < depth_sq else i_sq,
                "new": depth_vaf < depth_sq,
                "range": curr_range
            }   
            curr_depth = 2 * n + min(depth_vaf, depth_sq) + 1
            
            # Record and Update the minimum depth data
            if curr_depth < best_depth:
                best_depth = curr_depth
                best_range = curr_range
                best_key = (R,n)
            elif curr_depth == best_depth and curr_range > best_range:
                best_range = curr_range
                best_key = (R,n) 
        
    print("Minimal Depth (R,n)=", best_key, retdict[best_key])
        
    return best_key, retdict


### Example Runs

In [8]:
k = 4.5
L = 2
best_key, retdict = DEPSelector(k, L, 1<<12)

k = 27/4
L = 2.59
best_key, retdict = DEPSelector(k, L, 1<<12)

k = 17
L = 4
best_key, retdict = DEPSelector(k, L, 1<<12)

  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (357.265, 4) {'depth': 25, 'n_dep': 4, 'e_sep': 0.006297801763796971, 'depth_from_vaf': 16, 'n_fn': 8, 'new': True, 'range': 12.48085077484337}


  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (91.08999999999999, 4) {'depth': 24, 'n_dep': 4, 'e_sep': 0.010977923238173573, 'depth_from_vaf': 15, 'n_fn': 7, 'new': True, 'range': 12.00102916797928}


  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (1.0, 6) {'depth': 24, 'n_dep': 6, 'e_sep': 0.06255547396645149, 'depth_from_vaf': 11, 'n_fn': 4, 'new': True, 'range': 12.0}


In [9]:
k = 4.5
L = 2
best_key, retdict = DEPSelector(k, L, 1<<16)

k = 27/4
L = 2.59
best_key, retdict = DEPSelector(k, L, 1<<16)

k = 17
L = 4
best_key, retdict = DEPSelector(k, L, 1<<16)

  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (656.3499999999999, 7) {'depth': 31, 'n_dep': 7, 'e_sep': 0.006297683851748112, 'depth_from_vaf': 16, 'n_fn': 8, 'new': True, 'range': 16.358321529937218}


  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (132.07, 7) {'depth': 31, 'n_dep': 7, 'e_sep': 0.007571666680962719, 'depth_from_vaf': 16, 'n_fn': 8, 'new': True, 'range': 16.655823667506642}


  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (5112.73, 2) {'depth': 31, 'n_dep': 2, 'e_sep': 7.766082369278072e-05, 'depth_from_vaf': 26, 'n_fn': 16, 'new': True, 'range': 16.31987812489656}


In [10]:
k = 4.5
L = 2
best_key, retdict = DEPSelector(k, L, 1<<20)

k = 27/4
L = 2.59
best_key, retdict = DEPSelector(k, L, 1<<20)

k = 17
L = 4
best_key, retdict = DEPSelector(k, L, 1<<20)

  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (34603.975, 5) {'depth': 37, 'n_dep': 5, 'e_sep': 7.96346650211341e-05, 'depth_from_vaf': 26, 'n_fn': 16, 'new': True, 'range': 20.078650151035674}


  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (12583.900000000001, 5) {'depth': 37, 'n_dep': 5, 'e_sep': 7.946661995888676e-05, 'depth_from_vaf': 26, 'n_fn': 16, 'new': True, 'range': 20.48405198041052}


  0%|          | 0/1001 [00:00<?, ?it/s]

Minimal Depth (R,n)= (1049.575, 5) {'depth': 37, 'n_dep': 5, 'e_sep': 9.465066080089269e-05, 'depth_from_vaf': 26, 'n_fn': 16, 'new': True, 'range': 20.035589546348895}
