Skip to content

gmgeorg/torchlambertw

Repository files navigation

torchlambertw: Lambert W function and Lambert W x F distributions in pytorch

Python PyTorch PRs Welcome MIT license Github All Releases

IMPORTANT: This is the very first prototype for an implementation of the Lambert W function and Lambert W x F distributions in torch. For now this is a prototype serving as reference for discussion in pytorch/pytorch#108948. Use this only for prototyping/R&D (see also LICENSE).

See https://github.com/gmgeorg/torchlambertw/issues for remaining issues/TODOs.


Overview

This library is a native implementation in pytorch of

  • the Lambert W function (special.lambertw)

  • Lambert W x F distributions (torch.distributions)

While this library is for now standalone, the goal is to get both the mathematical function as well as the distributions into torch core package.

See also pytorch/pytorch#108948.

IMPORTANT: See also the accompanying pylambertw module which uses torchlambertw under the hood to train distribution parameters and can be used to Gaussianize skewed, heavy-tailed data.

The torchlambertw module here is solely focused on providing the building blocks for Lambert W functions and Lambert W x F distributions. If you are interested in using Transformations and estimating parameters of these distributions, take a look at the pylambertw instead.

Installation

It can be installed directly from GitHub using:

pip install git+https://github.com/gmgeorg/torchlambertw.git

Lambert W function (math)

Implementation of the Lambert W function (special function) in torch:

import torchlambertw as tw
import numpy as np
tw.special.lambertw(torch.tensor([-1., 0., 1., -np.exp(-1)]))

output:

tensor([nan,  0.0000,  0.5671, -1.0000], dtype=torch.float64)

As a more interesting example you can use this implementation to replicate the figure on the Lambert W Function Wikipedia page:

import numpy as np
import matplotlib.pyplot as plt
from torchlambertw import special

def plot_lambertW(range_start, range_end, num_points=2000):
    x_values = np.linspace(range_start, range_end, num_points)
    x_values_torch = torch.tensor(x_values)
    principal_branch_values = special.lambertw(x_values_torch, k=0).numpy()
    non_principal_branch_values = special.lambertw(x_values_torch, k=-1).numpy()

    plt.figure(figsize=(8, 5))
    plt.plot(x_values, principal_branch_values, label="Principal Branch", color='blue')
    plt.plot(x_values, non_principal_branch_values, label="Non-Principal Branch", color='red')

    plt.title("Lambert W Function")
    plt.xlabel("x")
    plt.ylabel("W(x)")
    plt.xlim(range_start, range_end)
    plt.ylim(-4, 2)  # same range as wiki figure
    plt.axhline(0, color='black', linestyle='--', linewidth=0.5)
    plt.axvline(0, color='black', linestyle='--', linewidth=0.5)
    plt.legend()

    plt.grid(True)
    plt.show()

# Example usage:
plot_lambertW(-1, 6)

Lambert W Function

Lambert W x F distributions

For the original papers see Goerg 2011 & 2015. If you want to jump into applications and examples I suggest looking at the LambertW R package for detailed references and links to many external examples on Stackoverflow / cross-validated and other external blogs.

Important: The torch.distributions framework allows you to easily build any Lambert W x F distribution by just using the skewed & heavy tail Lambert W transform here implemented here and pass whatever base_distribution -- that's F -- makes sense to you. Voila! You have just built a Lambert W x F distribution.

See demo notebook for details.

In a nutshell

Lambert W x F distributions are a generalized family of distributions, that take an "input" X ~ F and transform it to a skewed and/or heavy-tailed output, Y ~ Lambert W x F, via a particularly parameterized transformation. See Goerg (2011, 2015) for details.

Lambert W Function

For parameter values of 0, the new variable collapses to X, which means that Lambert W x F distributions always contain the original base distribution F as a special case. Ie it does not hurt to impose a Lambert W x F distribution on your data; worst case, parameter estimates are 0 and you get F back; best case: you properly account for skewness & heavy-tails in your data and can even remove it (by transforming data back to having X ~ F). The such obtained random variable / data / distribution is then a Lambert W x F distribution.

The convenient part about this is that when working with data y1, ..., yn, you can estimate the transformation from the data and transform it back into the (unobserved) x1, ..., xn. This is particularly useful when X ~ Normal(loc, scale), as then you can "Gaussianize" your data.

Heavy-tail Lambert W x F distributions

Here is an illustration of a heavy-tail Lambert W x Gaussian distribution, which takes a Gaussian input and turns it into something heavy-tailed. If tailweight = 0 then its just a Gaussian again.

from torchlambertw import distributions as tlwd

# Implements a Lambert W x Normal distribution with (loc=1, scale=3, tailweight=0.75)
m = tlwd.TailLambertWNormal(loc=1.0, scale=3.0, tailweight=0.75)
m.sample((2,))
tensor([[ 0.0159], [-0.9322]])

This distribution is quite heavy-tailed with moments existing only up to 1 / tailweight = 1.33, ie this random variable / distribution has infinite (population) variance.

m.tailweight, m.support, m.mean, m.variance
(tensor([0.7500]), Real(), tensor([1.]), tensor([inf]))

Let's draw a random sample from distribution and plot density / ecdfplot.

torch.manual_seed(0)
# Use a less heavy-tailed distribution with a tail parameter of 0.25 (ie moments < 1/0.25 = 4 exist).
m = tlwd.TailLambertWNormal(loc=1.0, scale=3.0, tailweight=0.25)
y = m.sample((1000,)).numpy().ravel()

import seaborn as sns
import statsmodels.api as sm

sns.displot(y, kde=True)
plt.show()
sm.qqplot(y, line='45', fit=True)
plt.grid()
plt.show()

Lambert W x Gaussian histogram and KDE

Lambert W x Gaussian qqnorm plot

Back-transformation

The parameters (loc, scale, tailweight) can be estimated from the data using the accompanying pylambertw module (see also LambertW R package).

Let's say you have the estimated parameters; then you can obtain the unobserved, Gaussian data using:

torch.manual_seed(0)

m = tlwd.LambertWNormal(loc=1.0, scale=3.0, tailweight=0.25)

y = m.sample((1000,)).numpy().ravel()
x = m.transforms[0]._inverse(torch.tensor(y)).numpy().ravel()
sns.displot(x, kde=True)
plt.show()
sm.qqplot(x, line='45', fit=True)
plt.grid()
plt.show()

Lambert W x Gaussian histogram and KDE

Lambert W x Gaussian qqnorm plot

Skewed Lambert W x F distributions

For examples of skewed Lambert W x F distributions, for F = Normal, Exponential, or Gamma see demo notebook.

Implementation details

This implementation closely follows the TensorFlow Probability version in tfp.special.lambertw.

Related Implementations

See also here and here) for minimum example pytorch implementations [not optimized for fast iteration though and good starting points.]

References

License

This project is licensed under the terms of the MIT license.