In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, '../discrete_mixflows/')
from discrete_mixflows import *
from gibbs import *

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 24})

In [2]:
########################
########################
# target specification #
########################
########################
np.random.seed(2023)
K1=10
prbs=np.random.rand(K1)
prbs=prbs/np.sum(prbs)
def lp(x,axis=None):
    # compute the univariate log joint and conditional target pmfs
    #
    # inputs:
    #    x    : (1,d) array with state values
    #    axis : int, full conditional to calculate; returns joint if None
    # outputs:
    #   ext_lprb : if axis is None, (d,) array with log joint; else, (d,K1) array with d conditionals 
    
    ext_lprb=np.log(np.repeat(prbs[:,np.newaxis],x.shape[1],axis=1).T)
    if axis==None: return np.squeeze(ext_lprb[np.arange(0,x.shape[1]),x])
    return ext_lprb

In [3]:
import torch
import flowtorch.bijectors as bij
import flowtorch.distributions as dist

In [None]:
import flowtorch.parameters as params
# Lazily instantiated flow plus base and target distributions
params = params.DenseAutoregressive(hidden_dims=(32,))
bijectors = bij.AffineAutoregressive(params)#params=params)
base_dist = torch.distributions.Independent(torch.distributions.RelaxedOneHotCategorical(
      torch.tensor([temperature]),
      torch.tensor(np.ones(K1)/K1)),0)
target_dist = torch.distributions.Independent(torch.distributions.RelaxedOneHotCategorical(
      torch.tensor([temperature]),
      torch.tensor(prbs)),0)

# Instantiate transformed distribution and parameters
flow = dist.Flow(base_dist, bijectors)

# Training loop
opt = torch.optim.Adam(flow.parameters(), lr=5e-3)
frame = 0
for idx in range(3001):
    opt.zero_grad()

    # Minimize KL(q || p)
    y = base_dist.sample((1000,)).to(torch.float32)
    print(flow.bijector.inverse(y))
    loss = -flow.log_prob(y).mean()

    if idx % 500 == 0:
        print('epoch', idx, 'loss', loss)
        
    loss.backward()
    opt.step()