In [None]:
# if working on google colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Imports and device
import torch
import torch.nn as nn
import numpy as np 
import pickle

from sklearn.feature_extraction.text import CountVectorizer

import torch.distributions as ds
import sklearn.model_selection as ms
import pandas as pd
import scipy
from opt_einsum import contract
from sklearn.metrics import roc_auc_score as auc

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
#@title Utils
def batchify(to_batch, batch_size):
  M = to_batch[0].shape[0]
  rand = torch.randperm(M)
  for thing in to_batch:
    thing = thing[rand]
  
  i = 0
  out = [[] for thing in to_batch]

  while i + batch_size < M:
    for j, thing in enumerate(to_batch):
      out[j].append(thing[i:i+batch_size])
    i += batch_size
  for j, thing in enumerate(to_batch):
      out[j].append(thing[i:])
  return out

def coherence_single(w1, w2, W):
  eps = 0.01
  dw1 = (W[:, w1] > 0)
  dw2 = (W[:, w2] > 0)
  N = W.shape[0]

  dw1w2 = (dw1 & dw2).float().sum() / N + eps
  dw1 = dw1.float().sum() / N + eps
  dw2 = dw2.float().sum() / N + eps

  return dw1w2.log() - dw1.log() - dw2.log()
 
def coherence(topics, W):
  score = 0
  count = 0
  K, V = topics.shape[0], topics.shape[1]
  for i in range(K):
    topic = topics[i]
    for j1 in range(len(topic) - 1):
      for j2 in range(j1+1, len(topic)):
        score += coherence_single(topic[j1], topic[j2], W) 
  return score / (K * V * (V-1) / 2)

# prints top n most probable words in each topic of the model
def print_topics(model, n, vocab):
  beta = model.beta.softmax(dim=1).cpu().detach().numpy()
  topn = np.argsort(beta, axis=1)[:, -n:]
  for i in range(model.K):
      print(f"Topic {i}: eta = {model.eta[i]}\n {vocab[topn[i]]}")

def s_term_normal(y_batch, gamma_batch, eta, delta, M):
  h =  -0.5 * M * delta.log() - (y_batch ** 2).sum() / (2 * delta)         
  g0 = gamma_batch.sum(dim=1, keepdim=True)
  g = gamma_batch / g0
  outer = contract('mi,mj->mij', g, g, backend='torch')
  EXtX = (-outer / (g0.unsqueeze(2) + 1) + outer).sum(dim=0) + \
    torch.diag((g / (g0 + 1)).sum(dim=0))
  EX = g      
  first =  contract('m,k,mk->', y_batch, eta, EX, backend='torch') 
  second = contract('k,kq,q->', eta, EXtX, eta, backend='torch')     
  s_term = h + (2 * first - second) / (2 * delta)
  return s_term

def s_term_bernoulli(y_batch, gamma_batch, eta):
  g0 = gamma_batch.sum(dim=1, keepdim=True)
  g = gamma_batch / g0
  probs = contract('mk,k->m', g, eta, backend='torch').sigmoid()
  # to prevent overflows in log
  probs_cpy = probs
  if probs.min() <= 0:
    c = probs.min().detach()
    probs = probs - c + self.epsilon
  s_term1 = (y_batch * probs.log()).sum()  
  probs = probs_cpy
  if probs.max() >= 1:
    c = probs.max().detach()
    probs = probs - (c - 1) - self.epsilon
  s_term2 = ((1-y_batch) * (1-probs).log()).sum()
  s_term = s_term1 + s_term2
  return s_term


In [None]:
#@title Data Manager

# Adjust this as necessary
# DATA_PATH = "/content/drive/My Drive/Colab Notebooks/research/"
DATA_PATH = ""

def save_dict(d, name):  
  file = open(name,'wb')
  pickle.dump(d, file)

def load_dict(name):
  file = open(name,'rb')
  d = pickle.load(file)
  return d

# will need to upload this to your drive if using colab
def load_Pang_Lee():
  return load_dict(DATA_PATH + "datasets/Pang_Lee")

In [None]:
#@title sLDA 
class sLDA(nn.Module):
  """
  Implementation based on McAuliffe and Blei Supervised Topic Models
  (https://arxiv.org/pdf/1003.0783.pdf).
  Optimized with SGD on ELBO instead of coordinate ascent.
  Constrained parameters stored in transformed way to allow
  unconstrained gradient updates.
  """
  def __init__(self, K, V, M, M_val, alpha_fixed=None, device=device): 
    """
    Args: 
    K: # topics, 
    V: Vocab size, 
    M: # docs,
    M_val: # val docs,
    alpha_fixed: if true fix alpha
    """
    super(sLDA, self).__init__()
    self.name = 'slda'
    self.K = K    
    self.V = V
    self.M = M
    self.epsilon = 0.0000001
    self.alpha_fixed = alpha_fixed
   
    

    # model parameters
    alpha = torch.ones(self.K).to(device) if self.alpha_fixed else \
      ds.Exponential(1).sample([self.K]) 
    # beta stored pre-softmax (over V)
    beta = ds.Exponential(1).sample([self.K, self.V])
    beta = beta / beta.sum(dim=1, keepdim=True)   
    eta = ds.Normal(0,1).sample([self.K])
    # delta stored pre-exponentiated
    delta = ds.Normal(0,1).sample().abs()   
    
    # variational parameters
    gamma = torch.ones((self.M, self.K))
    gamma_val = torch.ones((M_val, self.K))
    # phi stored pre-softmax (over K)
    phi = torch.ones((self.M, self.K, self.V))
    phi_val = torch.ones((M_val, self.K, self.V))
    
    self.alpha = alpha if self.alpha_fixed else nn.Parameter(alpha)  
    self.beta = nn.Parameter(beta)
    self.gamma = nn.Parameter(gamma)
    self.phi = nn.Parameter(phi)   
    self.eta = nn.Parameter(eta)
    self.delta = nn.Parameter(delta)
    self.phi_val = nn.Parameter(phi_val)
    self.gamma_val = nn.Parameter(gamma_val)


  def ELBO(self, W_batch, phi_batch, gamma_batch, y_batch, version='real'):
    """
    Computes sLDA ELBO:
    first_term: log p(theta|alpha)
    second_term: log p(z|theta)
    third_term: log p(w|z,beta)
    fourth_term: log q(theta|gamma)
    fifth_term: log q(z|phi)
    s_term: log p(y|theta)
    """
    M = W_batch.shape[0]
    ss = torch.digamma(gamma_batch) \
      - torch.digamma(gamma_batch.sum(dim=1, keepdim=True))
    
    # transform constrained parameters to be valid
    phi = phi_batch.softmax(dim=1)
    beta = self.beta.softmax(dim=1)
    delta = self.delta.exp()

    first_term = M * (torch.lgamma(self.alpha.sum()) \
      - torch.lgamma(self.alpha).sum()) \
      + contract('mk,k->', ss, self.alpha - 1, backend='torch')
    
    second_term = contract(
      'mkv,mk,mv->', 
      phi, ss, W_batch, 
      backend='torch'
    ) 

    third_term = contract(
      'mkv,mv,kv->', 
      phi, W_batch, beta.log(), 
      backend='torch'
    ) 
    
    fourth_term = torch.lgamma(gamma_batch.sum(dim=1)).sum() \
      - torch.lgamma(gamma_batch).sum() \
      + contract('mk,mk->', ss, gamma_batch - 1, backend='torch')
    
    fifth_term = contract(
      'mkv,mkv,mv->', 
      phi, phi.log(), W_batch, 
      backend='torch'
    )

    if version=='real':
      s_term = s_term_normal(y_batch, gamma_batch, self.eta, delta, M)
    else:
      s_term = s_term_bernoulli(y_batch, gamma_batch, self.eta)

    return first_term + second_term + third_term - \
      fourth_term - fifth_term + s_term

  
  # can also batch if needed, uses theta map as in pc-slda
  # (http://proceedings.mlr.press/v84/hughes18a/hughes18a.pdf)
  def pred(self, W, y=None):
    theta_maps = self.theta_map(W)
    preds = torch.mv(theta_maps, self.eta)
    return theta_maps, preds

 
  # calculate theta map with SGD on the posterior of theta
  def theta_map(self, W, num_epochs = 500, lr = 0.005):      
    theta_maps = torch.ones(
      (W.shape[0], self.K), 
      requires_grad=True, 
      device=device
    )
    
    opt = torch.optim.Adam([theta_maps], lr=lr)
    for i in range(num_epochs):
      opt.zero_grad()
      score = self.theta_post(W, theta_maps)
      loss = -1 * score
      loss.sum().backward()
      opt.step()
    
    return theta_maps.softmax(dim=1)
        
  
  # calculate the posterior of theta
  def theta_post(self, W, theta):
    bl = contract(
      'kv,mk->mv', 
      self.beta.softmax(dim=1), 
      theta.softmax(dim=1),
      backend='torch'
    )
    out1 = contract('mv,mv->m', W, bl.log(), backend='torch') 
    out2 = torch.mv(theta.softmax(dim=1).log(), self.alpha - 1)
    return out1 + out2




In [None]:
#@title pf-sLDA
class pfsLDA(nn.Module):

  """
  Implementation based on Ren et. al. pf-sLDA
  (https://arxiv.org/pdf/1910.05495.pdf).
  Constrained parameters stored in transformed way to allow
  unconstrained gradient updates.
  """
  def __init__(self, K, V, M, M_val, p, alpha_fixed = None):
    """
    Args: 
    K: # topics, 
    V: Vocab size, 
    M: # docs,
    M_val: # val docs,
    p : switch prior
    alpha_fixed: if true fix alpha
    """
    super(pfsLDA, self).__init__()
    self.name = 'pfslda'
    self.K = K      
    self.V = V
    self.M = M
    self.M_val = M_val
    self.epsilon = 0.0000001
    self.alpha_fixed = alpha_fixed

    # model parameters
    alpha = torch.ones(self.K).to(device) if self.alpha_fixed else \
      ds.Exponential(1).sample([self.K])        
    # beta stored pre-softmax (over V)
    beta = ds.Exponential(1).sample([self.K, self.V])
    beta = beta / beta.sum(dim=1, keepdim=True)   
    # pi stored pre-softmax (over V)
    pi = ds.Exponential(1).sample([self.V])
    pi = pi / pi.sum()  
    eta = ds.Normal(0,1).sample([self.K])
    # delta stored pre-exponentiated
    delta = ds.Normal(0,1).sample().abs() 
    
    # variational parameters
    gamma = torch.ones((self.M, self.K))
    gamma_val = torch.ones(self.M_val, self.K)
    # phi stored pre softmax (over K)
    phi = torch.ones(self.M, self.K, self.V)
    phi_val = torch.ones(self.M_val, self.K, self.V)
    # varphi stored pre-sigmoid
    varphi = torch.ones(self.V) * p 

    self.alpha = alpha if self.alpha_fixed else nn.Parameter(alpha)  
    self.beta = nn.Parameter(beta)
    self.gamma = nn.Parameter(gamma)
    self.phi = nn.Parameter(phi)   
    self.eta = nn.Parameter(eta)
    self.delta = nn.Parameter(delta)
    self.pi = nn.Parameter(pi)
    self.varphi = nn.Parameter(varphi)    
    self.phi_val = nn.Parameter(phi_val)
    self.gamma_val = nn.Parameter(gamma_val)
    self.p = p
      
 
  def ELBO(self, W_batch, phi_batch, gamma_batch, y_batch, version='real'):
    """
    Computes pf-sLDA ELBO:
    See appendix of https://arxiv.org/pdf/1910.05495.pdf for details
    first_term: log p(theta|alpha)
    second_term: log p(z|theta)
    third_term: log p(w|z,beta)
    fourth_term: log p(xi|p)
    fifth_term: log q(theta|gamma)
    sixth_term: log q(z|phi)
    seventh_term: log q(xi|varphi)
    s_term: log p(y|theta)
    """
    M = W_batch.shape[0]
    N_tot = W_batch.sum()
    ss = torch.digamma(gamma_batch) - \
      torch.digamma(gamma_batch.sum(dim=1, keepdim=True))
    
    # transform constrained parameters to be valid
    phi = phi_batch.softmax(dim=1)
    beta = self.beta.softmax(dim=1)
    pi = self.pi.softmax(dim=0)
    varphi = self.varphi.sigmoid() 
    p = self.p.sigmoid()
    delta = self.delta ** 2
         
    first_term = M \
      * (torch.lgamma(self.alpha.sum()) - torch.lgamma(self.alpha).sum()) \
      + contract('mk,k->', ss, self.alpha - 1, backend='torch')
    
    second_term = contract(
      'mkv,mk,mv->', 
      phi, ss, W_batch, 
      backend='torch'
    )  
                      
    third_term1 = contract(
      'mkv,mv,kv,v->', 
      phi, W_batch, beta.log(), varphi, 
      backend='torch'
    ) 
    third_term2 = contract(
      'mv,v,v->', 
      W_batch, pi.log(), varphi, 
      backend='torch'
    )
    third_term3 = contract(
      'mv,v->', 
      W_batch, pi.log(), 
      backend='torch'
    )
    third_term = third_term1 - third_term2 + third_term3
   
    fourth_term1 = contract(
      'mv,v->', 
      W_batch, varphi, 
      backend='torch'
    ) 
    fourth_term2 = contract(
      'mv,v->', 
      W_batch, 1 - varphi, 
      backend='torch'
    ) 
    fourth_term = p.log() * fourth_term1 + (1-p).log() * fourth_term2
    
    fifth_term = torch.lgamma(gamma_batch.sum(dim=1)).sum() - \
      torch.lgamma(gamma_batch).sum() + \
      contract('mk,mk->', ss, gamma_batch - 1, backend='torch')  

    sixth_term = contract(
      'mkv,mkv,mv->', 
      phi, phi.log(), W_batch, 
      backend='torch'
    )
              
    orig_varphi = varphi
    if varphi.min() <= 0:
      c = varphi.min().detach()
      varphi = varphi - c + self.epsilon
    seventh_term1 = contract(
      'mv,v,v->', 
      W_batch, varphi, varphi.log(), 
      backend='torch'
    )
    varphi = orig_varphi
    if varphi.max() >= 1:
      c = varphi.max().detach()
      varphi = varphi - (c - 1) - self.epsilon
    seventh_term2 = contract(
      'mv,v,v->', 
      W_batch, 1 - varphi, (1 - varphi).log(), 
      backend='torch'
    )
    seventh_term = seventh_term1 + seventh_term2

    if version=='real':
      s_term = s_term_normal(y_batch, gamma_batch, self.eta, delta, M)
    else:
      s_term = s_term_bernoulli(y_batch, gamma_batch, self.eta)
    
    return first_term + second_term + third_term + fourth_term \
      - fifth_term - sixth_term - seventh_term + s_term

  
  # can also batch if needed, uses theta map as in pc-slda
  # (http://proceedings.mlr.press/v84/hughes18a/hughes18a.pdf)  
  def pred(self, W, y=None):
    theta_maps = self.theta_map(W)
    preds = torch.mv(theta_maps, self.eta)
    return theta_maps, preds
  
  
  # calculate theta map with SGD on the posterior of theta
  def theta_map(self, W, num_epochs = 500, lr = 0.005):      
    theta_maps = torch.ones(
      (W.shape[0], self.K), 
      requires_grad = True, 
      device=device
    )
    opt = torch.optim.Adam([theta_maps], lr = lr)
    for i in range(num_epochs):
      opt.zero_grad()
      score = self.theta_post(W, theta_maps)
      loss = -1 * score
      loss.sum().backward()
      opt.step()
    return theta_maps.softmax(dim=1)
        
  
  # calculate the posterior of theta
  def theta_post(self, W, theta):
    bl = contract(
      'kv,mk->mv', 
      self.beta.softmax(dim=1), 
      theta.softmax(dim=1), 
      backend='torch'
    )
    
    out1 = contract(
      'v,mv,mv->m', 
      self.varphi.sigmoid(), W, bl.log(), 
      backend='torch'
    ) 
    out2 = torch.mv(theta.softmax(dim=1).log(), self.alpha)
    
    return out1 + out2

In [None]:
#@title Training
def fit(model, W, y, lr, lambd, num_epochs, batch_size, 
          check, version, W_val, y_val, device, y_thresh, c_thresh):     
  """
  Args: 
  W: count data, 
  y: targets, 
  lr: initial learning rate
  lambd: supervised task regularizer weight
  """     

  print(f"Training {model.name} on {device}.")
  opt = torch.optim.Adam(model.parameters(), lr = lr)
 
  for i in range(num_epochs):
    to_batch = [W, model.phi, model.gamma, y]    
    batches = batchify(to_batch, batch_size)
    W_b, phi_b, gamma_b, y_b = batches[0], batches[1], batches[2], batches[3]
    tot = 0
    for j in range(len(W_b)):
      opt.zero_grad()    
      elbo = model.ELBO(W_b[j].to(device), phi_b[j], gamma_b[j], y_b[j].to(device), version=version) 
      tot += elbo.item()
      loss = -1 * elbo + lambd * (model.eta ** 2).sum()
      loss.backward()
      opt.step()

    if i % check == 0:
      val_yscore, c = calc_stats_and_print(
        model, W, W_val.to(device), y_val.to(device), 
        tot / W.sum(), i,  version
      )
      
      save = False
      if (y_thresh and val_yscore < y_thresh) or (c_thresh and c > c_thresh):
         save = True
      if save:
         path = DATA_PATH + f"{model.name}_ y{val_yscore:.2f}_c{c:.2f}.pt"
         torch.save(model.state_dict(), path)
    
  # save last model if no thresholds   
  if not y_thresh and not c_thresh:
    val_yscore, c = calc_stats_and_print(
      model, W, W_val.to(device), y_val.to(device), 
      tot / W.sum(), num_epochs,  version
    )
    path = DATA_PATH + f"{model.name}_ y{val_yscore:.2f}_c{c:.2f}.pt"
    torch.save(model.state_dict(), path)

  return 

def yscore(model, W, y, version='real'):
  _, preds = model.pred(W)
  if version == 'real':
    score = ((preds - y) ** 2).mean().sqrt()
  else:
    probs = preds.sigmoid().cpu().detach().numpy()
    score = auc(y.cpu().detach().numpy(), probs)
  return score
  
def calc_stats_and_print(model, W, W_val, y_val, elbo, i, version):
  val_yscore = yscore(model, W_val, y_val, version=version)  
  beta = model.beta.softmax(dim=1).cpu().detach().numpy()
  topk = np.argsort(beta, axis=1)[:, -50:]
  c = coherence(topk, W)      
  print(f"Iter: {i}")
  print(f"ELBO: {elbo}")
  print(f"Val yscore: {val_yscore}")
  print(f"Coherence: {c}\n")
  return val_yscore, c

In [None]:
#@title Main
def main(args):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

  if args['K'] < 1:
    raise ValueError('Invalid number of topics specified.')
  
  p = args['p']
  if p > 1 or p < 0:
    raise ValueError('Invalid switch prior p.')
  p = torch.tensor(p).to(device)
  p = torch.log(p / (1 - p))

  d = load_Pang_Lee()
  W = d['W']
  W_val = d['W_val']
  y = d['y']
  y_val = d['y_val']  
  W_test = d['W_test']
  y_test = d['y_test']
  vocab = d['vocab']
  version = 'real'

  V = W.shape[1]
  M = W.shape[0]
  M_val = W_val.shape[0]

  if args['model'] == 'slda':
      model = sLDA(args['K'], V, M, M_val, alpha_fixed=args['alpha'])
  elif args['model'] == 'pfslda':
      model = pfsLDA(args['K'], V, M, M_val, p, alpha_fixed=args['alpha'])
  model.to(device)

  if 'path' in args:
      state_dict = torch.load(args['path'], map_location = device)
      model.load_state_dict(state_dict)

  kwargs = {
      'W' : W,
      'y' : y, 
      'lr' : args['lr'], 
      'lambd' : args['lambd'],
      'num_epochs' : args['num_epochs'], 
      'check' : args['check'], 
      'batch_size' : args['batch_size'], 
      'version' : version,
      'W_val' : W_val,
      'y_val' : y_val,
      'device' : device,
      'y_thresh' : args['y_thresh'],
      'c_thresh' : args['c_thresh']
  }

  fit(model, **kwargs)
  print_topics(model, 10, vocab)
  return model

In [None]:
args = {
    'K' : 5,
    'model' : 'pfslda',
    'p' : 0.15,
    'alpha' : True,
    'lr' : 0.025, 
    'lambd' : 0,
    'num_epochs' : 100, 
    'check' : 10, 
    'batch_size' : 100, 
    'y_thresh' : None,
    'c_thresh' : None
}

model = main(args)