In this notebook I walk through the generation/inferece pipeline (spending most of the time on inference).

In [1]:
from abc import ABC, abstractmethod

import os
from tqdm import tqdm
import math 
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 

from scipy import stats

import torch
from torch.distributions import Beta
from torch.distributions.bernoulli import Bernoulli
from torch.nn.functional import log_softmax
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, TensorDataset

In [2]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f05af2cf350>

### Generation:

Here I am using 2/10 generation.

In [3]:
### these params control the generation scheme
rho        = 0.8   # polarization
pop_size   = 50000 # num individuals
epsilon    = 0.05  # expected prop of speech consisting of neutral words
pi         = 0.5   # pi == 0.5 => beta mixture symmetrical (choose beta1 with prob pi = 0.5)
speech_len = 15    # words per speech

In [7]:
def generate(rho=rho, N=pop_size, epsilon=epsilon, pi=pi, speech_len=speech_len):
    """
    Uses 2/10 generation scheme to generate N samples.

    Returns:
        (X, y), (a, b, rho, epsilon, u)
        X.size() == [N, 3] is a vector of word counts
        y.size() == [N] is a vector of political parties
        u.shape  == (N,) is a vector of individual stances
        rho is true polarization
        epsilon is expected prop of neutral words
        a, b are true alpha/beta for beta mixture model
    """
    start = time.time()
    print(f'Beginning Data Generation...')
    print(f'=' * 20)
    
    ### get beta mixture model params 
    sigma = 0.175 * (rho ** 2) - 0.3625 * rho + 0.1875
    a     = rho * ((rho * (1 - rho)) / sigma - 1)
    b     = (1 - rho) * ((rho * (1 - rho)) / sigma - 1)

    print(f"True Alpha: {a}")
    print(f"True Beta: {b}")
    print(f"True Epsilon: {epsilon}\n")

    mean = a / (a + b)
    var  = a * b / ((a + b)**2 * (a + b + 1))

    if abs(mean - rho) > 10e-15:
        print(f'Mean: {mean}')
        print(f"Rho: {rho}")
        raise AssertionError(f"Mean of BMM params should be rho")

    if abs(var - sigma) > 10e-15:
        print(f'Var: {var}')
        print(f'Sigma: {sigma}')
        raise AssertionError(f"Var of BMM params should be sigma")

    ### u ~ pi Beta(a, b) + (1 - pi) Beta(b, a)
    weights = [pi, 1-pi]
    mixture_samples = np.random.choice([0, 1], size=N, p=weights)

    u = 2 * np.where(mixture_samples == 1, stats.beta.rvs(a, b, size=N), stats.beta.rvs(b, a, size=N)) - 1

    if u.shape != (N,):
        raise AssertionError(f"u.shape should be (N,)")

    print(f'mixture samples: {mixture_samples[:5]}')
    print(f'u samples: {u[:5]}\n')

    ### y = 1(u >= 0)
    y = (u >= 0).astype(int)
    print(f'y samples: {y[:5]}\n')

    ### phi is a prob matrix that is a function of u, epsilon
    phi = np.array([(1 - (u+1)/2) * (1 - epsilon), (u+1)/2 * (1 - epsilon), np.repeat(epsilon, pop_size)]).T
    print(f'phi samples:\n {phi[:5, :]}\n')

    if phi.shape != (N, 3):
        raise AssertionError(f'phi.shape should be (N, V) == (N, 3)')

    if abs(sum(phi[0]) - 1) > 10e-5:
        raise AssertionError(f'rows of phi should sum to 1')
    
    X = np.array([stats.multinomial.rvs(n=speech_len, p=phi[i, :]) for i in range(N)])
    print(f'X samples:\n {X[:5, :]}\n')
    
    if X.shape != (N, 3):
        raise AssertionError(f'X.shape should be (N, V) == (N, 3)')

    if X[:5].sum() != 15 * 5:
        raise AssertionError(f'rows of phi should sum to 1')

    X = torch.from_numpy(X).to(torch.float32)
    y = torch.from_numpy(y).to(torch.float32)

    known   = (X, y)
    unknown = (a, b, rho, epsilon, u)
    print('=' * 20)
    print(f'Generation Time: {round(time.time() - start, 3)} seconds for {N} samples.')
    return known, unknown

known, unknown = generate()
X, y = known
a, b, rho, epsilon, u = unknown

Beginning Data Generation...
True Alpha: 12.673684210526261
True Beta: 3.1684210526315644
True Epsilon: 0.05

mixture samples: [1 0 1 1 1]
u samples: [ 0.58836654 -0.44066272  0.36309715  0.47382017  0.55063803]

y samples: [1 0 1 1 1]

phi samples:
 [[0.19552589 0.75447411 0.05      ]
 [0.68431479 0.26568521 0.05      ]
 [0.30252886 0.64747114 0.05      ]
 [0.24993542 0.70006458 0.05      ]
 [0.21344694 0.73655306 0.05      ]]

X samples:
 [[ 5  9  1]
 [12  3  0]
 [ 7  6  2]
 [ 2 12  1]
 [ 4 11  0]]

Generation Time: 1.637 seconds for 50000 samples.


### Inference

The goal of inference is to recover $u$ given a sample of $(x,y)$ pairs. We assume the following data generating process:
\begin{align*}
    u &\sim 2 \cdot \left(\pi\cdot\mathrm{Beta}\left(u;\alpha,\beta\right)+(1-\pi)\cdot\mathrm{Beta}\left(u;\beta,\alpha\right)\right)-1 \\
    y &= 1(u \ge \lambda) \\
    x &\sim \mathrm{Multinomial}\left(S, \mathrm{Softmax}(Wu)\right), W \in \R^V.
\end{align*}

For now, we assume $\lambda = 0, V = 3, \pi = 0.5$, so our parameters are $\theta = \{\alpha, \beta, W\}$. Given we assume this full form of $u$ parameterized by $\alpha, \beta$, recovering the distribution of $u$ is equivalent to recovering $\alpha, \beta$. This also gives us the following joint:

\begin{align*}

p(u,x^{(n)},y^{(n)};\theta) &= p(u)p(y^{(n)}|u)\prod_{s=1}^{S}p(x^{(n)}_s | u) \\
&= \left(\frac{1}{4}\mathrm{Beta}\left(\frac{u+1}{2};\alpha,\beta\right)+\frac{1}{4}\mathrm{Beta}\left(\frac{u+1}{2};\beta,\alpha\right)\right) \\&\cdot 1\left(1(u\ge0)=y^{(n)}\right)\frac{S!}{x^{(n)}_1!x^{(n)}_2!x^{(n)}_3!}\prod_{s=1}^{S}\mathrm{softmax}(Wu)_{x^{(n)}_s}

\end{align*}

Here, we can get the density of $u$ by noting $u$ is a monotonic transform of a random variable whose density is known. Thus, applying a formula found [here](https://en.wikipedia.org/wiki/Probability_density_function#Function_of_random_variables_and_change_of_variables_in_the_probability_density_function) we get the above expression.

The basic idea is to maximize the log marginal likelihood. Let $\theta = \{\alpha, \beta, W\}$. We'd like to find 
$$
\theta^* = \argmax_{\theta} \underbrace{\log p(x^{(1:N)}, y^{(1:N)}; \theta)}_{\mathcal{L}(\theta)}. 
$$


The goal is to recover $\alpha, \beta$ (hence $u$ given we are assuming $u \sim 2\left(\frac{1}{2} \cdot \mathrm{Beta}(\alpha, \beta) + \frac{1}{2} \cdot \mathrm{Beta}(\beta, \alpha)\right) - 1$ given $(x,y)$ pairs.

To get MLE estimates for alpha, beta, W, we will take gradient descent steps on the log marginal likelihood. Its gradient is given by the following expression:
$$
\nabla_\theta \log p(x^{(1:N)},y^{(1:N)}) = \sum_{n=1}^{N} \mathbb{E}_{p(u|x^{(n)},y^{(n)};\theta)}[\nabla_\theta \log p(u,x^{(n)},y^{(n)};\theta)].
$$
There are a number of issues here, beginning with the fact that we cannot analytically recover the posterior due to the integral being intractable. However, since $u \in [-1,1]$, we can discretize the integral, which will be simpler than something like MCMC for sampling/using variational inference to approximate the posterior. To do this, let's write the gradient as follows:
$$
\nabla_\theta \log p(x^{(1:N)},y^{(1:N)}) = \sum_{n=1}^{N} \int_{u} \frac{p(x^{(n)}, y^{(n)}, u;\ \theta)}{p(x^{(n)},y^{(n)};\  \theta)}\nabla_\theta \log p(u,x^{(n)},y^{(n)};\theta) \mathrm{d}u
$$
To discretize, we simply need to assume that $u$ takes on $G$ linearly spaced points in $[-1,1]$, rather than is continuous in $[-1,1]$, which gives us 
$$
\nabla_\theta \log p(x^{(1:N)},y^{(1:N)}) = \sum_{n=1}^{N} \sum_{u} \frac{p(x^{(n)}, y^{(n)}, u;\ \theta)}{p(x^{(n)},y^{(n)};\  \theta)}\nabla_\theta \log p(u,x^{(n)},y^{(n)};\theta)
$$
Notice that when $y \ne 1(u \ge 0)$, the joint probability $p(x^{(n)}, y^{(n)}, u;\ \theta)$ is $0$ ($u$ and $y$ are 'incompatible'). Thus, in practice we can use $y^{(n)}$ as a condition: if $y^{(n)} = 1$ just compute $p(u, x^{(n)}, y^{(n)})$ for $G$ linearly spaced points in $[0,1]$, and then normalize that probability.

It's convenient to move the gradient outside the sum, since then we can efficiently compute the double sum via matrices and then just sum the matrix/call .backward(). This is what I was doing initially, but it's definitely wrong, since the posterior relies on theta. To fix this, I call .detach() on the "weight" (posterior probability), which means the gradients only flow through the joint log probability (which is what we want).

In [31]:
grid_size  = 200 ### discretize integral into grid_size linearly spaced pts
batch_size = 100 ### num samples processed simultaneously
num_epochs = 200 ### number epochs
lr         = 0.1 ### learning rate

def compute_nll(x_batch, y_batch, log_alpha, log_beta, W):
    """
    Parameters:
        (x_batch, y_batch) a batch of B (x,y) samples 
        (log_alpha, log_beta, W) trainable parameters

    Returns:
        A scaler quantity s.t. calling .backward() will compute
        grad_{theta} (L(theta))
    """
    raise NotImplementedError

def train(X, y, num_epochs=num_epochs, batch_size=batch_size, grid_size=grid_size, lr=lr):
    print(f'Beginning Inference...')
    print(f'=' * 20)
    print(f'Hyperparameters:')
    print(f'Dataset Length: {len(X)}')
    print(f'Number Epochs: {num_epochs}')
    print(f'Batch Size: {batch_size}')
    print(f'Grid Size: {grid_size}')
    print(f"Learning Rate: {lr}")
    print('=' * 20)
    
    start = time.time()

    dataset    = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    X_batch, y_batch = list(dataloader)[0]

    if list(X_batch.size()) != [batch_size, 3]:
        raise AssertionError(f"X_ex.size() should be [batch_size, V] == [batch_size, 3], got {X_batch.size()}")

    if list(y_batch.size()) != [batch_size]:
        raise AssertionError(f"y_ex.size() should be [batch_size], got {y_batch.size()}")

    ### play with parameter initialization!
    log_alpha = torch.normal(1.5, 1, size=(1,), requires_grad=True)
    log_beta  = torch.normal(0, 1, size=(1,), requires_grad=True)
    W         = torch.normal(3, 2, size=(3,), requires_grad=True)

    print(f'Initial Parameters:')
    print(f'Alpha: {torch.exp(log_alpha).item()}')
    print(f"Beta: {torch.exp(log_beta).item()}")
    print(f"W: {W.data.tolist()}")
    optimizer = SGD([log_alpha, log_beta, W], lr=lr)
    
    for epoch in tqdm(range(num_epochs)):
        print(f'Epoch {epoch+1} of {num_epochs}')
        ### only step after each epoch
        optimizer.zero_grad() 
        nll = 0
        for x_batch, y_batch in dataloader:
            nll += compute_nll(x_batch, y_batch, log_alpha, log_beta, W, grid_size)
        
        nll = nll / len(dataloader.dataset)
        nll.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}, NLL: {round(nll.item(), 3)}, Alpha: {round(np.exp(log_alpha.item()),3)}, Beta: {round(np.exp(log_beta.item()),3)}, W: {W.tolist()}')
        print('=' * 20)

    final_alpha = np.exp(log_alpha.item())
    final_beta  = np.exp(log_beta.item())

    print(f"Training Took {round(time.time() - start, 2)} Seconds.")
    print(f"Trained Params: alpha: {final_alpha}, beta: {final_beta}, W: {W.tolist()}")

    return final_alpha, final_beta, W

In [32]:
factorial = lambda x : torch.exp(torch.lgamma(x+1))

def compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
        u_mat   (size [batch_size, N])
        (log_alpha, log_beta, W) trainable params 
        constants (bool) true iff include constants
    
    constants are necessary for computing joint probabilities exactly but not for gradient expressions
    
    Returns:
        torch.tensor (size [batch_size, N]), where the (i,j)th element 
        is the log joint probability (log p(u, x, y; theta)). Here u is the (i,j)th
        element of u_mat, x is the ith row of x_batch, y is the ith element of y_batch

    """
    ### log prior: log p(u) = log(1/4 Beta((u+1)/2; a, b) + 1/4 Beta((u+1)/2; b, a))
    alpha = torch.exp(log_alpha)
    beta  = torch.exp(log_beta)
        
    beta_dist_ab = Beta(alpha, beta)
    beta_dist_ba = Beta(beta, alpha)

    log_beta_prob_ab = beta_dist_ab.log_prob((u_mat + 1) / 2)
    log_beta_prob_ba = beta_dist_ba.log_prob((u_mat + 1) / 2)

    prior = torch.log(torch.exp(log_beta_prob_ab) + torch.exp(log_beta_prob_ba)) + torch.log(torch.tensor(0.25))

    assert prior.size() == u_mat.size()
  
    ### log likelihood: log p(x, y | u) = dot(x, log Softmax(Wu)) + log(S!) - log(x1!) - log(x2!) - log(x3!) (assuming y is compatible)
    W_expanded = W.unsqueeze(0).unsqueeze(0)
    assert list(W_expanded.size()) == [1,1,3]
    u_expanded = u_mat.unsqueeze(2)

    softmax_input = torch.matmul(u_expanded, W_expanded)  ### size BxNx3
    likelihood = (x_batch.unsqueeze(1) * log_softmax(softmax_input, dim=2)).sum(dim=2) ### size BxN
    
    likelihood += torch.log(factorial(torch.tensor(speech_len)))  ### log(S!)
    likelihood -= torch.log(factorial(x_batch)).sum(dim=1, keepdim=True) # - log(x_1!) -log(x_2!) - log(x_3!)

    return prior + likelihood


def compute_nll(x_batch, y_batch, log_alpha, log_beta, W, grid_size):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
    
    Returns:
        torch.tensor (size [batch_size, approx_size]), where the ith row are 
        samples from the approximate posterior p(u | x^(i), y^(i); theta).
        Here, x^(i), y^(i) are the ith row of x_batch/y_batch. 
    """
    u_mat = torch.empty(batch_size, grid_size)
    u_mat[y_batch == 1] = torch.linspace(1/(grid_size+1), 1-1/(grid_size+1), grid_size).repeat((y_batch == 1).sum(), 1)
    u_mat[y_batch == 0] = torch.linspace(-1+1/(grid_size+1), -1/(grid_size+1), grid_size).repeat((y_batch == 0).sum(), 1)

    assert list(u_mat.size()) == [x_batch.size()[0], grid_size]

    log_joint = compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W)

    ### log sum exp (don't need to add back bc scaling)
    max_log_prob = torch.max(log_joint, dim=1, keepdim=True)[0]
    joint_probs = torch.exp(log_joint - max_log_prob)
    posterior = joint_probs / joint_probs.sum(dim=1, keepdim=True)  

    # detach posterior
    weighted_log_joint = posterior.detach() * log_joint
    nll = -weighted_log_joint.sum()

    return nll

In [33]:
train(X, y)

Beginning Inference...
Hyperparameters:
Dataset Length: 50000
Number Epochs: 200
Batch Size: 100
Grid Size: 200
Learning Rate: 0.1
Initial Parameters:
Alpha: 2.9923245906829834
Beta: 5.803126811981201
W: [2.1709372997283936, 4.712274551391602, 2.114123821258545]


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 of 200


  0%|          | 1/200 [00:06<20:48,  6.27s/it]

Epoch 1, NLL: 6.908, Alpha: 2.825, Beta: 6.041, W: [2.0754234790802, 4.7453532218933105, 2.1765589714050293]
Epoch 2 of 200


  1%|          | 2/200 [00:10<16:52,  5.11s/it]

Epoch 2, NLL: 6.728, Alpha: 2.681, Beta: 6.279, W: [1.9800111055374146, 4.7709197998046875, 2.2464048862457275]
Epoch 3 of 200


  2%|▏         | 3/200 [00:16<17:35,  5.36s/it]

Epoch 3, NLL: 6.544, Alpha: 2.556, Beta: 6.511, W: [1.8874129056930542, 4.789531707763672, 2.3203909397125244]
Epoch 4 of 200


  2%|▏         | 4/200 [00:22<18:54,  5.79s/it]

Epoch 4, NLL: 6.364, Alpha: 2.449, Beta: 6.731, W: [1.8003851175308228, 4.802069187164307, 2.394881248474121]
Epoch 5 of 200


  2%|▎         | 5/200 [00:29<20:20,  6.26s/it]

Epoch 5, NLL: 6.201, Alpha: 2.359, Beta: 6.932, W: [1.7210174798965454, 4.809635162353516, 2.4666826725006104]
Epoch 6 of 200


  3%|▎         | 6/200 [00:33<16:57,  5.25s/it]

Epoch 6, NLL: 6.061, Alpha: 2.284, Beta: 7.111, W: [1.6503965854644775, 4.8133978843688965, 2.533540725708008]
Epoch 7 of 200


  4%|▎         | 7/200 [00:37<15:35,  4.85s/it]

Epoch 7, NLL: 5.948, Alpha: 2.222, Beta: 7.264, W: [1.5886757373809814, 4.814443111419678, 2.5942161083221436]
Epoch 8 of 200


  4%|▍         | 8/200 [00:40<14:17,  4.47s/it]

Epoch 8, NLL: 5.859, Alpha: 2.172, Beta: 7.391, W: [1.535359501838684, 4.813689708709717, 2.6482858657836914]
Epoch 9 of 200


  4%|▍         | 9/200 [00:45<14:39,  4.61s/it]

Epoch 9, NLL: 5.791, Alpha: 2.131, Beta: 7.494, W: [1.489605188369751, 4.81186056137085, 2.6958696842193604]
Epoch 10 of 200


  5%|▌         | 10/200 [00:49<14:08,  4.46s/it]

Epoch 10, NLL: 5.74, Alpha: 2.098, Beta: 7.574, W: [1.4504468441009521, 4.809495449066162, 2.737393379211426]
Epoch 11 of 200


  6%|▌         | 11/200 [00:57<16:45,  5.32s/it]

Epoch 11, NLL: 5.702, Alpha: 2.072, Beta: 7.635, W: [1.4169307947158813, 4.806979179382324, 2.773425579071045]
Epoch 12 of 200


  6%|▌         | 12/200 [01:03<17:17,  5.52s/it]

Epoch 12, NLL: 5.674, Alpha: 2.051, Beta: 7.679, W: [1.388183832168579, 4.804574966430664, 2.804576873779297]
Epoch 13 of 200


  6%|▋         | 13/200 [01:06<15:40,  5.03s/it]

Epoch 13, NLL: 5.654, Alpha: 2.035, Beta: 7.709, W: [1.3634395599365234, 4.802452087402344, 2.831444025039673]
Epoch 14 of 200


  7%|▋         | 14/200 [01:12<16:17,  5.26s/it]

Epoch 14, NLL: 5.639, Alpha: 2.022, Beta: 7.727, W: [1.342042326927185, 4.8007121086120605, 2.854581356048584]
Epoch 15 of 200


  8%|▊         | 15/200 [01:16<14:46,  4.79s/it]

Epoch 15, NLL: 5.628, Alpha: 2.011, Beta: 7.735, W: [1.323439359664917, 4.799407482147217, 2.8744890689849854]
Epoch 16 of 200


  8%|▊         | 16/200 [01:21<15:19,  5.00s/it]

Epoch 16, NLL: 5.619, Alpha: 2.003, Beta: 7.736, W: [1.3071691989898682, 4.798556327819824, 2.8916103839874268]
Epoch 17 of 200


  8%|▊         | 17/200 [01:28<16:33,  5.43s/it]

Epoch 17, NLL: 5.613, Alpha: 1.997, Beta: 7.73, W: [1.2928483486175537, 4.798154354095459, 2.9063332080841064]
Epoch 18 of 200


  9%|▉         | 18/200 [01:35<18:24,  6.07s/it]

Epoch 18, NLL: 5.609, Alpha: 1.992, Beta: 7.72, W: [1.2801594734191895, 4.798182010650635, 2.918994665145874]
Epoch 19 of 200


 10%|▉         | 19/200 [01:40<16:54,  5.61s/it]

Epoch 19, NLL: 5.605, Alpha: 1.988, Beta: 7.706, W: [1.2688398361206055, 4.798610210418701, 2.9298858642578125]
Epoch 20 of 200


 10%|█         | 20/200 [01:44<15:06,  5.04s/it]

Epoch 20, NLL: 5.602, Alpha: 1.985, Beta: 7.689, W: [1.2586723566055298, 4.799406051635742, 2.939257860183716]
Epoch 21 of 200


 10%|█         | 21/200 [01:48<14:52,  4.99s/it]

Epoch 21, NLL: 5.6, Alpha: 1.983, Beta: 7.67, W: [1.2494771480560303, 4.800533294677734, 2.9473259449005127]
Epoch 22 of 200


 11%|█         | 22/200 [01:52<13:37,  4.59s/it]

Epoch 22, NLL: 5.599, Alpha: 1.981, Beta: 7.649, W: [1.2411054372787476, 4.801955699920654, 2.954275131225586]
Epoch 23 of 200


 12%|█▏        | 23/200 [01:58<14:25,  4.89s/it]

Epoch 23, NLL: 5.598, Alpha: 1.98, Beta: 7.628, W: [1.2334336042404175, 4.803638935089111, 2.960263967514038]
Epoch 24 of 200


 12%|█▏        | 24/200 [02:05<16:22,  5.58s/it]

Epoch 24, NLL: 5.597, Alpha: 1.979, Beta: 7.605, W: [1.2263590097427368, 4.805549621582031, 2.965428113937378]
Epoch 25 of 200


 12%|█▎        | 25/200 [02:12<17:58,  6.16s/it]

Epoch 25, NLL: 5.596, Alpha: 1.978, Beta: 7.582, W: [1.2197959423065186, 4.807656764984131, 2.969883680343628]
Epoch 26 of 200


 13%|█▎        | 26/200 [02:17<16:48,  5.79s/it]

Epoch 26, NLL: 5.595, Alpha: 1.978, Beta: 7.558, W: [1.2136731147766113, 4.809932708740234, 2.9737303256988525]
Epoch 27 of 200


 14%|█▎        | 27/200 [02:23<16:13,  5.63s/it]

Epoch 27, NLL: 5.594, Alpha: 1.978, Beta: 7.535, W: [1.2079306840896606, 4.812352180480957, 2.97705340385437]
Epoch 28 of 200


 14%|█▍        | 28/200 [02:28<16:09,  5.64s/it]

Epoch 28, NLL: 5.594, Alpha: 1.977, Beta: 7.511, W: [1.2025184631347656, 4.814891815185547, 2.9799258708953857]
Epoch 29 of 200


 14%|█▍        | 29/200 [02:35<16:53,  5.92s/it]

Epoch 29, NLL: 5.594, Alpha: 1.977, Beta: 7.488, W: [1.1973942518234253, 4.817531585693359, 2.982410192489624]
Epoch 30 of 200


 15%|█▌        | 30/200 [02:41<16:45,  5.92s/it]

Epoch 30, NLL: 5.593, Alpha: 1.977, Beta: 7.465, W: [1.1925225257873535, 4.820253372192383, 2.984560251235962]
Epoch 31 of 200


 16%|█▌        | 31/200 [02:47<16:50,  5.98s/it]

Epoch 31, NLL: 5.593, Alpha: 1.978, Beta: 7.442, W: [1.187873125076294, 4.823040962219238, 2.986422061920166]
Epoch 32 of 200


 16%|█▌        | 32/200 [02:52<15:53,  5.68s/it]

Epoch 32, NLL: 5.592, Alpha: 1.978, Beta: 7.42, W: [1.1834205389022827, 4.825880527496338, 2.988035202026367]
Epoch 33 of 200


 16%|█▋        | 33/200 [02:58<16:28,  5.92s/it]

Epoch 33, NLL: 5.592, Alpha: 1.978, Beta: 7.397, W: [1.1791430711746216, 4.82875919342041, 2.989433765411377]
Epoch 34 of 200


 17%|█▋        | 34/200 [03:02<14:32,  5.25s/it]

Epoch 34, NLL: 5.592, Alpha: 1.978, Beta: 7.375, W: [1.1750221252441406, 4.831666946411133, 2.9906468391418457]
Epoch 35 of 200


 18%|█▊        | 35/200 [03:07<13:47,  5.02s/it]

Epoch 35, NLL: 5.592, Alpha: 1.978, Beta: 7.354, W: [1.1710419654846191, 4.834594249725342, 2.991699695587158]
Epoch 36 of 200


 18%|█▊        | 36/200 [03:11<12:54,  4.72s/it]

Epoch 36, NLL: 5.591, Alpha: 1.979, Beta: 7.333, W: [1.1671890020370483, 4.837532997131348, 2.9926140308380127]
Epoch 37 of 200


 18%|█▊        | 37/200 [03:15<12:20,  4.55s/it]

Epoch 37, NLL: 5.591, Alpha: 1.979, Beta: 7.312, W: [1.1634514331817627, 4.840476036071777, 2.993408441543579]
Epoch 38 of 200


 19%|█▉        | 38/200 [03:19<12:13,  4.53s/it]

Epoch 38, NLL: 5.591, Alpha: 1.979, Beta: 7.292, W: [1.15981924533844, 4.843417644500732, 2.9940991401672363]
Epoch 39 of 200


 20%|█▉        | 39/200 [03:25<13:27,  5.02s/it]

Epoch 39, NLL: 5.59, Alpha: 1.98, Beta: 7.272, W: [1.1562837362289429, 4.846352577209473, 2.994699716567993]
Epoch 40 of 200


 20%|██        | 40/200 [03:31<13:31,  5.07s/it]

Epoch 40, NLL: 5.59, Alpha: 1.98, Beta: 7.252, W: [1.1528372764587402, 4.849276542663574, 2.995222330093384]
Epoch 41 of 200


 20%|██        | 41/200 [03:35<12:33,  4.74s/it]

Epoch 41, NLL: 5.59, Alpha: 1.98, Beta: 7.233, W: [1.1494731903076172, 4.8521857261657715, 2.9956772327423096]
Epoch 42 of 200


 21%|██        | 42/200 [03:39<12:28,  4.74s/it]

Epoch 42, NLL: 5.59, Alpha: 1.98, Beta: 7.215, W: [1.1461857557296753, 4.855076789855957, 2.9960734844207764]
Epoch 43 of 200


 22%|██▏       | 43/200 [03:43<11:37,  4.44s/it]

Epoch 43, NLL: 5.589, Alpha: 1.981, Beta: 7.196, W: [1.1429699659347534, 4.85794734954834, 2.9964187145233154]
Epoch 44 of 200


 22%|██▏       | 44/200 [03:48<12:18,  4.74s/it]

Epoch 44, NLL: 5.589, Alpha: 1.981, Beta: 7.178, W: [1.1398212909698486, 4.860795021057129, 2.9967195987701416]
Epoch 45 of 200


 22%|██▎       | 45/200 [03:53<12:23,  4.79s/it]

Epoch 45, NLL: 5.589, Alpha: 1.981, Beta: 7.161, W: [1.1367357969284058, 4.86361837387085, 2.9969818592071533]
Epoch 46 of 200


 23%|██▎       | 46/200 [03:59<12:46,  4.98s/it]

Epoch 46, NLL: 5.589, Alpha: 1.982, Beta: 7.143, W: [1.1337099075317383, 4.866415500640869, 2.9972105026245117]
Epoch 47 of 200


 24%|██▎       | 47/200 [04:03<11:48,  4.63s/it]

Epoch 47, NLL: 5.588, Alpha: 1.982, Beta: 7.126, W: [1.1307406425476074, 4.869185447692871, 2.9974100589752197]
Epoch 48 of 200


 24%|██▍       | 48/200 [04:07<11:17,  4.46s/it]

Epoch 48, NLL: 5.588, Alpha: 1.982, Beta: 7.11, W: [1.127825140953064, 4.871926784515381, 2.997584342956543]
Epoch 49 of 200


 24%|██▍       | 49/200 [04:10<10:09,  4.03s/it]

Epoch 49, NLL: 5.588, Alpha: 1.982, Beta: 7.093, W: [1.1249608993530273, 4.87463903427124, 2.9977364540100098]
Epoch 50 of 200


 25%|██▌       | 50/200 [04:13<09:35,  3.83s/it]

Epoch 50, NLL: 5.588, Alpha: 1.983, Beta: 7.077, W: [1.122145652770996, 4.877321720123291, 2.9978692531585693]
Epoch 51 of 200


 26%|██▌       | 51/200 [04:18<10:07,  4.08s/it]

Epoch 51, NLL: 5.587, Alpha: 1.983, Beta: 7.062, W: [1.1193773746490479, 4.879973888397217, 2.9979851245880127]
Epoch 52 of 200


 26%|██▌       | 52/200 [04:21<09:37,  3.91s/it]

Epoch 52, NLL: 5.587, Alpha: 1.983, Beta: 7.046, W: [1.1166542768478394, 4.882595539093018, 2.998086452484131]
Epoch 53 of 200


 26%|██▋       | 53/200 [04:27<10:57,  4.47s/it]

Epoch 53, NLL: 5.587, Alpha: 1.984, Beta: 7.031, W: [1.1139745712280273, 4.885186672210693, 2.9981749057769775]
Epoch 54 of 200


 27%|██▋       | 54/200 [04:31<10:51,  4.46s/it]

Epoch 54, NLL: 5.587, Alpha: 1.984, Beta: 7.016, W: [1.1113368272781372, 4.887747287750244, 2.9982521533966064]
Epoch 55 of 200


 28%|██▊       | 55/200 [04:36<11:12,  4.64s/it]

Epoch 55, NLL: 5.587, Alpha: 1.984, Beta: 7.002, W: [1.1087396144866943, 4.890276908874512, 2.998319625854492]
Epoch 56 of 200


 28%|██▊       | 56/200 [04:39<09:47,  4.08s/it]

Epoch 56, NLL: 5.586, Alpha: 1.985, Beta: 6.988, W: [1.1061816215515137, 4.892776012420654, 2.9983785152435303]
Epoch 57 of 200


 28%|██▊       | 57/200 [04:45<11:15,  4.72s/it]

Epoch 57, NLL: 5.586, Alpha: 1.985, Beta: 6.974, W: [1.1036616563796997, 4.895244598388672, 2.998430013656616]
Epoch 58 of 200


 29%|██▉       | 58/200 [04:50<10:55,  4.62s/it]

Epoch 58, NLL: 5.586, Alpha: 1.985, Beta: 6.96, W: [1.101178526878357, 4.8976826667785645, 2.9984750747680664]
Epoch 59 of 200


 30%|██▉       | 59/200 [04:57<12:35,  5.36s/it]

Epoch 59, NLL: 5.586, Alpha: 1.985, Beta: 6.947, W: [1.098731279373169, 4.90009069442749, 2.998514413833618]
Epoch 60 of 200


 30%|███       | 60/200 [05:03<12:42,  5.44s/it]

Epoch 60, NLL: 5.585, Alpha: 1.986, Beta: 6.934, W: [1.0963189601898193, 4.902468681335449, 2.998548746109009]
Epoch 61 of 200


 30%|███       | 61/200 [05:07<11:35,  5.01s/it]

Epoch 61, NLL: 5.585, Alpha: 1.986, Beta: 6.921, W: [1.0939406156539917, 4.9048171043396, 2.9985787868499756]
Epoch 62 of 200


 31%|███       | 62/200 [05:13<12:38,  5.50s/it]

Epoch 62, NLL: 5.585, Alpha: 1.986, Beta: 6.908, W: [1.0915954113006592, 4.9071364402771, 2.9986047744750977]
Epoch 63 of 200


 32%|███▏      | 63/200 [05:17<11:22,  4.98s/it]

Epoch 63, NLL: 5.585, Alpha: 1.987, Beta: 6.895, W: [1.089282512664795, 4.909426689147949, 2.9986274242401123]
Epoch 64 of 200


 32%|███▏      | 64/200 [05:23<11:53,  5.25s/it]

Epoch 64, NLL: 5.585, Alpha: 1.987, Beta: 6.883, W: [1.087001085281372, 4.911688327789307, 2.9986472129821777]
Epoch 65 of 200


 32%|███▎      | 65/200 [05:27<10:54,  4.85s/it]

Epoch 65, NLL: 5.584, Alpha: 1.987, Beta: 6.871, W: [1.0847504138946533, 4.91392183303833, 2.998664379119873]
Epoch 66 of 200


 33%|███▎      | 66/200 [05:33<11:34,  5.18s/it]

Epoch 66, NLL: 5.584, Alpha: 1.987, Beta: 6.859, W: [1.082529902458191, 4.9161272048950195, 2.9986793994903564]
Epoch 67 of 200


 34%|███▎      | 67/200 [05:37<10:54,  4.92s/it]

Epoch 67, NLL: 5.584, Alpha: 1.988, Beta: 6.848, W: [1.0803388357162476, 4.918305397033691, 2.998692512512207]
Epoch 68 of 200


 34%|███▍      | 68/200 [05:42<11:09,  5.07s/it]

Epoch 68, NLL: 5.584, Alpha: 1.988, Beta: 6.836, W: [1.0781766176223755, 4.920456409454346, 2.998703718185425]
Epoch 69 of 200


 34%|███▍      | 69/200 [05:46<10:00,  4.58s/it]

Epoch 69, NLL: 5.584, Alpha: 1.988, Beta: 6.825, W: [1.076042652130127, 4.922580718994141, 2.998713493347168]
Epoch 70 of 200


 35%|███▌      | 70/200 [05:50<09:54,  4.57s/it]

Epoch 70, NLL: 5.583, Alpha: 1.989, Beta: 6.814, W: [1.0739362239837646, 4.924678802490234, 2.9987218379974365]
Epoch 71 of 200


 36%|███▌      | 71/200 [05:55<10:05,  4.70s/it]

Epoch 71, NLL: 5.583, Alpha: 1.989, Beta: 6.804, W: [1.07185697555542, 4.926750659942627, 2.9987289905548096]
Epoch 72 of 200


 36%|███▌      | 72/200 [06:00<10:11,  4.78s/it]

Epoch 72, NLL: 5.583, Alpha: 1.989, Beta: 6.793, W: [1.0698041915893555, 4.928797245025635, 2.998735189437866]
Epoch 73 of 200


 36%|███▋      | 73/200 [06:04<09:37,  4.54s/it]

Epoch 73, NLL: 5.583, Alpha: 1.99, Beta: 6.783, W: [1.0677775144577026, 4.930818557739258, 2.9987404346466064]
Epoch 74 of 200


 37%|███▋      | 74/200 [06:10<09:58,  4.75s/it]

Epoch 74, NLL: 5.583, Alpha: 1.99, Beta: 6.773, W: [1.0657763481140137, 4.932815074920654, 2.9987449645996094]
Epoch 75 of 200


 38%|███▊      | 75/200 [06:14<09:21,  4.49s/it]

Epoch 75, NLL: 5.582, Alpha: 1.99, Beta: 6.763, W: [1.0638002157211304, 4.934787273406982, 2.998748779296875]
Epoch 76 of 200


 38%|███▊      | 76/200 [06:18<09:03,  4.38s/it]

Epoch 76, NLL: 5.582, Alpha: 1.991, Beta: 6.753, W: [1.0618486404418945, 4.9367356300354, 2.9987518787384033]
Epoch 77 of 200


 38%|███▊      | 77/200 [06:21<08:10,  3.99s/it]

Epoch 77, NLL: 5.582, Alpha: 1.991, Beta: 6.743, W: [1.059921145439148, 4.938660621643066, 2.9987545013427734]
Epoch 78 of 200


 39%|███▉      | 78/200 [06:24<07:57,  3.91s/it]

Epoch 78, NLL: 5.582, Alpha: 1.991, Beta: 6.734, W: [1.058017373085022, 4.9405622482299805, 2.9987566471099854]
Epoch 79 of 200


 40%|███▉      | 79/200 [06:28<07:27,  3.70s/it]

Epoch 79, NLL: 5.582, Alpha: 1.992, Beta: 6.724, W: [1.0561368465423584, 4.942440986633301, 2.998758316040039]
Epoch 80 of 200


 40%|████      | 80/200 [06:31<07:16,  3.64s/it]

Epoch 80, NLL: 5.581, Alpha: 1.992, Beta: 6.715, W: [1.054279088973999, 4.9442973136901855, 2.9987597465515137]
Epoch 81 of 200


 40%|████      | 81/200 [06:35<07:05,  3.58s/it]

Epoch 81, NLL: 5.581, Alpha: 1.992, Beta: 6.706, W: [1.0524437427520752, 4.946131706237793, 2.998760938644409]
Epoch 82 of 200


 41%|████      | 82/200 [06:40<08:02,  4.09s/it]

Epoch 82, NLL: 5.581, Alpha: 1.993, Beta: 6.697, W: [1.0506304502487183, 4.947944164276123, 2.9987616539001465]
Epoch 83 of 200


 42%|████▏     | 83/200 [06:44<08:05,  4.15s/it]

Epoch 83, NLL: 5.581, Alpha: 1.993, Beta: 6.688, W: [1.0488388538360596, 4.949735164642334, 2.9987621307373047]
Epoch 84 of 200


 42%|████▏     | 84/200 [06:48<07:56,  4.10s/it]

Epoch 84, NLL: 5.581, Alpha: 1.993, Beta: 6.68, W: [1.047068476676941, 4.951505184173584, 2.998762369155884]
Epoch 85 of 200


 42%|████▎     | 85/200 [06:52<07:48,  4.08s/it]

Epoch 85, NLL: 5.581, Alpha: 1.994, Beta: 6.672, W: [1.0453190803527832, 4.953254699707031, 2.998762607574463]
Epoch 86 of 200


 43%|████▎     | 86/200 [06:56<07:44,  4.08s/it]

Epoch 86, NLL: 5.58, Alpha: 1.994, Beta: 6.663, W: [1.0435901880264282, 4.954983711242676, 2.998762607574463]
Epoch 87 of 200


 44%|████▎     | 87/200 [07:01<08:18,  4.41s/it]

Epoch 87, NLL: 5.58, Alpha: 1.994, Beta: 6.655, W: [1.0418814420700073, 4.956692695617676, 2.998762369155884]
Epoch 88 of 200


 44%|████▍     | 88/200 [07:07<08:36,  4.61s/it]

Epoch 88, NLL: 5.58, Alpha: 1.995, Beta: 6.647, W: [1.0401926040649414, 4.958381652832031, 2.9987621307373047]
Epoch 89 of 200


 44%|████▍     | 89/200 [07:12<09:16,  5.02s/it]

Epoch 89, NLL: 5.58, Alpha: 1.995, Beta: 6.639, W: [1.0385233163833618, 4.960051536560059, 2.9987616539001465]
Epoch 90 of 200


 45%|████▌     | 90/200 [07:17<08:42,  4.75s/it]

Epoch 90, NLL: 5.58, Alpha: 1.995, Beta: 6.632, W: [1.0368732213974, 4.9617018699646, 2.9987611770629883]
Epoch 91 of 200


 46%|████▌     | 91/200 [07:21<08:19,  4.58s/it]

Epoch 91, NLL: 5.58, Alpha: 1.996, Beta: 6.624, W: [1.0352420806884766, 4.963333606719971, 2.99876070022583]
Epoch 92 of 200


 46%|████▌     | 92/200 [07:25<08:11,  4.55s/it]

Epoch 92, NLL: 5.579, Alpha: 1.996, Beta: 6.617, W: [1.0336295366287231, 4.964946746826172, 2.9987599849700928]
Epoch 93 of 200


 46%|████▋     | 93/200 [07:30<08:18,  4.66s/it]

Epoch 93, NLL: 5.579, Alpha: 1.996, Beta: 6.609, W: [1.032035231590271, 4.966541767120361, 2.9987592697143555]
Epoch 94 of 200


 47%|████▋     | 94/200 [07:34<07:36,  4.31s/it]

Epoch 94, NLL: 5.579, Alpha: 1.997, Beta: 6.602, W: [1.030458927154541, 4.968118667602539, 2.998758554458618]
Epoch 95 of 200


 48%|████▊     | 95/200 [07:38<07:19,  4.18s/it]

Epoch 95, NLL: 5.579, Alpha: 1.997, Beta: 6.595, W: [1.028900384902954, 4.969677925109863, 2.998757839202881]
Epoch 96 of 200


 48%|████▊     | 96/200 [07:41<06:46,  3.91s/it]

Epoch 96, NLL: 5.579, Alpha: 1.997, Beta: 6.588, W: [1.0273592472076416, 4.971220016479492, 2.9987568855285645]
Epoch 97 of 200


 48%|████▊     | 97/200 [07:46<07:08,  4.16s/it]

Epoch 97, NLL: 5.579, Alpha: 1.998, Beta: 6.581, W: [1.0258352756500244, 4.972744941711426, 2.998755931854248]
Epoch 98 of 200


 49%|████▉     | 98/200 [07:49<06:36,  3.89s/it]

Epoch 98, NLL: 5.578, Alpha: 1.998, Beta: 6.575, W: [1.0243282318115234, 4.974253177642822, 2.9987549781799316]
Epoch 99 of 200


 50%|████▉     | 99/200 [07:52<06:17,  3.74s/it]

Epoch 99, NLL: 5.578, Alpha: 1.999, Beta: 6.568, W: [1.02283775806427, 4.975744724273682, 2.9987540245056152]
Epoch 100 of 200


 50%|█████     | 100/200 [07:56<06:08,  3.68s/it]

Epoch 100, NLL: 5.578, Alpha: 1.999, Beta: 6.561, W: [1.021363615989685, 4.977219581604004, 2.998753070831299]
Epoch 101 of 200


 50%|█████     | 101/200 [07:59<05:47,  3.51s/it]

Epoch 101, NLL: 5.578, Alpha: 1.999, Beta: 6.555, W: [1.0199055671691895, 4.9786787033081055, 2.9987521171569824]
Epoch 102 of 200


 51%|█████     | 102/200 [08:02<05:40,  3.47s/it]

Epoch 102, NLL: 5.578, Alpha: 2.0, Beta: 6.549, W: [1.018463373184204, 4.980121612548828, 2.998751163482666]
Epoch 103 of 200


 52%|█████▏    | 103/200 [08:06<05:54,  3.65s/it]

Epoch 103, NLL: 5.578, Alpha: 2.0, Beta: 6.543, W: [1.01703679561615, 4.981549263000488, 2.9987502098083496]
Epoch 104 of 200


 52%|█████▏    | 104/200 [08:10<05:48,  3.63s/it]

Epoch 104, NLL: 5.577, Alpha: 2.0, Beta: 6.537, W: [1.0156255960464478, 4.982961654663086, 2.998749256134033]
Epoch 105 of 200


 52%|█████▎    | 105/200 [08:14<05:43,  3.61s/it]

Epoch 105, NLL: 5.577, Alpha: 2.001, Beta: 6.531, W: [1.0142295360565186, 4.984358787536621, 2.998748302459717]
Epoch 106 of 200


 53%|█████▎    | 106/200 [08:17<05:29,  3.50s/it]

Epoch 106, NLL: 5.577, Alpha: 2.001, Beta: 6.525, W: [1.0128483772277832, 4.985741138458252, 2.9987473487854004]
Epoch 107 of 200


 54%|█████▎    | 107/200 [08:21<05:48,  3.75s/it]

Epoch 107, NLL: 5.577, Alpha: 2.002, Beta: 6.519, W: [1.0114818811416626, 4.9871087074279785, 2.998746395111084]
Epoch 108 of 200


 54%|█████▍    | 108/200 [08:24<05:32,  3.61s/it]

Epoch 108, NLL: 5.577, Alpha: 2.002, Beta: 6.513, W: [1.0101299285888672, 4.988461494445801, 2.9987454414367676]
Epoch 109 of 200


 55%|█████▍    | 109/200 [08:28<05:26,  3.59s/it]

Epoch 109, NLL: 5.577, Alpha: 2.002, Beta: 6.508, W: [1.0087921619415283, 4.989800453186035, 2.998744487762451]
Epoch 110 of 200


 55%|█████▌    | 110/200 [08:32<05:29,  3.66s/it]

Epoch 110, NLL: 5.577, Alpha: 2.003, Beta: 6.502, W: [1.0074684619903564, 4.991125106811523, 2.9987435340881348]
Epoch 111 of 200


 56%|█████▌    | 111/200 [08:36<05:39,  3.81s/it]

Epoch 111, NLL: 5.576, Alpha: 2.003, Beta: 6.497, W: [1.0061585903167725, 4.992435932159424, 2.9987425804138184]
Epoch 112 of 200


 56%|█████▌    | 112/200 [08:40<05:39,  3.86s/it]

Epoch 112, NLL: 5.576, Alpha: 2.004, Beta: 6.492, W: [1.0048623085021973, 4.9937334060668945, 2.998741626739502]
Epoch 113 of 200


 56%|█████▋    | 113/200 [08:45<06:06,  4.22s/it]

Epoch 113, NLL: 5.576, Alpha: 2.004, Beta: 6.486, W: [1.0035794973373413, 4.9950175285339355, 2.9987406730651855]
Epoch 114 of 200


 57%|█████▋    | 114/200 [08:49<05:53,  4.11s/it]

Epoch 114, NLL: 5.576, Alpha: 2.004, Beta: 6.481, W: [1.0023099184036255, 4.996288299560547, 2.998739719390869]
Epoch 115 of 200


 57%|█████▊    | 115/200 [08:52<05:36,  3.96s/it]

Epoch 115, NLL: 5.576, Alpha: 2.005, Beta: 6.476, W: [1.0010533332824707, 4.9975457191467285, 2.9987387657165527]
Epoch 116 of 200


 58%|█████▊    | 116/200 [08:56<05:15,  3.76s/it]

Epoch 116, NLL: 5.576, Alpha: 2.005, Beta: 6.471, W: [0.9998095631599426, 4.998790740966797, 2.9987378120422363]
Epoch 117 of 200


 58%|█████▊    | 117/200 [09:00<05:17,  3.82s/it]

Epoch 117, NLL: 5.576, Alpha: 2.006, Beta: 6.466, W: [0.9985784292221069, 5.000022888183594, 2.99873685836792]
Epoch 118 of 200


 59%|█████▉    | 118/200 [09:03<04:58,  3.64s/it]

Epoch 118, NLL: 5.575, Alpha: 2.006, Beta: 6.462, W: [0.9973598122596741, 5.001242637634277, 2.9987359046936035]
Epoch 119 of 200


 60%|█████▉    | 119/200 [09:06<04:44,  3.52s/it]

Epoch 119, NLL: 5.575, Alpha: 2.007, Beta: 6.457, W: [0.9961534738540649, 5.002449989318848, 2.998734951019287]
Epoch 120 of 200


 60%|██████    | 120/200 [09:10<04:58,  3.73s/it]

Epoch 120, NLL: 5.575, Alpha: 2.007, Beta: 6.452, W: [0.9949592351913452, 5.003645420074463, 2.9987339973449707]
Epoch 121 of 200


 60%|██████    | 121/200 [09:13<04:40,  3.55s/it]

Epoch 121, NLL: 5.575, Alpha: 2.007, Beta: 6.448, W: [0.9937769174575806, 5.004828929901123, 2.9987330436706543]
Epoch 122 of 200


 61%|██████    | 122/200 [09:17<04:29,  3.46s/it]

Epoch 122, NLL: 5.575, Alpha: 2.008, Beta: 6.443, W: [0.9926064014434814, 5.006000518798828, 2.998732089996338]
Epoch 123 of 200


 62%|██████▏   | 123/200 [09:20<04:27,  3.48s/it]

Epoch 123, NLL: 5.575, Alpha: 2.008, Beta: 6.439, W: [0.9914475083351135, 5.007160186767578, 2.9987311363220215]
Epoch 124 of 200


 62%|██████▏   | 124/200 [09:24<04:21,  3.44s/it]

Epoch 124, NLL: 5.575, Alpha: 2.009, Beta: 6.434, W: [0.9902999997138977, 5.008308410644531, 2.998730182647705]
Epoch 125 of 200


 62%|██████▎   | 125/200 [09:27<04:22,  3.50s/it]

Epoch 125, NLL: 5.574, Alpha: 2.009, Beta: 6.43, W: [0.9891638159751892, 5.009445667266846, 2.9987292289733887]
Epoch 126 of 200


 63%|██████▎   | 126/200 [09:31<04:25,  3.58s/it]

Epoch 126, NLL: 5.574, Alpha: 2.01, Beta: 6.426, W: [0.9880387187004089, 5.010571479797363, 2.9987282752990723]
Epoch 127 of 200


 64%|██████▎   | 127/200 [09:34<04:13,  3.48s/it]

Epoch 127, NLL: 5.574, Alpha: 2.01, Beta: 6.422, W: [0.9869245886802673, 5.011686325073242, 2.998727321624756]
Epoch 128 of 200


 64%|██████▍   | 128/200 [09:38<04:10,  3.48s/it]

Epoch 128, NLL: 5.574, Alpha: 2.011, Beta: 6.418, W: [0.9858213067054749, 5.012790679931641, 2.9987263679504395]
Epoch 129 of 200


 64%|██████▍   | 129/200 [09:41<04:04,  3.45s/it]

Epoch 129, NLL: 5.574, Alpha: 2.011, Beta: 6.414, W: [0.9847286939620972, 5.0138840675354, 2.998725414276123]
Epoch 130 of 200


 65%|██████▌   | 130/200 [09:44<04:00,  3.44s/it]

Epoch 130, NLL: 5.574, Alpha: 2.011, Beta: 6.41, W: [0.9836465716362, 5.01496696472168, 2.9987244606018066]
Epoch 131 of 200


 66%|██████▌   | 131/200 [09:49<04:15,  3.71s/it]

Epoch 131, NLL: 5.574, Alpha: 2.012, Beta: 6.406, W: [0.9825748205184937, 5.016039848327637, 2.9987235069274902]
Epoch 132 of 200


 66%|██████▌   | 132/200 [09:53<04:21,  3.84s/it]

Epoch 132, NLL: 5.574, Alpha: 2.012, Beta: 6.402, W: [0.9815133213996887, 5.017102241516113, 2.998722553253174]
Epoch 133 of 200


 66%|██████▋   | 133/200 [09:56<04:10,  3.74s/it]

Epoch 133, NLL: 5.573, Alpha: 2.013, Beta: 6.399, W: [0.9804618954658508, 5.018154621124268, 2.9987215995788574]
Epoch 134 of 200


 67%|██████▋   | 134/200 [10:00<04:07,  3.75s/it]

Epoch 134, NLL: 5.573, Alpha: 2.013, Beta: 6.395, W: [0.9794204235076904, 5.0191969871521, 2.998720645904541]
Epoch 135 of 200


 68%|██████▊   | 135/200 [10:04<03:59,  3.69s/it]

Epoch 135, NLL: 5.573, Alpha: 2.014, Beta: 6.391, W: [0.978388786315918, 5.020229339599609, 2.9987196922302246]
Epoch 136 of 200


 68%|██████▊   | 136/200 [10:08<03:57,  3.71s/it]

Epoch 136, NLL: 5.573, Alpha: 2.014, Beta: 6.388, W: [0.9773668050765991, 5.021252155303955, 2.998718738555908]
Epoch 137 of 200


 68%|██████▊   | 137/200 [10:11<03:50,  3.65s/it]

Epoch 137, NLL: 5.573, Alpha: 2.015, Beta: 6.384, W: [0.9763544201850891, 5.022265434265137, 2.998717784881592]
Epoch 138 of 200


 69%|██████▉   | 138/200 [10:15<03:59,  3.87s/it]

Epoch 138, NLL: 5.573, Alpha: 2.015, Beta: 6.381, W: [0.9753514528274536, 5.023269176483154, 2.9987170696258545]
Epoch 139 of 200


 70%|██████▉   | 139/200 [10:19<03:49,  3.76s/it]

Epoch 139, NLL: 5.573, Alpha: 2.016, Beta: 6.378, W: [0.9743577837944031, 5.024263858795166, 2.998716115951538]
Epoch 140 of 200


 70%|███████   | 140/200 [10:22<03:40,  3.67s/it]

Epoch 140, NLL: 5.573, Alpha: 2.016, Beta: 6.374, W: [0.973373293876648, 5.025249004364014, 2.998715400695801]
Epoch 141 of 200


 70%|███████   | 141/200 [10:26<03:34,  3.64s/it]

Epoch 141, NLL: 5.572, Alpha: 2.017, Beta: 6.371, W: [0.9723978638648987, 5.0262250900268555, 2.9987146854400635]
Epoch 142 of 200


 71%|███████   | 142/200 [10:29<03:25,  3.54s/it]

Epoch 142, NLL: 5.572, Alpha: 2.017, Beta: 6.368, W: [0.9714313745498657, 5.02719259262085, 2.998713970184326]
Epoch 143 of 200


 72%|███████▏  | 143/200 [10:33<03:28,  3.66s/it]

Epoch 143, NLL: 5.572, Alpha: 2.017, Beta: 6.365, W: [0.9704737067222595, 5.028151035308838, 2.998713254928589]
Epoch 144 of 200


 72%|███████▏  | 144/200 [10:37<03:33,  3.81s/it]

Epoch 144, NLL: 5.572, Alpha: 2.018, Beta: 6.362, W: [0.9695247411727905, 5.0291008949279785, 2.9987125396728516]
Epoch 145 of 200


 72%|███████▎  | 145/200 [10:41<03:26,  3.76s/it]

Epoch 145, NLL: 5.572, Alpha: 2.018, Beta: 6.359, W: [0.9685843586921692, 5.0300421714782715, 2.9987118244171143]
Epoch 146 of 200


 73%|███████▎  | 146/200 [10:45<03:23,  3.77s/it]

Epoch 146, NLL: 5.572, Alpha: 2.019, Beta: 6.356, W: [0.967652440071106, 5.030974864959717, 2.998711109161377]
Epoch 147 of 200


 74%|███████▎  | 147/200 [10:49<03:29,  3.95s/it]

Epoch 147, NLL: 5.572, Alpha: 2.019, Beta: 6.353, W: [0.966728925704956, 5.0318989753723145, 2.9987103939056396]
Epoch 148 of 200


 74%|███████▍  | 148/200 [10:53<03:29,  4.03s/it]

Epoch 148, NLL: 5.572, Alpha: 2.02, Beta: 6.35, W: [0.9658136963844299, 5.032814979553223, 2.9987096786499023]
Epoch 149 of 200


 74%|███████▍  | 149/200 [10:57<03:19,  3.90s/it]

Epoch 149, NLL: 5.572, Alpha: 2.02, Beta: 6.347, W: [0.9649065732955933, 5.033722877502441, 2.998708963394165]
Epoch 150 of 200


 75%|███████▌  | 150/200 [11:01<03:16,  3.92s/it]

Epoch 150, NLL: 5.571, Alpha: 2.021, Beta: 6.344, W: [0.9640074968338013, 5.034622669219971, 2.9987082481384277]
Epoch 151 of 200


 76%|███████▌  | 151/200 [11:05<03:13,  3.94s/it]

Epoch 151, NLL: 5.571, Alpha: 2.021, Beta: 6.342, W: [0.9631164073944092, 5.0355143547058105, 2.9987075328826904]
Epoch 152 of 200


 76%|███████▌  | 152/200 [11:09<03:09,  3.94s/it]

Epoch 152, NLL: 5.571, Alpha: 2.022, Beta: 6.339, W: [0.9622331261634827, 5.036398410797119, 2.998706817626953]
Epoch 153 of 200


 76%|███████▋  | 153/200 [11:14<03:21,  4.28s/it]

Epoch 153, NLL: 5.571, Alpha: 2.022, Beta: 6.336, W: [0.961357593536377, 5.0372748374938965, 2.998706102371216]
Epoch 154 of 200


 77%|███████▋  | 154/200 [11:17<03:06,  4.05s/it]

Epoch 154, NLL: 5.571, Alpha: 2.023, Beta: 6.334, W: [0.9604897499084473, 5.038143634796143, 2.9987053871154785]
Epoch 155 of 200


 78%|███████▊  | 155/200 [11:21<02:59,  3.99s/it]

Epoch 155, NLL: 5.571, Alpha: 2.023, Beta: 6.331, W: [0.9596294164657593, 5.039004802703857, 2.998704671859741]
Epoch 156 of 200


 78%|███████▊  | 156/200 [11:25<02:55,  3.99s/it]

Epoch 156, NLL: 5.571, Alpha: 2.024, Beta: 6.329, W: [0.9587765336036682, 5.039858341217041, 2.998703956604004]
Epoch 157 of 200


 78%|███████▊  | 157/200 [11:29<02:46,  3.86s/it]

Epoch 157, NLL: 5.571, Alpha: 2.024, Beta: 6.326, W: [0.9579309821128845, 5.040704727172852, 2.9987032413482666]
Epoch 158 of 200


 79%|███████▉  | 158/200 [11:34<02:51,  4.09s/it]

Epoch 158, NLL: 5.571, Alpha: 2.025, Beta: 6.324, W: [0.9570927023887634, 5.041543960571289, 2.9987025260925293]
Epoch 159 of 200


 80%|███████▉  | 159/200 [11:37<02:46,  4.05s/it]

Epoch 159, NLL: 5.57, Alpha: 2.025, Beta: 6.322, W: [0.9562616348266602, 5.042375564575195, 2.998701810836792]
Epoch 160 of 200


 80%|████████  | 160/200 [11:41<02:33,  3.83s/it]

Epoch 160, NLL: 5.57, Alpha: 2.026, Beta: 6.319, W: [0.9554376602172852, 5.043200492858887, 2.9987010955810547]
Epoch 161 of 200


 80%|████████  | 161/200 [11:45<02:36,  4.01s/it]

Epoch 161, NLL: 5.57, Alpha: 2.026, Beta: 6.317, W: [0.9546206593513489, 5.044018268585205, 2.9987003803253174]
Epoch 162 of 200


 81%|████████  | 162/200 [11:48<02:23,  3.78s/it]

Epoch 162, NLL: 5.57, Alpha: 2.027, Beta: 6.315, W: [0.9538105726242065, 5.04482889175415, 2.99869966506958]
Epoch 163 of 200


 82%|████████▏ | 163/200 [11:52<02:18,  3.75s/it]

Epoch 163, NLL: 5.57, Alpha: 2.027, Beta: 6.313, W: [0.9530072808265686, 5.045632839202881, 2.9986989498138428]
Epoch 164 of 200


 82%|████████▏ | 164/200 [11:56<02:17,  3.83s/it]

Epoch 164, NLL: 5.57, Alpha: 2.028, Beta: 6.311, W: [0.9522107243537903, 5.0464301109313965, 2.9986982345581055]
Epoch 165 of 200


 82%|████████▎ | 165/200 [11:59<02:06,  3.62s/it]

Epoch 165, NLL: 5.57, Alpha: 2.028, Beta: 6.309, W: [0.9514208436012268, 5.047220706939697, 2.998697519302368]
Epoch 166 of 200


 83%|████████▎ | 166/200 [12:03<02:00,  3.54s/it]

Epoch 166, NLL: 5.57, Alpha: 2.029, Beta: 6.306, W: [0.9506375193595886, 5.048004627227783, 2.998696804046631]
Epoch 167 of 200


 84%|████████▎ | 167/200 [12:07<02:03,  3.74s/it]

Epoch 167, NLL: 5.57, Alpha: 2.03, Beta: 6.304, W: [0.949860692024231, 5.0487823486328125, 2.9986960887908936]
Epoch 168 of 200


 84%|████████▍ | 168/200 [12:10<01:53,  3.53s/it]

Epoch 168, NLL: 5.569, Alpha: 2.03, Beta: 6.302, W: [0.949090301990509, 5.049553394317627, 2.9986953735351562]
Epoch 169 of 200


 84%|████████▍ | 169/200 [12:13<01:47,  3.47s/it]

Epoch 169, NLL: 5.569, Alpha: 2.031, Beta: 6.301, W: [0.9483262300491333, 5.050318241119385, 2.998694658279419]
Epoch 170 of 200


 85%|████████▌ | 170/200 [12:17<01:46,  3.54s/it]

Epoch 170, NLL: 5.569, Alpha: 2.031, Beta: 6.299, W: [0.947568416595459, 5.051076889038086, 2.9986939430236816]
Epoch 171 of 200


 86%|████████▌ | 171/200 [12:20<01:40,  3.45s/it]

Epoch 171, NLL: 5.569, Alpha: 2.032, Beta: 6.297, W: [0.9468168020248413, 5.0518293380737305, 2.9986932277679443]
Epoch 172 of 200


 86%|████████▌ | 172/200 [12:24<01:42,  3.65s/it]

Epoch 172, NLL: 5.569, Alpha: 2.032, Beta: 6.295, W: [0.9460712671279907, 5.052575588226318, 2.998692512512207]
Epoch 173 of 200


 86%|████████▋ | 173/200 [12:28<01:36,  3.56s/it]

Epoch 173, NLL: 5.569, Alpha: 2.033, Beta: 6.293, W: [0.9453317523002625, 5.05331563949585, 2.998692035675049]
Epoch 174 of 200


 87%|████████▋ | 174/200 [12:31<01:29,  3.43s/it]

Epoch 174, NLL: 5.569, Alpha: 2.033, Beta: 6.291, W: [0.9445981979370117, 5.054049968719482, 2.9986915588378906]
Epoch 175 of 200


 88%|████████▊ | 175/200 [12:34<01:26,  3.45s/it]

Epoch 175, NLL: 5.569, Alpha: 2.034, Beta: 6.29, W: [0.9438706040382385, 5.054778099060059, 2.9986910820007324]
Epoch 176 of 200


 88%|████████▊ | 176/200 [12:37<01:20,  3.37s/it]

Epoch 176, NLL: 5.569, Alpha: 2.034, Beta: 6.288, W: [0.9431487917900085, 5.055500507354736, 2.998690605163574]
Epoch 177 of 200


 88%|████████▊ | 177/200 [12:41<01:19,  3.45s/it]

Epoch 177, NLL: 5.569, Alpha: 2.035, Beta: 6.286, W: [0.9424327611923218, 5.056217193603516, 2.998689889907837]
Epoch 178 of 200


 89%|████████▉ | 178/200 [12:45<01:21,  3.69s/it]

Epoch 178, NLL: 5.568, Alpha: 2.035, Beta: 6.285, W: [0.9417223930358887, 5.0569281578063965, 2.9986894130706787]
Epoch 179 of 200


 90%|████████▉ | 179/200 [12:51<01:28,  4.20s/it]

Epoch 179, NLL: 5.568, Alpha: 2.036, Beta: 6.283, W: [0.9410176277160645, 5.057633399963379, 2.9986889362335205]
Epoch 180 of 200


 90%|█████████ | 180/200 [12:55<01:25,  4.26s/it]

Epoch 180, NLL: 5.568, Alpha: 2.036, Beta: 6.282, W: [0.9403184652328491, 5.058333396911621, 2.998688220977783]
Epoch 181 of 200


 90%|█████████ | 181/200 [13:01<01:29,  4.71s/it]

Epoch 181, NLL: 5.568, Alpha: 2.037, Beta: 6.28, W: [0.9396247863769531, 5.059027671813965, 2.998687744140625]
Epoch 182 of 200


 91%|█████████ | 182/200 [13:08<01:37,  5.41s/it]

Epoch 182, NLL: 5.568, Alpha: 2.038, Beta: 6.279, W: [0.9389364719390869, 5.059716701507568, 2.998687267303467]
Epoch 183 of 200


 92%|█████████▏| 183/200 [13:13<01:32,  5.41s/it]

Epoch 183, NLL: 5.568, Alpha: 2.038, Beta: 6.277, W: [0.9382535219192505, 5.060400009155273, 2.9986867904663086]
Epoch 184 of 200


 92%|█████████▏| 184/200 [13:17<01:19,  4.99s/it]

Epoch 184, NLL: 5.568, Alpha: 2.039, Beta: 6.276, W: [0.9375758767127991, 5.061078071594238, 2.9986863136291504]
Epoch 185 of 200


 92%|█████████▎| 185/200 [13:22<01:12,  4.82s/it]

Epoch 185, NLL: 5.568, Alpha: 2.039, Beta: 6.275, W: [0.9369034767150879, 5.061750888824463, 2.998685598373413]
Epoch 186 of 200


 93%|█████████▎| 186/200 [13:25<01:01,  4.38s/it]

Epoch 186, NLL: 5.568, Alpha: 2.04, Beta: 6.273, W: [0.9362362623214722, 5.062418460845947, 2.998685121536255]
Epoch 187 of 200


 94%|█████████▎| 187/200 [13:29<00:56,  4.33s/it]

Epoch 187, NLL: 5.568, Alpha: 2.04, Beta: 6.272, W: [0.9355741739273071, 5.06308126449585, 2.9986846446990967]
Epoch 188 of 200


 94%|█████████▍| 188/200 [13:34<00:53,  4.42s/it]

Epoch 188, NLL: 5.568, Alpha: 2.041, Beta: 6.271, W: [0.934917151927948, 5.063738822937012, 2.9986841678619385]
Epoch 189 of 200


 94%|█████████▍| 189/200 [13:39<00:49,  4.50s/it]

Epoch 189, NLL: 5.567, Alpha: 2.041, Beta: 6.269, W: [0.9342650771141052, 5.064391613006592, 2.9986836910247803]
Epoch 190 of 200


 95%|█████████▌| 190/200 [13:42<00:41,  4.16s/it]

Epoch 190, NLL: 5.567, Alpha: 2.042, Beta: 6.268, W: [0.9336179494857788, 5.065039157867432, 2.998683214187622]
Epoch 191 of 200


 96%|█████████▌| 191/200 [13:46<00:35,  3.99s/it]

Epoch 191, NLL: 5.567, Alpha: 2.043, Beta: 6.267, W: [0.932975709438324, 5.0656819343566895, 2.9986824989318848]
Epoch 192 of 200


 96%|█████████▌| 192/200 [13:50<00:32,  4.09s/it]

Epoch 192, NLL: 5.567, Alpha: 2.043, Beta: 6.266, W: [0.932338297367096, 5.066319942474365, 2.9986820220947266]
Epoch 193 of 200


 96%|█████████▋| 193/200 [13:54<00:28,  4.04s/it]

Epoch 193, NLL: 5.567, Alpha: 2.044, Beta: 6.265, W: [0.93170565366745, 5.066953182220459, 2.9986815452575684]
Epoch 194 of 200


 97%|█████████▋| 194/200 [14:01<00:30,  5.02s/it]

Epoch 194, NLL: 5.567, Alpha: 2.044, Beta: 6.264, W: [0.9310777187347412, 5.067581653594971, 2.99868106842041]
Epoch 195 of 200


 98%|█████████▊| 195/200 [14:09<00:29,  5.83s/it]

Epoch 195, NLL: 5.567, Alpha: 2.045, Beta: 6.263, W: [0.930454432964325, 5.0682053565979, 2.998680591583252]
Epoch 196 of 200


 98%|█████████▊| 196/200 [14:16<00:25,  6.34s/it]

Epoch 196, NLL: 5.567, Alpha: 2.045, Beta: 6.261, W: [0.9298357963562012, 5.068824768066406, 2.9986801147460938]
Epoch 197 of 200


 98%|█████████▊| 197/200 [14:20<00:16,  5.64s/it]

Epoch 197, NLL: 5.567, Alpha: 2.046, Beta: 6.26, W: [0.9292216897010803, 5.06943941116333, 2.9986796379089355]
Epoch 198 of 200


 99%|█████████▉| 198/200 [14:25<00:10,  5.30s/it]

Epoch 198, NLL: 5.567, Alpha: 2.046, Beta: 6.259, W: [0.9286121129989624, 5.07004976272583, 2.9986791610717773]
Epoch 199 of 200


100%|█████████▉| 199/200 [14:30<00:05,  5.30s/it]

Epoch 199, NLL: 5.567, Alpha: 2.047, Beta: 6.259, W: [0.9280070066452026, 5.070655345916748, 2.998678684234619]
Epoch 200 of 200


100%|██████████| 200/200 [14:39<00:00,  4.40s/it]

Epoch 200, NLL: 5.566, Alpha: 2.048, Beta: 6.258, W: [0.9274062514305115, 5.071256637573242, 2.998678207397461]
Training Took 879.89 Seconds.
Trained Params: alpha: 2.047592904495089, beta: 6.257558317257251, W: [0.9274062514305115, 5.071256637573242, 2.998678207397461]





(2.047592904495089,
 6.257558317257251,
 tensor([0.9274, 5.0713, 2.9987], requires_grad=True))

In [35]:
6.257558317257251/(6.257558317257251 + 2.047592904495089)

0.7534550726623545