In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [3]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

#### Some helper

In [4]:
from bgflow import (
    DenseNet,
    
    WrapCDFTransformerWithInverse,
    GridInversion,
    AffineSigmoidComponentInitGrid,
    
    MixtureCDFTransformer,
    AffineSigmoidComponents,    
    
    SmoothRamp,
    BisectionRootFinder,
    
    ConstrainedBoundaryCDFTransformer,
    
    SequentialFlow,
    CouplingFlow,
    SplitFlow,
    InverseFlow,
    SwapFlow
)
    
from toy_data.toy_data import dataset_names, inf_train_gen

from matplotlib.colors import LogNorm


def plot_transformer_density(transformer, label, grid_size=100):
    with torch.enable_grad():
        ys = torch.linspace(0, 1, 100).view(-1, 1).requires_grad_(True)
        xs = torch.zeros_like(ys)
        ys_, dlogp_ = transformer(xs, ys)
        computed_density = dlogp_.exp()
        ad_density = torch.autograd.grad(ys_.sum(), ys, create_graph=True)[0]
        
        plt.figure(figsize=(10, 3))
        plt.title(label)
        plt.plot(ys.detach(), ys_.detach(), label="cdf")
        plt.plot(ys.detach(), computed_density.detach(), alpha=0.5, label="computed pdf")
        plt.plot(ys.detach(), ad_density.detach(), linestyle="--", label="AD pdf")
        plt.legend()
        

def make_transformer(
    d_in,
    d_out,
    d_hidden,
    n_components,
    periodic=False,
    zero_boundary_left=False,
    zero_boundary_right=False,
    activation=torch.nn.SiLU(),
    smoothness_type="type1",
    init_weight=1.
):
    return WrapCDFTransformerWithInverse(
        transformer=(
            t:=MixtureCDFTransformer(
                compute_weights=DenseNet([d_in, d_hidden, d_hidden, d_out * n_components], activation),
                compute_components=AffineSigmoidComponents(
                    conditional_ramp=SmoothRamp(
                        compute_alpha=DenseNet([d_in, d_hidden, d_hidden, d_out * n_components], activation),
                        unimodal=True,
                        ramp_type=smoothness_type
                    ),
                    compute_params=DenseNet([d_in, d_hidden, d_hidden, d_out * (3 * n_components)], activation, weight_scale=init_weight),
                    min_density=torch.tensor(1e-4),
                    periodic=periodic,
                    zero_boundary_left=zero_boundary_left,
                    zero_boundary_right=zero_boundary_right

                ),
            )
        ),
        oracle=GridInversion(
            transformer=t,
            compute_init_grid=AffineSigmoidComponentInitGrid(
                t._compute_components
            ),
            verbose=True
        )
    )


def make_constrained_transformer(transformer, left_bound=None, right_bound=None,  smoothness_type="type1"):
    def compute_constraints(x):
        bounds = []
        if left_bound is not None:
            bounds.append(torch.tensor([left_bound]).log().expand_as(x))
        if right_bound is not None:
            bounds.append(torch.tensor([right_bound]).log().expand_as(x))
        return torch.stack(bounds, dim=-1)
    
    return ConstrainedBoundaryCDFTransformer(
        transformer=transformer,
        compute_constraints=compute_constraints,
        left_constraint=left_bound is not None,
        right_constraint=right_bound is not None,
        smoothness_type=smoothness_type
    )


def make_coupling_flow(transformer_factory):    
    return SequentialFlow([
        SplitFlow(1),
        CouplingFlow(transformer_factory()),
        SwapFlow(),
        CouplingFlow(transformer_factory()),
        SwapFlow(),
        CouplingFlow(transformer_factory()),
        SwapFlow(),
        CouplingFlow(transformer_factory()),
        SwapFlow(),
        InverseFlow(SplitFlow(1))
    ])

def train(flow, dataset="pinwheel", train_with_inverse=False, batch_size=1_000, learning_rate=1e-3, n_iters=1_000, print_interval=100):
    optim = torch.optim.Adam(flow.parameters(), lr=learning_rate)
    for it in range(n_iters):
        x = (inf_train_gen(dataset, batch_size=batch_size) + 4) / 8
        x = torch.FloatTensor(x) % 1    
        y, dlogp = flow(x, inverse=train_with_inverse)    
        nll = -dlogp.mean()    
        optim.zero_grad()
        nll.backward()
        optim.step()
        if not it % print_interval:            
            print(f"it: {it}/{n_iters}, nll: {nll.item():.4}", end="\r")
            

            
def plot_evaluation(flow, dataset="pinwheel", train_with_inverse=False, n_samples=100_000, norm="log"):
    
    print("Computing energy plot...")
    with torch.no_grad():
        xs = torch.meshgrid(
            torch.linspace(0, 1, 100),
            torch.linspace(0, 1, 100)
        )
        xs = torch.stack(xs, dim=-1).view(-1, 2)
        ys, dlogp = flow(xs, inverse=train_with_inverse)
        u = -dlogp
        u = u.view(100, 100)
        
        
    if norm == "log":
        norm = LogNorm()
    else:
        norm = None
    plt.figure(figsize=(8, 8))
    plt.title("Energy")
    plt.imshow((-u).exp().detach(), norm=norm)
    
    print("Sampling ground truth...")
    if norm == "log":
        norm = LogNorm()
    else:
        norm = None
    x = (inf_train_gen(dataset, batch_size=n_samples) + 4) / 8
    x = torch.FloatTensor(x) % 1
    plt.figure(figsize=(8, 8))
    plt.title("Ground truth samples")
    plt.hist2d(*x.detach().numpy().T, bins=100, density=True, norm=norm, range=((0, 1), (0, 1)));
    
    print("Sampling model...")
    if norm == "log":
        norm = LogNorm()
    else:
        norm = None
    with torch.no_grad():
        z = torch.rand(n_samples, 2)
        x, _ = flow(z, inverse=not train_with_inverse)
    plt.figure(figsize=(8, 8))
    plt.title("Flow samples")
    plt.hist2d(*x.detach().numpy().T, bins=100, density=True, norm=norm, range=((0, 1), (0, 1)));

## How to create compact transformers

In [5]:
plot_transformer_density(
    make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=False, zero_boundary_left=False, zero_boundary_right=False),
    label="no constraints, non-periodic"
)

plot_transformer_density(
    make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=True, zero_boundary_left=False, zero_boundary_right=False),
    label="no constraints, periodic"
)

NameError: name 'NewtonGridInversion' is not defined

## Zero density constraint at boundary

In [None]:
plot_transformer_density(
    make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=False, zero_boundary_left=True, zero_boundary_right=False),
    label="left boundary is zero"
)

plot_transformer_density(
    make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=False, zero_boundary_left=False, zero_boundary_right=True),
    label="right boundary is zero"
)

plot_transformer_density(
    make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=False, zero_boundary_left=True, zero_boundary_right=True),
    label="both boundaries are zero"
)

## Value constraint at boundary

In [None]:
plot_transformer_density(
    make_constrained_transformer(
        make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=False, zero_boundary_left=True, zero_boundary_right=False),
        left_bound=5
    ),
    label="left boundary = 5"
)

In [None]:
plot_transformer_density(
    make_constrained_transformer(
        make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=False, zero_boundary_left=False, zero_boundary_right=True),
        right_bound=5
    ),
    label="right boundary = 5"
)

In [None]:
plot_transformer_density(
    make_constrained_transformer(
        make_transformer(d_in=1, d_out=1, d_hidden=40, n_components=4, periodic=False, zero_boundary_left=True, zero_boundary_right=True, smoothness_type="type1"),
        left_bound=3,
        right_bound=5,
        smoothness_type="type1"
    ),
    label="left boundary =3, right boundary = 5"
)

## Density Estimation Example

In [None]:
from functools import partial

flow = make_coupling_flow(partial(make_transformer, d_in=1, d_out=1, d_hidden=200, n_components=20, periodic=False))

In [None]:
train(flow)

In [None]:
plot_evaluation(flow)