In [22]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import sys
sys.path.append('../')
from utils import vmf as vmf_utils
from methods import s3wd as s3w, wd as wd

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
random_seed = 1
np.random.seed(random_seed)
torch.manual_seed(random_seed)

def vmf_at_north_pole(N, kappa, device='cpu'):
    mu = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=device)
    return torch.tensor(vmf_utils.rand_vmf(mu.cpu(), kappa=kappa, N=N)).float()

def vmf_at_south_pole(N, kappa, device='cpu'):
    mu = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device=device)
    return torch.tensor(vmf_utils.rand_vmf(mu.cpu(), kappa=kappa, N=N)).float()

def test_sensitivity(d_func, d_args, epsilon_list, n_runs, device='cpu'):
    kappa = 1000
    N = 2048
    X = vmf_at_north_pole(N, kappa, device)
    Xt = vmf_at_south_pole(N, kappa, device)

    results = {}
    for epsilon in epsilon_list:
        d_args['eps'] = epsilon
        distances = []
        for _ in range(n_runs):
            w = d_func(X, Xt, **d_args)
            distances.append(w.item())
        mean_distance = np.mean(distances)
        std_distance = np.std(distances)
        results[epsilon] = (mean_distance, std_distance)

    for epsilon, (mean_distance, std_distance) in results.items():
        print(f"Epsilon: {epsilon:.4f}, Mean Distance: {mean_distance:.4f}, Std Distance: {std_distance:.4f}")


eps = [1e-6,5e-6, 1e-5,5e-5,1e-4, 5e-4,1e-3, 5e-3,1e-2,5e-2]


In [31]:
print("Set 1: 128 projections")
d_args_128 = {'p': 2, 'device': device, 'n_projs': 128}
n_runs = 100

Set 1: 128 projections


In [32]:
d_func = s3w.s3wd
s3w_results_128 = test_sensitivity(d_func, d_args_128, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 957.7993, Std Distance: 1.6724
Epsilon: 0.0000, Mean Distance: 958.0859, Std Distance: 2.0987
Epsilon: 0.0000, Mean Distance: 958.1065, Std Distance: 2.0977
Epsilon: 0.0001, Mean Distance: 958.0544, Std Distance: 1.7864
Epsilon: 0.0001, Mean Distance: 957.6334, Std Distance: 1.9493
Epsilon: 0.0005, Mean Distance: 956.7641, Std Distance: 1.7809
Epsilon: 0.0010, Mean Distance: 954.5304, Std Distance: 2.1129
Epsilon: 0.0050, Mean Distance: 939.0492, Std Distance: 1.9431
Epsilon: 0.0100, Mean Distance: 925.8701, Std Distance: 1.6517
Epsilon: 0.0500, Mean Distance: 871.6674, Std Distance: 1.7789


In [33]:
d_func = s3w.ri_s3wd
d_args_128['n_rotations'] = 1
ri_s3w_1_results_128 = test_sensitivity(d_func, d_args_128, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 1771.9628, Std Distance: 249.2399
Epsilon: 0.0000, Mean Distance: 1858.7635, Std Distance: 226.5994
Epsilon: 0.0000, Mean Distance: 1803.3944, Std Distance: 290.0595
Epsilon: 0.0001, Mean Distance: 1829.8545, Std Distance: 259.4876
Epsilon: 0.0001, Mean Distance: 1831.2919, Std Distance: 259.3316
Epsilon: 0.0005, Mean Distance: 1838.7870, Std Distance: 287.2997
Epsilon: 0.0010, Mean Distance: 1780.9043, Std Distance: 274.3009
Epsilon: 0.0050, Mean Distance: 1870.2810, Std Distance: 251.9407
Epsilon: 0.0100, Mean Distance: 1791.5991, Std Distance: 282.1853
Epsilon: 0.0500, Mean Distance: 1844.3739, Std Distance: 249.5595


In [34]:
d_func = s3w.ri_s3wd
d_args_128['n_rotations'] = 10
ri_s3w_10_results_128 = test_sensitivity(d_func, d_args_128, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 1827.1313, Std Distance: 82.7943
Epsilon: 0.0000, Mean Distance: 1824.0676, Std Distance: 92.0620
Epsilon: 0.0000, Mean Distance: 1837.7344, Std Distance: 74.9027
Epsilon: 0.0001, Mean Distance: 1828.4397, Std Distance: 75.1069
Epsilon: 0.0001, Mean Distance: 1817.1224, Std Distance: 77.1129
Epsilon: 0.0005, Mean Distance: 1824.8650, Std Distance: 66.5573
Epsilon: 0.0010, Mean Distance: 1822.0892, Std Distance: 78.9655
Epsilon: 0.0050, Mean Distance: 1823.6524, Std Distance: 76.0546
Epsilon: 0.0100, Mean Distance: 1832.7319, Std Distance: 75.4223
Epsilon: 0.0500, Mean Distance: 1813.2539, Std Distance: 81.1063


In [35]:
d_func = s3w.ari_s3wd
d_args_128['n_rotations'] = 30
ari_s3w_30_results_128 = test_sensitivity(d_func, d_args_128, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 1866.4228, Std Distance: 31.7251
Epsilon: 0.0000, Mean Distance: 1868.2872, Std Distance: 35.2663
Epsilon: 0.0000, Mean Distance: 1864.0660, Std Distance: 35.6205
Epsilon: 0.0001, Mean Distance: 1862.0703, Std Distance: 33.8414
Epsilon: 0.0001, Mean Distance: 1862.2889, Std Distance: 34.4164
Epsilon: 0.0005, Mean Distance: 1867.2440, Std Distance: 33.0247
Epsilon: 0.0010, Mean Distance: 1866.4241, Std Distance: 32.1685
Epsilon: 0.0050, Mean Distance: 1864.3097, Std Distance: 32.0242
Epsilon: 0.0100, Mean Distance: 1863.0574, Std Distance: 34.9036
Epsilon: 0.0500, Mean Distance: 1860.0517, Std Distance: 35.2010


In [36]:
print("Set 2: 512 projections")
d_args_512 = {'p': 2, 'device': device, 'n_projs': 512}

Set 2: 512 projections


In [37]:
d_func = s3w.s3wd
s3w_results_512 = test_sensitivity(d_func, d_args_512, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 956.8493, Std Distance: 0.3606
Epsilon: 0.0000, Mean Distance: 956.9313, Std Distance: 0.3472
Epsilon: 0.0000, Mean Distance: 956.8701, Std Distance: 0.3873
Epsilon: 0.0001, Mean Distance: 956.8674, Std Distance: 0.3550
Epsilon: 0.0001, Mean Distance: 956.7620, Std Distance: 0.3746
Epsilon: 0.0005, Mean Distance: 955.5398, Std Distance: 0.3477
Epsilon: 0.0010, Mean Distance: 953.3782, Std Distance: 0.3911
Epsilon: 0.0050, Mean Distance: 938.0304, Std Distance: 0.3898
Epsilon: 0.0100, Mean Distance: 925.1433, Std Distance: 0.3116
Epsilon: 0.0500, Mean Distance: 870.8593, Std Distance: 0.3439


In [38]:
d_func = s3w.ri_s3wd
d_args_512['n_rotations'] = 1
ri_s3w_1_results_512 = test_sensitivity(d_func, d_args_512, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 1819.6581, Std Distance: 219.2019
Epsilon: 0.0000, Mean Distance: 1817.0251, Std Distance: 211.0000
Epsilon: 0.0000, Mean Distance: 1828.9268, Std Distance: 228.0387
Epsilon: 0.0001, Mean Distance: 1785.1691, Std Distance: 245.0443
Epsilon: 0.0001, Mean Distance: 1862.3419, Std Distance: 222.5437
Epsilon: 0.0005, Mean Distance: 1814.1375, Std Distance: 233.9971
Epsilon: 0.0010, Mean Distance: 1876.2494, Std Distance: 201.2994
Epsilon: 0.0050, Mean Distance: 1831.1130, Std Distance: 219.0665
Epsilon: 0.0100, Mean Distance: 1829.4150, Std Distance: 241.5372
Epsilon: 0.0500, Mean Distance: 1811.2832, Std Distance: 232.5653


In [39]:
d_func = s3w.ri_s3wd
d_args_512['n_rotations'] = 10
ri_s3w_10_results_512 = test_sensitivity(d_func, d_args_512, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 1818.8249, Std Distance: 73.4450
Epsilon: 0.0000, Mean Distance: 1818.2983, Std Distance: 69.1194
Epsilon: 0.0000, Mean Distance: 1824.9081, Std Distance: 79.5574
Epsilon: 0.0001, Mean Distance: 1829.2710, Std Distance: 67.0221
Epsilon: 0.0001, Mean Distance: 1813.8406, Std Distance: 77.5415
Epsilon: 0.0005, Mean Distance: 1825.4436, Std Distance: 78.7873
Epsilon: 0.0010, Mean Distance: 1826.7527, Std Distance: 79.6352
Epsilon: 0.0050, Mean Distance: 1816.8614, Std Distance: 78.7885
Epsilon: 0.0100, Mean Distance: 1817.2168, Std Distance: 77.8463
Epsilon: 0.0500, Mean Distance: 1819.8834, Std Distance: 77.0752


In [40]:
d_func = s3w.ari_s3wd
d_args_512['n_rotations'] = 30
ari_s3w_30_results_512 = test_sensitivity(d_func, d_args_512, eps, n_runs, device)

Epsilon: 0.0000, Mean Distance: 1865.0331, Std Distance: 28.3183
Epsilon: 0.0000, Mean Distance: 1868.2435, Std Distance: 31.2540
Epsilon: 0.0000, Mean Distance: 1865.6800, Std Distance: 30.4766
Epsilon: 0.0001, Mean Distance: 1868.1232, Std Distance: 32.6014
Epsilon: 0.0001, Mean Distance: 1866.7863, Std Distance: 31.8164
Epsilon: 0.0005, Mean Distance: 1862.7934, Std Distance: 29.9063
Epsilon: 0.0010, Mean Distance: 1866.6707, Std Distance: 30.8366
Epsilon: 0.0050, Mean Distance: 1864.5816, Std Distance: 35.8230
Epsilon: 0.0100, Mean Distance: 1866.7260, Std Distance: 30.6681
Epsilon: 0.0500, Mean Distance: 1869.8366, Std Distance: 28.5427
