A few slight changes from 3-6

In [1]:
from abc import ABC, abstractmethod

import os
from tqdm import tqdm
import math 
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 

from scipy import stats
from scipy.special import digamma

import torch
from torch.distributions import Beta, kl_divergence
from torch.nn.functional import log_softmax
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, TensorDataset

from bayes_opt import BayesianOptimization

In [2]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f6a5ffd7550>

In [3]:
### these params control the generation scheme (2/10 generation)
rho        = 0.8    # polarization
pop_size   = 50000  # num individuals
epsilon    = 0.05   # expected prop of speech consisting of neutral words
pi         = 0.5    # pi == 0.5 => beta mixture symmetrical (choose beta1 with prob pi = 0.5)
speech_len = 50     # words per speech
lex_size   = 3      # lexicon size 

In [4]:
def generate(rho=rho, N=pop_size, epsilon=epsilon, pi=pi, speech_len=speech_len, lex_size=lex_size, verbose='low'):
    """
    Uses 2/10 generation scheme to generate N samples.

    Returns:
        (X, y), (a, b, rho, epsilon, u, lex_size)
        X.size() == [N, S] is a vector of sequences 
        y.size() == [N] is a vector of political parties
        u.shape  == (N,) is a vector of individual stances
        rho is true polarization
        epsilon is expected prop of neutral words
        a, b are true alpha/beta for beta mixture model
    """
    start = time.time()
    if verbose in ['mid', 'high']:
        print(f'Beginning Data Generation...')
        print(f'=' * 20)
        
    ### get beta mixture model params 
    sigma = 0.175 * (rho ** 2) - 0.3625 * rho + 0.1875
    a     = rho * ((rho * (1 - rho)) / sigma - 1)
    b     = (1 - rho) * ((rho * (1 - rho)) / sigma - 1)

    if verbose in ['mid', 'high']:
        print(f'Lex Size: {lex_size}, Rho: {rho}, N: {pop_size}')
        print(f'epsilon: {epsilon}, pi: {pi}, S: {speech_len}\n')
        print(f"True Alpha: {a}")
        print(f"True Beta: {b}\n")

    mean = a / (a + b)
    var  = a * b / ((a + b)**2 * (a + b + 1))

    if abs(mean - rho) > 10e-15:
        print(f'Mean: {mean}')
        print(f"Rho: {rho}")
        raise AssertionError(f"Mean of BMM params should be rho")

    if abs(var - sigma) > 10e-15:
        print(f'Var: {var}')
        print(f'Sigma: {sigma}')
        raise AssertionError(f"Var of BMM params should be sigma")

    ### u ~ pi Beta(a, b) + (1 - pi) Beta(b, a)
    weights = [pi, 1-pi]
    mixture_samples = np.random.choice([0, 1], size=N, p=weights)

    u = 2 * np.where(mixture_samples == 1, stats.beta.rvs(a, b, size=N), stats.beta.rvs(b, a, size=N)) - 1

    if u.shape != (N,):
        raise AssertionError(f"u.shape should be (N,)")

    if verbose == 'high':
        print(f'mixture samples: {mixture_samples[:3]}')
        print(f'u samples: {u[:3]}')

    ### y = 1(u >= 0)
    y = (u >= 0).astype(int)

    if verbose == 'high':
        print(f'y samples: {y[:3]}\n')

    ### phi is a prob matrix that is a function of u, epsilon
    left_prob    = (1 - (u+1)/2) * (1 - epsilon) / lex_size
    right_prob   = (u+1)/2 * (1 - epsilon) / lex_size
    neutral_prob = np.repeat(epsilon, N) / lex_size
    phi          = np.array([left_prob] * lex_size + [right_prob] * lex_size + [neutral_prob] * lex_size).T
    
    if verbose == 'high':
        for i in range(3):
            print(f'P(L) = {phi[i,0]}, P(R): {phi[i, lex_size]}, P(N): {phi[i,-1]}')

    if phi.shape != (N, 3 * lex_size):
        raise AssertionError(f'phi.shape should be (N, V) == (N, 3L)')

    if abs(sum(phi[0]) - 1) > 10e-5:
        raise AssertionError(f'rows of phi should sum to 1')
    
    X = np.array([stats.multinomial.rvs(n=speech_len, p=phi[i, :]) for i in range(N)])

    if verbose == 'high':
        print(f'X samples (counts):\n {X[:3]}\n')
    
    if X.shape != (N, 3 * lex_size):
        raise AssertionError(f'X.shape should be (N, V) == (N, 3L)')

    if X[:5].sum() != speech_len * 5:
        raise AssertionError(f'rows of phi should sum to 1')

    expanded_rows = []
    for row in X:
        expanded_row = np.repeat(np.arange(len(row)), row)
        np.random.shuffle(expanded_row)
        expanded_rows.append(expanded_row)

    X = np.array(expanded_rows, dtype=np.int64)

    if X.shape != (N, speech_len):
        raise AssertionError(f'X.shape should be (N, S)')
    
    if verbose == 'high':
        print(f'X samples (sequences):\n {X[:3]}\n')

    X = torch.from_numpy(X).to(torch.float32)
    y = torch.from_numpy(y).to(torch.float32)

    known   = (X, y)
    unknown = (a, b, rho, epsilon, u, lex_size)

    if verbose in ['mid', 'high']:
        print('=' * 20)
        print(f'Generation Time: {round(time.time() - start, 3)} seconds for {N} samples.')

    return known, unknown

known, unknown = generate(verbose='high')
X, y = known
a, b, rho, epsilon, u, lex_size = unknown

Beginning Data Generation...
Lex Size: 3, Rho: 0.8, N: 50000
epsilon: 0.05, pi: 0.5, S: 50

True Alpha: 12.673684210526261
True Beta: 3.1684210526315644

mixture samples: [0 1 0]
u samples: [-0.71319006  0.57688516 -0.46523368]
y samples: [0 1 0]

P(L) = 0.2712550923688634, P(R): 0.04541157429780326, P(N): 0.016666666666666666
P(L) = 0.06699318292320035, P(R): 0.24967348374346632, P(N): 0.016666666666666666
P(L) = 0.23199533330535926, P(R): 0.08467133336130739, P(N): 0.016666666666666666
X samples (counts):
 [[16 13 14  2  1  2  0  0  2]
 [ 2  3  3 12 13 15  1  0  1]
 [15 15  8  3  4  5  0  0  0]]

X samples (sequences):
 [[0 2 2 2 2 0 0 2 0 1 0 4 0 1 1 3 2 1 2 0 0 5 0 1 0 0 1 2 1 1 0 8 2 2 0 1
  1 3 1 1 2 0 8 2 1 2 2 5 0 0]
 [3 4 5 4 5 5 3 3 1 6 5 0 3 8 4 4 3 5 0 3 5 5 3 4 5 4 1 3 5 5 5 3 2 2 2 5
  4 5 4 3 4 1 4 3 5 4 3 4 5 4]
 [1 5 1 0 1 1 1 2 1 0 5 4 4 0 0 1 0 0 3 0 1 0 0 0 5 1 0 5 3 5 0 2 1 1 2 1
  0 2 3 0 1 4 2 0 2 2 1 1 4 2]]

Generation Time: 1.808 seconds for 50000 samples.


In [5]:
grid_size  = 200 ### discretize integral into grid_size linearly spaced pts
batch_size = 100 ### num samples processed simultaneously
num_epochs = 200 ### number epochs
lr         = 0.1 ### learning rate

def compute_nll(x_batch, y_batch, log_alpha, log_beta, W):
    """
    Parameters:
        (x_batch, y_batch) a batch of B (x,y) samples 
        (log_alpha, log_beta, W) trainable parameters

    Returns:
        A scaler quantity s.t. calling .backward() will compute
        grad_{theta} (L(theta))
    """
    raise NotImplementedError

def train(X, y, log_alpha=None, log_beta=None, W=None, num_epochs=num_epochs, batch_size=batch_size, grid_size=grid_size, lr=lr, verbose=True, threshold=0.00001, patience=5, speech_len=speech_len, lex_size=lex_size):
    if verbose:
        print(f'Beginning Inference...')
        print(f'=' * 20)
        print(f'Hyperparameters:')
        print(f'Dataset Length: {len(X)}')
        print(f'Number Epochs: {num_epochs}')
        print(f'Batch Size: {batch_size}')
        print(f'Grid Size: {grid_size}')
        print(f"Learning Rate: {lr}")
        print(f"Improvement Threshold: {threshold}")
        print(f"Patience: {patience}")
        print(f'Speech Len: {speech_len}')
        print('=' * 20)

    ### not strictly necessary 
    assert len(X) % batch_size == 0
    
    start = time.time()

    dataset    = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    X_batch, y_batch = list(dataloader)[0]

    if list(X_batch.size()) != [batch_size, speech_len]:
        raise AssertionError(f"X_ex.size() should be [batch_size, S], got {X_batch.size()}")

    if list(y_batch.size()) != [batch_size]:
        raise AssertionError(f"y_ex.size() should be [batch_size], got {y_batch.size()}")

    if log_alpha is None:
        log_alpha = torch.normal(0, math.sqrt(3), size=(1,), requires_grad=True)

    if log_beta is None:
        log_beta  = torch.normal(0, math.sqrt(3), size=(1,), requires_grad=True)

    if W is None:
        W = torch.normal(0, 1, size=(3 * lex_size,), requires_grad=True)

    print(f'Initial Parameters:')
    print(f'Alpha: {torch.exp(log_alpha).item()}')
    print(f"Beta: {torch.exp(log_beta).item()}")
    print(f"W: {W.data.tolist()}")

    optimizer = SGD([log_alpha, log_beta, W], lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    # Early stopping logic
    best_nll = float('inf')
    epochs_without_improvement = 0

    for epoch in tqdm(range(num_epochs)):
        if verbose:
            print(f'Epoch {epoch+1} of {num_epochs}')
        ### only step after each epoch
        optimizer.zero_grad() 
        nll = torch.tensor(0.)
        
        for x_batch, y_batch in dataloader:
            nll += compute_nll(x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size)
        
        nll = nll / len(dataloader.dataset)
        nll.backward()
        print('=' * 20)
        print(f'Log Alpha Grad: {log_alpha.grad}')
        print(f'Log Beta Grad: {log_beta.grad}')
        print(f'W Grad: {W.grad}')
        print('=' * 20)
        optimizer.step()
        scheduler.step()

        current_nll = nll.item()
        if current_nll < best_nll - threshold:
            best_nll = current_nll
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if verbose:
            print(f'Epoch {epoch+1}, NLL: {round(nll.item(), 3)}, Alpha: {round(np.exp(log_alpha.item()),3)}, Beta: {round(np.exp(log_beta.item()),3)}, W: {W.tolist()}')
            print('=' * 20)

        if epochs_without_improvement >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    final_alpha = np.exp(log_alpha.item())
    final_beta  = np.exp(log_beta.item())

    if verbose:
        print(f"Training Took {round(time.time() - start, 2)} Seconds.")
        
    print(f"Trained Params: alpha: {final_alpha}, beta: {final_beta}, W: {W.tolist()}")

    return best_nll, final_alpha, final_beta, W

In [11]:
### numerically stable factorial
factorial = lambda x : torch.exp(torch.lgamma(x+1))

assert factorial(torch.tensor(3.)) == torch.tensor(6.)
print(f"3.0001! = {factorial(torch.tensor(3.001)).item()}")

### notice this is infinity, so we can't call factorial on negative stuff
print(f"-1! = {factorial(torch.tensor(-1.)).item()}")

def compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
        u_mat   (size [batch_size, grid_size])
        (log_alpha, log_beta, W) trainable params 
    
    Returns:
        torch.tensor (size [batch_size, grid_size]), where the (i,j)th element 
        is the log joint probability (log p(u, x, y; theta)). Here u is the (i,j)th
        element of u_mat, x is the ith row of x_batch, y is the ith element of y_batch

    """
    assert list(u_mat.size())   == [batch_size, grid_size] 
    assert list(x_batch.size()) == [batch_size, speech_len]
    assert list(y_batch.size()) == [batch_size]

    if not torch.all((u_mat >= 0).float() == y_batch.unsqueeze(1)).item():
        raise ValueError(f"u_mat is incompatible with y_batch.")

    ### log prior: log p(u) = log(1/4 Beta((u+1)/2; a, b) + 1/4 Beta((u+1)/2; b, a))
    ###                     = log(1/4) + log(Beta((u+1)/2; a, b) + Beta((u+1)/2; b, a))
    alpha = torch.exp(log_alpha)
    beta  = torch.exp(log_beta)
        
    beta_dist_ab = Beta(alpha, beta)
    beta_dist_ba = Beta(beta, alpha)

    ### see 3-6 testing.ipynb test1 to see how log_prob broadcasts
    log_beta_prob_ab = beta_dist_ab.log_prob((u_mat + 1) / 2)
    log_beta_prob_ba = beta_dist_ba.log_prob((u_mat + 1) / 2)

    prior = torch.log(torch.exp(log_beta_prob_ab) + torch.exp(log_beta_prob_ba)) + torch.log(torch.tensor(0.25))

    assert list(prior.size()) == [batch_size, grid_size]
  
    ### log likelihood: log p(x, y | u) = dot(x, log Softmax(Wu)) + log(S!/(x1! x2! x3!)) 

    ### step 1: compute Wu (see 3-6 testing.ipynb test2 to see this logic)
    W_expanded = W.unsqueeze(0).unsqueeze(0)
    u_expanded = u_mat.unsqueeze(2)
    Wu = torch.matmul(u_expanded, W_expanded) 

    assert list(W_expanded.size()) == [1,1,3 * lex_size]
    assert list(u_expanded.size()) == [batch_size, grid_size, 1]
    assert list(Wu.size()) == [batch_size, grid_size, 3 * lex_size]

    ### step 2: dot(x, log Softmax(Wu)) (see 3-6 testing.ipynb test3 to see this logic)
    ones = torch.ones_like(x_batch, dtype=torch.float)
    # Assuming x_batch needs to be adjusted by -1 for 0-indexing
    x_batch_adjusted = x_batch - 1
    x_batch_trans = torch.zeros(batch_size, 3 * lex_size, dtype=torch.float)
    ones = torch.ones_like(x_batch, dtype=torch.float)

    x_batch_trans.scatter_add_(dim=1, index=x_batch_adjusted.long(), src=ones)
    likelihood = (x_batch_trans.unsqueeze(1) * log_softmax(Wu, dim=2)).sum(dim=2)

    ### ensure not taking negative factorials
    assert torch.all(x_batch >= 0)

    assert list(likelihood.size()) == [batch_size, grid_size]
    return prior + likelihood


def compute_nll(x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size):
    """
    Parameters:
        x_batch (size [batch_size, 3])
        y_batch (size [batch_size])
        log_alpha, log_beta, W (trainable params)
        grid_size (int) number of lin spaced points to discretize integral into
    
    Returns:
        quantity s.t. .backward() would be the gradient of the nll
    """
    ### see 3-6 testing.ipynb test5 to see this logic
    u_mat = torch.empty(batch_size, grid_size)
    u_mat[y_batch == 1] = torch.linspace(1/(grid_size+1), 1-1/(grid_size+1), grid_size).repeat((y_batch == 1).sum(), 1)
    u_mat[y_batch == 0] = torch.linspace(-1+1/(grid_size+1), -1/(grid_size+1), grid_size).repeat((y_batch == 0).sum(), 1)

    assert list(u_mat.size()) == [batch_size, grid_size]

    log_joint = compute_log_joint(u_mat, x_batch, y_batch, log_alpha, log_beta, W, grid_size, batch_size, speech_len, lex_size)
    assert list(log_joint.size()) == [batch_size, grid_size]

    ### see 3-6 testing.ipynb test6 to see this logic
    max_log_prob = torch.max(log_joint, dim=1, keepdim=True)[0]
    joint_probs = torch.exp(log_joint - max_log_prob)
    posterior = joint_probs / joint_probs.sum(dim=1, keepdim=True) 
    # detach posterior (see 3-6 testing.ipynb test7 to verify this works)
    weighted_log_joint = posterior.detach() * log_joint
    nll = -weighted_log_joint.sum()

    return nll

3.0001! = 6.007541656494141
-1! = inf


In [7]:
rho = 0.9
N = 50000
epsilon = 0.
lex_size = 5
speech_len = 50

known, unknown = generate(rho=rho, N=N, verbose='high', epsilon=epsilon, lex_size=lex_size, speech_len=speech_len)
X, y = known 
a, b, rho, epsilon, u, lex_size = unknown

Beginning Data Generation...
Lex Size: 5, Rho: 0.9, N: 50000
epsilon: 0.0, pi: 0.5, S: 50

True Alpha: 26.099999999999973
True Beta: 2.8999999999999964

mixture samples: [0 0 0]
u samples: [-0.5642138  -0.92941731 -0.63742607]
y samples: [0 0 0]

P(L) = 0.15642138042073847, P(R): 0.04357861957926153, P(N): 0.0
P(L) = 0.19294173065395445, P(R): 0.007058269346045543, P(N): 0.0
P(L) = 0.16374260689035045, P(R): 0.03625739310964955, P(N): 0.0
X samples (counts):
 [[ 6  8  9  6  7  4  1  3  2  4  0  0  0  0  0]
 [ 9  7 11  7 13  0  0  2  1  0  0  0  0  0  0]
 [10  8  5  8  9  3  1  3  0  3  0  0  0  0  0]]

X samples (sequences):
 [[0 0 1 7 3 2 2 0 3 1 4 4 1 5 0 1 7 8 8 4 2 4 3 5 9 0 4 1 3 1 5 5 2 9 2 1
  2 2 3 9 6 1 4 9 4 2 0 3 2 7]
 [0 2 3 3 4 1 1 4 3 3 1 7 3 2 4 4 2 1 4 0 4 2 4 2 2 2 1 0 4 8 2 0 0 4 3 0
  2 2 4 3 1 1 2 4 0 0 0 7 4 4]
 [0 4 0 1 3 3 7 9 3 1 1 6 9 0 2 3 1 1 4 4 5 0 0 2 0 1 4 4 9 0 0 7 2 4 5 1
  3 0 2 5 0 2 3 4 3 1 7 4 3 4]]

Generation Time: 1.748 seconds for 50000 samples.

In [12]:
best_nll, final_alpha, final_beta, W = train(X, y, patience=100, grid_size=1000, num_epochs=200, lr=.1, speech_len=speech_len,
    log_alpha = torch.tensor(np.log(13.5), requires_grad=True), log_beta=torch.tensor(np.log(1.5), requires_grad=True))

Beginning Inference...
Hyperparameters:
Dataset Length: 50000
Number Epochs: 200
Batch Size: 100
Grid Size: 1000
Learning Rate: 0.1
Improvement Threshold: 1e-05
Patience: 100
Speech Len: 50
Initial Parameters:
Alpha: 13.5
Beta: 1.5
W: [0.02502315863966942, -0.3245088458061218, 0.2878694534301758, 1.0578700304031372, 0.9620521664619446, 0.39347872138023376, 1.1322051286697388, -0.5403615832328796, -2.210235357284546]


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 of 200





IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [7]:
def evaluate(a_hat, b_hat, a_true, b_true):
    pred_dist = Beta(a_hat, b_hat)
    true_dist = Beta(a_true, b_true)
    
    kl_div    = kl_divergence(pred_dist, true_dist)
    print(f'KL Divergence: {kl_div}')
    return kl_div