# Defaults and Imports

In [340]:
import numpy as np
import os
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

from tqdm import tqdm
from itertools import combinations
from numpy.random import default_rng
from numpy.linalg import norm

In [339]:
path = '../writeup/images/q6'
seed = 0
rng = default_rng(seed)

# Probabilistic JS Lemma

In [308]:
# parameters
eps = 0.1
delta = 0.2
n = 1000
D = 50

assert(eps > 0)
assert(delta < 0.5)

def compute_dim(eps, delta):
    d = np.power(eps, -2) * np.log(1/delta)
    return np.ceil(d).astype(int) # round up and convert to int

In [309]:
def generate_data(n, D):
    # generate normalized data matrix X [n, D]
    X = rng.multivariate_normal(mean=np.array([0]), cov=np.array([[1]]),
                                size=(n,D)).squeeze()
    X /= X.sum(axis=1)[:, np.newaxis]

    # projection dimension
    d = compute_dim(eps, delta)

    # generate projection matrix
    A = rng.multivariate_normal(mean=np.array([0]), cov=np.array([[1/d]]),
                                size=(d,D)).squeeze()
    
    return X, A


In [319]:
def check_distortion(X, A):
    n = X.shape[0]
    distored = 0
    
    # iterate over all data points x_i
    for x in X:
        # check if projected data pt Ax beyond distortion bounds eps
        distored += np.abs((norm(A @ x) ** 2) - 1) > eps # distored = 1, undistorted = 0
    
    return distored / n # normalize

In [320]:
X, A = generate_data(n, D)
distortion_percentage = check_distortion(X, A)
print(distortion_percentage)

0.95


# Deterministic JL Lemma

In [324]:
eps = 0.1 # distortion bound
n = 15 # num data points
d = 300 # original dimension

def projection_dim(eps, n):
    t1 = np.power(eps, 2)/2 - np.power(eps, 3)/3
    t2 = (4 / t1) * np.log(n)
    return np.ceil(t2).astype(int)

In [325]:
def generate_data(n, d):
    X = rng.multivariate_normal(mean=np.array([0]), cov=np.array([[1]]),
                                size=(n,d)).squeeze() # [n x d]
    # X = np.eye(n, d)
    k = projection_dim(eps, n) # projection dim
    A = rng.multivariate_normal(mean=np.array([0]), cov=np.array([[1/k]]),
                                size=(k,d)).squeeze() # [k x d]
    return X, A

    # A x = [k x d] x [d x 1] = [k x 1]

def check_projection(X, A, eps):
    lower_list, upper_list, new_dist_list = [], [], []
    within_bounds = 0 # num pairwise distances that are out of bounds
    n = X.shape[0]
    
    # iterate pairwise over all data points u, v \in X
    for i,j in tqdm(combinations(range(X.shape[0]), 2)):
        u, v = X[i,:], X[j,:]
        original_dist = norm(u-v) ** 2
        new_dist = norm(A @ u - A @ v) ** 2
                
        lower_bound = (1 - eps) * original_dist
        upper_bound = (1 + eps) * original_dist
        within_bounds += (lower_bound <= new_dist <= upper_bound) # True = 1, False = 0
        
        lower_list.append(lower_bound)
        upper_list.append(upper_bound)
        new_dist_list.append(new_dist)
        
    within_bounds /= (n*(n-1)/2) # normalize
    return within_bounds, lower_list, upper_list, new_dist_list

In [326]:
X, A = generate_data(n, d)
within_bounds, lower, upper, new_dist = check_projection(X, A, eps)

print(f'{within_bounds*100}% of pairwise distances are within bounds')


105it [00:00, 2546.32it/s]

100.0% of pairwise distances are within bounds





In [341]:
def plot_distortion(lower, upper, new_dist):
    lower_df = pd.DataFrame({'Lower Distortion Bound': lower})
    upper_df = pd.DataFrame({'Upper Distortion Bound': upper})
    new_dist_df = pd.DataFrame({'Embedding Distortion': new_dist})
    
    fig1 = px.scatter(lower_df)
    fig2 = px.scatter(upper_df)
    fig3 = px.scatter(new_dist_df, color_discrete_sequence=['black'])
    fig = go.Figure(data=fig1.data + fig2.data + fig3.data).update_layout(
            xaxis_title="Data Point Pair", yaxis_title="Distance",
            margin_l=5, margin_t=5, margin_b=5, margin_r=5,
            font_family="Serif", font_size=14)
    fig.update_traces(marker={'size': 3})
    fig.layout.showlegend = True
    
    fname = 'distortion.png'
    fig3.write_image(os.path.join(path, fname))
    fig.show()
    
plot_distortion(lower, upper, new_dist)