In [4]:
from abc import ABC, abstractmethod

import os
from tqdm import tqdm
import math 
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 

from scipy import stats
from scipy.special import digamma

import torch
from torch.distributions import Beta, kl_divergence
from torch.nn.functional import log_softmax
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, TensorDataset

from bayes_opt import BayesianOptimization

In [5]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7fd6901d43f0>

In [6]:
### these params control the generation scheme (2/10 generation)
rho        = 0.8    # polarization
pop_size   = 50000  # num individuals
epsilon    = 0.05   # expected prop of speech consisting of neutral words
pi         = 0.5    # pi == 0.5 => beta mixture symmetrical (choose beta1 with prob pi = 0.5)
speech_len = 50     # words per speech
lex_size   = 3      # lexicon size 

In [7]:
def generate(rho=rho, N=pop_size, epsilon=epsilon, pi=pi, speech_len=speech_len, lex_size=lex_size, verbose='low'):
    """
    Uses 2/10 generation scheme to generate N samples.

    Returns:
        (X, y), (a, b, rho, epsilon, u, lex_size)
        X.size() == [N, S] is a vector of sequences 
        y.size() == [N] is a vector of political parties
        u.shape  == (N,) is a vector of individual stances
        rho is true polarization
        epsilon is expected prop of neutral words
        a, b are true alpha/beta for beta mixture model
    """
    start = time.time()
    if verbose in ['mid', 'high']:
        print(f'Beginning Data Generation...')
        print(f'=' * 20)
        
    ### get beta mixture model params 
    sigma = 0.175 * (rho ** 2) - 0.3625 * rho + 0.1875
    a     = rho * ((rho * (1 - rho)) / sigma - 1)
    b     = (1 - rho) * ((rho * (1 - rho)) / sigma - 1)

    if verbose in ['mid', 'high']:
        print(f'Lex Size: {lex_size}, Rho: {rho}, N: {pop_size}')
        print(f'epsilon: {epsilon}, pi: {pi}, S: {speech_len}\n')
        print(f"True Alpha: {a}")
        print(f"True Beta: {b}\n")

    mean = a / (a + b)
    var  = a * b / ((a + b)**2 * (a + b + 1))

    if abs(mean - rho) > 10e-15:
        print(f'Mean: {mean}')
        print(f"Rho: {rho}")
        raise AssertionError(f"Mean of BMM params should be rho")

    if abs(var - sigma) > 10e-15:
        print(f'Var: {var}')
        print(f'Sigma: {sigma}')
        raise AssertionError(f"Var of BMM params should be sigma")

    ### u ~ pi Beta(a, b) + (1 - pi) Beta(b, a)
    weights = [pi, 1-pi]
    mixture_samples = np.random.choice([0, 1], size=N, p=weights)

    u = 2 * np.where(mixture_samples == 1, stats.beta.rvs(a, b, size=N), stats.beta.rvs(b, a, size=N)) - 1

    if u.shape != (N,):
        raise AssertionError(f"u.shape should be (N,)")

    if verbose == 'high':
        print(f'mixture samples: {mixture_samples[:3]}')
        print(f'u samples: {u[:3]}')

    ### y = 1(u >= 0)
    y = (u >= 0).astype(int)

    if verbose == 'high':
        print(f'y samples: {y[:3]}\n')

    ### phi is a prob matrix that is a function of u, epsilon
    left_prob    = (1 - (u+1)/2) * (1 - epsilon) / lex_size
    right_prob   = (u+1)/2 * (1 - epsilon) / lex_size
    neutral_prob = np.repeat(epsilon, N) / lex_size
    phi          = np.array([left_prob] * lex_size + [right_prob] * lex_size + [neutral_prob] * lex_size).T
    
    if verbose == 'high':
        for i in range(3):
            print(f'P(L) = {phi[i,0]}, P(R): {phi[i, lex_size]}, P(N): {phi[i,-1]}')

    if phi.shape != (N, 3 * lex_size):
        raise AssertionError(f'phi.shape should be (N, V) == (N, 3L)')

    if abs(sum(phi[0]) - 1) > 10e-5:
        raise AssertionError(f'rows of phi should sum to 1')
    
    X = np.array([stats.multinomial.rvs(n=speech_len, p=phi[i, :]) for i in range(N)])

    if verbose == 'high':
        print(f'X samples (counts):\n {X[:3]}\n')
    
    if X.shape != (N, 3 * lex_size):
        raise AssertionError(f'X.shape should be (N, V) == (N, 3L)')

    if X[:5].sum() != speech_len * 5:
        raise AssertionError(f'rows of phi should sum to 1')

    expanded_rows = []
    for row in X:
        expanded_row = np.repeat(np.arange(len(row)), row)
        np.random.shuffle(expanded_row)
        expanded_rows.append(expanded_row)

    X = np.array(expanded_rows, dtype=np.int64)

    if X.shape != (N, speech_len):
        raise AssertionError(f'X.shape should be (N, S)')
    
    if verbose == 'high':
        print(f'X samples (sequences):\n {X[:3]}\n')

    X = torch.from_numpy(X).to(torch.float32)
    y = torch.from_numpy(y).to(torch.float32)

    known   = (X, y)
    unknown = (a, b, rho, epsilon, u, lex_size)

    if verbose in ['mid', 'high']:
        print('=' * 20)
        print(f'Generation Time: {round(time.time() - start, 3)} seconds for {N} samples.')

    return known, unknown

known, unknown = generate(verbose='high')
X, y = known
a, b, rho, epsilon, u, lex_size = unknown

Beginning Data Generation...
Lex Size: 3, Rho: 0.8, N: 50000
epsilon: 0.05, pi: 0.5, S: 50

True Alpha: 12.673684210526261
True Beta: 3.1684210526315644

mixture samples: [0 1 0]
u samples: [-0.71319006  0.57688516 -0.46523368]
y samples: [0 1 0]

P(L) = 0.2712550923688634, P(R): 0.04541157429780326, P(N): 0.016666666666666666
P(L) = 0.06699318292320035, P(R): 0.24967348374346632, P(N): 0.016666666666666666
P(L) = 0.23199533330535926, P(R): 0.08467133336130739, P(N): 0.016666666666666666
X samples (counts):
 [[16 13 14  2  1  2  0  0  2]
 [ 2  3  3 12 13 15  1  0  1]
 [15 15  8  3  4  5  0  0  0]]

X samples (sequences):
 [[0 2 2 2 2 0 0 2 0 1 0 4 0 1 1 3 2 1 2 0 0 5 0 1 0 0 1 2 1 1 0 8 2 2 0 1
  1 3 1 1 2 0 8 2 1 2 2 5 0 0]
 [3 4 5 4 5 5 3 3 1 6 5 0 3 8 4 4 3 5 0 3 5 5 3 4 5 4 1 3 5 5 5 3 2 2 2 5
  4 5 4 3 4 1 4 3 5 4 3 4 5 4]
 [1 5 1 0 1 1 1 2 1 0 5 4 4 0 0 1 0 0 3 0 1 0 0 0 5 1 0 5 3 5 0 2 1 1 2 1
  0 2 3 0 1 4 2 0 2 2 1 1 4 2]]

Generation Time: 2.07 seconds for 50000 samples.


In [5]:
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics import accuracy_score


In [7]:
X_np = X.numpy()
y_np = y.numpy()

# Initialize the classifier
gb_clf = GradientBoostingClassifier()

# Setup cross-validation
kf = StratifiedKFold(n_splits=5)

# Get predicted probabilities for each sample across all folds using predict_proba
predicted_probs = cross_val_predict(gb_clf, X_np, y_np, cv=kf, method='predict_proba')

# Calculate the mean predicted probabilities (p_hats)
p_hats = predicted_probs[:, 1]  # Get probabilities for class 1

# Compute mean accuracy
predicted_class = predicted_probs[:, 1] > 0.5  # Convert probabilities to class predictions (threshold 0.5)
mean_accuracy = accuracy_score(y_np, predicted_class)

# Print results
print(f"Mean Accuracy: {mean_accuracy}")
print(f"p_hats (mean predicted probabilities for class 1): {p_hats}")

Mean Accuracy: 0.98782
p_hats (mean predicted probabilities for class 1): [6.22585932e-04 9.96754482e-01 4.88675910e-03 ... 9.97225659e-01
 1.01027479e-01 9.97233834e-01]


In [5]:
grid_size  = 200 ### discretize integral into grid_size linearly spaced pts
batch_size = 100 ### num samples processed simultaneously
num_epochs = 200 ### number epochs
lr         = 0.1 ### learning rate

def compute_nll(x_batch, y_batch, log_alpha, log_beta, W):
    """
    Parameters:
        (x_batch, y_batch) a batch of B (x,y) samples 
        (log_alpha, log_beta, W) trainable parameters

    Returns:
        A scaler quantity s.t. calling .backward() will compute
        grad_{theta} (L(theta))
    """
    raise NotImplementedError

def train(X, y, log_alpha=None, log_beta=None, W=None, num_epochs=num_epochs, batch_size=batch_size, grid_size=grid_size, lr=lr, verbose=True, threshold=0.00001, patience=5, speech_len=speech_len, lex_size=lex_size):
    if verbose:
        print(f'Beginning Inference...')
        print(f'=' * 20)
        print(f'Hyperparameters:')
        print(f'Dataset Length: {len(X)}')
        print(f'Number Epochs: {num_epochs}')
        print(f'Batch Size: {batch_size}')
        print(f'Grid Size: {grid_size}')
        print(f"Learning Rate: {lr}")
        print(f"Improvement Threshold: {threshold}")
        print(f"Patience: {patience}")
        print(f'Speech Len: {speech_len}')
        print('=' * 20)

    ### not strictly necessary 
    assert len(X) % batch_size == 0
    
    start = time.time()

    dataset    = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    X_batch, y_batch = list(dataloader)[0]

    if list(X_batch.size()) != [batch_size, speech_len]:
        raise AssertionError(f"X_ex.size() should be [batch_size, S], got {X_batch.size()}")

    if list(y_batch.size()) != [batch_size]:
        raise AssertionError(f"y_ex.size() should be [batch_size], got {y_batch.size()}")

    if log_alpha is None:
        log_alpha = torch.normal(0, math.sqrt(3), size=(1,), requires_grad=True)

    if log_beta is None:
        log_beta  = torch.normal(0, math.sqrt(3), size=(1,), requires_grad=True)

    if W is None:
        W = torch.normal(0, 1, size=(3 * lex_size,), requires_grad=True)

    print(f'Initial Parameters:')
    print(f'Alpha: {torch.exp(log_alpha).item()}')
    print(f"Beta: {torch.exp(log_beta).item()}")
    print(f"W: {W.data.tolist()}")

    optimizer = SGD([log_alpha, log_beta, W], lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    # Early stopping logic
    best_nll = float('inf')
    epochs_without_improvement = 0

    for epoch in tqdm(range(num_epochs)):
        if verbose:
            print(f'Epoch {epoch+1} of {num_epochs}')
        ### only step after each epoch
        optimizer.zero_grad() 
        nll = torch.tensor(0.)
        
        for x_batch, y_batch in dataloader:
            nll += compute_nll(x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size)
        
        nll = nll / len(dataloader.dataset)
        nll.backward()
        print('=' * 20)
        print(f'Log Alpha Grad: {log_alpha.grad}')
        print(f'Log Beta Grad: {log_beta.grad}')
        print(f'W Grad: {W.grad}')
        print('=' * 20)
        optimizer.step()
        scheduler.step()

        current_nll = nll.item()
        if current_nll < best_nll - threshold:
            best_nll = current_nll
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if verbose:
            print(f'Epoch {epoch+1}, NLL: {round(nll.item(), 3)}, Alpha: {round(np.exp(log_alpha.item()),3)}, Beta: {round(np.exp(log_beta.item()),3)}, W: {W.tolist()}')
            print('=' * 20)

        if epochs_without_improvement >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    final_alpha = np.exp(log_alpha.item())
    final_beta  = np.exp(log_beta.item())

    if verbose:
        print(f"Training Took {round(time.time() - start, 2)} Seconds.")
        
    print(f"Trained Params: alpha: {final_alpha}, beta: {final_beta}, W: {W.tolist()}")

    return best_nll, final_alpha, final_beta, W

In [6]:
### numerically stable factorial
factorial = lambda x : torch.exp(torch.lgamma(x+1))

assert factorial(torch.tensor(3.)) == torch.tensor(6.)
print(f"3.0001! = {factorial(torch.tensor(3.001)).item()}")

### notice this is infinity, so we can't call factorial on negative stuff
print(f"-1! = {factorial(torch.tensor(-1.)).item()}")

def compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
        u_mat   (size [batch_size, grid_size])
        (log_alpha, log_beta, W) trainable params 
    
    Returns:
        torch.tensor (size [batch_size, grid_size]), where the (i,j)th element 
        is the log joint probability (log p(u, x, y; theta)). Here u is the (i,j)th
        element of u_mat, x is the ith row of x_batch, y is the ith element of y_batch

    """
    assert list(u_mat.size())   == [batch_size, grid_size] 
    assert list(x_batch.size()) == [batch_size, speech_len]
    assert list(y_batch.size()) == [batch_size]

    if not torch.all((u_mat >= 0).float() == y_batch.unsqueeze(1)).item():
        raise ValueError(f"u_mat is incompatible with y_batch.")

    ### log prior: log p(u) = log(1/4 Beta((u+1)/2; a, b) + 1/4 Beta((u+1)/2; b, a))
    ###                     = log(1/4) + log(Beta((u+1)/2; a, b) + Beta((u+1)/2; b, a))
    alpha = torch.exp(log_alpha)
    beta  = torch.exp(log_beta)
        
    beta_dist_ab = Beta(alpha, beta)
    beta_dist_ba = Beta(beta, alpha)

    ### see 3-6 testing.ipynb test1 to see how log_prob broadcasts
    log_beta_prob_ab = beta_dist_ab.log_prob((u_mat + 1) / 2)
    log_beta_prob_ba = beta_dist_ba.log_prob((u_mat + 1) / 2)

    prior = torch.log(torch.exp(log_beta_prob_ab) + torch.exp(log_beta_prob_ba)) + torch.log(torch.tensor(0.25))

    assert list(prior.size()) == [batch_size, grid_size]
  
    ### log likelihood: log p(x, y | u) = dot(x, log Softmax(Wu)) + log(S!/(x1! x2! x3!)) 

    ### step 1: compute Wu (see 3-6 testing.ipynb test2 to see this logic)
    W_expanded = W.unsqueeze(0).unsqueeze(0)
    u_expanded = u_mat.unsqueeze(2)
    Wu = torch.matmul(u_expanded, W_expanded) 

    assert list(W_expanded.size()) == [1,1,3 * lex_size]
    assert list(u_expanded.size()) == [batch_size, grid_size, 1]
    assert list(Wu.size()) == [batch_size, grid_size, 3 * lex_size]

    ### step 2: dot(x, log Softmax(Wu)) (see 3-6 testing.ipynb test3 to see this logic)
    #print(x_batch[:2]) # B x S 
    ones = torch.ones_like(x_batch, dtype=torch.float)
    #print(ones[:2]) # B x S 
    # Assuming x_batch needs to be adjusted by -1 for 0-indexing
    x_batch_trans = torch.zeros(batch_size, 3 * lex_size, dtype=torch.float)

    ones = torch.ones_like(x_batch, dtype=torch.float)
    x_batch_trans.scatter_add_(dim=1, index=x_batch.long(), src=ones)
    likelihood = (x_batch_trans.unsqueeze(1) * log_softmax(Wu, dim=2)).sum(dim=2)

    ### ensure not taking negative factorials
    assert torch.all(x_batch >= 0)

    assert list(likelihood.size()) == [batch_size, grid_size]
    return prior + likelihood


def compute_nll(x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
        log_alpha, log_beta, W (trainable params)
        grid_size (int) number of lin spaced points to discretize integral into
    
    Returns:
        quantity s.t. .backward() would be the gradient of the nll
    """
    ### see 3-6 testing.ipynb test5 to see this logic
    u_mat = torch.empty(batch_size, grid_size)
    u_mat[y_batch == 1] = torch.linspace(1/(grid_size+1), 1-1/(grid_size+1), grid_size).repeat((y_batch == 1).sum(), 1)
    u_mat[y_batch == 0] = torch.linspace(-1+1/(grid_size+1), -1/(grid_size+1), grid_size).repeat((y_batch == 0).sum(), 1)

    assert list(u_mat.size()) == [batch_size, grid_size]

    log_joint = compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size)
    assert list(log_joint.size()) == [batch_size, grid_size]

    ### see 3-6 testing.ipynb test6 to see this logic
    max_log_prob = torch.max(log_joint, dim=1, keepdim=True)[0]
    joint_probs = torch.exp(log_joint - max_log_prob)
    posterior = joint_probs / joint_probs.sum(dim=1, keepdim=True) 
    # detach posterior (see 3-6 testing.ipynb test7 to verify this works)
    weighted_log_joint = posterior.detach() * log_joint
    nll = -weighted_log_joint.sum()

    return nll

3.0001! = 6.007541656494141
-1! = inf


In [7]:
rho = 0.8
N = 50000
epsilon = 0.
lex_size = 5
speech_len = 50

known, unknown = generate(rho=rho, N=N, verbose='high', epsilon=epsilon, lex_size=lex_size, speech_len=speech_len)
X, y = known 
a, b, rho, epsilon, u, lex_size = unknown

Beginning Data Generation...
Lex Size: 5, Rho: 0.8, N: 50000
epsilon: 0.0, pi: 0.5, S: 50

True Alpha: 12.673684210526261
True Beta: 3.1684210526315644

mixture samples: [0 0 0]
u samples: [-0.28357936 -0.52389393 -0.25908543]
y samples: [0 0 0]

P(L) = 0.12835793622206623, P(R): 0.07164206377793378, P(N): 0.0
P(L) = 0.15238939281614144, P(R): 0.047610607183858546, P(N): 0.0
P(L) = 0.12590854271088564, P(R): 0.07409145728911436, P(N): 0.0
X samples (counts):
 [[ 6  9  7  6  3  4  5  6  2  2  0  0  0  0  0]
 [ 5 10  7  8 11  2  2  1  1  3  0  0  0  0  0]
 [ 5  7  7  4  5  6  3  5  3  5  0  0  0  0  0]]

X samples (sequences):
 [[2 3 1 0 1 1 5 6 5 3 7 3 2 2 3 2 3 7 7 6 4 1 1 0 5 0 5 6 4 6 2 9 0 6 1 8
  8 7 2 0 3 2 1 4 1 7 9 1 7 0]
 [6 2 1 1 9 1 4 4 0 1 4 3 1 0 5 4 2 1 6 1 2 4 4 2 4 4 7 0 4 4 2 3 8 2 3 3
  1 1 5 4 2 0 3 3 9 0 9 3 3 1]
 [8 5 5 2 9 0 3 0 5 7 9 4 7 7 5 8 2 7 4 3 6 2 1 2 3 6 9 3 4 4 5 2 4 0 1 8
  7 9 0 1 1 5 1 2 2 6 1 9 0 1]]

Generation Time: 2.066 seconds for 50000 samples.

Testing if magnitude of starting params matter when initial params are bigger

In [11]:
best_nll, final_alpha, final_beta, W = train(X, y, patience=100, grid_size=200, num_epochs=200,
                                            lr=.1, speech_len=speech_len, lex_size=lex_size, 
                                            log_alpha= torch.tensor(np.log(12.67 + 0.7), requires_grad=True),
                                            log_beta= torch.tensor(np.log(3.168 + 0.7), requires_grad=True))
                                    

Beginning Inference...
Hyperparameters:
Dataset Length: 50000
Number Epochs: 200
Batch Size: 100
Grid Size: 200
Learning Rate: 0.1
Improvement Threshold: 1e-05
Patience: 100
Speech Len: 50
Initial Parameters:
Alpha: 13.369999999999996
Beta: 3.8680000000000003
W: [-0.8029409646987915, 0.2365746796131134, 0.2856927514076233, 0.6898148655891418, -0.6330540180206299, 0.8794752955436707, -0.6841781735420227, 0.45329079031944275, 0.29115796089172363, -0.8317165970802307, -0.5525082945823669, 0.6354773044586182, -0.39681580662727356, -0.6570598483085632, -1.6427524089813232]


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 of 200


  0%|          | 1/200 [00:11<38:39, 11.65s/it]

Log Alpha Grad: 2.3922102321122476
Log Beta Grad: -1.7699673850672748
W Grad: tensor([ 0.4093,  0.8110,  0.8409,  1.0852,  0.4642, -0.7275, -1.4446, -0.9908,
        -1.0771, -1.4860,  0.3603,  1.0042,  0.4246,  0.3196,  0.0069])
Epoch 1, NLL: 136.245, Alpha: 10.525, Beta: 4.617, W: [-0.8438687920570374, 0.1554744839668274, 0.2016005665063858, 0.5812910199165344, -0.6794707179069519, 0.9522245526313782, -0.5397146940231323, 0.5523750185966492, 0.39886826276779175, -0.6831145882606506, -0.5885363221168518, 0.5350614786148071, -0.43927285075187683, -0.6890159249305725, -1.6434441804885864]
Epoch 2 of 200


  1%|          | 2/200 [00:23<37:51, 11.47s/it]

Log Alpha Grad: 0.8172785324908814
Log Beta Grad: -0.641462905858895
W Grad: tensor([ 0.3699,  0.7474,  0.7746,  0.9961,  0.4222, -0.6753, -1.4216, -0.9434,
        -1.0331, -1.4651,  0.3891,  0.9776,  0.4489,  0.3511,  0.0616])
Epoch 2, NLL: 134.474, Alpha: 9.699, Beta: 4.923, W: [-0.880861222743988, 0.08073736727237701, 0.12414225190877914, 0.4816761314868927, -0.7216887474060059, 1.0197585821151733, -0.3975584805011749, 0.6467196941375732, 0.5021738409996033, -0.5366095900535583, -0.6274418234825134, 0.4373030364513397, -0.4841669499874115, -0.724124550819397, -1.6496024131774902]
Epoch 3 of 200


  2%|▏         | 3/200 [00:33<35:34, 10.84s/it]

Log Alpha Grad: 0.14750296999448578
Log Beta Grad: 0.08783507727487323
W Grad: tensor([ 0.3624,  0.7701,  0.7985,  1.0312,  0.4195, -0.6096, -1.5263, -0.9315,
        -1.0405, -1.5813,  0.4034,  1.0197,  0.4671,  0.3630,  0.0542])
Epoch 3, NLL: 133.333, Alpha: 9.557, Beta: 4.88, W: [-0.9171000123023987, 0.0037264185957610607, 0.044296327978372574, 0.37855765223503113, -0.7636391520500183, 1.0807145833969116, -0.24492555856704712, 0.7398684620857239, 0.6062217354774475, -0.3784807324409485, -0.6677847504615784, 0.33532819151878357, -0.5308747291564941, -0.7604256272315979, -1.655025839805603]
Epoch 4 of 200


  2%|▏         | 4/200 [00:41<32:48, 10.04s/it]

Log Alpha Grad: -0.14158019593291685
Log Beta Grad: 0.529571169448739
W Grad: tensor([ 0.3622,  0.8123,  0.8425,  1.0913,  0.4263, -0.5019, -1.6526, -0.8944,
        -1.0295, -1.7245,  0.3938,  1.0546,  0.4634,  0.3495,  0.0069])
Epoch 4, NLL: 132.174, Alpha: 9.694, Beta: 4.628, W: [-0.9533156752586365, -0.07750232517719269, -0.03995591029524803, 0.26942935585975647, -0.8062700629234314, 1.130906343460083, -0.07966738939285278, 0.8293095231056213, 0.7091734409332275, -0.20603378117084503, -0.7071682214736938, 0.22986473143100739, -0.5772165656089783, -0.7953760623931885, -1.655720591545105]
Epoch 5 of 200


  2%|▎         | 5/200 [00:51<32:02,  9.86s/it]

Log Alpha Grad: -0.238180978715729
Log Beta Grad: 0.7618874017693374
W Grad: tensor([ 0.3684,  0.8522,  0.8836,  1.1386,  0.4394, -0.3642, -1.7365, -0.8126,
        -0.9711, -1.8277,  0.3485,  1.0380,  0.4234,  0.3005, -0.0804])
Epoch 5, NLL: 130.9, Alpha: 9.927, Beta: 4.289, W: [-0.9901584386825562, -0.16272282600402832, -0.12831445038318634, 0.15557068586349487, -0.8502122759819031, 1.1673293113708496, 0.0939847007393837, 0.9105679392814636, 0.8062854409217834, -0.02325969934463501, -0.7420158982276917, 0.1260630339384079, -0.6195525527000427, -0.8254227638244629, -1.6476854085922241]
Epoch 6 of 200


  3%|▎         | 6/200 [01:02<33:24, 10.33s/it]

Log Alpha Grad: -0.24315453186662964
Log Beta Grad: 0.85297740860503
W Grad: tensor([ 0.3825,  0.8854,  0.9168,  1.1647,  0.4595, -0.2328, -1.7461, -0.6976,
        -0.8684, -1.8559,  0.2685,  0.9622,  0.3472,  0.2176, -0.2035])
Epoch 6, NLL: 129.574, Alpha: 10.172, Beta: 3.938, W: [-1.0284069776535034, -0.2512620687484741, -0.21999037265777588, 0.0391019806265831, -0.8961595892906189, 1.190605878829956, 0.2685965299606323, 0.980330765247345, 0.8931275606155396, 0.16232989728450775, -0.7688685059547424, 0.02984069660305977, -0.6542713046073914, -0.8471825122833252, -1.6273349523544312]
Epoch 7 of 200


  4%|▎         | 7/200 [01:12<32:52, 10.22s/it]

Log Alpha Grad: -0.21722641595503592
Log Beta Grad: 0.8607163375877113
W Grad: tensor([ 0.4013,  0.9128,  0.9432,  1.1750,  0.4836, -0.1368, -1.6764, -0.5738,
        -0.7420, -1.8009,  0.1645,  0.8443,  0.2457,  0.1113, -0.3519])
Epoch 7, NLL: 128.285, Alpha: 10.395, Beta: 3.613, W: [-1.068532109260559, -0.3425390124320984, -0.31431329250335693, -0.07840154320001602, -0.9445224404335022, 1.2042858600616455, 0.4362335801124573, 1.0377143621444702, 0.9673259258270264, 0.342417448759079, -0.7853213548660278, -0.054587479680776596, -0.6788404583930969, -0.8583154678344727, -1.5921469926834106]
Epoch 8 of 200


  4%|▍         | 8/200 [01:22<32:19, 10.10s/it]

Log Alpha Grad: -0.1890148821835595
Log Beta Grad: 0.8240840064110296
W Grad: tensor([ 0.4204,  0.9363,  0.9654,  1.1780,  0.5080, -0.0838, -1.5401, -0.4617,
        -0.6142, -1.6729,  0.0481,  0.7054,  0.1310, -0.0069, -0.5130])
Epoch 8, NLL: 127.097, Alpha: 10.593, Beta: 3.327, W: [-1.1105704307556152, -0.4361717998981476, -0.4108574390411377, -0.19619767367839813, -0.9953193068504333, 1.2126654386520386, 0.5902392268180847, 1.08388090133667, 1.0287436246871948, 0.5097063183784485, -0.7901307940483093, -0.12512843310832977, -0.6919366121292114, -0.8576207756996155, -1.5408453941345215]
Epoch 9 of 200


  4%|▍         | 9/200 [01:30<29:42,  9.33s/it]

Log Alpha Grad: -0.17092943596464885
Log Beta Grad: 0.7692770879577343
W Grad: tensor([ 0.4364,  0.9581,  0.9858,  1.1803,  0.5294, -0.0644, -1.3593, -0.3708,
        -0.5000, -1.4926, -0.0720,  0.5610,  0.0121, -0.1285, -0.6755])
Epoch 9, NLL: 126.036, Alpha: 10.776, Beta: 3.081, W: [-1.154206395149231, -0.5319806933403015, -0.5094401240348816, -0.31422871351242065, -1.048256278038025, 1.2191065549850464, 0.7261669635772705, 1.1209566593170166, 1.0787404775619507, 0.658967137336731, -0.7829334735870361, -0.18123123049736023, -0.693146288394928, -0.8447672128677368, -1.4732904434204102]
Epoch 10 of 200


  5%|▌         | 10/200 [01:40<30:04,  9.50s/it]

Log Alpha Grad: -0.166933237025135
Log Beta Grad: 0.7135703182236883
W Grad: tensor([ 0.4456,  0.9782,  1.0048,  1.1850,  0.5445, -0.0633, -1.1580, -0.3007,
        -0.4045, -1.2842, -0.1889,  0.4205, -0.1042, -0.2466, -0.8281])
Epoch 10, NLL: 125.097, Alpha: 10.957, Beta: 2.869, W: [-1.198770523071289, -0.6297968029975891, -0.6099176406860352, -0.4327293336391449, -1.1027045249938965, 1.225433111190796, 0.841968297958374, 1.1510236263275146, 1.119193434715271, 0.7873894572257996, -0.7640390396118164, -0.22327697277069092, -0.6827273964881897, -0.820105791091919, -1.3904831409454346]
Epoch 11 of 200


  6%|▌         | 11/200 [01:50<30:55,  9.82s/it]

Log Alpha Grad: -0.1743711426672713
Log Beta Grad: 0.6659019166696298
W Grad: tensor([ 0.4440,  0.9926,  1.0184,  1.1890,  0.5486, -0.0670, -0.9566, -0.2458,
        -0.3261, -1.0698, -0.2957,  0.2898, -0.2112, -0.3538, -0.9565])
Epoch 11, NLL: 124.249, Alpha: 11.15, Beta: 2.684, W: [-1.2431671619415283, -0.7290549874305725, -0.7117571234703064, -0.5516313910484314, -1.157569169998169, 1.2321327924728394, 0.9376273155212402, 1.1755986213684082, 1.1518045663833618, 0.8943722248077393, -0.7344661355018616, -0.2522587180137634, -0.6616092920303345, -0.7847287654876709, -1.2948358058929443]
Epoch 12 of 200


  6%|▌         | 12/200 [02:01<31:23, 10.02s/it]

Log Alpha Grad: -0.18539557361095524
Log Beta Grad: 0.6269314367958408
W Grad: tensor([ 0.4277,  0.9934,  1.0186,  1.1836,  0.5370, -0.0679, -0.7694, -0.2003,
        -0.2609, -0.8664, -0.3837,  0.1744, -0.3009, -0.4409, -1.0445])
Epoch 12, NLL: 123.465, Alpha: 11.359, Beta: 2.521, W: [-1.2859410047531128, -0.8283953070640564, -0.8136177659034729, -0.6699957847595215, -1.2112724781036377, 1.2389240264892578, 1.0145704746246338, 1.1956294775009155, 1.1778907775878906, 0.9810141921043396, -0.6960958242416382, -0.2697029709815979, -0.6315222978591919, -0.7406377792358398, -1.1903905868530273]
Epoch 13 of 200


  6%|▋         | 13/200 [02:09<30:11,  9.69s/it]

Log Alpha Grad: -0.19025498897746485
Log Beta Grad: 0.5910629433982553
W Grad: tensor([ 0.3969,  0.9723,  0.9968,  1.1584,  0.5078, -0.0641, -0.6054, -0.1614,
        -0.2062, -0.6854, -0.4446,  0.0793, -0.3654, -0.4994, -1.0796])
Epoch 13, NLL: 122.731, Alpha: 11.577, Beta: 2.376, W: [-1.325630784034729, -0.9256241321563721, -0.9132987260818481, -0.7858381271362305, -1.2620521783828735, 1.2453383207321167, 1.0751093626022339, 1.2117723226547241, 1.1985145807266235, 1.04954993724823, -0.6516401767730713, -0.27763086557388306, -0.5949777960777283, -0.690700888633728, -1.0824337005615234]
Epoch 14 of 200


  7%|▋         | 14/200 [02:18<29:13,  9.43s/it]

Log Alpha Grad: -0.18174254825727426
Log Beta Grad: 0.5512721083441146
W Grad: tensor([ 0.3563,  0.9256,  0.9491,  1.1070,  0.4644, -0.0575, -0.4688, -0.1288,
        -0.1616, -0.5328, -0.4741,  0.0075, -0.4007, -0.5250, -1.0604])
Epoch 14, NLL: 122.054, Alpha: 11.789, Beta: 2.249, W: [-1.3612606525421143, -1.0181812047958374, -1.0082062482833862, -0.896540105342865, -1.3084900379180908, 1.2510923147201538, 1.1219898462295532, 1.2246572971343994, 1.214674711227417, 1.1028318405151367, -0.6042313575744629, -0.2783780097961426, -0.5549079179763794, -0.6382027864456177, -0.9763903617858887]
Epoch 15 of 200


  8%|▊         | 15/200 [02:27<28:14,  9.16s/it]

Log Alpha Grad: -0.15799840827185474
Log Beta Grad: 0.5038933286987157
W Grad: tensor([ 0.3136,  0.8566,  0.8786,  1.0306,  0.4142, -0.0509, -0.3603, -0.1032,
        -0.1270, -0.4104, -0.4744, -0.0411, -0.4082, -0.5203, -0.9977])
Epoch 15, NLL: 121.454, Alpha: 11.977, Beta: 2.138, W: [-1.3926182985305786, -1.1038390398025513, -1.0960649251937866, -0.9995973706245422, -1.3499122858047485, 1.2561863660812378, 1.1580191850662231, 1.2349820137023926, 1.2273733615875244, 1.143875002861023, -0.556792676448822, -0.27426865696907043, -0.5140926837921143, -0.5861738920211792, -0.8766186833381653]
Epoch 16 of 200


  8%|▊         | 16/200 [02:35<26:46,  8.73s/it]

Log Alpha Grad: -0.12197774144199253
Log Beta Grad: 0.4499110603567332
W Grad: tensor([ 0.2753,  0.7736,  0.7936,  0.9364,  0.3650, -0.0463, -0.2776, -0.0848,
        -0.1019, -0.3163, -0.4524, -0.0697, -0.3939, -0.4929, -0.9081])
Epoch 16, NLL: 120.947, Alpha: 12.124, Beta: 2.044, W: [-1.4201449155807495, -1.1811978816986084, -1.1754273176193237, -1.093234896659851, -1.386410117149353, 1.2608178853988647, 1.1857768297195435, 1.243459701538086, 1.2375624179840088, 1.1755003929138184, -0.5115519762039185, -0.26730185747146606, -0.47469910979270935, -0.5368869304656982, -0.7858048677444458]
Epoch 17 of 200


  8%|▊         | 17/200 [02:42<25:04,  8.22s/it]

Log Alpha Grad: -0.07888898214447454
Log Beta Grad: 0.3930576269869269
W Grad: tensor([ 0.2446,  0.6859,  0.7036,  0.8340,  0.3216, -0.0443, -0.2167, -0.0727,
        -0.0850, -0.2462, -0.4166, -0.0831, -0.3658, -0.4517, -0.8074])
Epoch 17, NLL: 120.536, Alpha: 12.22, Beta: 1.965, W: [-1.4446011781692505, -1.2497844696044922, -1.2457857131958008, -1.1766339540481567, -1.4185693264007568, 1.2652513980865479, 1.2074425220489502, 1.2507301568984985, 1.2460626363754272, 1.200120449066162, -0.46989062428474426, -0.258988618850708, -0.43812018632888794, -0.4917145371437073, -0.7050613164901733]
Epoch 18 of 200


  9%|▉         | 18/200 [02:50<25:23,  8.37s/it]

Log Alpha Grad: -0.03384366469906593
Log Beta Grad: 0.3373624810335541
W Grad: tensor([ 0.2216,  0.6008,  0.6160,  0.7319,  0.2858, -0.0448, -0.1730, -0.0658,
        -0.0746, -0.1955, -0.3743, -0.0864, -0.3306, -0.4045, -0.7068])
Epoch 18, NLL: 120.214, Alpha: 12.261, Beta: 1.9, W: [-1.4667614698410034, -1.309863567352295, -1.3073893785476685, -1.2498278617858887, -1.4471511840820312, 1.2697268724441528, 1.224744439125061, 1.2573051452636719, 1.2535200119018555, 1.2196660041809082, -0.43245652318000793, -0.2503499388694763, -0.4050615727901459, -0.45126280188560486, -0.6343808174133301]
Epoch 19 of 200


 10%|▉         | 19/200 [02:59<25:57,  8.60s/it]

Log Alpha Grad: 0.009306747500644448
Log Beta Grad: 0.2858021132417882
W Grad: tensor([ 0.2049,  0.5229,  0.5357,  0.6362,  0.2573, -0.0469, -0.1425, -0.0625,
        -0.0689, -0.1594, -0.3308, -0.0834, -0.2933, -0.3566, -0.6127])
Epoch 19, NLL: 119.968, Alpha: 12.25, Beta: 1.847, W: [-1.4872503280639648, -1.3621560335159302, -1.3609626293182373, -1.3134485483169556, -1.4728771448135376, 1.274421215057373, 1.2389898300170898, 1.263558030128479, 1.2604056596755981, 1.2356107234954834, -0.3993791937828064, -0.2420068383216858, -0.3757316768169403, -0.41560474038124084, -0.5731108784675598]
Epoch 20 of 200


 10%|█         | 20/200 [03:09<26:23,  8.79s/it]

Log Alpha Grad: 0.048172087895179294
Log Beta Grad: 0.24002863189497542
W Grad: tensor([ 0.1925,  0.4544,  0.4649,  0.5502,  0.2346, -0.0502, -0.1215, -0.0618,
        -0.0664, -0.1343, -0.2891, -0.0771, -0.2570, -0.3111, -0.5281])
Epoch 20, NLL: 119.782, Alpha: 12.191, Beta: 1.803, W: [-1.5065046548843384, -1.4075921773910522, -1.4074532985687256, -1.3684639930725098, -1.4963364601135254, 1.279443621635437, 1.2511348724365234, 1.2697408199310303, 1.2670421600341797, 1.2490423917770386, -0.3704739212989807, -0.23429358005523682, -0.35002899169921875, -0.38449662923812866, -0.5203028917312622]
Epoch 21 of 200


 10%|█         | 21/200 [03:18<26:48,  8.99s/it]

Log Alpha Grad: 0.08157327553744524
Log Beta Grad: 0.20062151848820445
W Grad: tensor([ 0.1830,  0.3955,  0.4040,  0.4751,  0.2164, -0.0540, -0.1073, -0.0627,
        -0.0660, -0.1170, -0.2508, -0.0694, -0.2234, -0.2696, -0.4538])
Epoch 21, NLL: 119.642, Alpha: 12.092, Beta: 1.767, W: [-1.5248011350631714, -1.447142243385315, -1.4478576183319092, -1.4159729480743408, -1.5179729461669922, 1.2848472595214844, 1.261861801147461, 1.2760101556777954, 1.2736371755599976, 1.2607446908950806, -0.34539321064949036, -0.22735683619976044, -0.3276844024658203, -0.3575356602668762, -0.4749268591403961]
Epoch 22 of 200


 11%|█         | 22/200 [03:27<26:43,  9.01s/it]

Log Alpha Grad: 0.10918561069654428
Log Beta Grad: 0.1674573504971963
W Grad: tensor([ 0.1750,  0.3458,  0.3526,  0.4110,  0.2014, -0.0580, -0.0979, -0.0644,
        -0.0668, -0.1053, -0.2167, -0.0613, -0.1933, -0.2328, -0.3894])
Epoch 22, NLL: 119.537, Alpha: 11.961, Beta: 1.738, W: [-1.5423027276992798, -1.4817190170288086, -1.4831178188323975, -1.4570692777633667, -1.5381083488464355, 1.2906438112258911, 1.271647334098816, 1.2824537754058838, 1.2803146839141846, 1.2712715864181519, -0.32372477650642395, -0.22122950851917267, -0.3083556294441223, -0.33426064252853394, -0.43598631024360657]
Epoch 23 of 200


 12%|█▏        | 23/200 [03:40<29:53, 10.13s/it]

Log Alpha Grad: 0.13121876965141757
Log Beta Grad: 0.14002163296599732
W Grad: tensor([ 0.1680,  0.3041,  0.3096,  0.3569,  0.1887, -0.0617, -0.0917, -0.0666,
        -0.0683, -0.0974, -0.1868, -0.0535, -0.1667, -0.2005, -0.3341])
Epoch 23, NLL: 119.457, Alpha: 11.805, Beta: 1.714, W: [-1.5591009855270386, -1.5121327638626099, -1.5140728950500488, -1.4927639961242676, -1.556973934173584, 1.2968175411224365, 1.2808187007904053, 1.2891114950180054, 1.2871403694152832, 1.2810084819793701, -0.30504900217056274, -0.21588076651096344, -0.29168352484703064, -0.3142097592353821, -0.40257206559181213]
Epoch 24 of 200


 12%|█▏        | 24/200 [03:52<31:04, 10.59s/it]

Log Alpha Grad: 0.14818000545713345
Log Beta Grad: 0.11762216034615101
W Grad: tensor([ 0.1615,  0.2694,  0.2737,  0.3119,  0.1776, -0.0652, -0.0878, -0.0688,
        -0.0700, -0.0921, -0.1608, -0.0463, -0.1436, -0.1726, -0.2869])
Epoch 24, NLL: 119.395, Alpha: 11.631, Beta: 1.693, W: [-1.5752465724945068, -1.5390775203704834, -1.5414435863494873, -1.5239492654800415, -1.574737787246704, 1.3033356666564941, 1.2895961999893188, 1.2959916591644287, 1.294141173362732, 1.2902191877365112, -0.2889694273471832, -0.21124768257141113, -0.2773231565952301, -0.2969508469104767, -0.3738812208175659]
Epoch 25 of 200


 12%|█▎        | 25/200 [04:06<33:51, 11.61s/it]

Log Alpha Grad: 0.16071249049707487
Log Beta Grad: 0.0995231383945903
W Grad: tensor([ 0.1552,  0.2406,  0.2439,  0.2744,  0.1679, -0.0682, -0.0853, -0.0709,
        -0.0718, -0.0886, -0.1384, -0.0399, -0.1236, -0.1486, -0.2466])
Epoch 25, NLL: 119.347, Alpha: 11.446, Beta: 1.677, W: [-1.59076988697052, -1.5631349086761475, -1.565834403038025, -1.5513890981674194, -1.5915261507034302, 1.3101563453674316, 1.2981247901916504, 1.3030825853347778, 1.301318883895874, 1.2990810871124268, -0.2751268446445465, -0.20725437998771667, -0.2649587392807007, -0.2820945978164673, -0.34921789169311523]
Epoch 26 of 200


 13%|█▎        | 26/200 [04:17<33:09, 11.43s/it]

Log Alpha Grad: 0.16949099803779516
Log Beta Grad: 0.08501982040312697
W Grad: tensor([ 0.1492,  0.2165,  0.2191,  0.2434,  0.1591, -0.0708, -0.0837, -0.0728,
        -0.0734, -0.0863, -0.1192, -0.0343, -0.1065, -0.1280, -0.2123])
Epoch 26, NLL: 119.309, Alpha: 11.254, Beta: 1.663, W: [-1.6056925058364868, -1.5847853422164917, -1.5877450704574585, -1.5757241249084473, -1.607437252998352, 1.3172334432601929, 1.3064967393875122, 1.3103610277175903, 1.3086596727371216, 1.307710886001587, -0.2632039189338684, -0.2038230001926422, -0.25430944561958313, -0.2692984640598297, -0.32798585295677185]
Epoch 27 of 200


 14%|█▎        | 27/200 [04:27<31:49, 11.04s/it]

Log Alpha Grad: 0.17516427330978074
Log Beta Grad: 0.07347357669395578
W Grad: tensor([ 0.1434,  0.1964,  0.1984,  0.2176,  0.1511, -0.0729, -0.0827, -0.0744,
        -0.0748, -0.0847, -0.1028, -0.0294, -0.0918, -0.1103, -0.1831])
Epoch 27, NLL: 119.278, Alpha: 11.058, Beta: 1.65, W: [-1.6200333833694458, -1.6044224500656128, -1.6075853109359741, -1.5974843502044678, -1.6225510835647583, 1.3245208263397217, 1.3147683143615723, 1.3177974224090576, 1.31614089012146, 1.3161829710006714, -0.2529246509075165, -0.20087948441505432, -0.24512997269630432, -0.25826534628868103, -0.3096776306629181]
Epoch 28 of 200


 14%|█▍        | 28/200 [04:46<38:26, 13.41s/it]

Log Alpha Grad: 0.1783197004639847
Log Beta Grad: 0.0643273525002049
W Grad: tensor([ 0.1378,  0.1795,  0.1810,  0.1962,  0.1438, -0.0745, -0.0820, -0.0756,
        -0.0759, -0.0836, -0.0887, -0.0252, -0.0792, -0.0953, -0.1581])
Epoch 28, NLL: 119.252, Alpha: 10.863, Beta: 1.64, W: [-1.6338120698928833, -1.6223676204681396, -1.625688910484314, -1.6171040534973145, -1.6369349956512451, 1.3319737911224365, 1.3229713439941406, 1.3253599405288696, 1.3237357139587402, 1.3245426416397095, -0.24405145645141602, -0.19835640490055084, -0.23720847070217133, -0.24873989820480347, -0.2938627600669861]
Epoch 29 of 200


 14%|█▍        | 29/200 [04:57<36:09, 12.69s/it]

Log Alpha Grad: 0.17947169317444964
Log Beta Grad: 0.05710809676581601
W Grad: tensor([ 0.1324,  0.1651,  0.1664,  0.1783,  0.1371, -0.0758, -0.0815, -0.0766,
        -0.0768, -0.0827, -0.0767, -0.0216, -0.0685, -0.0824, -0.1369])
Epoch 29, NLL: 119.231, Alpha: 10.669, Beta: 1.63, W: [-1.6470495462417603, -1.6388823986053467, -1.6423275470733643, -1.634937047958374, -1.6506472826004028, 1.3395508527755737, 1.3311212062835693, 1.3330165147781372, 1.3314155340194702, 1.3328156471252441, -0.23638103902339935, -0.19619391858577728, -0.23036330938339233, -0.2405039370059967, -0.2801768481731415]
Epoch 30 of 200


 15%|█▌        | 30/200 [05:07<33:42, 11.90s/it]

Log Alpha Grad: 0.17905813798537296
Log Beta Grad: 0.05142036428816849
W Grad: tensor([ 0.1272,  0.1530,  0.1539,  0.1633,  0.1309, -0.0766, -0.0810, -0.0772,
        -0.0774, -0.0820, -0.0664, -0.0185, -0.0592, -0.0713, -0.1187])
Epoch 30, NLL: 119.212, Alpha: 10.48, Beta: 1.622, W: [-1.6597683429718018, -1.6541799306869507, -1.6577222347259521, -1.651271104812622, -1.6637394428253174, 1.3472141027450562, 1.3392229080200195, 1.3407365083694458, 1.339152455329895, 1.3410146236419678, -0.22974006831645966, -0.1943398118019104, -0.22443942725658417, -0.23337160050868988, -0.26831158995628357]
Epoch 31 of 200


 16%|█▌        | 31/200 [05:16<31:33, 11.20s/it]

Log Alpha Grad: 0.17744400531068866
Log Beta Grad: 0.046938860442494706
W Grad: tensor([ 0.1222,  0.1425,  0.1433,  0.1507,  0.1252, -0.0772, -0.0805, -0.0775,
        -0.0777, -0.0813, -0.0576, -0.0159, -0.0513, -0.0619, -0.1031])
Epoch 31, NLL: 119.196, Alpha: 10.296, Beta: 1.614, W: [-1.6719919443130493, -1.6684341430664062, -1.6720532178878784, -1.6663397550582886, -1.6762577295303345, 1.354929804801941, 1.347274661064148, 1.3484914302825928, 1.3469198942184448, 1.349143624305725, -0.223981112241745, -0.1927490383386612, -0.21930478513240814, -0.22718487679958344, -0.2580059766769409]
Epoch 32 of 200


 16%|█▌        | 32/200 [05:28<31:38, 11.30s/it]

Log Alpha Grad: 0.17492797859744688
Log Beta Grad: 0.04339944977611854
W Grad: tensor([ 0.1175,  0.1335,  0.1341,  0.1399,  0.1199, -0.0774, -0.0800, -0.0776,
        -0.0777, -0.0806, -0.0500, -0.0137, -0.0446, -0.0538, -0.0897])
Epoch 32, NLL: 119.182, Alpha: 10.117, Beta: 1.607, W: [-1.6837445497512817, -1.681787371635437, -1.6854674816131592, -1.6803325414657593, -1.6882439851760864, 1.3626681566238403, 1.3552709817886353, 1.3562557697296143, 1.3546935319900513, 1.3572015762329102, -0.21897882223129272, -0.1913829743862152, -0.21484707295894623, -0.22180946171283722, -0.2490389049053192]
Epoch 33 of 200


 16%|█▋        | 33/200 [05:38<30:08, 10.83s/it]

Log Alpha Grad: 0.17175136424643217
Log Beta Grad: 0.0405892729573844
W Grad: tensor([ 0.1131,  0.1257,  0.1262,  0.1307,  0.1149, -0.0774, -0.0793, -0.0775,
        -0.0776, -0.0798, -0.0435, -0.0117, -0.0388, -0.0468, -0.0782])
Epoch 33, NLL: 119.169, Alpha: 9.945, Beta: 1.601, W: [-1.6950504779815674, -1.6943566799163818, -1.698085904121399, -1.6934040784835815, -1.6997361183166504, 1.3704031705856323, 1.3632042407989502, 1.364006757736206, 1.3624515533447266, 1.3651843070983887, -0.21462664008140564, -0.19020867347717285, -0.21097077429294586, -0.21713121235370636, -0.24122276902198792]
Epoch 34 of 200


 17%|█▋        | 34/200 [05:48<29:27, 10.65s/it]

Log Alpha Grad: 0.16810648864547886
Log Beta Grad: 0.038340232839914994
W Grad: tensor([ 0.1088,  0.1188,  0.1192,  0.1228,  0.1103, -0.0771, -0.0786, -0.0772,
        -0.0772, -0.0790, -0.0379, -0.0101, -0.0338, -0.0408, -0.0682])
Epoch 34, NLL: 119.157, Alpha: 9.779, Beta: 1.595, W: [-1.7059338092803955, -1.7062387466430664, -1.7100077867507935, -1.705680251121521, -1.710768699645996, 1.378112554550171, 1.3710659742355347, 1.3717243671417236, 1.370174765586853, 1.3730859756469727, -0.21083390712738037, -0.1891980916261673, -0.20759464800357819, -0.21305304765701294, -0.23439815640449524]
Epoch 35 of 200


 18%|█▊        | 35/200 [05:59<29:53, 10.87s/it]

Log Alpha Grad: 0.16414486337417541
Log Beta Grad: 0.03651722008509272
W Grad: tensor([ 0.1048,  0.1127,  0.1131,  0.1158,  0.1060, -0.0766, -0.0778, -0.0767,
        -0.0767, -0.0781, -0.0331, -0.0087, -0.0295, -0.0356, -0.0597])
Epoch 35, NLL: 119.147, Alpha: 9.62, Beta: 1.589, W: [-1.716417908668518, -1.7175136804580688, -1.7213155031204224, -1.7172645330429077, -1.721373200416565, 1.385777235031128, 1.3788478374481201, 1.3793915510177612, 1.377846598625183, 1.3809001445770264, -0.2075234204530716, -0.18832740187644958, -0.20464949309825897, -0.2094922810792923, -0.22842937707901]
Epoch 36 of 200


 18%|█▊        | 36/200 [06:10<29:41, 10.86s/it]

Log Alpha Grad: 0.15998456552669704
Log Beta Grad: 0.03501531618232695
W Grad: tensor([ 0.1011,  0.1073,  0.1076,  0.1098,  0.1021, -0.0760, -0.0769, -0.0760,
        -0.0761, -0.0772, -0.0289, -0.0075, -0.0257, -0.0311, -0.0523])
Epoch 36, NLL: 119.137, Alpha: 9.467, Beta: 1.583, W: [-1.7265251874923706, -1.7282484769821167, -1.732077717781067, -1.7282421588897705, -1.7315785884857178, 1.3933812379837036, 1.3865418434143066, 1.386993646621704, 1.3854526281356812, 1.3886204957962036, -0.20462937653064728, -0.18757636845111847, -0.20207631587982178, -0.20637841522693634, -0.2232007086277008]
Epoch 37 of 200


 18%|█▊        | 37/200 [06:21<29:18, 10.79s/it]

Log Alpha Grad: 0.15571606476165714
Log Beta Grad: 0.03375424983960653
W Grad: tensor([ 0.0975,  0.1025,  0.1027,  0.1044,  0.0983, -0.0753, -0.0760, -0.0752,
        -0.0753, -0.0762, -0.0253, -0.0065, -0.0225, -0.0273, -0.0459])
Epoch 37, NLL: 119.128, Alpha: 9.321, Beta: 1.578, W: [-1.7362771034240723, -1.738499402999878, -1.74235200881958, -1.7386837005615234, -1.7414110898971558, 1.4009112119674683, 1.3941408395767212, 1.3945183753967285, 1.392980933189392, 1.3962410688400269, -0.20209555327892303, -0.18692782521247864, -0.1998247653245926, -0.20365120470523834, -0.21861326694488525]
Epoch 38 of 200


 19%|█▉        | 38/200 [06:30<27:55, 10.34s/it]

Log Alpha Grad: 0.15140822142167362
Log Beta Grad: 0.032670940124487485
W Grad: tensor([ 0.0942,  0.0981,  0.0983,  0.0996,  0.0948, -0.0744, -0.0750, -0.0744,
        -0.0744, -0.0752, -0.0222, -0.0056, -0.0197, -0.0239, -0.0403])
Epoch 38, NLL: 119.12, Alpha: 9.181, Beta: 1.573, W: [-1.7456939220428467, -1.7483140230178833, -1.752186894416809, -1.7486478090286255, -1.7508949041366577, 1.4083560705184937, 1.4016386270523071, 1.4019556045532227, 1.4004215002059937, 1.4037564992904663, -0.19987384974956512, -0.18636713922023773, -0.1978517472743988, -0.20125913619995117, -0.2145823836326599]
Epoch 39 of 200


 20%|█▉        | 39/200 [06:39<26:50, 10.00s/it]

Log Alpha Grad: 0.14711284747233466
Log Beta Grad: 0.031716191793390976
W Grad: tensor([ 0.0910,  0.0942,  0.0944,  0.0954,  0.0916, -0.0735, -0.0739, -0.0734,
        -0.0734, -0.0741, -0.0195, -0.0049, -0.0173, -0.0210, -0.0355])
Epoch 39, NLL: 119.112, Alpha: 9.047, Beta: 1.568, W: [-1.754794716835022, -1.75773286819458, -1.7616236209869385, -1.7581838369369507, -1.7600520849227905, 1.415706992149353, 1.4090299606323242, 1.409296989440918, 1.4077662229537964, 1.411162257194519, -0.19792306423187256, -0.18588189780712128, -0.19612038135528564, -0.19915802776813507, -0.21103543043136597]
Epoch 40 of 200


 20%|██        | 40/200 [06:49<26:14,  9.84s/it]

Log Alpha Grad: 0.14286761782795973
Log Beta Grad: 0.03085525056799678
W Grad: tensor([ 0.0880,  0.0906,  0.0907,  0.0915,  0.0885, -0.0725, -0.0728, -0.0724,
        -0.0724, -0.0729, -0.0172, -0.0042, -0.0152, -0.0185, -0.0313])
Epoch 40, NLL: 119.104, Alpha: 8.919, Beta: 1.563, W: [-1.7635973691940308, -1.7667906284332275, -1.7706973552703857, -1.7673333883285522, -1.7689027786254883, 1.4229567050933838, 1.416310429573059, 1.4165359735488892, 1.4150084257125854, 1.418454647064209, -0.19620782136917114, -0.18546150624752045, -0.194598987698555, -0.1973099708557129, -0.2079099714756012]
Epoch 41 of 200


 20%|██        | 41/200 [06:59<26:06,  9.85s/it]

Log Alpha Grad: 0.1387000424536972
Log Beta Grad: 0.030060220447264422
W Grad: tensor([ 0.0852,  0.0873,  0.0874,  0.0880,  0.0856, -0.0714, -0.0717, -0.0713,
        -0.0713, -0.0718, -0.0151, -0.0036, -0.0134, -0.0163, -0.0276])
Epoch 41, NLL: 119.097, Alpha: 8.796, Beta: 1.558, W: [-1.7721186876296997, -1.775517225265503, -1.7794384956359863, -1.7761318683624268, -1.7774654626846313, 1.4300997257232666, 1.4234766960144043, 1.423667311668396, 1.4221431016921997, 1.425630807876587, -0.19469767808914185, -0.18509694933891296, -0.19326035678386688, -0.1956823319196701, -0.2051522433757782]
Epoch 42 of 200


 21%|██        | 42/200 [07:09<26:18,  9.99s/it]

Log Alpha Grad: 0.13462894858972008
Log Beta Grad: 0.029312282102465517
W Grad: tensor([ 0.0826,  0.0842,  0.0843,  0.0848,  0.0829, -0.0703, -0.0705, -0.0702,
        -0.0702, -0.0706, -0.0133, -0.0032, -0.0118, -0.0144, -0.0244])
Epoch 42, NLL: 119.091, Alpha: 8.678, Beta: 1.554, W: [-1.7803741693496704, -1.7839385271072388, -1.787873387336731, -1.784609317779541, -1.785757303237915, 1.4371317625045776, 1.4305261373519897, 1.4306871891021729, 1.429166316986084, 1.4326884746551514, -0.1933664232492447, -0.1847805231809616, -0.1920810341835022, -0.19424699246883392, -0.20271584391593933]
Epoch 43 of 200


 22%|██▏       | 43/200 [07:19<26:26, 10.10s/it]

Log Alpha Grad: 0.13066757444302776
Log Beta Grad: 0.02859516002724754
W Grad: tensor([ 0.0800,  0.0814,  0.0815,  0.0818,  0.0804, -0.0692, -0.0693, -0.0691,
        -0.0691, -0.0694, -0.0117, -0.0027, -0.0104, -0.0127, -0.0216])
Epoch 43, NLL: 119.085, Alpha: 8.565, Beta: 1.549, W: [-1.788378357887268, -1.7920773029327393, -1.7960249185562134, -1.7927918434143066, -1.793793797492981, 1.4440498352050781, 1.4374570846557617, 1.437592625617981, 1.436075210571289, 1.4396262168884277, -0.19219143688678741, -0.1845056265592575, -0.1910407990217209, -0.19297967851161957, -0.20056065917015076]
Epoch 44 of 200


 22%|██▏       | 44/200 [07:31<27:27, 10.56s/it]

Log Alpha Grad: 0.1268236029486551
Log Beta Grad: 0.02789904926523484
W Grad: tensor([ 0.0777,  0.0788,  0.0789,  0.0791,  0.0780, -0.0680, -0.0681, -0.0679,
        -0.0679, -0.0682, -0.0104, -0.0024, -0.0092, -0.0112, -0.0191])
Epoch 44, NLL: 119.079, Alpha: 8.457, Beta: 1.545, W: [-1.7961448431015015, -1.7999534606933594, -1.8039133548736572, -1.8007020950317383, -1.8015893697738647, 1.4508517980575562, 1.4442684650421143, 1.4443817138671875, 1.4428677558898926, 1.4464432001113892, -0.1911531537771225, -0.18426662683486938, -0.19012217223644257, -0.1918594092130661, -0.1986519694328308]
Epoch 45 of 200


 22%|██▎       | 45/200 [07:41<27:12, 10.53s/it]

Log Alpha Grad: 0.12310167198764398
Log Beta Grad: 0.027216653088103202
W Grad: tensor([ 0.0754,  0.0763,  0.0764,  0.0766,  0.0757, -0.0668, -0.0669, -0.0667,
        -0.0668, -0.0670, -0.0092, -0.0021, -0.0081, -0.0099, -0.0169])
Epoch 45, NLL: 119.073, Alpha: 8.354, Beta: 1.541, W: [-1.8036861419677734, -1.8075847625732422, -1.811556339263916, -1.8083597421646118, -1.8091572523117065, 1.4575363397598267, 1.4509596824645996, 1.4510533809661865, 1.4495428800582886, 1.4531391859054565, -0.19023461639881134, -0.18405868113040924, -0.1893100142478943, -0.19086799025535583, -0.19695962965488434]
Epoch 46 of 200


 23%|██▎       | 46/200 [07:50<25:22,  9.89s/it]

Log Alpha Grad: 0.11950351274830989
Log Beta Grad: 0.026542946055689517
W Grad: tensor([ 0.0733,  0.0740,  0.0741,  0.0742,  0.0735, -0.0657, -0.0657, -0.0655,
        -0.0656, -0.0658, -0.0081, -0.0018, -0.0072, -0.0088, -0.0150])
Epoch 46, NLL: 119.068, Alpha: 8.255, Beta: 1.537, W: [-1.81101393699646, -1.8149868249893188, -1.8189697265625, -1.815782070159912, -1.8165096044540405, 1.464102864265442, 1.4575308561325073, 1.4576071500778198, 1.4561002254486084, 1.4597142934799194, -0.1894211322069168, -0.183877632021904, -0.18859121203422546, -0.18998965620994568, -0.19545747339725494]
Epoch 47 of 200


 24%|██▎       | 47/200 [07:59<24:38,  9.66s/it]

Log Alpha Grad: 0.11602873056035086
Log Beta Grad: 0.025874779506172552
W Grad: tensor([ 0.0713,  0.0719,  0.0720,  0.0720,  0.0715, -0.0645, -0.0645, -0.0644,
        -0.0644, -0.0645, -0.0072, -0.0016, -0.0064, -0.0078, -0.0133])
Epoch 47, NLL: 119.063, Alpha: 8.16, Beta: 1.533, W: [-1.8181390762329102, -1.8221735954284668, -1.8261675834655762, -1.8229843378067017, -1.8236576318740845, 1.4705514907836914, 1.4639825820922852, 1.4640430212020874, 1.462539792060852, 1.4661691188812256, -0.18869993090629578, -0.1837199181318283, -0.187954381108284, -0.1892106682062149, -0.19412268698215485]
Epoch 48 of 200


 24%|██▍       | 48/200 [08:08<24:21,  9.62s/it]

Log Alpha Grad: 0.1126759999680612
Log Beta Grad: 0.025209611744681178
W Grad: tensor([ 0.0693,  0.0698,  0.0699,  0.0700,  0.0695, -0.0633, -0.0633, -0.0632,
        -0.0632, -0.0634, -0.0064, -0.0014, -0.0056, -0.0069, -0.0119])
Epoch 48, NLL: 119.058, Alpha: 8.068, Beta: 1.529, W: [-1.8250716924667358, -1.829157829284668, -1.833162546157837, -1.8299802541732788, -1.8306118249893188, 1.4768825769424438, 1.4703155755996704, 1.4703617095947266, 1.4688620567321777, 1.472504734992981, -0.18805991113185883, -0.18358244001865387, -0.18738959729671478, -0.18851910531520844, -0.19293542206287384]
Epoch 49 of 200


 24%|██▍       | 49/200 [08:18<24:22,  9.68s/it]

Log Alpha Grad: 0.10944222427910544
Log Beta Grad: 0.02454776103753362
W Grad: tensor([ 0.0675,  0.0679,  0.0680,  0.0680,  0.0677, -0.0621, -0.0622, -0.0620,
        -0.0621, -0.0622, -0.0057, -0.0012, -0.0050, -0.0061, -0.0106])
Epoch 49, NLL: 119.053, Alpha: 7.98, Beta: 1.525, W: [-1.831821084022522, -1.8359508514404297, -1.839966058731079, -1.8367820978164673, -1.8373817205429077, 1.4830971956253052, 1.4765311479568481, 1.4765640497207642, 1.4750680923461914, 1.4787224531173706, -0.18749135732650757, -0.18346256017684937, -0.1868882179260254, -0.18790453672409058, -0.1918783187866211]
Epoch 50 of 200


 25%|██▌       | 50/200 [08:28<24:39,  9.86s/it]

Log Alpha Grad: 0.10632443607719311
Log Beta Grad: 0.02388763760737569
W Grad: tensor([ 0.0658,  0.0661,  0.0662,  0.0662,  0.0659, -0.0610, -0.0610, -0.0609,
        -0.0609, -0.0610, -0.0051, -0.0010, -0.0045, -0.0055, -0.0094])
Epoch 50, NLL: 119.049, Alpha: 7.896, Beta: 1.522, W: [-1.8383961915969849, -1.8425629138946533, -1.8465886116027832, -1.8434005975723267, -1.843976378440857, 1.489196538925171, 1.482630729675293, 1.4826513528823853, 1.4811590909957886, 1.4848237037658691, -0.1869858354330063, -0.18335796892642975, -0.18644271790981293, -0.18735788762569427, -0.19093620777130127]
Epoch 51 of 200


 26%|██▌       | 51/200 [08:39<25:21, 10.21s/it]

Log Alpha Grad: 0.10331867422866343
Log Beta Grad: 0.023230131763081138
W Grad: tensor([ 0.0641,  0.0644,  0.0645,  0.0645,  0.0643, -0.0599, -0.0599, -0.0597,
        -0.0598, -0.0599, -0.0045, -0.0009, -0.0040, -0.0049, -0.0084])
Epoch 51, NLL: 119.045, Alpha: 7.815, Beta: 1.518, W: [-1.8448050022125244, -1.849003553390503, -1.8530395030975342, -1.8498458862304688, -1.8504040241241455, 1.4951821565628052, 1.4886159896850586, 1.488625168800354, 1.4871366024017334, 1.4908102750778198, -0.18653593957424164, -0.18326669931411743, -0.1860465109348297, -0.1868712157011032, -0.19009579718112946]
Epoch 52 of 200


 26%|██▌       | 52/200 [08:52<27:00, 10.95s/it]

Log Alpha Grad: 0.10042093274527586
Log Beta Grad: 0.02257516255928826
W Grad: tensor([ 0.0625,  0.0628,  0.0629,  0.0628,  0.0627, -0.0587, -0.0587, -0.0586,
        -0.0587, -0.0587, -0.0040, -0.0008, -0.0035, -0.0043, -0.0075])
Epoch 52, NLL: 119.041, Alpha: 7.737, Beta: 1.515, W: [-1.8510551452636719, -1.8552813529968262, -1.8593274354934692, -1.8561269044876099, -1.8566724061965942, 1.5010557174682617, 1.4944888353347778, 1.4944872856140137, 1.4930022954940796, 1.496684193611145, -0.1861352026462555, -0.18318702280521393, -0.18569384515285492, -0.18643754720687866, -0.18934546411037445]
Epoch 53 of 200


 26%|██▋       | 53/200 [09:02<26:11, 10.69s/it]

Log Alpha Grad: 0.09762712745953513
Log Beta Grad: 0.02192362271057173
W Grad: tensor([ 0.0610,  0.0612,  0.0613,  0.0612,  0.0612, -0.0576, -0.0576, -0.0575,
        -0.0576, -0.0576, -0.0036, -0.0007, -0.0031, -0.0039, -0.0067])
Epoch 53, NLL: 119.037, Alpha: 7.661, Beta: 1.511, W: [-1.8571537733078003, -1.8614041805267334, -1.8654602766036987, -1.8622517585754395, -1.8627886772155762, 1.506819248199463, 1.5002511739730835, 1.5002394914627075, 1.49875807762146, 1.5024473667144775, -0.185777947306633, -0.18311746418476105, -0.18537966907024384, -0.18605080246925354, -0.1886749565601349]
Epoch 54 of 200


 27%|██▋       | 54/200 [09:12<25:08, 10.33s/it]

Log Alpha Grad: 0.09493299979007078
Log Beta Grad: 0.02127610108791091
W Grad: tensor([ 0.0595,  0.0598,  0.0599,  0.0598,  0.0597, -0.0566, -0.0565, -0.0564,
        -0.0565, -0.0565, -0.0032, -0.0006, -0.0028, -0.0035, -0.0060])
Epoch 54, NLL: 119.033, Alpha: 7.589, Beta: 1.508, W: [-1.863107442855835, -1.8673793077468872, -1.8714454174041748, -1.8682280778884888, -1.8687596321105957, 1.5124746561050415, 1.505905270576477, 1.5058839321136475, 1.5044060945510864, 1.5081020593643188, -0.1854592114686966, -0.18305672705173492, -0.1850995570421219, -0.18570560216903687, -0.18807528913021088]
Epoch 55 of 200


 28%|██▊       | 55/200 [09:21<24:13, 10.02s/it]

Log Alpha Grad: 0.09233429630889398
Log Beta Grad: 0.0206334135725995
W Grad: tensor([ 0.0582,  0.0583,  0.0584,  0.0583,  0.0583, -0.0555, -0.0555, -0.0554,
        -0.0554, -0.0555, -0.0028, -0.0005, -0.0025, -0.0031, -0.0054])
Epoch 55, NLL: 119.03, Alpha: 7.519, Beta: 1.505, W: [-1.868922472000122, -1.8732134103775024, -1.8772894144058228, -1.8740627765655518, -1.8745914697647095, 1.518024206161499, 1.5114532709121704, 1.5114227533340454, 1.509948492050171, 1.5136505365371704, -0.1851746141910553, -0.18300367891788483, -0.18484961986541748, -0.18539725244045258, -0.18753854930400848]
Epoch 56 of 200


 28%|██▊       | 56/200 [09:29<22:18,  9.30s/it]

Log Alpha Grad: 0.08982718893935017
Log Beta Grad: 0.019995995870124968
W Grad: tensor([ 0.0568,  0.0570,  0.0571,  0.0570,  0.0570, -0.0545, -0.0544, -0.0544,
        -0.0544, -0.0544, -0.0025, -0.0005, -0.0022, -0.0028, -0.0048])
Epoch 56, NLL: 119.026, Alpha: 7.452, Beta: 1.502, W: [-1.87460458278656, -1.8789128065109253, -1.8829987049102783, -1.8797622919082642, -1.880290150642395, 1.5234700441360474, 1.516897439956665, 1.5168582201004028, 1.5153874158859253, 1.5190949440002441, -0.1849202960729599, -0.18295736610889435, -0.18462644517421722, -0.18512161076068878, -0.1870577484369278]
Epoch 57 of 200


 28%|██▊       | 57/200 [09:39<22:44,  9.54s/it]

Log Alpha Grad: 0.0874074736731733
Log Beta Grad: 0.019365127125755114
W Grad: tensor([ 0.0555,  0.0557,  0.0558,  0.0557,  0.0557, -0.0534, -0.0534, -0.0533,
        -0.0534, -0.0534, -0.0023, -0.0004, -0.0020, -0.0025, -0.0043])
Epoch 57, NLL: 119.023, Alpha: 7.387, Beta: 1.499, W: [-1.8801593780517578, -1.8844833374023438, -1.8885788917541504, -1.885332465171814, -1.8858611583709717, 1.528814435005188, 1.522240161895752, 1.5221924781799316, 1.520725131034851, 1.5244377851486206, -0.1846928894519806, -0.18291692435741425, -0.18442702293395996, -0.184875026345253, -0.18662674725055695]
Epoch 58 of 200


 29%|██▉       | 58/200 [09:47<21:42,  9.17s/it]

Log Alpha Grad: 0.08507122388441735
Log Beta Grad: 0.01874191433593023
W Grad: tensor([ 0.0543,  0.0545,  0.0546,  0.0545,  0.0545, -0.0525, -0.0524, -0.0524,
        -0.0524, -0.0524, -0.0020, -0.0004, -0.0018, -0.0022, -0.0039])
Epoch 58, NLL: 119.02, Alpha: 7.325, Beta: 1.496, W: [-1.8855918645858765, -1.889930248260498, -1.8940355777740479, -1.8907787799835205, -1.8913094997406006, 1.5340596437454224, 1.5274837017059326, 1.5274279117584229, 1.5259639024734497, 1.5296812057495117, -0.1844893991947174, -0.18288162350654602, -0.18424870073795319, -0.18465428054332733, -0.186240091919899]
Epoch 59 of 200


 30%|██▉       | 59/200 [09:57<21:58,  9.35s/it]

Log Alpha Grad: 0.08281520699448708
Log Beta Grad: 0.018125003344592477
W Grad: tensor([ 0.0532,  0.0533,  0.0534,  0.0533,  0.0533, -0.0515, -0.0515, -0.0514,
        -0.0514, -0.0515, -0.0018, -0.0003, -0.0016, -0.0020, -0.0035])
Epoch 59, NLL: 119.017, Alpha: 7.264, Beta: 1.494, W: [-1.8909069299697876, -1.8952587842941284, -1.8993736505508423, -1.8961063623428345, -1.8966401815414429, 1.5392080545425415, 1.5326303243637085, 1.532566785812378, 1.5311059951782227, 1.534827709197998, -0.18430717289447784, -0.18285082280635834, -0.18408915400505066, -0.18445652723312378, -0.1858929842710495]
Epoch 60 of 200


 30%|███       | 60/200 [10:09<23:37, 10.12s/it]

Log Alpha Grad: 0.08063554901367395
Log Beta Grad: 0.017516884330674744
W Grad: tensor([ 0.0520,  0.0521,  0.0522,  0.0521,  0.0522, -0.0505, -0.0505, -0.0504,
        -0.0505, -0.0505, -0.0016, -0.0003, -0.0014, -0.0018, -0.0031])
Epoch 60, NLL: 119.014, Alpha: 7.206, Beta: 1.491, W: [-1.8961091041564941, -1.9004734754562378, -1.9045978784561157, -1.9013200998306274, -1.9018577337265015, 1.5442619323730469, 1.5376824140548706, 1.5376113653182983, 1.536153793334961, 1.5398796796798706, -0.18414390087127686, -0.18282395601272583, -0.18394629657268524, -0.18427926301956177, -0.18558116257190704]
Epoch 61 of 200


 30%|███       | 61/200 [10:19<23:14, 10.03s/it]

Log Alpha Grad: 0.07852877564385918
Log Beta Grad: 0.016917759624706315
W Grad: tensor([ 0.0509,  0.0511,  0.0511,  0.0510,  0.0511, -0.0496, -0.0496, -0.0495,
        -0.0496, -0.0496, -0.0015, -0.0002, -0.0013, -0.0016, -0.0028])
Epoch 61, NLL: 119.011, Alpha: 7.15, Beta: 1.489, W: [-1.90120267868042, -1.905578851699829, -1.9097126722335815, -1.9064244031906128, -1.9069664478302002, 1.5492236614227295, 1.54264235496521, 1.542564034461975, 1.541109561920166, 1.5448392629623413, -0.1839975118637085, -0.1828005313873291, -0.18381831049919128, -0.18412025272846222, -0.18530085682868958]
Epoch 62 of 200


 31%|███       | 62/200 [10:28<22:45,  9.89s/it]

Log Alpha Grad: 0.07649202144632213
Log Beta Grad: 0.016327081979938655
W Grad: tensor([ 0.0499,  0.0500,  0.0501,  0.0500,  0.0500, -0.0487, -0.0487, -0.0486,
        -0.0487, -0.0487, -0.0013, -0.0002, -0.0011, -0.0014, -0.0025])
Epoch 62, NLL: 119.008, Alpha: 7.095, Beta: 1.486, W: [-1.9061917066574097, -1.9105790853500366, -1.914722204208374, -1.9114233255386353, -1.9119703769683838, 1.5540953874588013, 1.547512412071228, 1.5474269390106201, 1.5459755659103394, 1.5497089624404907, -0.18386617302894592, -0.18278013169765472, -0.18370358645915985, -0.18397754430770874, -0.18504872918128967]
Epoch 63 of 200


 32%|███▏      | 63/200 [10:36<21:32,  9.43s/it]

Log Alpha Grad: 0.07452211654820091
Log Beta Grad: 0.015746049407898283
W Grad: tensor([ 0.0489,  0.0490,  0.0491,  0.0490,  0.0490, -0.0478, -0.0478, -0.0478,
        -0.0478, -0.0478, -0.0012, -0.0002, -0.0010, -0.0013, -0.0023])
Epoch 63, NLL: 119.006, Alpha: 7.042, Beta: 1.484, W: [-1.911080002784729, -1.915477991104126, -1.9196304082870483, -1.916321039199829, -1.9168734550476074, 1.5588794946670532, 1.5522947311401367, 1.5522024631500244, 1.5507540702819824, 1.5544909238815308, -0.1837482750415802, -0.18276236951351166, -0.18360067903995514, -0.18384937942028046, -0.18482179939746857]
Epoch 64 of 200


 32%|███▏      | 64/200 [10:45<20:52,  9.21s/it]

Log Alpha Grad: 0.07261618978613622
Log Beta Grad: 0.01517494353206208
W Grad: tensor([ 0.0479,  0.0480,  0.0481,  0.0480,  0.0481, -0.0470, -0.0470, -0.0469,
        -0.0469, -0.0470, -0.0011, -0.0002, -0.0009, -0.0012, -0.0020])
Epoch 64, NLL: 119.003, Alpha: 6.991, Beta: 1.482, W: [-1.915871262550354, -1.9202793836593628, -1.9244409799575806, -1.9211211204528809, -1.921679139137268, 1.5635782480239868, 1.556991696357727, 1.5568927526474, 1.5554473400115967, 1.559187412261963, -0.18364238739013672, -0.18274693191051483, -0.18350832164287567, -0.1837342083454132, -0.18461741507053375]
Epoch 65 of 200


 32%|███▎      | 65/200 [10:56<22:04,  9.81s/it]

Log Alpha Grad: 0.07077150979736382
Log Beta Grad: 0.014613444917994173
W Grad: tensor([ 0.0470,  0.0471,  0.0472,  0.0471,  0.0471, -0.0462, -0.0461, -0.0461,
        -0.0461, -0.0461, -0.0010, -0.0001, -0.0008, -0.0010, -0.0018])
Epoch 65, NLL: 119.001, Alpha: 6.942, Beta: 1.479, W: [-1.920568823814392, -1.924986720085144, -1.9291573762893677, -1.925827145576477, -1.9263910055160522, 1.5681936740875244, 1.561605453491211, 1.561500072479248, 1.560057520866394, 1.563800573348999, -0.1835472285747528, -0.18273352086544037, -0.1834253966808319, -0.18363066017627716, -0.1844332367181778]
Epoch 66 of 200


 33%|███▎      | 66/200 [11:07<22:17,  9.98s/it]

Log Alpha Grad: 0.06898529518700738
Log Beta Grad: 0.01406261805411749
W Grad: tensor([ 0.0461,  0.0462,  0.0463,  0.0462,  0.0462, -0.0453, -0.0453, -0.0453,
        -0.0453, -0.0453, -0.0009, -0.0001, -0.0007, -0.0009, -0.0017])
Epoch 66, NLL: 118.999, Alpha: 6.894, Beta: 1.477, W: [-1.9251759052276611, -1.9296032190322876, -1.9337828159332275, -1.9304423332214355, -1.9310121536254883, 1.5727280378341675, 1.5661381483078003, 1.5660264492034912, 1.5645867586135864, 1.5683326721191406, -0.18346166610717773, -0.18272188305854797, -0.18335090577602386, -0.18353751301765442, -0.18426717817783356]
Epoch 67 of 200


 34%|███▎      | 67/200 [11:21<24:54, 11.23s/it]

Log Alpha Grad: 0.06725530439128105
Log Beta Grad: 0.01352200024282876
W Grad: tensor([ 0.0452,  0.0453,  0.0454,  0.0453,  0.0453, -0.0446, -0.0445, -0.0445,
        -0.0445, -0.0445, -0.0008, -0.0001, -0.0007, -0.0008, -0.0015])
Epoch 67, NLL: 118.997, Alpha: 6.848, Beta: 1.475, W: [-1.929695725440979, -1.9341320991516113, -1.9383206367492676, -1.9349699020385742, -1.935545802116394, 1.5771833658218384, 1.5705918073654175, 1.5704740285873413, 1.5690370798110962, 1.57278573513031, -0.1833847016096115, -0.18271180987358093, -0.1832839548587799, -0.18345367908477783, -0.1841173768043518]
Epoch 68 of 200


 34%|███▍      | 68/200 [11:31<24:14, 11.02s/it]

Log Alpha Grad: 0.06557906855116277
Log Beta Grad: 0.012992067599549302
W Grad: tensor([ 4.4353e-02,  4.4441e-02,  4.4530e-02,  4.4429e-02,  4.4490e-02,
        -4.3784e-02, -4.3767e-02, -4.3708e-02, -4.3735e-02, -4.3761e-02,
        -6.9275e-04, -8.7138e-05, -6.0208e-04, -7.5482e-04, -1.3521e-03])
Epoch 68, NLL: 118.994, Alpha: 6.803, Beta: 1.473, W: [-1.9341310262680054, -1.9385762214660645, -1.9427735805511475, -1.9394127130508423, -1.9399948120117188, 1.581561803817749, 1.5749685764312744, 1.5748448371887207, 1.5734105110168457, 1.5771617889404297, -0.1833154261112213, -0.18270309269428253, -0.18322373926639557, -0.1833781898021698, -0.18398216366767883]
Epoch 69 of 200


 34%|███▍      | 69/200 [11:39<22:06, 10.12s/it]

Log Alpha Grad: 0.06395459851384411
Log Beta Grad: 0.012472002895789976
W Grad: tensor([ 4.3537e-02,  4.3622e-02,  4.3710e-02,  4.3610e-02,  4.3672e-02,
        -4.3034e-02, -4.3018e-02, -4.2960e-02, -4.2986e-02, -4.3011e-02,
        -6.2378e-04, -7.5233e-05, -5.4166e-04, -6.8001e-04, -1.2211e-03])
Epoch 69, NLL: 118.992, Alpha: 6.76, Beta: 1.472, W: [-1.938484787940979, -1.9429384469985962, -1.947144627571106, -1.9437737464904785, -1.9443620443344116, 1.5858652591705322, 1.579270362854004, 1.5791409015655518, 1.5777090787887573, 1.5814628601074219, -0.18325304985046387, -0.18269556760787964, -0.18316957354545593, -0.18331019580364227, -0.18386004865169525]
Epoch 70 of 200


 35%|███▌      | 70/200 [11:50<22:08, 10.22s/it]

Log Alpha Grad: 0.06237946844446703
Log Beta Grad: 0.011962971320067977
W Grad: tensor([ 4.2747e-02,  4.2830e-02,  4.2917e-02,  4.2818e-02,  4.2880e-02,
        -4.2304e-02, -4.2288e-02, -4.2232e-02, -4.2257e-02, -4.2281e-02,
        -5.6193e-04, -6.4729e-05, -4.8748e-04, -6.1291e-04, -1.1033e-03])
Epoch 70, NLL: 118.99, Alpha: 6.718, Beta: 1.47, W: [-1.9427595138549805, -1.9472215175628662, -1.9514362812042236, -1.948055624961853, -1.9486501216888428, 1.5900956392288208, 1.5834991931915283, 1.5833641290664673, 1.5819348096847534, 1.585690975189209, -0.18319685757160187, -0.1826891005039215, -0.18312083184719086, -0.18324890732765198, -0.183749720454216]
Epoch 71 of 200


 36%|███▌      | 71/200 [12:01<22:18, 10.38s/it]

Log Alpha Grad: 0.060851811762010724
Log Beta Grad: 0.011464587863547947
W Grad: tensor([ 4.1983e-02,  4.2064e-02,  4.2150e-02,  4.2052e-02,  4.2114e-02,
        -4.1593e-02, -4.1577e-02, -4.1523e-02, -4.1547e-02, -4.1570e-02,
        -5.0643e-04, -5.5532e-05, -4.3892e-04, -5.5264e-04, -9.9736e-04])
Epoch 71, NLL: 118.988, Alpha: 6.677, Beta: 1.468, W: [-1.9469578266143799, -1.951427936553955, -1.9556512832641602, -1.9522608518600464, -1.9528615474700928, 1.594254970550537, 1.5876569747924805, 1.5875164270401, 1.5860894918441772, 1.5898480415344238, -0.18314620852470398, -0.1826835423707962, -0.18307693302631378, -0.183193638920784, -0.18364998698234558]
Epoch 72 of 200


 36%|███▌      | 72/200 [12:10<21:11,  9.93s/it]

Log Alpha Grad: 0.05936980832477084
Log Beta Grad: 0.010976407152751567
W Grad: tensor([ 4.1242e-02,  4.1321e-02,  4.1406e-02,  4.1310e-02,  4.1372e-02,
        -4.0901e-02, -4.0885e-02, -4.0832e-02, -4.0856e-02, -4.0878e-02,
        -4.5660e-04, -4.7521e-05, -3.9533e-04, -4.9854e-04, -9.0201e-04])
Epoch 72, NLL: 118.987, Alpha: 6.638, Beta: 1.466, W: [-1.9510819911956787, -1.9555600881576538, -1.959791898727417, -1.9563918113708496, -1.9569987058639526, 1.5983450412750244, 1.5917454957962036, 1.5915995836257935, 1.5901750326156616, 1.5939358472824097, -0.18310055136680603, -0.18267878890037537, -0.1830374002456665, -0.18314377963542938, -0.18355979025363922]
Epoch 73 of 200


 36%|███▋      | 73/200 [12:19<20:37,  9.75s/it]

Log Alpha Grad: 0.0579314810157843
Log Beta Grad: 0.010499050851057823
W Grad: tensor([ 4.0525e-02,  4.0602e-02,  4.0686e-02,  4.0591e-02,  4.0652e-02,
        -4.0226e-02, -4.0210e-02, -4.0159e-02, -4.0182e-02, -4.0203e-02,
        -4.1185e-04, -4.0529e-05, -3.5626e-04, -4.4991e-04, -8.1619e-04])
Epoch 73, NLL: 118.985, Alpha: 6.599, Beta: 1.465, W: [-1.955134391784668, -1.9596202373504639, -1.9638605117797852, -1.9604508876800537, -1.9610639810562134, 1.602367639541626, 1.595766544342041, 1.5956155061721802, 1.5941932201385498, 1.5979561805725098, -0.18305936455726624, -0.18267473578453064, -0.18300177156925201, -0.18309879302978516, -0.18347817659378052]
Epoch 74 of 200


 37%|███▋      | 74/200 [12:31<22:04, 10.51s/it]

Log Alpha Grad: 0.05653529293399355
Log Beta Grad: 0.01003183757105747
W Grad: tensor([ 3.9829e-02,  3.9905e-02,  3.9988e-02,  3.9894e-02,  3.9955e-02,
        -3.9569e-02, -3.9554e-02, -3.9503e-02, -3.9526e-02, -3.9546e-02,
        -3.7161e-04, -3.4379e-05, -3.2113e-04, -4.0618e-04, -7.3882e-04])
Epoch 74, NLL: 118.983, Alpha: 6.562, Beta: 1.463, W: [-1.9591172933578491, -1.9636107683181763, -1.9678593873977661, -1.9644402265548706, -1.9650593996047974, 1.606324553489685, 1.599721908569336, 1.599565863609314, 1.598145842552185, 1.6019108295440674, -0.18302220106124878, -0.18267129361629486, -0.18296965956687927, -0.18305817246437073, -0.18340429663658142]
Epoch 75 of 200


 38%|███▊      | 75/200 [12:41<21:36, 10.38s/it]

Log Alpha Grad: 0.05517960587134552
Log Beta Grad: 0.009574771739623197
W Grad: tensor([ 3.9154e-02,  3.9229e-02,  3.9311e-02,  3.9218e-02,  3.9279e-02,
        -3.8929e-02, -3.8914e-02, -3.8865e-02, -3.8887e-02, -3.8906e-02,
        -3.3544e-04, -2.9024e-05, -2.8956e-04, -3.6685e-04, -6.6907e-04])
Epoch 75, NLL: 118.981, Alpha: 6.526, Beta: 1.462, W: [-1.9630327224731445, -1.9675335884094238, -1.9717904329299927, -1.9683620929718018, -1.968987226486206, 1.6102174520492554, 1.603613257408142, 1.6034523248672485, 1.602034568786621, 1.6058014631271362, -0.18298865854740143, -0.18266838788986206, -0.18294070661067963, -0.18302148580551147, -0.18333739042282104]
Epoch 76 of 200


 38%|███▊      | 76/200 [12:52<21:36, 10.45s/it]

Log Alpha Grad: 0.05386276735867231
Log Beta Grad: 0.0091279083646062
W Grad: tensor([ 3.8500e-02,  3.8573e-02,  3.8654e-02,  3.8563e-02,  3.8623e-02,
        -3.8306e-02, -3.8290e-02, -3.8243e-02, -3.8264e-02, -3.8283e-02,
        -3.0291e-04, -2.4375e-05, -2.6119e-04, -3.3145e-04, -6.0619e-04])
Epoch 76, NLL: 118.98, Alpha: 6.491, Beta: 1.461, W: [-1.9668827056884766, -1.9713908433914185, -1.9756559133529663, -1.97221839427948, -1.9728494882583618, 1.6140480041503906, 1.6074422597885132, 1.6072765588760376, 1.605860948562622, 1.60962975025177, -0.18295836448669434, -0.18266594409942627, -0.1829145848751068, -0.18298834562301636, -0.18327677249908447]
Epoch 77 of 200


 38%|███▊      | 77/200 [13:04<22:11, 10.83s/it]

Log Alpha Grad: 0.05258326463008554
Log Beta Grad: 0.008691319620795036
W Grad: tensor([ 3.7864e-02,  3.7936e-02,  3.8017e-02,  3.7926e-02,  3.7985e-02,
        -3.7698e-02, -3.7683e-02, -3.7637e-02, -3.7657e-02, -3.7676e-02,
        -2.7359e-04, -2.0283e-05, -2.3568e-04, -2.9957e-04, -5.4943e-04])
Epoch 77, NLL: 118.978, Alpha: 6.457, Beta: 1.459, W: [-1.970669150352478, -1.975184440612793, -1.9794576168060303, -1.976011037826538, -1.9766480922698975, 1.617817759513855, 1.611210584640503, 1.6110402345657349, 1.6096266508102417, 1.6133973598480225, -0.18293100595474243, -0.1826639175415039, -0.18289101123809814, -0.18295839428901672, -0.18322183191776276]
Epoch 78 of 200


 39%|███▉      | 78/200 [13:12<20:31, 10.09s/it]

Log Alpha Grad: 0.05133979233742328
Log Beta Grad: 0.00826455034656985
W Grad: tensor([ 3.7247e-02,  3.7318e-02,  3.7398e-02,  3.7308e-02,  3.7367e-02,
        -3.7106e-02, -3.7091e-02, -3.7046e-02, -3.7066e-02, -3.7084e-02,
        -2.4719e-04, -1.6721e-05, -2.1267e-04, -2.7084e-04, -4.9814e-04])
Epoch 78, NLL: 118.977, Alpha: 6.424, Beta: 1.458, W: [-1.9743938446044922, -1.9789161682128906, -1.9831973314285278, -1.9797418117523193, -1.9803848266601562, 1.6215283870697021, 1.614919662475586, 1.614744782447815, 1.6133332252502441, 1.6171057224273682, -0.18290628492832184, -0.1826622486114502, -0.18286974728107452, -0.1829313039779663, -0.18317201733589172]
Epoch 79 of 200


 39%|███▉      | 78/200 [13:12<20:39, 10.16s/it]


KeyboardInterrupt: 

In [12]:
best_nll, final_alpha, final_beta, W = train(X, y, patience=100, grid_size=200, num_epochs=200,
                                            lr=.1, speech_len=speech_len, lex_size=lex_size, 
                                            log_alpha= torch.tensor(np.log(12.67 + -.5), requires_grad=True),
                                            log_beta= torch.tensor(np.log(3.168 + 0.1), requires_grad=True),
                                            W = torch.normal(10, 2, size=(3 * lex_size,), requires_grad=True))

Beginning Inference...
Hyperparameters:
Dataset Length: 50000
Number Epochs: 200
Batch Size: 100
Grid Size: 200
Learning Rate: 0.1
Improvement Threshold: 1e-05
Patience: 100
Speech Len: 50
Initial Parameters:
Alpha: 12.170000000000002
Beta: 3.268
W: [11.960583686828613, 11.142026901245117, 12.866023063659668, 13.337730407714844, 13.613570213317871, 8.694578170776367, 12.097506523132324, 10.995043754577637, 10.773017883300781, 6.944344997406006, 10.778311729431152, 8.514144897460938, 9.893498420715332, 10.913046836853027, 11.223490715026855]


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 of 200


  0%|          | 1/200 [00:14<47:20, 14.27s/it]

Log Alpha Grad: 5.239091613097167
Log Beta Grad: -3.165634268340038
W Grad: tensor([ 0.0724,  0.0668,  0.0805,  0.0846,  0.0877, -0.1184, -0.1077, -0.1123,
        -0.1127, -0.1199,  0.0398,  0.0221,  0.0323,  0.0410,  0.0439])
Epoch 1, NLL: 138.867, Alpha: 7.207, Beta: 4.485, W: [11.953339576721191, 11.135344505310059, 12.857972145080566, 13.329269409179688, 13.604802131652832, 8.706415176391602, 12.10827922821045, 11.006277084350586, 10.784293174743652, 6.956336975097656, 10.774333000183105, 8.511940002441406, 9.890268325805664, 10.908946990966797, 11.219100952148438]
Epoch 2 of 200


  1%|          | 2/200 [00:23<36:21, 11.02s/it]

Log Alpha Grad: 1.2758036240218333
Log Beta Grad: -1.4815351790474376
W Grad: tensor([ 0.0686,  0.0645,  0.0745,  0.0773,  0.0794, -0.1062, -0.0981, -0.1014,
        -0.1016, -0.1076,  0.0333,  0.0191,  0.0274,  0.0343,  0.0365])
Epoch 2, NLL: 136.415, Alpha: 6.344, Beta: 5.201, W: [11.946480751037598, 11.128896713256836, 12.850522994995117, 13.321539878845215, 13.596858024597168, 8.717032432556152, 12.118086814880371, 11.016416549682617, 10.794458389282227, 6.967097759246826, 10.770999908447266, 8.510031700134277, 9.887528419494629, 10.905519485473633, 11.215448379516602]
Epoch 3 of 200


  2%|▏         | 3/200 [00:34<36:47, 11.21s/it]

Log Alpha Grad: 0.3554134003383521
Log Beta Grad: -0.8107643294474587
W Grad: tensor([ 0.0686,  0.0645,  0.0743,  0.0770,  0.0791, -0.1057, -0.0978, -0.1010,
        -0.1013, -0.1072,  0.0331,  0.0190,  0.0272,  0.0340,  0.0362])
Epoch 3, NLL: 136.131, Alpha: 6.122, Beta: 5.64, W: [11.93962574005127, 11.122442245483398, 12.843093872070312, 13.31383991241455, 13.58895206451416, 8.727605819702148, 12.127863883972168, 11.02651596069336, 10.804583549499512, 6.9778151512146, 10.767694473266602, 8.508129119873047, 9.884806632995605, 10.902120590209961, 11.21182918548584]
Epoch 4 of 200


  2%|▏         | 4/200 [00:46<37:06, 11.36s/it]

Log Alpha Grad: -0.003631288493307368
Log Beta Grad: -0.4951826158853243
W Grad: tensor([ 0.0688,  0.0648,  0.0746,  0.0773,  0.0794, -0.1063, -0.0983, -0.1016,
        -0.1018, -0.1078,  0.0334,  0.0193,  0.0275,  0.0343,  0.0365])
Epoch 4, NLL: 136.068, Alpha: 6.125, Beta: 5.927, W: [11.932744979858398, 11.115964889526367, 12.8356351852417, 13.306109428405762, 13.581013679504395, 8.738239288330078, 12.1376953125, 11.03667163848877, 10.814765930175781, 6.988593578338623, 10.764357566833496, 8.506203651428223, 9.88205623626709, 10.898690223693848, 11.208176612854004]
Epoch 5 of 200


  2%|▎         | 5/200 [00:54<33:58, 10.46s/it]

Log Alpha Grad: -0.152139682276802
Log Beta Grad: -0.3534663942221082
W Grad: tensor([ 0.0691,  0.0650,  0.0750,  0.0777,  0.0798, -0.1071, -0.0990, -0.1023,
        -0.1026, -0.1086,  0.0338,  0.0195,  0.0279,  0.0347,  0.0370])
Epoch 5, NLL: 136.044, Alpha: 6.218, Beta: 6.14, W: [11.925832748413086, 11.109461784362793, 12.82813835144043, 13.298336029052734, 13.573029518127441, 8.748953819274902, 12.14759635925293, 11.04690170288086, 10.825023651123047, 6.999454021453857, 10.760977745056152, 8.504250526428223, 9.87926959991455, 10.895215034484863, 11.204477310180664]
Epoch 6 of 200


  3%|▎         | 6/200 [01:05<33:34, 10.38s/it]

Log Alpha Grad: -0.21311246456151228
Log Beta Grad: -0.29274321349673904
W Grad: tensor([ 0.0695,  0.0653,  0.0754,  0.0782,  0.0803, -0.1080, -0.0998, -0.1031,
        -0.1034, -0.1095,  0.0343,  0.0198,  0.0283,  0.0352,  0.0375])
Epoch 6, NLL: 136.027, Alpha: 6.352, Beta: 6.322, W: [11.918888092041016, 11.10293197631836, 12.820599555969238, 13.29051685333252, 13.564996719360352, 8.759754180908203, 12.157571792602539, 11.057211875915527, 10.83536148071289, 7.010401725769043, 10.757552146911621, 8.502266883850098, 9.876443862915039, 10.891693115234375, 11.200728416442871]
Epoch 7 of 200


  4%|▎         | 7/200 [01:17<35:32, 11.05s/it]

Log Alpha Grad: -0.23733340423096427
Log Beta Grad: -0.26766204176494185
W Grad: tensor([ 0.0698,  0.0656,  0.0758,  0.0787,  0.0808, -0.1089, -0.1005, -0.1039,
        -0.1042, -0.1104,  0.0347,  0.0201,  0.0287,  0.0357,  0.0380])
Epoch 7, NLL: 136.011, Alpha: 6.505, Beta: 6.494, W: [11.911909103393555, 11.09637451171875, 12.813018798828125, 13.282649993896484, 13.556913375854492, 8.77064323425293, 12.167623519897461, 11.067604064941406, 10.845782279968262, 7.021439075469971, 10.75407886505127, 8.500252723693848, 9.873578071594238, 10.88812255859375, 11.196928024291992]
Epoch 8 of 200


  4%|▍         | 8/200 [01:30<37:42, 11.79s/it]

Log Alpha Grad: -0.2464280738732359
Log Beta Grad: -0.2575237577802299
W Grad: tensor([ 0.0701,  0.0659,  0.0762,  0.0792,  0.0813, -0.1098, -0.1013, -0.1048,
        -0.1051, -0.1113,  0.0352,  0.0205,  0.0291,  0.0362,  0.0385])
Epoch 8, NLL: 135.995, Alpha: 6.667, Beta: 6.663, W: [11.904895782470703, 11.089789390563965, 12.805395126342773, 13.274734497070312, 13.548778533935547, 8.781623840332031, 12.177753448486328, 11.078079223632812, 10.856287956237793, 7.032567977905273, 10.750556945800781, 8.498208045959473, 9.870671272277832, 10.884502410888672, 11.193074226379395]
Epoch 9 of 200


  4%|▍         | 9/200 [01:42<37:39, 11.83s/it]

Log Alpha Grad: -0.2494908881176136
Log Beta Grad: -0.25338015928895996
W Grad: tensor([ 0.0705,  0.0661,  0.0767,  0.0796,  0.0819, -0.1107, -0.1021, -0.1056,
        -0.1059, -0.1122,  0.0357,  0.0208,  0.0295,  0.0367,  0.0391])
Epoch 9, NLL: 135.979, Alpha: 6.836, Beta: 6.834, W: [11.897848129272461, 11.083175659179688, 12.797727584838867, 13.266770362854004, 13.5405912399292, 8.792696952819824, 12.187963485717773, 11.088640213012695, 10.866880416870117, 7.043790817260742, 10.74698543548584, 8.49613094329834, 9.867721557617188, 10.880830764770508, 11.189167022705078]
Epoch 10 of 200


  5%|▌         | 10/200 [01:53<36:31, 11.54s/it]

Log Alpha Grad: -0.25023541470089455
Log Beta Grad: -0.2515378897519179
W Grad: tensor([ 0.0708,  0.0664,  0.0771,  0.0801,  0.0824, -0.1117, -0.1029, -0.1065,
        -0.1068, -0.1132,  0.0362,  0.0211,  0.0299,  0.0372,  0.0396])
Epoch 10, NLL: 135.963, Alpha: 7.009, Beta: 7.008, W: [11.890764236450195, 11.076532363891602, 12.790014266967773, 13.258755683898926, 13.532349586486816, 8.803865432739258, 12.198254585266113, 11.099288940429688, 10.87756061553955, 7.055109977722168, 10.743362426757812, 8.49402141571045, 9.864728927612305, 10.877106666564941, 11.18520450592041]
Epoch 11 of 200


  6%|▌         | 11/200 [02:02<33:30, 10.64s/it]

Log Alpha Grad: -0.2501223118365571
Log Beta Grad: -0.2505377394840475
W Grad: tensor([ 0.0712,  0.0667,  0.0776,  0.0807,  0.0830, -0.1127, -0.1037, -0.1074,
        -0.1077, -0.1142,  0.0367,  0.0214,  0.0304,  0.0378,  0.0402])
Epoch 11, NLL: 135.947, Alpha: 7.186, Beta: 7.186, W: [11.883644104003906, 11.069860458374023, 12.782255172729492, 13.250689506530762, 13.524052619934082, 8.815131187438965, 12.20862865447998, 11.110026359558105, 10.888331413269043, 7.066527366638184, 10.7396879196167, 8.491877555847168, 9.861692428588867, 10.873329162597656, 11.181185722351074]
Epoch 12 of 200


  6%|▌         | 12/200 [02:11<32:06, 10.25s/it]

Log Alpha Grad: -0.24970212820142124
Log Beta Grad: -0.2498278821604869
W Grad: tensor([ 0.0716,  0.0670,  0.0781,  0.0812,  0.0835, -0.1136, -0.1046, -0.1083,
        -0.1086, -0.1152,  0.0373,  0.0218,  0.0308,  0.0383,  0.0408])
Epoch 12, NLL: 135.932, Alpha: 7.368, Beta: 7.368, W: [11.876486778259277, 11.06315803527832, 12.774449348449707, 13.242570877075195, 13.51569938659668, 8.826496124267578, 12.219087600708008, 11.120855331420898, 10.899194717407227, 7.078044891357422, 10.735960006713867, 8.489700317382812, 9.858611106872559, 10.869497299194336, 11.177108764648438]
Epoch 13 of 200


  6%|▋         | 13/200 [02:21<31:59, 10.27s/it]

Log Alpha Grad: -0.2491727748277517
Log Beta Grad: -0.24920876058821081
W Grad: tensor([ 0.0720,  0.0673,  0.0785,  0.0817,  0.0841, -0.1147, -0.1055, -0.1092,
        -0.1096, -0.1162,  0.0378,  0.0221,  0.0313,  0.0389,  0.0414])
Epoch 13, NLL: 135.916, Alpha: 7.554, Beta: 7.554, W: [11.869291305541992, 11.056425094604492, 12.766595840454102, 13.234397888183594, 13.507288932800293, 8.83796215057373, 12.229632377624512, 11.131776809692383, 10.910152435302734, 7.089665412902832, 10.732177734375, 8.48748779296875, 9.855484008789062, 10.865610122680664, 11.172972679138184]
Epoch 14 of 200


  7%|▋         | 14/200 [02:32<32:22, 10.44s/it]

Log Alpha Grad: -0.2486016132030356
Log Beta Grad: -0.24861130429236536
W Grad: tensor([ 0.0723,  0.0676,  0.0790,  0.0823,  0.0847, -0.1157, -0.1063, -0.1102,
        -0.1105, -0.1173,  0.0384,  0.0225,  0.0317,  0.0394,  0.0420])
Epoch 14, NLL: 135.9, Alpha: 7.744, Beta: 7.744, W: [11.86205768585205, 11.049660682678223, 12.758692741394043, 13.226170539855957, 13.498819351196289, 8.849533081054688, 12.240265846252441, 11.142793655395508, 10.9212064743042, 7.101391315460205, 10.728339195251465, 8.485239028930664, 9.852309226989746, 10.861665725708008, 11.16877555847168]
Epoch 15 of 200


  8%|▊         | 15/200 [02:43<32:05, 10.41s/it]

Log Alpha Grad: -0.2480099485266394
Log Beta Grad: -0.24801239203769307
W Grad: tensor([ 0.0727,  0.0680,  0.0795,  0.0828,  0.0853, -0.1168, -0.1072, -0.1111,
        -0.1115, -0.1183,  0.0390,  0.0229,  0.0322,  0.0400,  0.0426])
Epoch 15, NLL: 135.884, Alpha: 7.939, Beta: 7.939, W: [11.85478401184082, 11.042864799499512, 12.750740051269531, 13.217886924743652, 13.490289688110352, 8.861209869384766, 12.250988960266113, 11.153907775878906, 10.932358741760254, 7.113224983215332, 10.724444389343262, 8.482954025268555, 9.84908676147461, 10.857662200927734, 11.16451644897461]
Epoch 16 of 200


  8%|▊         | 16/200 [02:53<31:45, 10.36s/it]

Log Alpha Grad: -0.24740368875230495
Log Beta Grad: -0.24740426217218958
W Grad: tensor([ 0.0731,  0.0683,  0.0800,  0.0834,  0.0859, -0.1179, -0.1082, -0.1121,
        -0.1125, -0.1194,  0.0395,  0.0232,  0.0327,  0.0406,  0.0432])
Epoch 16, NLL: 135.868, Alpha: 8.137, Beta: 8.137, W: [11.8474702835083, 11.036036491394043, 12.742735862731934, 13.209545135498047, 13.481698989868164, 8.87299633026123, 12.261804580688477, 11.165122032165527, 10.943612098693848, 7.125169277191162, 10.720490455627441, 8.480630874633789, 9.84581470489502, 10.853598594665527, 11.160194396972656]
Epoch 17 of 200


  8%|▊         | 17/200 [03:01<29:53,  9.80s/it]

Log Alpha Grad: -0.24678380341242134
Log Beta Grad: -0.24678392778503916
W Grad: tensor([ 0.0736,  0.0686,  0.0806,  0.0840,  0.0865, -0.1190, -0.1091, -0.1132,
        -0.1136, -0.1206,  0.0401,  0.0236,  0.0332,  0.0412,  0.0439])
Epoch 17, NLL: 135.852, Alpha: 8.341, Beta: 8.341, W: [11.840115547180176, 11.0291748046875, 12.734679222106934, 13.201144218444824, 13.473045349121094, 8.884894371032715, 12.272714614868164, 11.176437377929688, 10.954968452453613, 7.137226581573486, 10.716477394104004, 8.478269577026367, 9.84249210357666, 10.84947395324707, 11.155807495117188]
Epoch 18 of 200


  9%|▉         | 18/200 [03:12<30:42, 10.13s/it]

Log Alpha Grad: -0.24614986722512905
Log Beta Grad: -0.2461498919528384
W Grad: tensor([ 0.0740,  0.0689,  0.0811,  0.0846,  0.0872, -0.1201, -0.1101, -0.1142,
        -0.1146, -0.1217,  0.0407,  0.0240,  0.0337,  0.0419,  0.0445])
Epoch 18, NLL: 135.836, Alpha: 8.549, Beta: 8.549, W: [11.832718849182129, 11.022279739379883, 12.726569175720215, 13.192683219909668, 13.464326858520508, 8.896906852722168, 12.283720016479492, 11.187856674194336, 10.9664306640625, 7.149399757385254, 10.71240234375, 8.475869178771973, 9.839117050170898, 10.84528636932373, 11.15135383605957]
Epoch 19 of 200


 10%|▉         | 19/200 [03:23<31:27, 10.43s/it]

Log Alpha Grad: -0.24550090489003143
Log Beta Grad: -0.24550090935196506
W Grad: tensor([ 0.0744,  0.0693,  0.0816,  0.0852,  0.0878, -0.1213, -0.1110, -0.1153,
        -0.1157, -0.1229,  0.0414,  0.0244,  0.0343,  0.0425,  0.0452])
Epoch 19, NLL: 135.82, Alpha: 8.761, Beta: 8.761, W: [11.825278282165527, 11.015350341796875, 12.718404769897461, 13.184160232543945, 13.455541610717773, 8.909036636352539, 12.29482364654541, 11.199382781982422, 10.97800064086914, 7.161692142486572, 10.708264350891113, 8.473427772521973, 9.835689544677734, 10.841034889221191, 11.146831512451172]
Epoch 20 of 200


 10%|█         | 20/200 [03:32<29:51,  9.95s/it]

Log Alpha Grad: -0.2448358749569914
Log Beta Grad: -0.24483587567879903
W Grad: tensor([ 0.0748,  0.0696,  0.0822,  0.0859,  0.0885, -0.1225, -0.1120, -0.1164,
        -0.1168, -0.1241,  0.0420,  0.0248,  0.0348,  0.0432,  0.0459])
Epoch 20, NLL: 135.803, Alpha: 8.978, Beta: 8.978, W: [11.817793846130371, 11.00838565826416, 12.710184097290039, 13.17557430267334, 13.446688652038574, 8.921286582946777, 12.306028366088867, 11.211018562316895, 10.989681243896484, 7.174106121063232, 10.704062461853027, 8.470945358276367, 9.832207679748535, 10.836716651916504, 11.142239570617676]
Epoch 21 of 200


 10%|█         | 21/200 [03:42<29:34,  9.91s/it]

Log Alpha Grad: -0.24415359290712588
Log Beta Grad: -0.24415359301020703
W Grad: tensor([ 0.0753,  0.0700,  0.0828,  0.0865,  0.0892, -0.1237, -0.1131, -0.1175,
        -0.1179, -0.1254,  0.0427,  0.0252,  0.0354,  0.0439,  0.0466])
Epoch 21, NLL: 135.787, Alpha: 9.2, Beta: 9.2, W: [11.810264587402344, 11.001385688781738, 12.701905250549316, 13.166923522949219, 13.437766075134277, 8.933659553527832, 12.317336082458496, 11.222765922546387, 11.00147533416748, 7.186645030975342, 10.69979476928711, 8.46842098236084, 9.828669548034668, 10.832330703735352, 11.13757610321045]
Epoch 22 of 200


 10%|█         | 21/200 [03:47<32:16, 10.82s/it]


KeyboardInterrupt: 

In [13]:
best_nll, final_alpha, final_beta, W = train(X, y, patience=100, grid_size=200, num_epochs=200,
                                            lr=.1, speech_len=speech_len, lex_size=lex_size, 
                                            log_alpha= torch.tensor(np.log(12.67/2), requires_grad=True),
                                            log_beta= torch.tensor(np.log(3.168/2), requires_grad=True),
                                            W = torch.normal(5, 2, size=(3 * lex_size,), requires_grad=True))

Beginning Inference...
Hyperparameters:
Dataset Length: 50000
Number Epochs: 200
Batch Size: 100
Grid Size: 200
Learning Rate: 0.1
Improvement Threshold: 1e-05
Patience: 100
Speech Len: 50
Initial Parameters:
Alpha: 6.335
Beta: 1.584
W: [7.939346790313721, 4.2098917961120605, 3.979785442352295, 7.232603549957275, 3.8148579597473145, 6.81782865524292, 2.8415327072143555, 3.649792194366455, 7.81660795211792, 1.30876886844635, 3.8690185546875, 3.18125319480896, 4.021854400634766, 3.7650866508483887, 5.767164707183838]


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 of 200


  0%|          | 1/200 [00:10<35:21, 10.66s/it]

Log Alpha Grad: 2.7409296677660584
Log Beta Grad: -1.8574226545494417
W Grad: tensor([ 0.1757,  0.1105,  0.1088,  0.1601,  0.1070, -0.1697, -0.2052, -0.2003,
        -0.1523, -0.2136,  0.0506,  0.0394,  0.0532,  0.0489,  0.0867])
Epoch 1, NLL: 137.632, Alpha: 4.816, Beta: 1.907, W: [7.921778678894043, 4.198837757110596, 3.9689013957977295, 7.21658992767334, 3.8041574954986572, 6.834796905517578, 2.8620548248291016, 3.6698269844055176, 7.831838130950928, 1.3301278352737427, 3.863954782485962, 3.17730975151062, 4.0165300369262695, 3.760197639465332, 5.758491039276123]
Epoch 2 of 200


  1%|          | 2/200 [00:22<38:22, 11.63s/it]

Log Alpha Grad: 1.520025228146022
Log Beta Grad: -1.44292274903582
W Grad: tensor([ 0.1719,  0.1098,  0.1082,  0.1572,  0.1064, -0.1676, -0.2018, -0.1970,
        -0.1511, -0.2100,  0.0498,  0.0390,  0.0524,  0.0481,  0.0846])
Epoch 2, NLL: 136.708, Alpha: 4.137, Beta: 2.203, W: [7.904592514038086, 4.187854766845703, 3.958080768585205, 7.200870513916016, 3.7935142517089844, 6.85155725479126, 2.882232666015625, 3.6895298957824707, 7.846945285797119, 1.351123332977295, 3.858971357345581, 3.1734094619750977, 4.011294841766357, 3.755382776260376, 5.750032901763916]
Epoch 3 of 200


  2%|▏         | 3/200 [00:35<40:16, 12.27s/it]

Log Alpha Grad: 0.8830067787551209
Log Beta Grad: -1.1299970713980678
W Grad: tensor([ 0.1724,  0.1103,  0.1087,  0.1578,  0.1069, -0.1684, -0.2030, -0.1981,
        -0.1517, -0.2112,  0.0503,  0.0394,  0.0528,  0.0486,  0.0851])
Epoch 3, NLL: 136.316, Alpha: 3.787, Beta: 2.467, W: [7.88735294342041, 4.176820755004883, 3.9472105503082275, 7.185094833374023, 3.782822370529175, 6.868393421173096, 2.9025278091430664, 3.7093420028686523, 7.862116813659668, 1.3722467422485352, 3.8539412021636963, 3.1694650650024414, 4.006012439727783, 3.750521659851074, 5.741523742675781]
Epoch 4 of 200


  2%|▏         | 4/200 [00:47<38:42, 11.85s/it]

Log Alpha Grad: 0.4964435850164178
Log Beta Grad: -0.8894735772992531
W Grad: tensor([ 0.1744,  0.1112,  0.1095,  0.1595,  0.1077, -0.1701, -0.2057, -0.2007,
        -0.1530, -0.2142,  0.0512,  0.0402,  0.0538,  0.0495,  0.0865])
Epoch 4, NLL: 136.121, Alpha: 3.604, Beta: 2.696, W: [7.8699140548706055, 4.165698051452637, 3.936255931854248, 7.1691460609436035, 3.7720494270324707, 6.885400295257568, 2.923096179962158, 3.72941255569458, 7.877412796020508, 1.3936631679534912, 3.8488199710845947, 3.165445327758789, 4.000635623931885, 3.745572090148926, 5.732871055603027]
Epoch 5 of 200


  2%|▎         | 5/200 [00:54<33:34, 10.33s/it]

Log Alpha Grad: 0.24675945613167935
Log Beta Grad: -0.707270556173286
W Grad: tensor([ 0.1771,  0.1123,  0.1105,  0.1618,  0.1087, -0.1722, -0.2092, -0.2040,
        -0.1545, -0.2179,  0.0523,  0.0411,  0.0549,  0.0506,  0.0884])
Epoch 5, NLL: 136.013, Alpha: 3.516, Beta: 2.894, W: [7.852207660675049, 4.154469966888428, 3.925201654434204, 7.152969837188721, 3.761180877685547, 6.902620315551758, 2.94401216506958, 3.749812126159668, 7.892858982086182, 1.4154517650604248, 3.8435864448547363, 3.1613352298736572, 3.99514102935791, 3.740513563156128, 5.7240309715271]
Epoch 6 of 200


  3%|▎         | 6/200 [01:02<30:03,  9.30s/it]

Log Alpha Grad: 0.08084384484693512
Log Beta Grad: -0.5720226206593847
W Grad: tensor([ 0.1801,  0.1134,  0.1116,  0.1644,  0.1097, -0.1746, -0.2131, -0.2077,
        -0.1561, -0.2221,  0.0536,  0.0421,  0.0563,  0.0518,  0.0905])
Epoch 6, NLL: 135.945, Alpha: 3.488, Beta: 3.064, W: [7.834194183349609, 4.143126964569092, 3.914038896560669, 7.136534690856934, 3.750209093093872, 6.920077323913574, 2.965319871902466, 3.770582675933838, 7.9084672927856445, 1.437659740447998, 3.8382279872894287, 3.1571261882781982, 3.989515542984009, 3.7353341579437256, 5.714978218078613]
Epoch 7 of 200


  4%|▎         | 7/200 [01:12<31:13,  9.71s/it]

Log Alpha Grad: -0.03070025014479279
Log Beta Grad: -0.4736181580452941
W Grad: tensor([ 0.1835,  0.1146,  0.1128,  0.1672,  0.1108, -0.1771, -0.2173, -0.2117,
        -0.1578, -0.2266,  0.0549,  0.0431,  0.0577,  0.0531,  0.0928])
Epoch 7, NLL: 135.895, Alpha: 3.499, Beta: 3.213, W: [7.8158464431762695, 4.131662845611572, 3.9027621746063232, 7.119819164276123, 3.739128589630127, 6.937785625457764, 2.9870502948760986, 3.791753053665161, 7.924243927001953, 1.4603208303451538, 3.8327362537384033, 3.1528117656707764, 3.983750104904175, 3.7300260066986084, 5.705695629119873]
Epoch 8 of 200


  4%|▍         | 8/200 [01:21<30:25,  9.51s/it]

Log Alpha Grad: -0.10583785465164483
Log Beta Grad: -0.4032982353354786
W Grad: tensor([ 0.1870,  0.1159,  0.1140,  0.1701,  0.1119, -0.1797, -0.2218, -0.2159,
        -0.1595, -0.2314,  0.0563,  0.0442,  0.0591,  0.0544,  0.0953])
Epoch 8, NLL: 135.855, Alpha: 3.536, Beta: 3.345, W: [7.797144412994385, 4.120072364807129, 3.891366958618164, 7.10280704498291, 3.727935314178467, 6.955755233764648, 3.009228467941284, 3.8133466243743896, 7.940191745758057, 1.4834624528884888, 3.8271045684814453, 3.1483874320983887, 3.977837562561035, 3.7245826721191406, 5.696170330047607]
Epoch 9 of 200


  4%|▍         | 9/200 [01:31<30:26,  9.56s/it]

Log Alpha Grad: -0.15622117638049612
Log Beta Grad: -0.3538183569191582
W Grad: tensor([ 0.1908,  0.1172,  0.1152,  0.1732,  0.1131, -0.1824, -0.2265, -0.2204,
        -0.1612, -0.2365,  0.0578,  0.0454,  0.0607,  0.0558,  0.0978])
Epoch 9, NLL: 135.817, Alpha: 3.591, Beta: 3.466, W: [7.778069496154785, 4.108351230621338, 3.879849433898926, 7.085484027862549, 3.716625690460205, 6.973994731903076, 3.031877040863037, 3.8353841304779053, 7.95631217956543, 1.507109522819519, 3.8213272094726562, 3.1438486576080322, 3.971771717071533, 3.718998432159424, 5.686390399932861]
Epoch 10 of 200


  5%|▌         | 10/200 [01:37<27:14,  8.60s/it]

Log Alpha Grad: -0.1896763134729509
Log Beta Grad: -0.3194378063151835
W Grad: tensor([ 0.1947,  0.1186,  0.1164,  0.1765,  0.1143, -0.1852, -0.2314, -0.2250,
        -0.1629, -0.2418,  0.0593,  0.0466,  0.0622,  0.0573,  0.1005])
Epoch 10, NLL: 135.781, Alpha: 3.66, Beta: 3.578, W: [7.758604526519775, 4.096495628356934, 3.868206024169922, 7.067835807800293, 3.7051966190338135, 6.992510795593262, 3.0550177097320557, 3.8578851222991943, 7.972605228424072, 1.531286358833313, 3.815398693084717, 3.1391918659210205, 3.9655468463897705, 3.713268280029297, 5.676344871520996]
Epoch 11 of 200


  6%|▌         | 11/200 [01:47<28:26,  9.03s/it]

Log Alpha Grad: -0.2115591290467139
Log Beta Grad: -0.29576745948243116
W Grad: tensor([ 0.1987,  0.1199,  0.1177,  0.1799,  0.1155, -0.1880, -0.2365, -0.2298,
        -0.1646, -0.2473,  0.0609,  0.0478,  0.0639,  0.0588,  0.1032])
Epoch 11, NLL: 135.744, Alpha: 3.739, Beta: 3.686, W: [7.73873233795166, 4.08450174331665, 3.856433153152466, 7.049849033355713, 3.6936450004577637, 7.01131010055542, 3.078672170639038, 3.8808696269989014, 7.989069938659668, 1.556017518043518, 3.8093135356903076, 3.134413003921509, 3.9591572284698486, 3.7073872089385986, 5.666022777557373]
Epoch 12 of 200


  6%|▌         | 12/200 [01:57<28:52,  9.22s/it]

Log Alpha Grad: -0.22556402734752143
Log Beta Grad: -0.27954870081232586
W Grad: tensor([ 0.2030,  0.1214,  0.1191,  0.1834,  0.1168, -0.1909, -0.2419, -0.2349,
        -0.1663, -0.2531,  0.0625,  0.0490,  0.0656,  0.0604,  0.1061])
Epoch 12, NLL: 135.707, Alpha: 3.824, Beta: 3.79, W: [7.718435764312744, 4.0723652839660645, 3.8445277214050293, 7.031509876251221, 3.6819677352905273, 7.030398368835449, 3.102863311767578, 3.90435791015625, 8.005703926086426, 1.581328272819519, 3.8030667304992676, 3.1295084953308105, 3.9525973796844482, 3.701349973678589, 5.655413627624512]
Epoch 13 of 200


  6%|▋         | 13/200 [02:06<28:01,  8.99s/it]

Log Alpha Grad: -0.23424071814955183
Log Beta Grad: -0.2684206636554879
W Grad: tensor([ 0.2074,  0.1228,  0.1204,  0.1871,  0.1181, -0.1938, -0.2475, -0.2401,
        -0.1680, -0.2592,  0.0641,  0.0503,  0.0674,  0.0620,  0.1091])
Epoch 13, NLL: 135.668, Alpha: 3.914, Beta: 3.893, W: [7.697696685791016, 4.06008243560791, 3.8324859142303467, 7.01280403137207, 3.670161485671997, 7.049781322479248, 3.1276142597198486, 3.9283711910247803, 8.022505760192871, 1.6072453260421753, 3.7966527938842773, 3.12447452545166, 3.945861339569092, 3.6951515674591064, 5.644505977630615]
Epoch 14 of 200


  7%|▋         | 14/200 [02:13<26:50,  8.66s/it]

Log Alpha Grad: -0.23934079721297843
Log Beta Grad: -0.26070802112001845
W Grad: tensor([ 0.2120,  0.1243,  0.1218,  0.1909,  0.1194, -0.1968, -0.2534, -0.2456,
        -0.1697, -0.2655,  0.0659,  0.0517,  0.0692,  0.0636,  0.1122])
Epoch 14, NLL: 135.629, Alpha: 4.009, Beta: 3.996, W: [7.6764960289001465, 4.047649383544922, 3.8203043937683105, 6.993716239929199, 3.6582233905792236, 7.069464206695557, 3.152949333190918, 3.9529311656951904, 8.039471626281738, 1.6337963342666626, 3.7900662422180176, 3.119307518005371, 3.9389431476593018, 3.6887869834899902, 5.633288860321045]
Epoch 15 of 200


  8%|▊         | 15/200 [02:21<25:42,  8.34s/it]

Log Alpha Grad: -0.24205774118981366
Log Beta Grad: -0.25524335820816446
W Grad: tensor([ 0.2168,  0.1259,  0.1232,  0.1948,  0.1207, -0.1999, -0.2595, -0.2513,
        -0.1712, -0.2721,  0.0676,  0.0530,  0.0711,  0.0654,  0.1154])
Epoch 15, NLL: 135.588, Alpha: 4.107, Beta: 4.099, W: [7.654814720153809, 4.035062313079834, 3.8079795837402344, 6.974231243133545, 3.6461501121520996, 7.089451789855957, 3.178894519805908, 3.978060722351074, 8.056596755981445, 1.6610106229782104, 3.783301591873169, 3.1140034198760986, 3.9318370819091797, 3.682250738143921, 5.6217498779296875]
Epoch 16 of 200


  8%|▊         | 16/200 [02:30<25:47,  8.41s/it]

Log Alpha Grad: -0.24319586252968475
Log Beta Grad: -0.2512262735797951
W Grad: tensor([ 0.2218,  0.1275,  0.1247,  0.1990,  0.1221, -0.2030, -0.2658, -0.2572,
        -0.1728, -0.2791,  0.0695,  0.0544,  0.0730,  0.0671,  0.1187])
Epoch 16, NLL: 135.545, Alpha: 4.209, Beta: 4.204, W: [7.632632732391357, 4.0223164558410645, 3.7955081462860107, 6.9543328285217285, 3.6339383125305176, 7.109748840332031, 3.2054762840270996, 4.003783702850342, 8.073875427246094, 1.6889188289642334, 3.776353120803833, 3.1085586547851562, 3.924537181854248, 3.675537586212158, 5.609877109527588]
Epoch 17 of 200


  8%|▊         | 17/200 [02:40<27:42,  9.09s/it]

Log Alpha Grad: -0.24329009598896076
Log Beta Grad: -0.2481160187200666
W Grad: tensor([ 0.2270,  0.1291,  0.1262,  0.2033,  0.1235, -0.2061, -0.2725, -0.2634,
        -0.1742, -0.2863,  0.0714,  0.0559,  0.0750,  0.0690,  0.1222])
Epoch 17, NLL: 135.501, Alpha: 4.312, Beta: 4.309, W: [7.609928607940674, 4.009407997131348, 3.782886266708374, 6.934004306793213, 3.6215851306915283, 7.130359172821045, 3.232722759246826, 4.030124664306641, 8.091300010681152, 1.7175533771514893, 3.7692153453826904, 3.1029696464538574, 3.91703724861145, 3.668642282485962, 5.597658157348633]
Epoch 18 of 200


  9%|▉         | 18/200 [02:49<26:58,  8.89s/it]

Log Alpha Grad: -0.24269061497103972
Log Beta Grad: -0.24555190070422706
W Grad: tensor([ 0.2325,  0.1308,  0.1278,  0.2078,  0.1250, -0.2093, -0.2794, -0.2699,
        -0.1756, -0.2939,  0.0733,  0.0574,  0.0771,  0.0708,  0.1258])
Epoch 18, NLL: 135.456, Alpha: 4.418, Beta: 4.416, W: [7.5866804122924805, 3.9963326454162598, 3.7701101303100586, 6.913228511810303, 3.6090872287750244, 7.151285648345947, 3.2606632709503174, 4.057109832763672, 8.10886287689209, 1.7469482421875, 3.761882781982422, 3.0972325801849365, 3.9093310832977295, 3.661559581756592, 5.585080146789551]
Epoch 19 of 200


 10%|▉         | 19/200 [02:58<27:17,  9.04s/it]

Log Alpha Grad: -0.24162251268011797
Log Beta Grad: -0.24329602465452138
W Grad: tensor([ 0.2381,  0.1325,  0.1293,  0.2124,  0.1265, -0.2125, -0.2867, -0.2766,
        -0.1769, -0.3019,  0.0753,  0.0589,  0.0792,  0.0728,  0.1295])
Epoch 19, NLL: 135.408, Alpha: 4.526, Beta: 4.525, W: [7.562865734100342, 3.983086109161377, 3.757176399230957, 6.891987323760986, 3.5964412689208984, 7.1725311279296875, 3.2893285751342773, 4.084765911102295, 8.126553535461426, 1.7771391868591309, 3.754349708557129, 3.091343879699707, 3.9014127254486084, 3.6542840003967285, 5.572129726409912]
Epoch 20 of 200


 10%|█         | 20/200 [03:07<27:28,  9.16s/it]

Log Alpha Grad: -0.24022714073640092
Log Beta Grad: -0.24119263498999843
W Grad: tensor([ 0.2440,  0.1342,  0.1309,  0.2173,  0.1280, -0.2157, -0.2942, -0.2835,
        -0.1781, -0.3102,  0.0774,  0.0604,  0.0814,  0.0747,  0.1334])
Epoch 20, NLL: 135.358, Alpha: 4.636, Beta: 4.636, W: [7.538460731506348, 3.9696640968322754, 3.744081497192383, 6.870262145996094, 3.583644390106201, 7.19409704208374, 3.3187503814697266, 4.1131205558776855, 8.144360542297363, 1.8081635236740112, 3.746610403060913, 3.0853004455566406, 3.8932759761810303, 3.646810531616211, 5.558793544769287]
Epoch 21 of 200


 10%|█         | 21/200 [03:17<27:43,  9.29s/it]

Log Alpha Grad: -0.23859061387287556
Log Beta Grad: -0.2391400432047272
W Grad: tensor([ 0.2502,  0.1360,  0.1326,  0.2223,  0.1295, -0.2189, -0.3021, -0.2908,
        -0.1791, -0.3190,  0.0795,  0.0620,  0.0836,  0.0768,  0.1374])
Epoch 21, NLL: 135.307, Alpha: 4.748, Beta: 4.748, W: [7.51344108581543, 3.9560623168945312, 3.7308218479156494, 6.848033428192139, 3.570693254470825, 7.2159833908081055, 3.3489620685577393, 4.142202854156494, 8.162269592285156, 1.8400605916976929, 3.738659620285034, 3.07909893989563, 3.8849148750305176, 3.639133930206299, 5.545057773590088]
Epoch 22 of 200


 11%|█         | 22/200 [03:26<27:11,  9.16s/it]

Log Alpha Grad: -0.2367629429230323
Log Beta Grad: -0.23707135266941506
W Grad: tensor([ 0.2566,  0.1379,  0.1343,  0.2275,  0.1311, -0.2221, -0.3104, -0.2984,
        -0.1799, -0.3281,  0.0817,  0.0636,  0.0859,  0.0788,  0.1415])
Epoch 22, NLL: 135.253, Alpha: 4.862, Beta: 4.862, W: [7.487782001495361, 3.9422767162323, 3.7173938751220703, 6.825281620025635, 3.5575850009918213, 7.238189697265625, 3.379998207092285, 4.1720428466796875, 8.180264472961426, 1.8728715181350708, 3.730491876602173, 3.0727365016937256, 3.8763234615325928, 3.63124942779541, 5.530908584594727]
Epoch 23 of 200


 12%|█▏        | 23/200 [03:34<26:08,  8.86s/it]

Log Alpha Grad: -0.23477097794816612
Log Beta Grad: -0.2349417592833581
W Grad: tensor([ 0.2632,  0.1397,  0.1360,  0.2330,  0.1327, -0.2252, -0.3190, -0.3063,
        -0.1806, -0.3377,  0.0839,  0.0653,  0.0883,  0.0810,  0.1458])
Epoch 23, NLL: 135.196, Alpha: 4.977, Beta: 4.977, W: [7.461457252502441, 3.9283030033111572, 3.703794240951538, 6.801986217498779, 3.5443167686462402, 7.260714054107666, 3.4118947982788086, 4.202670574188232, 8.198326110839844, 1.9066392183303833, 3.722102165222168, 3.0662105083465576, 3.8674960136413574, 3.623152256011963, 5.516332149505615]
Epoch 24 of 200


 12%|█▏        | 24/200 [03:42<24:46,  8.44s/it]

Log Alpha Grad: -0.23262671619090702
Log Beta Grad: -0.23272002529233604
W Grad: tensor([ 0.2702,  0.1417,  0.1377,  0.2386,  0.1343, -0.2284, -0.3279, -0.3145,
        -0.1811, -0.3477,  0.0862,  0.0669,  0.0907,  0.0831,  0.1502])
Epoch 24, NLL: 135.138, Alpha: 5.095, Beta: 5.095, W: [7.434440612792969, 3.914137125015259, 3.6900196075439453, 6.7781267166137695, 3.530885934829712, 7.283552169799805, 3.4446887969970703, 4.2341179847717285, 8.21643352508545, 1.941408634185791, 3.7134857177734375, 3.059518575668335, 3.8584272861480713, 3.614837884902954, 5.501314640045166]
Epoch 25 of 200


 12%|█▎        | 25/200 [03:49<23:22,  8.01s/it]

Log Alpha Grad: -0.23033281600477862
Log Beta Grad: -0.2303831304193925
W Grad: tensor([ 0.2774,  0.1436,  0.1395,  0.2444,  0.1360, -0.2315, -0.3373, -0.3230,
        -0.1813, -0.3582,  0.0885,  0.0686,  0.0932,  0.0854,  0.1547])
Epoch 25, NLL: 135.076, Alpha: 5.213, Beta: 5.213, W: [7.406705379486084, 3.8997750282287598, 3.6760668754577637, 6.753681659698486, 3.517289876937866, 7.306697845458984, 3.4784185886383057, 4.266417503356934, 8.234561920166016, 1.9772263765335083, 3.7046380043029785, 3.052659034729004, 3.849112033843994, 3.606302499771118, 5.485841751098633]
Epoch 26 of 200


 13%|█▎        | 26/200 [03:57<23:23,  8.07s/it]

Log Alpha Grad: -0.22788596310719147
Log Beta Grad: -0.2279127487694473
W Grad: tensor([ 0.2848,  0.1456,  0.1413,  0.2505,  0.1376, -0.2345, -0.3471, -0.3318,
        -0.1812, -0.3691,  0.0908,  0.0703,  0.0957,  0.0876,  0.1594])
Epoch 26, NLL: 135.011, Alpha: 5.334, Beta: 5.333, W: [7.378223419189453, 3.8852131366729736, 3.661933183670044, 6.7286295890808105, 3.503526210784912, 7.330143451690674, 3.5131239891052246, 4.299602031707764, 8.252684593200684, 2.014141082763672, 3.6955549716949463, 3.045630693435669, 3.839545488357544, 3.5975422859191895, 5.469900131225586]
Epoch 27 of 200


 14%|█▎        | 27/200 [04:04<22:23,  7.77s/it]

Log Alpha Grad: -0.22527902365690428
Log Beta Grad: -0.22529310873548317
W Grad: tensor([ 0.2926,  0.1477,  0.1432,  0.2568,  0.1393, -0.2374, -0.3572, -0.3410,
        -0.1808, -0.3806,  0.0932,  0.0720,  0.0982,  0.0899,  0.1642])
Epoch 27, NLL: 134.944, Alpha: 5.455, Beta: 5.455, W: [7.348967552185059, 3.8704476356506348, 3.647615909576416, 6.702948570251465, 3.4895927906036377, 7.353878974914551, 3.5488452911376953, 4.333705425262451, 8.270769119262695, 2.0522029399871826, 3.6862330436706543, 3.0384328365325928, 3.829723596572876, 3.5885543823242188, 5.4534759521484375]
Epoch 28 of 200


 14%|█▍        | 28/200 [04:10<20:41,  7.22s/it]

Log Alpha Grad: -0.22250235888584408
Log Beta Grad: -0.22250967908145172
W Grad: tensor([ 0.3006,  0.1497,  0.1450,  0.2633,  0.1411, -0.2401, -0.3678, -0.3506,
        -0.1801, -0.3926,  0.0956,  0.0737,  0.1008,  0.0922,  0.1692])
Epoch 28, NLL: 134.873, Alpha: 5.578, Beta: 5.578, W: [7.318909645080566, 3.8554751873016357, 3.633112668991089, 6.67661714553833, 3.47548770904541, 7.3778910636901855, 3.5856239795684814, 4.3687615394592285, 8.28878116607666, 2.091464042663574, 3.6766695976257324, 3.0310654640197754, 3.8196427822113037, 3.579336404800415, 5.436556339263916]
Epoch 29 of 200


 14%|█▍        | 29/200 [04:16<19:58,  7.01s/it]

Log Alpha Grad: -0.2195446886167025
Log Beta Grad: -0.219548451426128
W Grad: tensor([ 0.3089,  0.1518,  0.1469,  0.2700,  0.1428, -0.2427, -0.3788, -0.3604,
        -0.1790, -0.4051,  0.0981,  0.0754,  0.1034,  0.0945,  0.1743])
Epoch 29, NLL: 134.798, Alpha: 5.702, Beta: 5.702, W: [7.288021564483643, 3.8402926921844482, 3.6184210777282715, 6.649612903594971, 3.461209535598755, 7.402163505554199, 3.623502254486084, 4.404804229736328, 8.306681632995605, 2.1319777965545654, 3.6668624877929688, 3.023529529571533, 3.809300422668457, 3.5698864459991455, 5.41912841796875]
Epoch 30 of 200


 15%|█▌        | 30/200 [04:24<20:34,  7.26s/it]

Log Alpha Grad: -0.2163934261185132
Log Beta Grad: -0.2163953408709197
W Grad: tensor([ 0.3175,  0.1540,  0.1488,  0.2770,  0.1445, -0.2451, -0.3902, -0.3706,
        -0.1774, -0.4182,  0.1005,  0.0770,  0.1061,  0.0968,  0.1795])
Epoch 30, NLL: 134.719, Alpha: 5.826, Beta: 5.826, W: [7.256275653839111, 3.824897289276123, 3.6035397052764893, 6.621913909912109, 3.4467573165893555, 7.426677227020264, 3.662523031234741, 4.441867828369141, 8.324426651000977, 2.173799753189087, 3.6568102836608887, 3.015826940536499, 3.798694372177124, 3.560204029083252, 5.401180744171143]
Epoch 31 of 200


 16%|█▌        | 31/200 [04:33<21:41,  7.70s/it]

Log Alpha Grad: -0.2130352680253515
Log Beta Grad: -0.21303623358622745
W Grad: tensor([ 0.3263,  0.1561,  0.1507,  0.2842,  0.1463, -0.2473, -0.4021, -0.3812,
        -0.1754, -0.4319,  0.1030,  0.0787,  0.1087,  0.0991,  0.1848])
Epoch 31, NLL: 134.637, Alpha: 5.952, Beta: 5.952, W: [7.223644256591797, 3.8092868328094482, 3.5884668827056885, 6.593498706817627, 3.4321300983428955, 7.451409339904785, 3.7027294635772705, 4.47998571395874, 8.341967582702637, 2.2169859409332275, 3.646512985229492, 3.007960557937622, 3.7878236770629883, 3.5502896308898926, 5.382702350616455]
Epoch 32 of 200


 16%|█▌        | 31/200 [04:41<25:34,  9.08s/it]


KeyboardInterrupt: 