In [1]:
import os
import numpy as np
import torch
import torch.optim as optim
import torch.distributions as D
from src.Tmy_svgd import etmySVGD
from src.svgd import SVGD
from src.gsvgd import FullGSVGDBatch
from src.kernel import RBF, BatchRBF
from src.utils import plot_particles

from src.manifold import Grassmann
from src.s_svgd import SlicedSVGD
from src.mysvgd import mySVGD
from src.rand_mysvgd import min_mySVGD
import pickle
import argparse
import time
import torch
import torch.optim as optim
import torch.autograd as autograd
import autograd.numpy as np
from tqdm import tqdm, trange

import torch.distributions as D
import gc

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


ImportError: cannot import name 'tmySVGD' from 'src.Tmy_svgd' (/home/zhoujk/SVGD/SVGD_code/GSVGD-main/src/Tmy_svgd.py)

In [None]:
parser = argparse.ArgumentParser(description='Running xshaped experiment.')
d = 180
parser.add_argument('--dim', type=int, default=d, help='dimension')
s = 2000
parser.add_argument('--effdim', type=int, default=-1, help='dimension')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--delta', type=float,default=0.01,help='stepsize for projections')
parser.add_argument('--T', type=float, default=1e-4, help='noise multiplier for projections')
parser.add_argument('--lr_g', type=float, default=0.1, help='learning rate for g')
parser.add_argument('--nparticles', type=int,default=s, help='no. of particles')
parser.add_argument('--epochs', type=int, default=1000000000000, help='no. of epochs')
parser.add_argument('--metric', type=str, default="energy", help='distance metric')
parser.add_argument('--noise', type=str, default="True", help='whether to add noise')
parser.add_argument('--kernel', type=str, default="rbf", help='kernel')
parser.add_argument('--gpu', type=int, default=3, help='gpu')
parser.add_argument('--seed', type=int, default=235, help='random seed')
parser.add_argument('--suffix', type=str, default="", help='suffix for res folder')
parser.add_argument('--m', type=int, help='no. of projections')
parser.add_argument('--save_every', type=int, default=200, help='step intervals to save particles')
parser.add_argument('--method', type=str, default="all", help='which method to use')

In [None]:
args = parser.parse_args([])
dim = args.dim
lr = args.lr
lr_gsvgd = args.lr
delta = args.delta
T = args.T
nparticles = args.nparticles
epochs = args.epochs
seed = args.seed
eff_dims = [args.effdim] if args.effdim > 0 else [1, 2, 5]
save_every = args.save_every # save metric values
print(f"Running for dim: {dim}, lr: {lr}, nparticles: {nparticles}")

In [None]:
device = torch.device(f'cuda:{args.gpu}' if args.gpu != -1 else 'cpu')

metric = args.metric

results_folder = f"./res/gaussian{args.suffix}/{args.kernel}_epoch{epochs}_lr{lr}_delta{delta}_n{nparticles}_dim{dim}"
results_folder = f"{results_folder}/seed{seed}"

In [None]:
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

if args.kernel == "rbf":
    Kernel = RBF
    BatchKernel = BatchRBF

In [None]:
def comm_func_eval(samples, ground_truth):

    samples = samples.clone()
    ground_truth = ground_truth.clone()

    def ex():
        f0 = torch.mean(samples, axis=0)
        f1 = torch.mean(ground_truth, axis=0)
        return torch.mean((f0-f1)**2)

    def exsqr():
        f0 = torch.var(samples, axis=0)
        f1 = torch.var(ground_truth, axis=0)
        return torch.mean((f0-f1)**2)


    out = {}
    out['mean_dis'] = ex()
    out['var_dis'] = exsqr()
    return out

In [None]:
print(f"Device: {device}")
torch.manual_seed(seed)

## target density
means = torch.zeros(dim, device=device)

torch.manual_seed(0)
'''A
A = torch.randn(dim,dim).to('cuda') * 0.9
A = torch.matmul(A, A.T)

m = torch.max(A) 
B = torch.eye(dim).to('cuda') * m + 0.1
diag = torch.diag(A)
cov = A + B'''

cov = torch.eye(dim, device=device)

distribution = D.MultivariateNormal(means.to(device), cov)

# sample from target (for computing metric)
x_target = distribution.sample((nparticles, ))
# sample from variational density
torch.manual_seed(235)
x_init = 2 + np.sqrt(2) * torch.randn(nparticles, *distribution.event_shape, device=device)



    
    

In [None]:
## SVGD
if args.method in ["SVGD", "all"]:
    print("Running SVGD")
    # sample from variational density
    x = x_init.clone().to(device)
    kernel = Kernel(method="med_heuristic")
    svgd = SVGD(distribution, kernel, optim.Adam([x], lr=lr), device=device)
    start = time.time()
    svgd.fit(x, epochs, verbose=True, save_every=save_every)
    elapsed_time_svgd = time.time() - start

    # plot particles
    fig_svgd = plot_particles(
        x_init.detach(), 
        x.detach(), 
        distribution, 
        d=6.0, 
        step=0.1, 
        concat=means[2:],
        savedir=results_folder + f"/svgd.png"
    )

In [None]:
theta  = x
index_svgd = []
samn_svgd = []
for i in range(theta.shape[0]):
    samn_svgd.append(torch.linalg.norm(theta[i,:].cpu()).item())
    index_svgd.append(i)

import matplotlib.pyplot as plt
plt.scatter(index_svgd, samn_svgd, c='blue')
cov_svgd = torch.cov(theta.T)
print(torch.linalg.norm(cov - cov_svgd))

print(comm_func_eval(theta, x_target))
print(torch.trace(cov - cov_svgd))
(evals, evecs) = torch.linalg.eig(cov - cov_svgd)
print(evals[0])

In [None]:
print("Running min_mySVGD")


def score(X):
        X_cp = X.clone().detach().requires_grad_()
        log_prob = distribution.log_prob(X_cp)
        score_func = autograd.grad(log_prob.sum(), X_cp)[0]
        return score_func


# sample from variational density
res = []
rres = []
steps = []
lr = 0.1
x0 = x_init
vector1  = torch.randn(nparticles, dim).to(device)
model  = tmySVGD(kernel, device)

    
theta, vector = model.update(x0, score,  k = 2, n_iter =5000,  debug = False, lr= lr, vector=vector1)
    #mean = np.mean(theta, axis=0)  + np.random.random(1)
    #var_theta = np.cov(theta.T) + np.random.random(1)
    #x0 = np.random.multivariate_normal(mean, var_theta,num)
    

In [None]:

index_svgd = []
samn_svgd = []
for i in range(theta.shape[0]):
    samn_svgd.append(torch.linalg.norm(theta[i,:].cpu()).item())
    index_svgd.append(i)

import matplotlib.pyplot as plt
plt.scatter(index_svgd, samn_svgd, c='blue')
cov_mysvgd = torch.cov(theta.T)
print(torch.linalg.norm(cov - cov_mysvgd))
print(comm_func_eval(theta, x_target))
print(torch.trace(cov - cov_mysvgd))
(evals, evecs) = torch.linalg.eig(cov - cov_mysvgd)
print(evals[0])

In [None]:
theta

In [None]:

fig_svgd = plot_particles(
        x_init.detach(), 
        theta.detach(), 
        distribution, 
        d=6.0, 
        step=0.1, 
        concat=means[2:]
    )

In [None]:
## GSVGD
if args.method in ["GSVGD", "all"]:
    res_gsvgd = [0] * len(eff_dims)
    def run_gsvgd(eff_dims):
        for i, eff_dim in enumerate(eff_dims):
            print(f"Running GSVGD with eff dim = {eff_dim}")

            if args.m is None:
                m = min(20, dim // eff_dim)
            elif args.m == -1:
                m = dim // eff_dim
            else:
                m = args.m
            print("number of projections:", m)

            # sample from variational density
            x_init_gsvgd = x_init.clone()
            x_gsvgd = x_init_gsvgd.clone()

            kernel_gsvgd = BatchKernel(method="med_heuristic")
            optimizer = optim.Adam([x_gsvgd], lr=lr_gsvgd)
            manifold = Grassmann(dim, eff_dim)
            U = torch.eye(dim).requires_grad_().to(device)
            U = U[:, :(m*eff_dim)]

            gsvgd = FullGSVGDBatch(
                target=distribution,
                kernel=kernel_gsvgd,
                manifold=manifold,
                optimizer=optimizer,
                delta=delta,
                T=T,
                device=device
            )
            start = time.time()
            U, metric_gsvgd = gsvgd.fit(x_gsvgd, U, m, epochs, 
                verbose=True, save_every=save_every, threshold=0.0001*m)
            elapsed_time = time.time() - start

            # plot particles
            fig_gsvgd = plot_particles(
                x_init_gsvgd.detach(), 
                x_gsvgd.detach(), 
                distribution, 
                d=6.0, 
                step=0.1, 
                concat=means[2:],
                savedir=results_folder + f"/fullgsvgd_effdim{eff_dim}_lr{lr_gsvgd}_delta{delta}_m{m}_T{T}.png"
            )

            # store results
            res_gsvgd[i] = {"init":x_init_gsvgd, "final":x_gsvgd, "metric":metric_gsvgd, 
                "fig":fig_gsvgd, "particles":gsvgd.particles, "pam":gsvgd.pam, "res": gsvgd,
                "elapsed_time": elapsed_time}
        return res_gsvgd, x_gsvgd

    res_gsvgd, x_gsvgd = run_gsvgd(eff_dims)

In [None]:
theta  = x_gsvgd
index_svgd = []
samn_svgd = []
for i in range(theta.shape[0]):
    samn_svgd.append(torch.linalg.norm(theta[i,:].cpu()).item())
    index_svgd.append(i)

import matplotlib.pyplot as plt
plt.scatter(index_svgd, samn_svgd, c='blue')
cov_gsvgd = torch.cov(theta.T)
print(torch.linalg.norm(cov - cov_gsvgd))
print(comm_func_eval(theta, x_target))
print(torch.trace(cov - cov_gsvgd))
(evals, evecs) = torch.linalg.eig(cov - cov_gsvgd)
print(evals[0])

In [None]:

    
## S-SVGD
if args.method in ["S-SVGD", "all"]:
    print("Running S-SVGD")
    # sample from variational density
    x_init_s_svgd = x_init.clone()
    x_s_svgd = x_init_s_svgd.clone().requires_grad_()
    s_svgd = SlicedSVGD(distribution, device=device)

    start = time.time()
    x_s_svgd, metric_s_svgd = s_svgd.fit(
        samples=x_s_svgd, 
        n_epoch=epochs, 
        lr=args.lr_g,
        eps=lr,
        save_every=save_every
    )
    elapsed_time_s_svgd = time.time() - start

    # plot particles
    fig_s_svgd = plot_particles(
        x_init_s_svgd.detach(), 
        x_s_svgd.detach(), 
        distribution, 
        d=6.0, 
        step=0.1, 
        concat=means[2:],
        savedir=results_folder + f"/ssvgd_lr{lr}_lrg{args.lr_g}.png"
    )




In [None]:
theta  = x_s_svgd
index_svgd = []
samn_svgd = []
for i in range(theta.shape[0]):
    samn_svgd.append(torch.linalg.norm(theta[i,:].cpu()).item())
    index_svgd.append(i)

import matplotlib.pyplot as plt
plt.scatter(index_svgd, samn_svgd, c='blue')
cov_ssvgd = torch.cov(theta.T)
print(torch.linalg.norm(cov - cov_ssvgd))
print(comm_func_eval(theta, x_target))
print(torch.trace(cov - cov_ssvgd))
(evals, evecs) = torch.linalg.eig(cov - cov_ssvgd)
print(evals[0])