# Kernel Mean Embedding (KME) Analysis

This notebook performs the KME experiments in Section 7.6.

0. **Setup and Parameters**: Create directories and set parameters.  

1. **Compute Sliced Wasserstein Distance**: We compute the SWD first (before computing the kernel), since this is the computationally expensive part, so we only do this once.  

2. **Compute Sliced Wasserstein Kernels**: Compute the Sliced Wasserstein Kernel for various hyperparameters.  

3. **Perform Regression**: Perform parameter estimation and optimize over the SWK hyperparameters.


## 0. Setup and Parameters

In [None]:
using MAT
using Statistics
# using ThreadPools
include("compute_features.jl")
include("PathSignatures.jl")
include("regression_utils.jl")
include("utils.jl")

In [None]:
# Parameters
ss = 50

# Experiments in paper are init20 and random20
tdir = "init20"


# Regression parameters
num_iterations = 100
tr_split = 0.8 # training split
hyp_cv = 4 # number of folds in cross-validation to do hyperparameter optimization

SVR_C = 10. .^(-3:1:1) # SVR C values to optimize over
SVR_eps = 10. .^(-4:1:0) # SVR epsilon values to optimize over

In [None]:
# Create directories
base_fpath = string("data/ss", ss, "/", tdir, "/")

if ~isdir(string(base_fpath, "SW/"))
    mkdir(string(base_fpath, "SW/"))
end

if ~isdir(string(base_fpath, "SW_KE/"))
    mkdir(string(base_fpath, "SW_KE/"))
end

## 1. Compute Sliced Wasserstein Distance

In [None]:
fname = string(base_fpath, "PD/PD.mat")
file = matopen(fname, "r")
B0 = read(file, "B0")
B1 = read(file, "B1")
B2 = read(file, "B2")
close(file)

D_all = batch_SWD(B0, B1, B2, 1, 500, 20)

fname = string(base_fpath, "SWD/SWD.mat")
file = matopen(fname, "w")
write(file, "SWD", D_all)
close(file)



## 2. Kernels

In [None]:
# SMM kernel - init20

file = matopen(string(base_fpath, "SWD/SWD.mat"), "r")
SWD = read(file, "SWD")
close(file)
numRun = 500

KE_fpath = string(base_fpath, "SWD_KE/")

sigma_all = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]
for sigma in sigma_all
    SMMK = zeros(numRun, numRun)

    for i in 1:500
        for j in i:500
            curD = SWD[i,j,:,:]
            K = exp.(-curD./(2*sigma))

            SMMK[i,j] = sum(K[1:20, 21:40])

            if i != j
                SMMK[j,i] = SMMK[i,j]
            end
        end
    end

    KE_fname = string(KE_fpath, "SW_", Int(sigma*1000), ".mat")
    file = matopen(KE_fname, "w")
    write(file, "K", SMMK)
    close(file)
end

## 3. Perform Regression

In [None]:
# SW Hyperparameter search

file = matopen("CL_data.mat", "r")
CL = read(file, "CL")
close(file)

KE_fpath = string(base_fpath,"SWD_KE/")
RG_fpath = string(base_fpath,"RG/")

all_K = readdir(KE_fpath)
numK = length(all_K)

K_all = []
for i = 1:numK
    curK = all_K[i]
    file = matopen(string(KE_fpath, curK),"r")
    K = read(file, "K")
    close(file)
    push!(K_all, K)
end

reg_error, SVR_params = run_regression_multikernel(K_all, K_all, CL, num_iterations, SVR_C, SVR_eps, hyp_cv, tr_split)
        
fname = string(RG_fpath, "SW_SMM.mat")
file = matopen(fname, "w")
write(file, "reg_error", reg_error)
write(file, "SVR_params", SVR_params)
close(file)

reg_mean = mean(reg_error, dims=1)
reg_std = std(reg_error, dims=1)
