In this notebook I walk through the generation/inferece pipeline (spending most of the time on inference).

In [27]:
from abc import ABC, abstractmethod

import os
from tqdm import tqdm
import math 
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 

from scipy import stats

import torch
from torch.distributions import Beta
from torch.distributions.bernoulli import Bernoulli
from torch.nn.functional import log_softmax
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, TensorDataset

In [13]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7fb08d0bf350>

### Generation:

Here I am using 2/10 generation.

In [14]:
### these params control the generation scheme
rho        = 0.8   # polarization
pop_size   = 50000 # num individuals
epsilon    = 0.05  # expected prop of speech consisting of neutral words
pi         = 0.5   # pi == 0.5 => beta mixture symmetrical (choose beta1 with prob pi = 0.5)
speech_len = 15    # words per speech

In [15]:
def generate(rho=rho, N=pop_size, epsilon=epsilon, pi=pi, speech_len=speech_len):
    """
    Uses 2/10 generation scheme to generate N samples.

    Returns:
        (X, y), (a, b, rho, epsilon, u)
        X.size() == [N, 3] is a vector of word counts
        y.size() == [N] is a vector of political parties
        u.shape  == (N,) is a vector of individual stances
        rho is true polarization
        epsilon is expected prop of neutral words
        a, b are true alpha/beta for beta mixture model
    """
    start = time.time()
    print(f'Beginning Data Generation...')
    print(f'=' * 20)
    
    ### get beta mixture model params 
    sigma = 0.175 * (rho ** 2) - 0.3625 * rho + 0.1875
    a     = rho * ((rho * (1 - rho)) / sigma - 1)
    b     = (1 - rho) * ((rho * (1 - rho)) / sigma - 1)

    print(f"True Alpha: {a}")
    print(f"True Beta: {b}")
    print(f"True Epsilon: {epsilon}\n")

    mean = a / (a + b)
    var  = a * b / ((a + b)**2 * (a + b + 1))

    if abs(mean - rho) > 10e-15:
        print(f'Mean: {mean}')
        print(f"Rho: {rho}")
        raise AssertionError(f"Mean of BMM params should be rho")

    if abs(var - sigma) > 10e-15:
        print(f'Var: {var}')
        print(f'Sigma: {sigma}')
        raise AssertionError(f"Var of BMM params should be sigma")

    ### u ~ pi Beta(a, b) + (1 - pi) Beta(b, a)
    weights = [pi, 1-pi]
    mixture_samples = np.random.choice([0, 1], size=N, p=weights)

    u = 2 * np.where(mixture_samples == 1, stats.beta.rvs(a, b, size=N), stats.beta.rvs(b, a, size=N)) - 1

    if u.shape != (N,):
        raise AssertionError(f"u.shape should be (N,)")

    print(f'mixture samples: {mixture_samples[:5]}')
    print(f'u samples: {u[:5]}\n')

    ### y = 1(u >= 0)
    y = (u >= 0).astype(int)
    print(f'y samples: {y[:5]}\n')

    ### phi is a prob matrix that is a function of u, epsilon
    phi = np.array([(1 - (u+1)/2) * (1 - epsilon), (u+1)/2 * (1 - epsilon), np.repeat(epsilon, pop_size)]).T
    print(f'phi samples:\n {phi[:5, :]}\n')

    if phi.shape != (N, 3):
        raise AssertionError(f'phi.shape should be (N, V) == (N, 3)')

    if abs(sum(phi[0]) - 1) > 10e-5:
        raise AssertionError(f'rows of phi should sum to 1')
    
    X = np.array([stats.multinomial.rvs(n=speech_len, p=phi[i, :]) for i in range(N)])
    print(f'X samples:\n {X[:5, :]}\n')
    
    if X.shape != (N, 3):
        raise AssertionError(f'X.shape should be (N, V) == (N, 3)')

    if X[:5].sum() != 15 * 5:
        raise AssertionError(f'rows of phi should sum to 1')

    X = torch.from_numpy(X).to(torch.float32)
    y = torch.from_numpy(y).to(torch.float32)

    known   = (X, y)
    unknown = (a, b, rho, epsilon, u)
    print('=' * 20)
    print(f'Generation Time: {round(time.time() - start, 3)} seconds for {N} samples.')
    return known, unknown

known, unknown = generate()
X, y = known
a, b, rho, epsilon, u = unknown

Beginning Data Generation...
True Alpha: 12.673684210526261
True Beta: 3.1684210526315644
True Epsilon: 0.05

mixture samples: [0 1 0 0 0]
u samples: [-0.71319006  0.57688516 -0.46523368 -0.74177243 -0.73809268]

y samples: [0 1 0 0 0]

phi samples:
 [[0.81376528 0.13623472 0.05      ]
 [0.20097955 0.74902045 0.05      ]
 [0.695986   0.254014   0.05      ]
 [0.8273419  0.1226581  0.05      ]
 [0.82559402 0.12440598 0.05      ]]

X samples:
 [[11  3  1]
 [ 2 12  1]
 [11  3  1]
 [14  0  1]
 [14  1  0]]

Generation Time: 1.479 seconds for 50000 samples.


### Inference

The goal of inference is to recover $u$ given a sample of $(x,y)$ pairs. We assume the following data generating process:
\begin{align*}
    u &\sim 2 \cdot \left(\pi\cdot\mathrm{Beta}\left(u;\alpha,\beta\right)+(1-\pi)\cdot\mathrm{Beta}\left(u;\beta,\alpha\right)\right)-1 \\
    y &= 1(u \ge \lambda) \\
    x &=\mathrm{Softmax}(Wu), W \in \R^V.
\end{align*}

For now, we assume $\lambda = 0, V = 3, \pi = 0.5$, so our parameters are $\theta = \{\alpha, \beta, W\}$. Given we assume this full form of $u$ parameterized by $\alpha, \beta$, recovering the distribution of $u$ is equivalent to recovering $\alpha, \beta$. This also gives us the following joint:

\begin{align*}

p(u,x^{(n)},y^{(n)};\theta) &= p(u)p(y^{(n)}|u)\prod_{s=1}^{S}p(x^{(n)}_s | u) \\
&= \left(\frac{1}{4}\mathrm{Beta}\left(\frac{u+1}{2};\alpha,\beta\right)+\frac{1}{4}\mathrm{Beta}\left(\frac{u+1}{2};\beta,\alpha\right)\right) \\&\cdot 1\left(1(u\ge0)=y^{(n)}\right)\frac{S!}{x^{(n)}_1!x^{(n)}_2!x^{(n)}_3!}\prod_{s=1}^{S}\mathrm{softmax}(Wu)_{x^{(n)}_s}

\end{align*}

The basic idea is to maximize the log marginal likelihood. Let $\theta = \{\alpha, \beta, W\}$. We'd like to find 
$$
\theta^* = \argmax_{\theta} \underbrace{\log p(x^{(1:N), y^{(1:N)}; \theta})}_{\mathcal{L}(\theta)}. 
$$


The goal is to recover alpha/beta (hence u given we are assuming u ~ 2[[1/2 Beta(a, b) + 1/2 Beta(b, a)] - 1) given (x,y) pairs.

To get MLE estimates for alpha, beta, W, we will take gradient descent steps on the log marginal likelihood. Its gradient is given by the following expression:
$$
\nabla_\theta \log p(x^{(1:N)},y^{(1:N)}) = \sum_{n=1}^{N} \mathbb{E}_{p(u|x^{(n)},y^{(n)};\theta)}[\nabla_\theta \log p(u,x^{(n)},y^{(n)};\theta)].
$$
There are a number of issues here, beginning with the fact that we cannot analytically recover the posterior due to the integral being intractable. However, since $u \in [-1,1]$, we can discretize the integral, which will be simpler than something like MCMC for sampling/using variational inference to approximate the posterior. To do this, let's write the gradient as follows:
$$
\nabla_\theta \log p(x^{(1:N)},y^{(1:N)}) = \sum_{n=1}^{N} \int_{u} \frac{p(x^{(n)}, y^{(n)}, u;\ \theta)}{p(x^{(n)},y^{(n)};\  \theta)}\nabla_\theta \log p(u,x^{(n)},y^{(n)};\theta) \mathrm{d}u
$$
To discretize, we simply need to assume that $u$ takes on $G$ linearlly spaced points in $[-1,1]$, rather than is continuous in $[-1,1]$, which gives us 
$$
\nabla_\theta \log p(x^{(1:N)},y^{(1:N)}) = \sum_{n=1}^{N} \sum_{u} \frac{p(x^{(n)}, y^{(n)}, u;\ \theta)}{p(x^{(n)},y^{(n)};\  \theta)}\nabla_\theta \log p(u,x^{(n)},y^{(n)};\theta) \mathrm{d}u
$$
It's convenient to move the gradient outside the sum, since then we can efficiently compute the double sum via matrices and then just sum the matrix/call .backward(). However, we cannot simply move the gradient outside unless we .detach() the "weight" (posterior probability).


In [22]:
grid_size  = 200 ### discretize integral into grid_size linearly spaced pts
batch_size = 100 ### num samples processed simultaneously
num_epochs = 200 ### number epochs
lr         = 0.1 ### learning rate

log_alpha = torch.normal(1.5, 1, size=(1,), requires_grad=True)
log_beta  = torch.normal(0, 1, size=(1,), requires_grad=True)
W         = torch.normal(3, 2, size=(3,), requires_grad=True)

print(f'Beginning Inference')
print(f'Initial Parameters:')
print(f'Alpha: {torch.exp(log_alpha).item()}')
print(f"Beta: {torch.exp(log_beta).item()}")
print(f"W: {W.data.tolist()}")

Beginning Inference
Initial Parameters:
Alpha: 0.9299120903015137
Beta: 0.8840445876121521
W: [10.173978805541992, -0.6625802516937256, 6.1974005699157715]


In [20]:
dataset    = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size)

X_batch, y_batch = list(dataloader)[0]

if list(X_batch.size()) != [batch_size, 3]:
    raise AssertionError(f"X_ex.size() should be [batch_size, V] == [batch_size, 3], got {X_batch.size()}")

if list(y_batch.size()) != [batch_size]:
    raise AssertionError(f"y_ex.size() should be [batch_size], got {y_batch.size()}")

In [23]:
def compute_nll(x_batch, y_batch, log_alpha, log_beta, W):
    """
    Parameters:
        (x_batch, y_batch) a batch of B (x,y) samples 
        (log_alpha, log_beta, W) trainable parameters

    Returns:
        A scaler quantity s.t. calling .backward() will compute
        grad_{theta} (L(theta))
    """
    raise NotImplementedError

def train(X, y, num_epochs=num_epochs, batch_size=batch_size, grid_size=grid_size, lr=lr):
    print(f'Beginning Training...')
    start = time.time()

    dataset    = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    log_alpha = torch.normal(1.5, 1, size=(1,), requires_grad=True)
    log_beta  = torch.normal(0, 1, size=(1,), requires_grad=True)
    W         = torch.normal(3, 2, size=(3,), requires_grad=True)

    optimizer = SGD([log_alpha, log_beta, W], lr=lr)
    
    for epoch in tqdm(range(num_epochs)):
        print(f'Epoch {epoch+1} of {num_epochs}')
        ### only step after each epoch
        optimizer.zero_grad() 
        nll = 0
        for x_batch, y_batch in dataloader:
            nll += compute_nll(x_batch, y_batch, log_alpha, log_beta, W)
        
        nll = nll / len(dataloader.dataset)
        nll.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}, NLL: {round(nll.item(), 3)}, Alpha: {round(np.exp(log_alpha.item()),3)}, Beta: {round(np.exp(log_beta.item()),3)}, W: {W.tolist()}')
        print('=' * 20)

    final_alpha = np.exp(log_alpha.item())
    final_beta  = np.exp(log_beta.item())

    print(f"Training Took {round(time.time() - start, 2)} Seconds.")
    print(f"Trained Params: alpha: {final_alpha}, beta: {final_beta}, W: {W.tolist()}")

    return final_alpha, final_beta, W

Let's check how a few things work in torch.

In [32]:
"""
Checking if .log_prob() broadcasts across a tensor
"""
test = torch.tensor(
    [
    [1., 0., 1.],
    [1., 0., 1.]
    ]
)
test_dist = Bernoulli(torch.tensor([0.7]))
torch.exp(test_dist.log_prob(test))

tensor([[0.7000, 0.3000, 0.7000],
        [0.7000, 0.3000, 0.7000]])

In [49]:
"""
Given an NxM matrix and an Bx1 vector, create a tensor size
NxMxB which is the vector multiplied by each of the matrix values 
"""
test_u = torch.tensor(
    [
    [1., 2., 2.4],
    [1., 3., 1.1]
    ]
)
test_W = torch.tensor([1.,1.,2.])

test_res = torch.matmul(test_u.unsqueeze(2), test_W.unsqueeze(0).unsqueeze(0))

assert torch.all(test_res[0, 1] == torch.tensor([2., 2., 4.]))

In [70]:
test_res.size()

scaled_x_batch.unsqueeze(1) * log_softmax(softmax_input, dim=2)).sum(dim=2)

torch.Size([2, 3, 3])

In [77]:
test_res

tensor([[[1.0000, 1.0000, 2.0000],
         [2.0000, 2.0000, 4.0000],
         [2.4000, 2.4000, 4.8000]],

        [[1.0000, 1.0000, 2.0000],
         [3.0000, 3.0000, 6.0000],
         [1.1000, 1.1000, 2.2000]]])

In [79]:
ex = torch.tensor([[3., 7., 5.],
    [8., 3., 4.]])
print(log_softmax(test_res, dim=2))
ex.unsqueeze(1) * log_softmax(test_res, dim=2)

tensor([[[-1.5514, -1.5514, -0.5514],
         [-2.2395, -2.2395, -0.2395],
         [-2.5667, -2.5667, -0.1667]],

        [[-1.5514, -1.5514, -0.5514],
         [-3.0949, -3.0949, -0.0949],
         [-1.6103, -1.6103, -0.5103]]])


tensor([[[ -4.6543, -10.8601,  -2.7572],
         [ -6.7186, -15.6768,  -1.1977],
         [ -7.7002, -17.9671,  -0.8337]],

        [[-12.4116,  -4.6543,  -2.2058],
         [-24.7594,  -9.2848,  -0.3797],
         [-12.8822,  -4.8308,  -2.0411]]])

In [55]:
"""
Numerically stable factorial
Uses gamma function
"""
factorial = lambda x : torch.exp(torch.lgamma(x+1))

assert factorial(torch.tensor(1.)) == torch.tensor(1.)
assert factorial(torch.tensor(2.)) == torch.tensor(2.)
assert factorial(torch.tensor(3.)) == torch.tensor(6.)
print(f"2.9999! = {factorial(torch.tensor(2.9999))}")

2.9999! = 5.999247074127197


In [64]:
test_x_batch

tensor([[11.,  3.,  1.],
        [ 2., 12.,  1.],
        [11.,  3.,  1.],
        [14.,  0.,  1.],
        [14.,  1.,  0.]])

In [66]:
test_x_batch = X_batch[:5]
print(test_x_batch)

x_factorial = factorial(test_x_batch)
print(torch.log(x_factorial))
print(torch.log(x_factorial.sum(dim=1, keepdim=True)))
"""
x_batch * factorial(x_batch) 
    x_factorial = factorial(x_batch)

    sum_x_factorial = x_factorial.sum(dim=1, keepdim=True)

    scaling_factor = x_factorial / sum_x_factorial

    scaled_x_batch = x_batch * scaling_factor
"""

tensor([[11.,  3.,  1.],
        [ 2., 12.,  1.],
        [11.,  3.,  1.],
        [14.,  0.,  1.],
        [14.,  1.,  0.]])
tensor([[17.5023,  1.7918,  0.0000],
        [ 0.6931, 19.9872,  0.0000],
        [17.5023,  1.7918,  0.0000],
        [25.1912,  0.0000,  0.0000],
        [25.1912,  0.0000,  0.0000]])
tensor([[17.5023],
        [19.9872],
        [17.5023],
        [25.1912],
        [25.1912]])


'\nx_batch * factorial(x_batch) \n    x_factorial = factorial(x_batch)\n\n    sum_x_factorial = x_factorial.sum(dim=1, keepdim=True)\n\n    scaling_factor = x_factorial / sum_x_factorial\n\n    scaled_x_batch = x_batch * scaling_factor\n'

In [90]:
jamie = torch.tensor([
    [1.1, 2., 3.4],
    [4.2, 5.1, 6.2]
])
jamie - torch.log(factorial(X_batch[:2])).sum(dim=1, keepdim=True)

tensor([[-18.1941, -17.2941, -15.8941],
        [-16.4804, -15.5804, -14.4804]])

In [95]:
def compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W, constants=True):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
        u_mat   (size [batch_size, N])
        (log_alpha, log_beta, W) trainable params 
        constants (bool) true iff include constants
    
    constants are necessary for computing joint probabilities exactly but not for gradient expressions
    
    Returns:
        torch.tensor (size [batch_size, N]), where the (i,j)th element 
        is the log joint probability (log p(u, x, y; theta)). Here u is the (i,j)th
        element of u_mat, x is the ith row of x_batch, y is the ith element of y_batch

    """
    ### log prior: log p(u) = log(1/4 Beta((u+1)/2; a, b) + 1/4 Beta((u+1)/2; b, a))
    alpha = torch.exp(log_alpha)
    beta  = torch.exp(log_beta)
        
    beta_dist_ab = Beta(alpha, beta)
    beta_dist_ba = Beta(beta, alpha)

    log_beta_prob_ab = beta_dist_ab.log_prob((u_mat + 1) / 2)
    log_beta_prob_ba = beta_dist_ba.log_prob((u_mat + 1) / 2)

    prior = torch.log(torch.exp(log_beta_prob_ab) + torch.exp(log_beta_prob_ba)) + torch.log(torch.tensor(0.25))

    assert prior.size() == u_mat.size()
  
    ### log likelihood: log p(x, y | u) = dot(x, log Softmax(Wu)) (assuming y is compatible)
    W_expanded = W.unsqueeze(0).unsqueeze(0)
    assert list(W_expanded.size()) == [1,1,3]
    u_expanded = u_mat.unsqueeze(2)

    softmax_input = torch.matmul(u_expanded, W_expanded)  ### size BxNx3
    likelihood = (x_batch.unsqueeze(1) * log_softmax(softmax_input, dim=2)).sum(dim=2) ### size BxN
    
    if constants:
        likelihood += torch.log(factorial(torch.tensor(speech_len)))  ### log(S!)
        likelihood -= torch.log(factorial(x_batch)).sum(dim=1, keepdim=True) # - log(x_1!) -log(x_2!) - log(x_3!)

    # x_factorial = factorial(x_batch)

    # sum_x_factorial = x_factorial.sum(dim=1, keepdim=True)

    # scaling_factor = x_factorial / sum_x_factorial

    # scaled_x_batch = x_batch * scaling_factor

    # x_batch_log_prob = (scaled_x_batch.unsqueeze(1) * log_softmax(softmax_input, dim=2)).sum(dim=2)
   
    #assert x_batch_log_prob.size() == beta_log_prob.size()
    return prior + likelihood


def compute_nll(x_batch, y_batch, log_alpha, log_beta, W):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
    
    Returns:
        torch.tensor (size [batch_size, approx_size]), where the ith row are 
        samples from the approximate posterior p(u | x^(i), y^(i); theta).
        Here, x^(i), y^(i) are the ith row of x_batch/y_batch. 
    """
   # Initialize u_mat based on conditions
    u_mat = torch.empty(batch_size, grid_size, device=x_batch.device)
    u_mat[y_batch == 1] = torch.linspace(1/(grid_size+1), 1-1/(grid_size+1), grid_size).repeat((y_batch == 1).sum(), 1)
    u_mat[y_batch == 0] = torch.linspace(-1+1/(grid_size+1), -1/(grid_size+1), grid_size).repeat((y_batch == 0).sum(), 1)

    assert list(u_mat.size()) == [x_batch.size()[0], grid_size]

    # Compute log_joint with u_mat, x_batch, and y_batch
    log_joint = compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W)

    # Compute the posterior probabilities but detach to prevent gradients through this path
    max_log_prob = torch.max(log_joint, dim=1, keepdim=True)[0]
    joint_probs = torch.exp(log_joint - max_log_prob)
    posterior = joint_probs / joint_probs.sum(dim=1, keepdim=True)  # Do not detach yet

    # Compute the weighted negative log likelihood (NLL)
    # Detach posterior here to use it as a constant weight for log_joint contributions
    weighted_log_joint = posterior.detach() * log_joint
    nll = -weighted_log_joint.sum()

    return nll

In [96]:
train(X, y)

Beginning Training...


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 of 200


  0%|          | 1/200 [00:07<25:21,  7.64s/it]

Epoch 1, NLL: 21.233, Alpha: 8.47, Beta: 0.222, W: [4.713240146636963, 3.1470117568969727, 1.3748860359191895]
Epoch 2 of 200


  1%|          | 2/200 [00:11<17:23,  5.27s/it]

Epoch 2, NLL: 15.56, Alpha: 5.744, Beta: 0.251, W: [4.536328315734863, 3.3263635635375977, 1.3724459409713745]
Epoch 3 of 200


  2%|▏         | 3/200 [00:15<15:06,  4.60s/it]

Epoch 3, NLL: 13.535, Alpha: 4.488, Beta: 0.282, W: [4.347681045532227, 3.517460823059082, 1.3699960708618164]
Epoch 4 of 200


  2%|▏         | 4/200 [00:21<16:52,  5.17s/it]

Epoch 4, NLL: 12.083, Alpha: 3.797, Beta: 0.313, W: [4.127195358276367, 3.7357943058013916, 1.372148036956787]
Epoch 5 of 200


  2%|▎         | 5/200 [00:27<18:06,  5.57s/it]

Epoch 5, NLL: 10.612, Alpha: 3.374, Beta: 0.341, W: [3.895415782928467, 3.9576807022094727, 1.382041335105896]
Epoch 6 of 200


  3%|▎         | 6/200 [00:33<18:19,  5.67s/it]

Epoch 6, NLL: 9.35, Alpha: 3.082, Beta: 0.367, W: [3.6958789825439453, 4.140596866607666, 1.3986619710922241]
Epoch 7 of 200


  4%|▎         | 7/200 [00:37<17:05,  5.31s/it]

Epoch 7, NLL: 8.571, Alpha: 2.86, Beta: 0.392, W: [3.5391464233398438, 4.276275157928467, 1.4197161197662354]
Epoch 8 of 200


  4%|▍         | 8/200 [00:42<16:08,  5.04s/it]

Epoch 8, NLL: 8.145, Alpha: 2.683, Beta: 0.417, W: [3.415137529373169, 4.376164436340332, 1.4438358545303345]
Epoch 9 of 200


  4%|▍         | 9/200 [00:47<16:01,  5.04s/it]

Epoch 9, NLL: 7.904, Alpha: 2.538, Beta: 0.441, W: [3.313215732574463, 4.451636791229248, 1.470285415649414]
Epoch 10 of 200


  5%|▌         | 10/200 [00:51<15:28,  4.89s/it]

Epoch 10, NLL: 7.758, Alpha: 2.418, Beta: 0.465, W: [3.2259585857391357, 4.510488986968994, 1.4986904859542847]
Epoch 11 of 200


  6%|▌         | 11/200 [00:56<15:01,  4.77s/it]

Epoch 11, NLL: 7.66, Alpha: 2.316, Beta: 0.488, W: [3.148481845855713, 4.557747840881348, 1.528908133506775]
Epoch 12 of 200


  6%|▌         | 12/200 [01:01<15:09,  4.84s/it]

Epoch 12, NLL: 7.591, Alpha: 2.229, Beta: 0.511, W: [3.077500104904175, 4.5966796875, 1.5609582662582397]
Epoch 13 of 200


  6%|▋         | 13/200 [01:04<13:39,  4.38s/it]

Epoch 13, NLL: 7.539, Alpha: 2.155, Beta: 0.534, W: [3.0106887817382812, 4.629456996917725, 1.5949922800064087]
Epoch 14 of 200


  7%|▋         | 14/200 [01:07<12:22,  3.99s/it]

Epoch 14, NLL: 7.497, Alpha: 2.091, Beta: 0.556, W: [2.9462907314300537, 4.657562255859375, 1.6312850713729858]
Epoch 15 of 200


  8%|▊         | 15/200 [01:12<13:05,  4.25s/it]

Epoch 15, NLL: 7.461, Alpha: 2.036, Beta: 0.578, W: [2.8828673362731934, 4.68202543258667, 1.6702452898025513]
Epoch 16 of 200


  8%|▊         | 16/200 [01:17<13:47,  4.50s/it]

Epoch 16, NLL: 7.429, Alpha: 1.989, Beta: 0.599, W: [2.8191304206848145, 4.703566074371338, 1.7124414443969727]
Epoch 17 of 200


  8%|▊         | 17/200 [01:23<14:46,  4.85s/it]

Epoch 17, NLL: 7.397, Alpha: 1.948, Beta: 0.619, W: [2.7538187503814697, 4.722676753997803, 1.7586421966552734]
Epoch 18 of 200


  9%|▉         | 18/200 [01:28<15:16,  5.03s/it]

Epoch 18, NLL: 7.362, Alpha: 1.913, Beta: 0.639, W: [2.685606002807617, 4.739666938781738, 1.8098647594451904]
Epoch 19 of 200


 10%|▉         | 19/200 [01:33<14:57,  4.96s/it]

Epoch 19, NLL: 7.321, Alpha: 1.884, Beta: 0.657, W: [2.613051652908325, 4.754677772521973, 1.867408275604248]
Epoch 20 of 200


 10%|█         | 20/200 [01:40<16:57,  5.65s/it]

Epoch 20, NLL: 7.268, Alpha: 1.86, Beta: 0.674, W: [2.5346548557281494, 4.767674922943115, 1.9328080415725708]
Epoch 21 of 200


 10%|█         | 21/200 [01:45<15:53,  5.33s/it]

Epoch 21, NLL: 7.194, Alpha: 1.841, Beta: 0.688, W: [2.4491279125213623, 4.778433322906494, 2.0075764656066895]
Epoch 22 of 200


 11%|█         | 22/200 [01:48<13:49,  4.66s/it]

Epoch 22, NLL: 7.09, Alpha: 1.828, Beta: 0.7, W: [2.356024742126465, 4.786552906036377, 2.092560052871704]
Epoch 23 of 200


 12%|█▏        | 23/200 [01:52<13:18,  4.51s/it]

Epoch 23, NLL: 6.945, Alpha: 1.82, Beta: 0.709, W: [2.256631374359131, 4.791574954986572, 2.186931610107422]
Epoch 24 of 200


 12%|█▏        | 24/200 [01:57<13:06,  4.47s/it]

Epoch 24, NLL: 6.761, Alpha: 1.818, Beta: 0.714, W: [2.1544930934906006, 4.793245315551758, 2.2873995304107666]
Epoch 25 of 200


 12%|█▎        | 25/200 [02:01<13:14,  4.54s/it]

Epoch 25, NLL: 6.552, Alpha: 1.82, Beta: 0.715, W: [2.0547046661376953, 4.791801929473877, 2.388631582260132]
Epoch 26 of 200


 13%|█▎        | 26/200 [02:09<16:00,  5.52s/it]

Epoch 26, NLL: 6.343, Alpha: 1.826, Beta: 0.714, W: [1.962065577507019, 4.788021087646484, 2.4850516319274902]
Epoch 27 of 200


 14%|█▎        | 27/200 [02:16<17:11,  5.96s/it]

Epoch 27, NLL: 6.158, Alpha: 1.835, Beta: 0.71, W: [1.8795262575149536, 4.782947540283203, 2.572664260864258]
Epoch 28 of 200


 14%|█▍        | 28/200 [02:23<17:48,  6.21s/it]

Epoch 28, NLL: 6.008, Alpha: 1.845, Beta: 0.705, W: [1.807943344116211, 4.777551174163818, 2.649643659591675]
Epoch 29 of 200


 14%|█▍        | 29/200 [02:30<18:43,  6.57s/it]

Epoch 29, NLL: 5.893, Alpha: 1.856, Beta: 0.7, W: [1.746748685836792, 4.77253532409668, 2.7158539295196533]
Epoch 30 of 200


 15%|█▌        | 30/200 [02:33<15:39,  5.53s/it]

Epoch 30, NLL: 5.809, Alpha: 1.867, Beta: 0.696, W: [1.6947191953659058, 4.768319129943848, 2.772099494934082]
Epoch 31 of 200


 16%|█▌        | 31/200 [02:37<13:43,  4.88s/it]

Epoch 31, NLL: 5.748, Alpha: 1.879, Beta: 0.691, W: [1.6504793167114258, 4.765097141265869, 2.81956148147583]
Epoch 32 of 200


 16%|█▌        | 32/200 [02:42<13:54,  4.96s/it]

Epoch 32, NLL: 5.704, Alpha: 1.889, Beta: 0.687, W: [1.612742304801941, 4.762913703918457, 2.8594815731048584]
Epoch 33 of 200


 16%|█▋        | 33/200 [02:48<15:08,  5.44s/it]

Epoch 33, NLL: 5.672, Alpha: 1.9, Beta: 0.684, W: [1.5803958177566528, 4.761725425720215, 2.8930163383483887]
Epoch 34 of 200


 17%|█▋        | 34/200 [02:56<16:31,  5.97s/it]

Epoch 34, NLL: 5.649, Alpha: 1.909, Beta: 0.681, W: [1.5525126457214355, 4.761440753936768, 2.9211843013763428]
Epoch 35 of 200


 18%|█▊        | 35/200 [03:03<17:11,  6.25s/it]

Epoch 35, NLL: 5.632, Alpha: 1.919, Beta: 0.679, W: [1.5283327102661133, 4.761948108673096, 2.944856882095337]
Epoch 36 of 200


 18%|█▊        | 36/200 [03:13<20:16,  7.42s/it]

Epoch 36, NLL: 5.62, Alpha: 1.927, Beta: 0.677, W: [1.5072367191314697, 4.763131618499756, 2.9647693634033203]
Epoch 37 of 200


 18%|█▊        | 37/200 [03:17<17:25,  6.41s/it]

Epoch 37, NLL: 5.611, Alpha: 1.935, Beta: 0.676, W: [1.4887202978134155, 4.764880657196045, 2.981537103652954]
Epoch 38 of 200


 19%|█▉        | 38/200 [03:23<16:53,  6.26s/it]

Epoch 38, NLL: 5.604, Alpha: 1.943, Beta: 0.675, W: [1.4723718166351318, 4.767093658447266, 2.9956727027893066]
Epoch 39 of 200


 20%|█▉        | 39/200 [03:29<16:54,  6.30s/it]

Epoch 39, NLL: 5.599, Alpha: 1.95, Beta: 0.675, W: [1.4578540325164795, 4.769680500030518, 3.007603645324707]
Epoch 40 of 200


 20%|██        | 40/200 [03:36<17:14,  6.47s/it]

Epoch 40, NLL: 5.596, Alpha: 1.957, Beta: 0.675, W: [1.4448893070220947, 4.772563457489014, 3.017685651779175]
Epoch 41 of 200


 20%|██        | 41/200 [03:44<18:25,  6.95s/it]

Epoch 41, NLL: 5.593, Alpha: 1.963, Beta: 0.675, W: [1.4332479238510132, 4.7756757736206055, 3.026214838027954]
Epoch 42 of 200


 21%|██        | 42/200 [03:48<16:09,  6.13s/it]

Epoch 42, NLL: 5.591, Alpha: 1.969, Beta: 0.675, W: [1.4227391481399536, 4.778961181640625, 3.0334384441375732]
Epoch 43 of 200


 22%|██▏       | 43/200 [03:51<13:18,  5.09s/it]

Epoch 43, NLL: 5.59, Alpha: 1.974, Beta: 0.676, W: [1.413203477859497, 4.78237247467041, 3.039562702178955]
Epoch 44 of 200


 22%|██▏       | 44/200 [03:55<12:41,  4.88s/it]

Epoch 44, NLL: 5.589, Alpha: 1.979, Beta: 0.677, W: [1.4045072793960571, 4.7858710289001465, 3.044760227203369]
Epoch 45 of 200


 22%|██▎       | 45/200 [04:01<13:01,  5.04s/it]

Epoch 45, NLL: 5.588, Alpha: 1.984, Beta: 0.678, W: [1.3965378999710083, 4.789424896240234, 3.04917573928833]
Epoch 46 of 200


 23%|██▎       | 46/200 [04:09<15:16,  5.95s/it]

Epoch 46, NLL: 5.587, Alpha: 1.989, Beta: 0.68, W: [1.3892003297805786, 4.793007850646973, 3.0529301166534424]
Epoch 47 of 200


 24%|██▎       | 47/200 [04:18<17:59,  7.05s/it]

Epoch 47, NLL: 5.587, Alpha: 1.993, Beta: 0.681, W: [1.3824137449264526, 4.796599388122559, 3.0561254024505615]
Epoch 48 of 200


 24%|██▍       | 48/200 [04:28<20:04,  7.93s/it]

Epoch 48, NLL: 5.587, Alpha: 1.997, Beta: 0.683, W: [1.37610924243927, 4.800182342529297, 3.058846950531006]
Epoch 49 of 200


 24%|██▍       | 49/200 [04:31<16:13,  6.45s/it]

Epoch 49, NLL: 5.587, Alpha: 2.001, Beta: 0.684, W: [1.3702282905578613, 4.803743362426758, 3.061167001724243]
Epoch 50 of 200


 25%|██▌       | 50/200 [04:35<14:16,  5.71s/it]

Epoch 50, NLL: 5.587, Alpha: 2.005, Beta: 0.686, W: [1.3647204637527466, 4.807271957397461, 3.0631463527679443]
Epoch 51 of 200


 26%|██▌       | 51/200 [04:40<13:01,  5.24s/it]

Epoch 51, NLL: 5.587, Alpha: 2.009, Beta: 0.688, W: [1.3595424890518188, 4.810760021209717, 3.064836025238037]
Epoch 52 of 200


 26%|██▌       | 52/200 [04:44<12:24,  5.03s/it]

Epoch 52, NLL: 5.587, Alpha: 2.012, Beta: 0.69, W: [1.3546571731567383, 4.814201831817627, 3.066279649734497]
Epoch 53 of 200


 26%|██▋       | 53/200 [04:51<13:40,  5.58s/it]

Epoch 53, NLL: 5.587, Alpha: 2.016, Beta: 0.692, W: [1.3500323295593262, 4.817592620849609, 3.0675137042999268]
Epoch 54 of 200


 27%|██▋       | 54/200 [04:56<13:01,  5.35s/it]

Epoch 54, NLL: 5.588, Alpha: 2.019, Beta: 0.694, W: [1.3456400632858276, 4.820929527282715, 3.0685691833496094]
Epoch 55 of 200


 28%|██▊       | 55/200 [05:01<13:09,  5.45s/it]

Epoch 55, NLL: 5.588, Alpha: 2.022, Beta: 0.696, W: [1.3414561748504639, 4.824209690093994, 3.069472551345825]
Epoch 56 of 200


 28%|██▊       | 56/200 [05:05<11:30,  4.80s/it]

Epoch 56, NLL: 5.588, Alpha: 2.025, Beta: 0.698, W: [1.3374598026275635, 4.827432632446289, 3.0702462196350098]
Epoch 57 of 200


 28%|██▊       | 57/200 [05:08<10:11,  4.28s/it]

Epoch 57, NLL: 5.589, Alpha: 2.028, Beta: 0.7, W: [1.333632469177246, 4.830596923828125, 3.070909023284912]
Epoch 58 of 200


 29%|██▉       | 58/200 [05:13<10:41,  4.52s/it]

Epoch 58, NLL: 5.589, Alpha: 2.031, Beta: 0.703, W: [1.329958200454712, 4.83370304107666, 3.071477174758911]
Epoch 59 of 200


 30%|██▉       | 59/200 [05:19<11:55,  5.07s/it]

Epoch 59, NLL: 5.589, Alpha: 2.034, Beta: 0.705, W: [1.326422929763794, 4.8367509841918945, 3.0719645023345947]
Epoch 60 of 200


 30%|███       | 60/200 [05:25<12:01,  5.16s/it]

Epoch 60, NLL: 5.59, Alpha: 2.037, Beta: 0.707, W: [1.3230143785476685, 4.8397417068481445, 3.0723824501037598]
Epoch 61 of 200


 30%|███       | 61/200 [05:30<11:55,  5.15s/it]

Epoch 61, NLL: 5.59, Alpha: 2.04, Beta: 0.709, W: [1.3197216987609863, 4.842675685882568, 3.0727410316467285]
Epoch 62 of 200


 31%|███       | 62/200 [05:36<12:23,  5.39s/it]

Epoch 62, NLL: 5.59, Alpha: 2.042, Beta: 0.711, W: [1.316535472869873, 4.845554351806641, 3.0730488300323486]
Epoch 63 of 200


 32%|███▏      | 63/200 [05:41<11:59,  5.25s/it]

Epoch 63, NLL: 5.591, Alpha: 2.045, Beta: 0.714, W: [1.313447117805481, 4.848378658294678, 3.073312997817993]
Epoch 64 of 200


 32%|███▏      | 64/200 [05:46<12:02,  5.31s/it]

Epoch 64, NLL: 5.591, Alpha: 2.048, Beta: 0.716, W: [1.310449242591858, 4.851149559020996, 3.073539972305298]
Epoch 65 of 200


 32%|███▎      | 65/200 [05:49<10:30,  4.67s/it]

Epoch 65, NLL: 5.592, Alpha: 2.05, Beta: 0.718, W: [1.307535171508789, 4.8538689613342285, 3.073734760284424]
Epoch 66 of 200


 33%|███▎      | 66/200 [05:52<09:22,  4.19s/it]

Epoch 66, NLL: 5.592, Alpha: 2.053, Beta: 0.72, W: [1.3046989440917969, 4.856537818908691, 3.073902130126953]
Epoch 67 of 200


 34%|███▎      | 67/200 [05:57<09:45,  4.40s/it]

Epoch 67, NLL: 5.592, Alpha: 2.056, Beta: 0.723, W: [1.3019354343414307, 4.859157562255859, 3.0740458965301514]
Epoch 68 of 200


 34%|███▍      | 68/200 [06:03<10:56,  4.97s/it]

Epoch 68, NLL: 5.593, Alpha: 2.058, Beta: 0.725, W: [1.2992398738861084, 4.861729621887207, 3.074169158935547]
Epoch 69 of 200


 34%|███▍      | 69/200 [06:11<12:40,  5.81s/it]

Epoch 69, NLL: 5.593, Alpha: 2.061, Beta: 0.727, W: [1.2966080904006958, 4.864255428314209, 3.074275016784668]
Epoch 70 of 200


 35%|███▌      | 70/200 [06:19<13:47,  6.37s/it]

Epoch 70, NLL: 5.593, Alpha: 2.063, Beta: 0.729, W: [1.2940363883972168, 4.86673641204834, 3.0743658542633057]
Epoch 71 of 200


 36%|███▌      | 71/200 [06:25<13:20,  6.21s/it]

Epoch 71, NLL: 5.594, Alpha: 2.066, Beta: 0.731, W: [1.2915213108062744, 4.869173526763916, 3.0744435787200928]
Epoch 72 of 200


 36%|███▌      | 72/200 [06:31<13:37,  6.38s/it]

Epoch 72, NLL: 5.594, Alpha: 2.068, Beta: 0.733, W: [1.2890597581863403, 4.871568202972412, 3.074510335922241]
Epoch 73 of 200


 36%|███▋      | 73/200 [06:38<13:23,  6.33s/it]

Epoch 73, NLL: 5.594, Alpha: 2.071, Beta: 0.736, W: [1.2866489887237549, 4.873921871185303, 3.0745673179626465]
Epoch 74 of 200


 37%|███▋      | 74/200 [06:41<11:28,  5.46s/it]

Epoch 74, NLL: 5.595, Alpha: 2.074, Beta: 0.738, W: [1.2842864990234375, 4.876235485076904, 3.0746161937713623]
Epoch 75 of 200


 38%|███▊      | 75/200 [06:45<10:05,  4.85s/it]

Epoch 75, NLL: 5.595, Alpha: 2.076, Beta: 0.74, W: [1.2819700241088867, 4.878510475158691, 3.074657917022705]
Epoch 76 of 200


 38%|███▊      | 76/200 [06:52<11:34,  5.60s/it]

Epoch 76, NLL: 5.595, Alpha: 2.079, Beta: 0.742, W: [1.2796975374221802, 4.880747318267822, 3.074693441390991]
Epoch 77 of 200


 38%|███▊      | 77/200 [07:00<13:02,  6.36s/it]

Epoch 77, NLL: 5.596, Alpha: 2.081, Beta: 0.744, W: [1.277467131614685, 4.8829474449157715, 3.074723720550537]
Epoch 78 of 200


 39%|███▉      | 78/200 [07:08<13:59,  6.88s/it]

Epoch 78, NLL: 5.596, Alpha: 2.084, Beta: 0.746, W: [1.275277018547058, 4.8851118087768555, 3.074749231338501]
Epoch 79 of 200


 40%|███▉      | 79/200 [07:17<14:47,  7.34s/it]

Epoch 79, NLL: 5.596, Alpha: 2.086, Beta: 0.748, W: [1.2731255292892456, 4.887241840362549, 3.074770927429199]
Epoch 80 of 200


 40%|████      | 80/200 [07:24<14:26,  7.22s/it]

Epoch 80, NLL: 5.597, Alpha: 2.089, Beta: 0.75, W: [1.2710113525390625, 4.88933801651001, 3.074789047241211]
Epoch 81 of 200


 40%|████      | 81/200 [07:27<12:23,  6.25s/it]

Epoch 81, NLL: 5.597, Alpha: 2.091, Beta: 0.752, W: [1.2689330577850342, 4.8914008140563965, 3.0748043060302734]
Epoch 82 of 200


 41%|████      | 82/200 [07:39<15:40,  7.97s/it]

Epoch 82, NLL: 5.597, Alpha: 2.094, Beta: 0.754, W: [1.2668893337249756, 4.893431663513184, 3.074817180633545]
Epoch 83 of 200


 42%|████▏     | 83/200 [07:49<16:19,  8.37s/it]

Epoch 83, NLL: 5.597, Alpha: 2.096, Beta: 0.756, W: [1.2648791074752808, 4.8954315185546875, 3.0748276710510254]
Epoch 84 of 200


 42%|████▏     | 84/200 [07:56<15:38,  8.09s/it]

Epoch 84, NLL: 5.598, Alpha: 2.099, Beta: 0.758, W: [1.2629013061523438, 4.897400856018066, 3.074836254119873]
Epoch 85 of 200


 42%|████▎     | 85/200 [07:59<12:35,  6.57s/it]

Epoch 85, NLL: 5.598, Alpha: 2.101, Beta: 0.76, W: [1.2609548568725586, 4.8993401527404785, 3.074843406677246]
Epoch 86 of 200


 43%|████▎     | 86/200 [08:02<10:27,  5.50s/it]

Epoch 86, NLL: 5.598, Alpha: 2.104, Beta: 0.762, W: [1.2590389251708984, 4.90125036239624, 3.0748491287231445]
Epoch 87 of 200


 44%|████▎     | 87/200 [08:07<09:49,  5.22s/it]

Epoch 87, NLL: 5.598, Alpha: 2.106, Beta: 0.764, W: [1.2571525573730469, 4.903132438659668, 3.0748536586761475]
Epoch 88 of 200


 44%|████▍     | 88/200 [08:14<10:35,  5.68s/it]

Epoch 88, NLL: 5.599, Alpha: 2.109, Beta: 0.766, W: [1.255294919013977, 4.90498685836792, 3.074856996536255]
Epoch 89 of 200


 44%|████▍     | 89/200 [08:22<11:52,  6.42s/it]

Epoch 89, NLL: 5.599, Alpha: 2.111, Beta: 0.768, W: [1.2534652948379517, 4.906814098358154, 3.074859380722046]
Epoch 90 of 200


 45%|████▌     | 90/200 [08:29<12:14,  6.68s/it]

Epoch 90, NLL: 5.599, Alpha: 2.114, Beta: 0.77, W: [1.2516628503799438, 4.908614635467529, 3.0748610496520996]
Epoch 91 of 200


 46%|████▌     | 91/200 [08:35<11:48,  6.50s/it]

Epoch 91, NLL: 5.599, Alpha: 2.117, Beta: 0.771, W: [1.2498869895935059, 4.910389423370361, 3.074862003326416]
Epoch 92 of 200


 46%|████▌     | 92/200 [08:40<10:50,  6.03s/it]

Epoch 92, NLL: 5.6, Alpha: 2.119, Beta: 0.773, W: [1.2481369972229004, 4.912138938903809, 3.074862480163574]
Epoch 93 of 200


 46%|████▋     | 93/200 [08:43<09:12,  5.16s/it]

Epoch 93, NLL: 5.6, Alpha: 2.122, Beta: 0.775, W: [1.2464122772216797, 4.913863658905029, 3.074862480163574]
Epoch 94 of 200


 47%|████▋     | 94/200 [08:47<08:33,  4.84s/it]

Epoch 94, NLL: 5.6, Alpha: 2.124, Beta: 0.777, W: [1.244712233543396, 4.915564060211182, 3.074862003326416]
Epoch 95 of 200


 48%|████▊     | 95/200 [08:54<09:25,  5.38s/it]

Epoch 95, NLL: 5.6, Alpha: 2.127, Beta: 0.779, W: [1.2430363893508911, 4.917240619659424, 3.0748610496520996]
Epoch 96 of 200


 48%|████▊     | 96/200 [09:00<09:40,  5.58s/it]

Epoch 96, NLL: 5.6, Alpha: 2.129, Beta: 0.781, W: [1.2413840293884277, 4.918894290924072, 3.074859857559204]
Epoch 97 of 200


 48%|████▊     | 97/200 [09:06<09:37,  5.61s/it]

Epoch 97, NLL: 5.601, Alpha: 2.132, Beta: 0.782, W: [1.2397547960281372, 4.920525074005127, 3.0748584270477295]
Epoch 98 of 200


 49%|████▉     | 98/200 [09:11<09:39,  5.68s/it]

Epoch 98, NLL: 5.601, Alpha: 2.135, Beta: 0.784, W: [1.2381480932235718, 4.922133445739746, 3.074856758117676]
Epoch 99 of 200


 50%|████▉     | 99/200 [09:16<09:05,  5.40s/it]

Epoch 99, NLL: 5.601, Alpha: 2.137, Beta: 0.786, W: [1.2365635633468628, 4.923719882965088, 3.074854850769043]
Epoch 100 of 200


 50%|█████     | 100/200 [09:19<07:49,  4.69s/it]

Epoch 100, NLL: 5.601, Alpha: 2.14, Beta: 0.788, W: [1.235000729560852, 4.9252848625183105, 3.074852705001831]
Epoch 101 of 200


 50%|█████     | 101/200 [09:22<06:54,  4.19s/it]

Epoch 101, NLL: 5.601, Alpha: 2.142, Beta: 0.789, W: [1.2334591150283813, 4.926828861236572, 3.074850559234619]
Epoch 102 of 200


 51%|█████     | 102/200 [09:27<07:03,  4.32s/it]

Epoch 102, NLL: 5.602, Alpha: 2.145, Beta: 0.791, W: [1.2319382429122925, 4.928351879119873, 3.0748484134674072]
Epoch 103 of 200


 52%|█████▏    | 103/200 [09:31<06:56,  4.30s/it]

Epoch 103, NLL: 5.602, Alpha: 2.148, Beta: 0.793, W: [1.2304377555847168, 4.929854869842529, 3.074846029281616]
Epoch 104 of 200


 52%|█████▏    | 104/200 [09:36<07:11,  4.50s/it]

Epoch 104, NLL: 5.602, Alpha: 2.15, Beta: 0.794, W: [1.2289572954177856, 4.931337833404541, 3.074843645095825]
Epoch 105 of 200


 52%|█████▎    | 105/200 [09:42<08:01,  5.07s/it]

Epoch 105, NLL: 5.602, Alpha: 2.153, Beta: 0.796, W: [1.2274965047836304, 4.932801246643066, 3.074841022491455]
Epoch 106 of 200


 53%|█████▎    | 106/200 [09:51<09:35,  6.12s/it]

Epoch 106, NLL: 5.602, Alpha: 2.155, Beta: 0.798, W: [1.2260550260543823, 4.9342451095581055, 3.074838399887085]
Epoch 107 of 200


 54%|█████▎    | 107/200 [09:55<08:23,  5.42s/it]

Epoch 107, NLL: 5.602, Alpha: 2.158, Beta: 0.799, W: [1.2246325016021729, 4.935670375823975, 3.074835777282715]
Epoch 108 of 200


 54%|█████▍    | 108/200 [09:59<07:47,  5.08s/it]

Epoch 108, NLL: 5.603, Alpha: 2.161, Beta: 0.801, W: [1.2232285737991333, 4.937077045440674, 3.0748331546783447]
Epoch 109 of 200


 55%|█████▍    | 109/200 [10:04<07:44,  5.11s/it]

Epoch 109, NLL: 5.603, Alpha: 2.163, Beta: 0.802, W: [1.221842885017395, 4.938465118408203, 3.0748305320739746]
Epoch 110 of 200


 55%|█████▌    | 110/200 [10:09<07:34,  5.04s/it]

Epoch 110, NLL: 5.603, Alpha: 2.166, Beta: 0.804, W: [1.220475196838379, 4.939835548400879, 3.0748279094696045]
Epoch 111 of 200


 56%|█████▌    | 111/200 [10:14<07:31,  5.07s/it]

Epoch 111, NLL: 5.603, Alpha: 2.169, Beta: 0.806, W: [1.2191251516342163, 4.941188335418701, 3.0748252868652344]
Epoch 112 of 200


 56%|█████▌    | 112/200 [10:20<07:40,  5.23s/it]

Epoch 112, NLL: 5.603, Alpha: 2.171, Beta: 0.807, W: [1.2177925109863281, 4.94252347946167, 3.0748226642608643]
Epoch 113 of 200


 56%|█████▋    | 113/200 [10:25<07:25,  5.12s/it]

Epoch 113, NLL: 5.603, Alpha: 2.174, Beta: 0.809, W: [1.2164769172668457, 4.943841934204102, 3.074820041656494]
Epoch 114 of 200


 57%|█████▋    | 114/200 [10:28<06:26,  4.49s/it]

Epoch 114, NLL: 5.603, Alpha: 2.177, Beta: 0.81, W: [1.21517813205719, 4.945143222808838, 3.074817419052124]
Epoch 115 of 200


 57%|█████▊    | 115/200 [10:32<06:08,  4.34s/it]

Epoch 115, NLL: 5.603, Alpha: 2.179, Beta: 0.812, W: [1.2138957977294922, 4.946428298950195, 3.074814796447754]
Epoch 116 of 200


 58%|█████▊    | 116/200 [10:38<06:42,  4.79s/it]

Epoch 116, NLL: 5.604, Alpha: 2.182, Beta: 0.813, W: [1.212629795074463, 4.947697162628174, 3.074812173843384]
Epoch 117 of 200


 58%|█████▊    | 117/200 [10:46<08:12,  5.93s/it]

Epoch 117, NLL: 5.604, Alpha: 2.185, Beta: 0.815, W: [1.2113797664642334, 4.948949813842773, 3.0748095512390137]
Epoch 118 of 200


 59%|█████▉    | 118/200 [10:53<08:19,  6.09s/it]

Epoch 118, NLL: 5.604, Alpha: 2.187, Beta: 0.816, W: [1.2101454734802246, 4.950186729431152, 3.0748069286346436]
Epoch 119 of 200


 60%|█████▉    | 119/200 [11:00<08:41,  6.43s/it]

Epoch 119, NLL: 5.604, Alpha: 2.19, Beta: 0.818, W: [1.2089266777038574, 4.951408386230469, 3.0748043060302734]
Epoch 120 of 200


 60%|██████    | 120/200 [11:07<08:52,  6.65s/it]

Epoch 120, NLL: 5.604, Alpha: 2.193, Beta: 0.819, W: [1.2077230215072632, 4.952614784240723, 3.0748016834259033]
Epoch 121 of 200


 60%|██████    | 121/200 [11:11<07:29,  5.69s/it]

Epoch 121, NLL: 5.604, Alpha: 2.195, Beta: 0.821, W: [1.2065343856811523, 4.953805923461914, 3.074799060821533]
Epoch 122 of 200


 61%|██████    | 122/200 [11:16<07:11,  5.53s/it]

Epoch 122, NLL: 5.604, Alpha: 2.198, Beta: 0.822, W: [1.2053605318069458, 4.954982280731201, 3.074796438217163]
Epoch 123 of 200


 62%|██████▏   | 123/200 [11:22<07:22,  5.74s/it]

Epoch 123, NLL: 5.604, Alpha: 2.201, Beta: 0.824, W: [1.204201340675354, 4.956143856048584, 3.074793815612793]
Epoch 124 of 200


 62%|██████▏   | 124/200 [11:31<08:33,  6.76s/it]

Epoch 124, NLL: 5.604, Alpha: 2.204, Beta: 0.825, W: [1.2030564546585083, 4.957291126251221, 3.074791193008423]
Epoch 125 of 200


 62%|██████▎   | 125/200 [11:38<08:33,  6.84s/it]

Epoch 125, NLL: 5.605, Alpha: 2.206, Beta: 0.827, W: [1.2019256353378296, 4.9584245681762695, 3.074788808822632]
Epoch 126 of 200


 63%|██████▎   | 126/200 [11:41<07:04,  5.74s/it]

Epoch 126, NLL: 5.605, Alpha: 2.209, Beta: 0.828, W: [1.2008087635040283, 4.959543704986572, 3.074786424636841]
Epoch 127 of 200


 64%|██████▎   | 127/200 [11:44<05:59,  4.93s/it]

Epoch 127, NLL: 5.605, Alpha: 2.212, Beta: 0.829, W: [1.199705719947815, 4.960649490356445, 3.07478404045105]
Epoch 128 of 200


 64%|██████▍   | 128/200 [11:49<05:58,  4.98s/it]

Epoch 128, NLL: 5.605, Alpha: 2.215, Beta: 0.831, W: [1.1986161470413208, 4.9617414474487305, 3.074781656265259]
Epoch 129 of 200


 64%|██████▍   | 129/200 [11:55<06:08,  5.20s/it]

Epoch 129, NLL: 5.605, Alpha: 2.217, Beta: 0.832, W: [1.1975399255752563, 4.962820053100586, 3.0747792720794678]
Epoch 130 of 200


 65%|██████▌   | 130/200 [12:04<07:31,  6.44s/it]

Epoch 130, NLL: 5.605, Alpha: 2.22, Beta: 0.834, W: [1.1964768171310425, 4.963885307312012, 3.0747768878936768]
Epoch 131 of 200
