In [None]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
import sys

from itertools import cycle
from tqdm.auto import trange
from tqdm.notebook import tqdm

sys.path.append('../')
from matplotlib import cm
import matplotlib.pyplot as plt

from utils import plot as plot_utils, s3w as s3w_utils, vmf as vmf_utils
from utils.nf import normalizing_flows
from methods import s3wd as s3w, sswd as ssw, swd as sw

from scipy.stats import gaussian_kde

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
def run_exp(X_target, X0, d_func, d_args, device, n_steps=2000, lr=0.1, batch_size=500):
    """
    Performs gradient flow on particles using a specified distance function.

    Args:
    - X_target (torch.Tensor): Target distribution tensor.
    - distance_fn (function): Distance function to be used.
    - distance_fn_args (dict): Arguments required for the distance function.
    - device (torch.device): Device to perform computations on.
    - n_steps (int): Number of gradient steps.
    - lr (float): Learning rate.
    - batch_size (int): Batch size for processing.

    Returns:
    - List of tensors representing the state of particles at each step.
    - List of loss values.
    """

    loader = DataLoader(X_target, batch_size=batch_size, shuffle=True)
    dataiter = cycle(loader)

    X0 = X0.to(device)
    X0.requires_grad_(True)

    optimizer = torch.optim.Adam([X0], lr=lr)

    L = [X0.clone()]
    L_loss = []

    pbar = trange(n_steps)

    for k in pbar:
        optimizer.zero_grad()
        X_batch = next(dataiter).type(torch.float).to(device)

        distance = d_func(X_batch, X0, **d_args)
        distance.backward()
        optimizer.step()

        X0.data /= torch.norm(X0.data, dim=1, keepdim=True)

        L_loss.append(distance.item())
        L.append(X0.clone().detach())
        pbar.set_postfix_str(f"Loss = {distance.item():.3f}")

    return L, L_loss

## Run Gradient Flows and Visualization

In [None]:
X_target = []

for mu_target in vmf_utils.fibonacci_sphere(12):
    mu_target = mu_target / np.linalg.norm(mu_target)
    kappa_target = 50
    X_target.append(torch.tensor(vmf_utils.rand_vmf(mu_target, kappa=kappa_target, N=500), dtype=torch.float))

X_target = torch.cat(X_target, dim=0)

X0_base = torch.randn((6000, 3), device=device)
X0_base = F.normalize(X0_base, p=2, dim=-1)

In [None]:
d_func = s3w.ri_s3wd 
d_args = {'p': 2, 'n_projs': 100, 'device': device, 'n_rotations': 50}
L_ri, _ = run_exp(X_target, X0_base.clone(), d_func , d_args, device, n_steps=101, lr=.05, batch_size=6000)

In [None]:
d_func = s3w.ari_s3wd 
d_args = {'p': 2, 'n_projs': 100, 'device': device, 'n_rotations': 50}
L_ari, _ = run_exp(X_target, X0_base.clone(), d_func , d_args, device, n_steps=101, lr=.05, batch_size=6000)

In [None]:
d_func = s3w.s3wd 
d_args = {'p': 2, 'n_projs': 100, 'device': device}
L_s3w, _ = run_exp(X_target, X0_base.clone(), d_func , d_args, device, n_steps=101, lr=.05, batch_size=6000)

In [None]:
from utils.plot import *

def make_plot(pts, ax, title):
    ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=3)
    n_meridians=50
    n_parallels = 100

    u, v = np.mgrid[0:2*np.pi:n_meridians*1j, 0:np.pi:n_parallels*1j]
    x, y, z = spherical_to_cartesian(np.column_stack((u.ravel(), v.ravel()))).T
    x = x.reshape(u.shape)
    y = y.reshape(u.shape)
    z = z.reshape(u.shape)
    ax.plot_surface(x, y, z, color='gray', alpha=0.05)
    ax.plot_wireframe(x, y, z, color="black", alpha=0.05, lw=1)

    plt.axis('off')
    ax.view_init(25, 25)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    ax.set_title(title, y=.85)

In [None]:
for i in tqdm(range(len(L_ari))):

    fig = plt.figure(figsize=(12, 6))

    pts = L_s3w[i].cpu().detach().numpy()
    ax = fig.add_subplot(131, projection='3d')
    make_plot(pts, ax, '$S3W$')

    pts = L_ri[i].cpu().detach().numpy()
    ax = fig.add_subplot(132, projection='3d')
    make_plot(pts, ax, '$RI$-$S3W$')

    pts = L_ari[i].cpu().detach().numpy()
    ax = fig.add_subplot(133, projection='3d')
    make_plot(pts, ax, '$ARI$-$S3W$')

    plt.subplots_adjust(wspace=-0.2)
    plt.savefig(f'gif/gif_{i}.png', bbox_inches='tight', pad_inches=0)

    plt.close(fig)

## Create GIF

In [None]:
from glob import glob
import imageio
from tqdm import tqdm
from PIL import Image
import numpy as np
import os
file_list = sorted(glob('./gif/*.png'), key=lambda x: int(x.split('_')[-1].split('.')[0]))

In [None]:
# Run cell to crop all images 
for fname in tqdm(file_list[:66]):
    img = Image.open(fname)
    w,h = img.size
    img = img.crop((25, 50, w-25, h-85))
    img.save(os.path.join('./gif_cropped', os.path.basename(fname)))

In [None]:
# Uncomment to use cropped images
# file_list = sorted(glob('./gif_cropped/*.png'), key=lambda x: int(x.split('_')[-1].split('.')[0]))

with imageio.get_writer(f"./gif/gf.gif", mode="I", duration=0.5, loop=0) as writer:
    for fname in tqdm(file_list[:66]):
        writer.append_data(imageio.imread(fname))