In [10]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

#from datasets import get_mnist_dataset, get_data_loader
#from utils import *
#from models import *

import pickle as pkl
import os

## Get Dataloaders

In [12]:
def get_dataloaders(train_filename,val_filename):
    path = os.getcwd()
    data_dir = path + '/data/'
    train_dataloader = pkl.load(open(data_dir + train_filename,'rb'))
    val_dataloader = pkl.load(open(data_dir + val_filename,'rb'))
    return train_dataloader,val_dataloader

In [13]:
#train_dataloader,val_dataloader = get_dataloaders('name1','name2')

## Scratchwork (IGNORE)

In [26]:
tensor = torch.Tensor([[0, 1, 2, 3],[4,5,6,7]])

In [32]:
tensor.sum(0)

tensor([ 4.,  6.,  8., 10.])

## Neural Network Class

NOTE: Data loader is defined as:
- tuple: (tokens, flagged_index, problematic)

In [34]:
class neuralNetBow(nn.Module):
    """
    BagOfWords classification model
    """
    # NOTE: in baseline model, can't do linear layers, because they will remember certain
    # positions as being more important than others (ie, 4th word vs 7th word)
    def __init__(self, vocab_size, emb_dim, upweight=10):
        super(neuralNetBow, self).__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.upweight = upweight
    
    def forward(self, tokens, flagged_index):
        num_tokens = len(tokens)
        embedding = self.embed(tokens)
        # print(embedding.shape) # below assumes "num_tokens x emb_dim" (VERIFY)
        
        # upweight by flagged_index
        embedding[:,flagged_index] *= self.upweight
        
        # average across embeddings
        embedding_ave = embedding.sum(0) / (num_tokens + self.upweight - 1)
        
        return embedding_ave

### Clustering Stuff (un-tailored)

In [16]:
class KMeansCriterion(nn.Module):
    
    def __init__(self, lmbda):
        super().__init__()
        self.lmbda = lmbda
    
    def forward(self, embeddings, centroids):
        distances = torch.sum((embeddings[:, None, :] - centroids)**2, 2)
        cluster_distances, cluster_assignments = distances.max(1)
        loss = self.lmbda * cluster_distances.sum()
        return loss, cluster_assignments

In [35]:
def centroid_init(k, d):
    ## Here we ideally don't want to do randomized/zero initialization
    centroid_sums = Variable(torch.zeros(k, d))
    centroid_counts = Variable(torch.zeros(k))
    for X, y in trainloader:
        X_var, y_var = Variable(X), Variable(y)
        cluster_assignments = Variable(torch.LongTensor(X.size(0)).random_(k))
        embeddings = encoder(X_var)
        update_clusters(centroid_sums, centroid_counts,
                        cluster_assignments, embeddings)
    
    centroid_means = centroid_sums / centroid_counts[:, None]
    return centroid_means.clone()

def update_clusters(centroid_sums, centroid_counts,
                    cluster_assignments, embeddings):
    k = centroid_sums.size(0)
    centroid_sums.index_add_(0, cluster_assignments, embeddings)
    np_counts = np.bincount(cluster_assignments.data.numpy(), minlength=k)
    centroid_counts.add_(Variable(torch.FloatTensor(np_counts)))

### Training Function

In [36]:
def train(encoder, decoder, centroids, optimizer, criterion,
          print_every=100, verbose=False):
    k, d = centroids.size()
    centroid_sums = torch.zeros_like(centroids)
    centroid_counts = Variable(torch.zeros(k))
    
    # run one epoch of gradient descent on autoencoders wrt centroids
    for i, (X, y) in enumerate(trainloader):
        
        # forward pass and compute loss
        X_var, y_var = Variable(X), Variable(y)
        embeddings = encoder(X_var)
        X_hat = decoder(embeddings)
        recon_loss = F.mse_loss(X_hat, X_var)
        cluster_loss, cluster_assignments = criterion(embeddings, centroids)
        loss = recon_loss + cluster_loss
        
        # run update step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        # store centroid sums and counts in memory for later centering
        update_clusters(centroid_sums, centroid_counts,
                        cluster_assignments, embeddings)
        
        if verbose and i % print_every == 0:
            batch_hat = autoencoder(Variable(batch))
            plot_batch(batch_hat.data)
            losses = (loss.data[0], recon_loss.data[0], cluster_loss.data[0])
            print('Trn Loss: %.3f [Recon Loss %.3f, Cluster Loss %.3f]' % losses)
    
    # update centroids based on assignments from autoencoders
    centroid_means = centroid_sums / (centroid_counts[:, None] + 1)
    return centroid_means, centroid_counts