In [1]:
import os
import numpy as np
import torch
import torch.optim as optim
import torch.distributions as D
from tqdm import tqdm, trange
from src.svgd import SVGD
from src.gsvgd import FullGSVGDBatch
from src.kernel import RBF, BatchRBF
from src.utils import plot_particles
from src.Tmy_svgd import tmySVGD
from src.manifold import Grassmann
from src.s_svgd import SlicedSVGD
from src.mysvgd import etmySVGD
import matplotlib.pyplot as plt
from src.rand_mysvgd import min_mySVGD

import pickle
import argparse
import time

import torch.autograd as autograd
from scipy.stats import energy_distance


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser(description='Running xshaped experiment.')
parser.add_argument('--dim', type=int,default=5, help='dimension')
parser.add_argument('--effdim', type=int, default=-1, help='dimension')
parser.add_argument('--lr', type=float,default=0.01, help='learning rate')
parser.add_argument('--lr_g', type=float, default=0.1, help='learning rate for g')
parser.add_argument('--delta', type=float,default=0.01, help='stepsize for projections')
parser.add_argument('--T', type=float, default=1e-4, help='noise multiplier for projections')
parser.add_argument('--nparticles', type=int,default=100, help='no. of particles')
parser.add_argument('--epochs', type=int, default=50000,help='no. of epochs')
parser.add_argument('--nmix', type=int, default=2, help='no. of modes')
parser.add_argument('--metric', type=str, default="energy", help='distance metric')
parser.add_argument('--noise', type=str, default="True", help='whether to add noise')
parser.add_argument('--kernel', type=str, default="rbf", help='kernel')
parser.add_argument('--gpu', type=int, default=3, help='gpu')
parser.add_argument('--seed', type=int, default=0, help='random seed') 
parser.add_argument('--suffix', type=str, default="", help='suffix for res folder')
parser.add_argument('--m', type=int, help='no. of projections')
parser.add_argument('--save_every', type=int, default=200, help='step intervals to save particles')
parser.add_argument('--method', type=str, default="all", help='which method to use')



_StoreAction(option_strings=['--method'], dest='method', nargs=None, const=None, default='all', type=<class 'str'>, choices=None, required=False, help='which method to use', metavar=None)

In [3]:
args = parser.parse_args([])
device = torch.device(f'cuda:{args.gpu}' if args.gpu != -1 else 'cpu')
dim = args.dim
lr = args.lr
delta = args.delta
T = args.T
nparticles = args.nparticles
epochs = args.epochs
seed = args.seed
eff_dims = [args.effdim] if args.effdim > 0 else [1, 2, 5]
nmix = args.nmix
add_noise = True if args.noise == "True" else False
radius = 5
save_every = args.save_every
print(f"Running for dim: {dim}, lr: {lr}, nparticles: {nparticles}")

Running for dim: 5, lr: 0.01, nparticles: 100


In [4]:
def comm_func_eval(samples, ground_truth):

    samples = samples.clone()
    ground_truth = ground_truth.clone()

    def ex():
        f0 = torch.mean(samples, axis=0)
        f1 = torch.mean(ground_truth, axis=0)
        return torch.mean((f0-f1)**2)

    def exsqr():
        f0 = torch.var(samples, axis=0)
        f1 = torch.var(ground_truth, axis=0)
        return torch.mean((f0-f1)**2)


    out = {}
    out['mean_dis'] = ex()
    out['var_dis'] = exsqr()
    return out

In [5]:

device = torch.device(f'cuda:{args.gpu}' if args.gpu != -1 else 'cpu')

metric = args.metric

results_folder = f"./res/multimodal{args.suffix}/{args.kernel}_epoch{epochs}_lr{lr}_delta{delta}_n{nparticles}_dim{dim}"
results_folder = f"{results_folder}/seed{seed}"

In [6]:
def mix_gauss_experiment(mixture_dist, means):
    '''Mixture of Multivariate gaussian with cov matrices being the identity.
    Args:
        probs: Tensor of shape (nmix,) for the mixture_distribution.
        means: Tensor of shape (nmix, d), where nmix is the number of components 
            and d is the dimension of each component.
    '''
    nmix = means.shape[0]
    comp = D.Independent(D.Normal(means.to(device), torch.ones((nmix, means.shape[1]), device=device)), 1)
    distribution = D.mixture_same_family.MixtureSameFamily(mixture_dist, comp) 
    return distribution


def points_on_circle(theta, rad):
    '''Generate d-dim points whose first two dimensions lies on a circle of 
    radius rad, with position being specified by the angle from the positive 
    x-axis theta.
    '''
    return torch.Tensor([[rad * np.cos(theta + 0.25*np.pi), rad * np.sin(theta + 0.25*np.pi)]])

In [7]:

if not os.path.exists(results_folder):
    os.makedirs(results_folder)

if args.kernel == "rbf":
    Kernel = RBF
    BatchKernel = BatchRBF
kernel = Kernel(method="med_heuristic")

In [8]:
args = parser.parse_args([])
device = torch.device(f'cuda:{args.gpu}' if args.gpu != -1 else 'cpu')
dim = args.dim
lr = args.lr
delta = args.delta
T = args.T
nparticles = args.nparticles
epochs = args.epochs
seed = args.seed
eff_dims = [args.effdim] if args.effdim > 0 else [1, 2, 5]
nmix = args.nmix
add_noise = True if args.noise == "True" else False
radius = 5
save_every = args.save_every
print(f"Running for dim: {dim}, lr: {lr}, nparticles: {nparticles}")

def comm_func_eval(samples, ground_truth):

    samples = samples.clone()
    ground_truth = ground_truth.clone()

    def ex():
        f0 = torch.mean(samples, axis=0)
        f1 = torch.mean(ground_truth, axis=0)
        return torch.mean((f0-f1)**2)

    def exsqr():
        f0 = torch.var(samples, axis=0)
        f1 = torch.var(ground_truth, axis=0)
        return torch.mean((f0-f1)**2)


    out = {}
    out['mean_dis'] = ex()
    out['var_dis'] = exsqr()
    return out


def score(X):
        X_cp = X.clone().detach().requires_grad_()
        log_prob = distribution.log_prob(X_cp)
        score_func = autograd.grad(log_prob.sum(), X_cp)[0]
        return score_func

def energy_dis(x, x_target, dim):
    a = 0
    for i in range(x.shape[1]):
        x_dim = x[:,i]
        x_target_dim = x_target[:,i]
        a = a + energy_distance(x_dim.cpu().detach().numpy(), x_target_dim.cpu().detach().numpy())
    
    a = a / dim
    
    return a



if args.kernel == "rbf":
    Kernel = RBF
    BatchKernel = BatchRBF

Running for dim: 5, lr: 0.01, nparticles: 100


In [9]:
print(f"Device: {device}")
torch.manual_seed(seed)
list_norm = torch.zeros(20, 4)
list_gr_var = torch.zeros(20, 4)
list_tr = torch.zeros(20, 4)
list_eig = torch.zeros(20, 4)
list_energy = torch.zeros(20, 4)

s = -1

Device: cuda:3


In [10]:
for dim in range(5,90, 5):
    print(f"Running for dim: {dim}")
    print("#####################################################")
    s = s+1
    ## target density
    
    ## target density
    mix_means = torch.cat(
        [points_on_circle(i * 2*np.pi / nmix, rad=radius) for i in range(nmix)]).to(device)
    mix_means = torch.cat((mix_means, torch.zeros((mix_means.shape[0], dim - 2), device=device)), dim=1)

    distribution = mix_gauss_experiment(
        mixture_dist=D.Categorical(torch.ones(mix_means.shape[0], device=device)),
        means=mix_means
    )


    # sample from target (for computing metric)
    x_target = distribution.sample((nparticles, )).to(device)
    cov = torch.cov(x_target)

    # sample from variational density
    x_init =  torch.randn(nparticles, *distribution.event_shape).to(device)

    ## SVGD
    cov = torch.cov(x_target.T)

    if args.method in ["SVGD", "all"]:
        

        print("Running SVGD >>>>>>>>>>>>>>>>>>")
        # sample from variational density
        x = x_init.clone().to(device)
        kernel = Kernel(method="med_heuristic")
        svgd = SVGD(distribution, kernel, optim.Adam([x], lr=lr), device=device)
        
        svgd.fit(x, epochs, verbose=True, save_every=save_every)
        

    theta = x
        
    


    cov_svgd = torch.cov(theta.T)
    print(f"norm_mse of svgd : {torch.linalg.norm(cov - cov_svgd)}")
    list_norm[s,0] = torch.linalg.norm(cov - cov_svgd)
    print(f"mmd of svgd : {comm_func_eval(theta, x_target)['var_dis']}")
    list_gr_var[s,0] = comm_func_eval(theta, x_target)['var_dis']
    print(f"trace of svgd: {torch.trace(cov - cov_svgd)}")
    list_tr[s,0] = torch.trace(cov - cov_svgd)
    (evals, evecs) = torch.linalg.eig(cov - cov_svgd)
    print(f"eig of svgd is : {evals[0]}")
    list_eig[s,0] = evals[0]
    energy  = energy_dis(theta, x_target, dim)
    print(f"energy of svgd is : {energy}")
    list_energy[s, 0] = energy


    print('Running mysvgd >>>>>>>>>>>>>>>>>>>>>>>')
    

    
    x0 = x_init.clone().to(device)
    vector1  = torch.randn(nparticles, dim).to(device)


        
    lr = 0.01
    theta, vector = etmySVGD(kernel,device).update(x0, score, k = 2, n_iter = 20000,   lr= lr, vector=vector1)
        #mean = np.mean(theta, axis=0)  + np.random.random(1)
        #var_theta = np.cov(theta.T) + np.random.random(1)
        #x0 = np.random.multivariate_normal(mean, var_theta,num)
        
        

    cov_mysvgd = torch.cov(theta.T)
    print(f"norm_mse of mysvgd : {torch.linalg.norm(cov - cov_mysvgd)}")
    list_norm[s,1] = torch.linalg.norm(cov - cov_mysvgd)
    print(f"mmd of mysvgd : {comm_func_eval(theta, x_target)['var_dis']}")
    list_gr_var[s,1] = comm_func_eval(theta, x_target)['var_dis']
    print(f"trace of mysvgd: {torch.trace(cov - cov_mysvgd)}")
    list_tr[s,1] = torch.trace(cov - cov_mysvgd)
    (evals, evecs) = torch.linalg.eig(cov - cov_mysvgd)
    print(f"eig of mysvgd is : {evals[0]}")
    list_eig[s,1] = evals[0]
    energy  = energy_dis(theta, x_target, dim)
    print(f"energy of mysvgd is : {energy}")
    list_energy[s, 1] = energy

    if args.method in ["GSVGD", "all"]:
        res_gsvgd = [0] * len(eff_dims)
        def run_gsvgd(eff_dims):
            for i, eff_dim in enumerate(eff_dims):
                print(f"Running GSVGD with eff dim = {eff_dim}")

                m = min(20, dim // eff_dim) if args.m is None else args.m
                print("number of projections:", m)

                # sample from variational density
                x_init_gsvgd = x_init.clone()
                x_gsvgd = x_init_gsvgd.clone()

                kernel_gsvgd = BatchKernel(method="med_heuristic")
                optimizer = optim.Adam([x_gsvgd], lr=lr)
                manifold = Grassmann(dim, eff_dim)
                U = torch.eye(dim, device=device).requires_grad_(True)
                U = U[:, :(m*eff_dim)]

                gsvgd = FullGSVGDBatch(
                    target=distribution,
                    kernel=kernel_gsvgd,
                    manifold=manifold,
                    optimizer=optimizer,
                    delta=delta,
                    T=T,
                    device=device,
                    noise=add_noise
                )
                start = time.time()
                U, metric_gsvgd = gsvgd.fit(x_gsvgd, U, m, epochs, 
                    verbose=True, save_every=save_every, threshold=0.0001*m)
                elapsed_time = time.time() - start

                

                
            return res_gsvgd,x_gsvgd

    res_gsvgd ,x_gsvgd= run_gsvgd(eff_dims)
    theta = x_gsvgd
    cov_gsvgd = torch.cov(theta.T)
    print(f"norm_mse of mysvgd : {torch.linalg.norm(cov - cov_gsvgd)}")
    list_norm[s,2] = torch.linalg.norm(cov - cov_gsvgd)
    print(f"mmd of mysvgd : {comm_func_eval(theta, x_target)['var_dis']}")
    list_gr_var[s,2] = comm_func_eval(theta, x_target)['var_dis']
    print(f"trace of mysvgd: {torch.trace(cov - cov_gsvgd)}")
    list_tr[s,2] = torch.trace(cov - cov_gsvgd)
    (evals, evecs) = torch.linalg.eig(cov - cov_gsvgd)
    print(f"eig of mysvgd is : {evals[0]}")
    list_eig[s,2] = evals[0]
    energy  = energy_dis(theta, x_target, dim)
    print(f"energy of mysvgd is : {energy}")
    list_energy[s, 2] = energy

    if args.method in ["S-SVGD", "all"]:
        # sample from variational density
        print("Running S-SVGD >>>>>>>>>>>>>>>>>>>>>>>")
        x_init_s_svgd = x_init.clone()
        x_s_svgd = x_init_s_svgd.clone().requires_grad_()
        s_svgd = SlicedSVGD(distribution, device=device)

        start = time.time()
        x_s_svgd, metric_s_svgd = s_svgd.fit(
            samples=x_s_svgd, 
            n_epoch=epochs, 
            lr=args.lr_g,
            eps=lr,
            save_every=save_every
        )
    
    theta = x_s_svgd
    cov_ssvgd = torch.cov(theta.T)
    print(f"norm_mse of mysvgd : {torch.linalg.norm(cov - cov_ssvgd)}")
    list_norm[s,3] = torch.linalg.norm(cov - cov_ssvgd)
    print(f"mmd of mysvgd : {comm_func_eval(theta, x_target)['var_dis']}")
    list_gr_var[s,3] = comm_func_eval(theta, x_target)['var_dis']
    print(f"trace of mysvgd: {torch.trace(cov - cov_ssvgd)}")
    list_tr[s,3] = torch.trace(cov - cov_ssvgd)
    (evals, evecs) = torch.linalg.eig(cov - cov_ssvgd)
    print(f"eig of mysvgd is : {evals[0]}")
    list_eig[s,3] = evals[0]
    energy  = energy_dis(theta, x_target, dim)
    print(f"energy of mysvgd is : {energy}")
    list_energy[s, 3] = energy
torch.save(list_norm, "norm_xshape_mu_vs_55.pt")
torch.save(list_gr_var, "gr_xshape_mu_vs_55.pt")
torch.save(list_tr, "tr_xshape_mu_vs_55.pt")
torch.save(list_eig, "eig_xshape_mu_vs_55.pt")
torch.save(list_energy, "energy_xshape_mu_vs_55.pt")

Running for dim: 5
#####################################################
Running SVGD >>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [03:27<00:00, 240.62it/s]
  list_eig[s,0] = evals[0]


norm_mse of svgd : 1.870155692100525
mmd of svgd : 0.3287099003791809
trace of svgd: 0.2635122537612915
eig of svgd is : (-1.3219611644744873+0j)
energy of svgd is : 0.2934601376914839
Running mysvgd >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 20000/20000 [06:14<00:00, 53.42it/s]


norm_mse of mysvgd : 2.1231632232666016
mmd of mysvgd : 0.4813431203365326
trace of mysvgd: -0.9696491360664368
eig of mysvgd is : (-1.8180630207061768+0j)
energy of mysvgd is : 0.1936267681655142
Running GSVGD with eff dim = 1
number of projections: 5


The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at /opt/conda/conda-bld/pytorch_1666643016022/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2349.)
  A, _ = torch.qr(A)
100%|██████████| 50000/50000 [21:03<00:00, 39.57it/s]


Running GSVGD with eff dim = 2
number of projections: 2


100%|██████████| 50000/50000 [21:28<00:00, 38.80it/s]


Running GSVGD with eff dim = 5
number of projections: 1


100%|██████████| 50000/50000 [21:11<00:00, 39.33it/s]


norm_mse of mysvgd : 2.038560628890991
mmd of mysvgd : 0.3934960961341858
trace of mysvgd: 0.9066106081008911
eig of mysvgd is : (-1.420514464378357+0j)
energy of mysvgd is : 0.2582384562635255
Running S-SVGD >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [10:39<00:00, 78.14it/s] 


norm_mse of mysvgd : 2.1585428714752197
mmd of mysvgd : 0.4376320540904999
trace of mysvgd: -0.9753745794296265
eig of mysvgd is : (-1.9187865257263184+0j)
energy of mysvgd is : 0.2628116730788622
Running for dim: 10
#####################################################
Running SVGD >>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [03:04<00:00, 270.70it/s]


norm_mse of svgd : 2.1009714603424072
mmd of svgd : 0.0813242569565773
trace of svgd: 0.2847434878349304
eig of svgd is : (-1.6020543575286865+0j)
energy of svgd is : 0.1311446247427375
Running mysvgd >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 20000/20000 [14:41<00:00, 22.70it/s]


norm_mse of mysvgd : 2.4550788402557373
mmd of mysvgd : 0.06861469894647598
trace of mysvgd: -0.3879725933074951
eig of mysvgd is : (-1.6887787580490112+0j)
energy of mysvgd is : 0.14263796350105093
Running GSVGD with eff dim = 1
number of projections: 10


100%|██████████| 50000/50000 [20:18<00:00, 41.03it/s]


Running GSVGD with eff dim = 2
number of projections: 5


100%|██████████| 50000/50000 [21:03<00:00, 39.57it/s]


Running GSVGD with eff dim = 5
number of projections: 2


100%|██████████| 50000/50000 [20:04<00:00, 41.51it/s]


norm_mse of mysvgd : 2.3509292602539062
mmd of mysvgd : 0.13764429092407227
trace of mysvgd: 1.9698641300201416
eig of mysvgd is : (-1.5949381589889526+0j)
energy of mysvgd is : 0.15574531857651536
Running S-SVGD >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [09:57<00:00, 83.74it/s] 


norm_mse of mysvgd : 2.495126485824585
mmd of mysvgd : 0.21320366859436035
trace of mysvgd: 1.9009621143341064
eig of mysvgd is : (1.903307557106018+0j)
energy of mysvgd is : 0.13921389425911757
Running for dim: 15
#####################################################
Running SVGD >>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [02:43<00:00, 305.34it/s]


norm_mse of svgd : 3.2651901245117188
mmd of svgd : 0.26048585772514343
trace of svgd: 7.332609176635742
eig of svgd is : (-1.0059247016906738+0j)
energy of svgd is : 0.2336186809447471
Running mysvgd >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 20000/20000 [17:55<00:00, 18.60it/s]


norm_mse of mysvgd : 3.5922436714172363
mmd of mysvgd : 0.01967620477080345
trace of mysvgd: 0.12354147434234619
eig of mysvgd is : (2.35280179977417+0j)
energy of mysvgd is : 0.136101893393052
Running GSVGD with eff dim = 1
number of projections: 15


100%|██████████| 50000/50000 [19:40<00:00, 42.35it/s]


Running GSVGD with eff dim = 2
number of projections: 7


100%|██████████| 50000/50000 [20:51<00:00, 39.97it/s]


Running GSVGD with eff dim = 5
number of projections: 3


100%|██████████| 50000/50000 [20:36<00:00, 40.43it/s]


norm_mse of mysvgd : 3.1827056407928467
mmd of mysvgd : 0.21694344282150269
trace of mysvgd: 6.611688613891602
eig of mysvgd is : (2.123818874359131+0j)
energy of mysvgd is : 0.21184762329377574
Running S-SVGD >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [10:08<00:00, 82.21it/s] 


norm_mse of mysvgd : 4.261358261108398
mmd of mysvgd : 0.11094582825899124
trace of mysvgd: 4.413021087646484
eig of mysvgd is : (-2.7631278038024902+0j)
energy of mysvgd is : 0.1819295522027851
Running for dim: 20
#####################################################
Running SVGD >>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [02:59<00:00, 278.34it/s]


norm_mse of svgd : 4.492527484893799
mmd of svgd : 0.21782703697681427
trace of svgd: 5.992672920227051
eig of svgd is : (3.2775585651397705+0j)
energy of svgd is : 0.13362173408843583
Running mysvgd >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 20000/20000 [25:19<00:00, 13.16it/s]


norm_mse of mysvgd : 4.689162254333496
mmd of mysvgd : 0.1872454285621643
trace of mysvgd: 1.2345104217529297
eig of mysvgd is : (3.3768811225891113+0j)
energy of mysvgd is : 0.09302457372653616
Running GSVGD with eff dim = 1
number of projections: 20


100%|██████████| 50000/50000 [20:25<00:00, 40.79it/s]


Running GSVGD with eff dim = 2
number of projections: 10


100%|██████████| 50000/50000 [22:28<00:00, 37.09it/s]


Running GSVGD with eff dim = 5
number of projections: 4


100%|██████████| 50000/50000 [22:16<00:00, 37.40it/s]


norm_mse of mysvgd : 4.708903789520264
mmd of mysvgd : 0.3124522864818573
trace of mysvgd: 8.584290504455566
eig of mysvgd is : (3.3106255531311035+0j)
energy of mysvgd is : 0.164042540213243
Running S-SVGD >>>>>>>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [10:26<00:00, 79.84it/s] 


norm_mse of mysvgd : 5.3883514404296875
mmd of mysvgd : 0.23265047371387482
trace of mysvgd: 5.934279441833496
eig of mysvgd is : (3.611668586730957+0j)
energy of mysvgd is : 0.1382329593718162
Running for dim: 25
#####################################################
Running SVGD >>>>>>>>>>>>>>>>>>


100%|██████████| 50000/50000 [02:53<00:00, 288.26it/s]


norm_mse of svgd : 4.1528120040893555
mmd of svgd : 0.18036270141601562
trace of svgd: 9.750812530517578
eig of svgd is : (2.2867202758789062+0j)
energy of svgd is : 0.17634321857664997
Running mysvgd >>>>>>>>>>>>>>>>>>>>>>>


 84%|████████▍ | 16751/20000 [30:39<05:56,  9.11it/s] 


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 3; 23.69 GiB total capacity; 7.65 GiB already allocated; 2.94 MiB free; 7.65 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

: 