In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import matplotlib as mpl
params = {
    'axes.labelsize': 8,
    'font.size': 8,
    'legend.fontsize': 10,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'text.usetex': False,
    'figure.figsize': [4.5, 4.5],    
    'font.family': 'serif',
    'text.usetex': True,
    "pgf.texsystem": "pdflatex",
    'pgf.rcfonts': True,
}
mpl.rcParams['axes.unicode_minus'] = False
mpl.rcParams.update(params)
mpl.use("pgf")

import seaborn as sns
sns.set_style("darkgrid")

In [4]:
import pandas as pd
import math
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

cmap = 'YlGnBu'


def plot_density(d, ax=None, norm=None, bounds=[-2, 5, -2, 2], exp=True):
    xmin, xmax, ymin, ymax = bounds
    xx, yy = torch.meshgrid(
        torch.linspace(xmin, xmax, 100),
        torch.linspace(ymin, ymax, 100),
    )
    f = d.log_prob(torch.stack((xx, yy), dim=-1).reshape(-1, 2)).reshape((100,100)).detach().numpy()
    if exp:
        f = np.exp(f)
    if ax is None:
        fig = plt.figure()
        ax = fig.gca()
    ax.grid(False)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    im = ax.imshow(np.rot90(f), cmap=cmap, norm=norm, extent=[xmin, xmax, ymin, ymax], aspect='auto')
    #ax.contour(xx, yy, f, norm=norm, colors='w', linestyles='dashed')
    return ax

In [5]:
import torch
import torch.distributions as dist

d = dist.Independent(dist.StudentT(df=torch.tensor([1, 2]), loc=0, scale=1), 1)

In [6]:
import torch.distributions.transforms


def rect_to_polar(x,y):
    r = torch.sqrt(x**2 + y**2)
    theta = torch.atan2(y, x)
    return (r,theta)

class SpinTransform(torch.distributions.transforms.Transform):
    domain = torch.distributions.constraints.real
    codomain = torch.distributions.constraints.positive
    bijective = True
    sign = +1

    def __eq__(self, other):
        return isinstance(other, SpinTransform)

    def _call(self, x):
        r, theta = rect_to_polar(x[:,0], x[:,1])
        z = torch.polar(r, theta+r)
        return torch.stack([z.real, z.imag], dim=1)

    def _inverse(self, y):
        r, theta = rect_to_polar(y[:,0], y[:,1])
        z = torch.polar(r, theta-r)
        return torch.stack([z.real, z.imag], dim=1)

    def log_abs_det_jacobian(self, x, y):
        return torch.sqrt(x[:,0]**2 + x[:,1]**2)

In [11]:
fig, ax = plt.subplots(1,3, figsize=(5.5, 2))
plot_density(d, ax=ax[0], bounds=[-10, 10, -10, 10], exp=False)
ax[0].set_title('$\mathrm{StudentT}(1) \otimes \mathrm{StudentT}(2)$')
plot_density(
    dist.TransformedDistribution(d, [SpinTransform()]),
    bounds=[-10, 10, -10, 10], ax=ax[1], exp=False)
ax[1].set_title('After spin transform')

ax[2].plot(
    dist.TransformedDistribution(d, [SpinTransform()]).log_prob(torch.stack([torch.linspace(0, 10, steps=100), torch.zeros(100)], dim=1))
)
ax[2].set_title('$\log p(x,y=0)$')

fig.tight_layout()
fig.show()

In [12]:
fig.savefig('spiral.pdf')