In [112]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

#from datasets import get_mnist_dataset, get_data_loader
#from utils import *
#from models import *

import pickle as pkl
import os
import datetime as dt

from generate_dataloaders import *

from tqdm import tqdm_notebook as tqdm

## Get Dataloaders

In [89]:
def get_dataloaders(train_filename,val_filename):
    path = os.getcwd()
    data_dir = path + '/data/'
    train_dataloader = pkl.load(open(data_dir + train_filename,'rb'))
    val_dataloader = pkl.load(open(data_dir + val_filename,'rb'))
    return train_dataloader,val_dataloader

In [90]:
path = os.getcwd()
data_dir = path + '/data/'

In [91]:
train_loader,val_loader = get_dataloaders('train_dataloader.p','val_dataloader.p')

In [92]:
ground_truth_dataloader = pkl.load(open(data_dir + 'ground_truth_dataloader.p','rb'))

In [93]:
print(torch.__version__)

1.0.0


## Scratchwork (IGNORE)

In [None]:
for i,x in enumerate(train_loader):
    print(len(x[0]))
    break

32


In [None]:
minibatch = torch.tensor([
                            [[1,2,3,4,5],[3,3,3,3,3],[1,1,1,1,1],[2,1,2,1,2]],
                            [[0,1,0,1,0],[1,1,1,1,1],[2,0,0,0,0],[0,0,0,0,2]]
                         ], dtype=torch.float32)

flagged_indices = torch.tensor([1,2])

upweight_value = 10

print(minibatch.shape)
print(minibatch)

print(flagged_indices.shape)
print(flagged_indices)

torch.Size([2, 4, 5])
tensor([[[1., 2., 3., 4., 5.],
         [3., 3., 3., 3., 3.],
         [1., 1., 1., 1., 1.],
         [2., 1., 2., 1., 2.]],

        [[0., 1., 0., 1., 0.],
         [1., 1., 1., 1., 1.],
         [2., 0., 0., 0., 0.],
         [0., 0., 0., 0., 2.]]])
torch.Size([2])
tensor([1, 2])


In [None]:
batch_size, num_tokens, emb_dim = minibatch.shape
print(type(minibatch))
minibatch[range(batch_size),flagged_indices,:] *= upweight_value
print(batch_size, num_tokens, emb_dim)
minibatch

<class 'torch.Tensor'>
2 4 5


tensor([[[ 1.,  2.,  3.,  4.,  5.],
         [30., 30., 30., 30., 30.],
         [ 1.,  1.,  1.,  1.,  1.],
         [ 2.,  1.,  2.,  1.,  2.]],

        [[ 0.,  1.,  0.,  1.,  0.],
         [ 1.,  1.,  1.,  1.,  1.],
         [20.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  2.]]])

In [None]:
minibatch.sum(1) / (num_tokens + upweight_value - 1)

tensor([[2.6154, 2.6154, 2.7692, 2.7692, 2.9231],
        [1.6154, 0.1538, 0.0769, 0.1538, 0.2308]])

In [None]:
print(type(minibatch))

<class 'torch.Tensor'>


In [None]:
embed = torch.tensor(np.array([[2,4,5,6],[1,3,45,7],[3,4,5,6]]))

In [None]:
centers = torch.tensor(np.array(([2,3,4,5],[1,2,4,5])))

In [None]:
torch.sum((embed[:,None,:]-centers)**2,2)

tensor([[   3,    7],
        [1686, 1686],
        [   4,   10]])

In [None]:
cluster_distances, cluster_assignments = torch.sum((embed[:,None,:]-centers)**2, 2).min(1)
cluster_assignments

tensor([0, 1, 0])

In [None]:
for i, (tokens, labels, flagged_indices) in enumerate(train_loader):
    #print(tokens, labels, flagged_indices)
    break

In [56]:
cluster_assts = torch.LongTensor([1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1,
        0, 0, 0, 1, 0, 0, 0, 1])
k = 2
bin_counts = torch.bincount(cluster_assts,minlength=k)

In [72]:
bin_counts = bin_counts.type(torch.FloatTensor).to(current_device)
bin_counts

tensor([16., 16.])

## Neural Network Class

NOTE: Data loader is defined as:
- tuple: (tokens, flagged_index, problematic)

In [94]:
class neuralNetBow(nn.Module):
    """
    BagOfWords classification model
    """
    # NOTE: we can't use linear layer until we take weighted average, otherwise it will
    # remember certain positions incorrectly (ie, 4th word has bigger weights vs 7th word)
    def __init__(self, vocab_size, emb_dim, upweight=10):
        super(neuralNetBow, self).__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=2)
        self.upweight = upweight
    
    def forward(self, tokens, flagged_index):
        batch_size, num_tokens = tokens.shape
        embedding = self.embed(tokens)
#         print(embedding.shape) # below assumes "batch_size x num_tokens x Emb_dim" (VERIFY)
        
        # upweight by flagged_index
#         print(type(embedding))
        embedding[torch.LongTensor(range(batch_size)),flagged_index.type(torch.LongTensor),:] *= self.upweight
        
        # average across embeddings
        embedding_ave = embedding.sum(1) / (num_tokens + self.upweight - 1)
        
        return embedding_ave

### Clustering Stuff (un-tailored)

In [95]:
class KMeansCriterion(nn.Module):
    
    def __init__(self, lmbda):
        super().__init__()
        self.lmbda = lmbda
    
    def forward(self, embeddings, centroids):
        distances = torch.sum((embeddings[:, None, :] - centroids)**2, 2)
        cluster_distances, cluster_assignments = distances.min(1)
        loss = self.lmbda * cluster_distances.sum()
        return loss, cluster_assignments

In [96]:
def centroid_init(k, d, dataloader, model, current_device):
    ## Here we ideally don't want to do randomized/zero initialization
    centroid_sums = torch.zeros(k, d).to(current_device)
    centroid_counts = torch.zeros(k).to(current_device)
    for (tokens, labels, flagged_indices) in dataloader:
        # cluster_assignments = torch.LongTensor(tokens.size(0)).random_(k)
        cluster_assignments = labels.to(current_device)
        
        model.eval()
        sentence_embed = model(tokens.to(current_device),flagged_indices.to(current_device))
    
        update_clusters(centroid_sums, centroid_counts,
                        cluster_assignments, sentence_embed.to(current_device))
    
    centroid_means = centroid_sums / centroid_counts[:, None].to(current_device)
    return centroid_means.clone()

def update_clusters(centroid_sums, centroid_counts,
                    cluster_assignments, embeddings):
    k = centroid_sums.size(0)

    centroid_sums.index_add_(0, cluster_assignments, embeddings)
    bin_counts = torch.bincount(cluster_assignments,minlength=k).type(torch.FloatTensor).to(current_device)
    centroid_counts.add_(bin_counts)
    
    #np_cluster_assignments = cluster_assignments.to('cpu')
    #np_counts = np.bincount(np_cluster_assignments.data.numpy(), minlength=k)
    #centroid_counts.add_(torch.FloatTensor(np_counts))

### Training Function (un-tailored, needs alterations)

In [113]:
def train_model(model, centroids, criterion, optimizer, train_loader, valid_loader, num_epochs=10, path_to_save=None, print_every = 100):

    train_losses=[]
    val_losses=[]
    num_gpus = torch.cuda.device_count()
    if num_gpus > 0:
        current_device = 'cuda'
    else:
        current_device = 'cpu'
    
    for epoch in range(num_epochs):
        print('{} | Epoch {}'.format(dt.datetime.now(), epoch))
        model.train()
        k, d = centroids.size()
        centroid_sums = torch.zeros_like(centroids).to(current_device)
        centroid_counts = torch.zeros(k).to(current_device)
        total_epoch_loss = 0

        # run one epoch of gradient descent on autoencoders wrt centroids
        for i, (tokens, labels, flagged_indices) in tqdm(enumerate(train_loader)):
            tokens = tokens.to(current_device)
            labels = labels.to(current_device)
            flagged_indices = flagged_indices.to(current_device)

            # forward pass and compute loss
            sentence_embed = model(tokens,flagged_indices)
            cluster_loss, cluster_assignments = criterion(sentence_embed, centroids.detach())

            # run update step
            optimizer.zero_grad()
            cluster_loss.backward()
            optimizer.step()
            
            #Add loss to the epoch loss
            total_epoch_loss += cluster_loss.data

            # store centroid sums and counts in memory for later centering
            update_clusters(centroid_sums, centroid_counts,
                            cluster_assignments, sentence_embed)

            if i % print_every == 0:
                losses = cluster_loss.data/len(tokens)
                print('Average training loss at batch ',i,': %.3f' % losses)
            
        total_epoch_loss /= len(train_loader.dataset)
        train_losses.append(total_epoch_loss)
        print('Average training loss after epoch ',epoch,': %.3f' % total_epoch_loss)
        
        # update centroids based on assignments from autoencoders
        centroids = centroid_sums / (centroid_counts[:, None] + 1).to(current_device)
        
        # calculate validation loss after every epoch
        total_validation_loss = 0
        for i, (tokens, labels, flagged_indices) in enumerate(valid_loader):
            model.eval()
            tokens = tokens.to(current_device)
            labels = labels.to(current_device)
            flagged_indices = flagged_indices.to(current_device)
            
            # forward pass and compute loss
            sentence_embed = model(tokens,flagged_indices)
            cluster_loss, cluster_assignments = criterion(sentence_embed, centroids)
            
            #Add loss to the validation loss
            total_validation_loss += cluster_loss.data

        total_validation_loss /= len(valid_loader.dataset)
        val_losses.append(total_validation_loss)
        print('Average validation loss after epoch ',epoch,': %.3f' % total_validation_loss)
        
        if path_to_save == None:
            pass
        else:
            torch.save(model.state_dict(), path_to_save+'_dict_epoch'+str(epoch)+'.pt')
            torch.save(centroids, path_to_save+'_centroids_epoch'+str(epoch))
            torch.save(train_losses, path_to_save+'_train_losses')
            torch.save(val_losses, path_to_save+'_val_losses')
        
    return model, centroids, train_losses, val_losses

In [114]:
opts = {
    'vocab_size': 20000,
    'emb_dim': 512
}

In [115]:
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

model = neuralNetBow(opts['vocab_size'], opts['emb_dim']).to(current_device)

In [116]:
# model = neuralNetBow(opts['vocab_size'], opts['emb_dim'])
centroids = centroid_init(2, opts['emb_dim'],ground_truth_dataloader, model, current_device)
criterion = KMeansCriterion(1).to(current_device)
optimizer = torch.optim.Adam(model.parameters(), 0.01, amsgrad=True)

In [117]:
centroids

tensor([[ 2.3282e-01, -4.9162e-01, -2.4064e-01,  ...,  2.6259e-01,
          1.1852e-01,  3.5728e-02],
        [ 2.2224e-01, -4.7257e-01, -1.2804e-01,  ...,  1.1401e-01,
          6.7910e-02, -2.6494e-04]], grad_fn=<CloneBackward>)

In [118]:
current_device

'cpu'

In [None]:
train_model(model, centroids, criterion, optimizer, train_loader, val_loader, num_epochs=5, path_to_save="baseline_model")




0it [00:00, ?it/s][A[A[A


1it [00:00,  7.21it/s][A[A[A

2019-11-13 18:12:16.519000 | Epoch 0
Average training loss at batch  0 : 50.999





3it [00:00,  8.11it/s][A[A[A


5it [00:00,  8.89it/s][A[A[A


7it [00:00,  9.59it/s][A[A[A


9it [00:00, 10.19it/s][A[A[A


11it [00:00, 10.63it/s][A[A[A


13it [00:01, 10.92it/s][A[A[A


15it [00:01, 11.09it/s][A[A[A


17it [00:01, 11.30it/s][A[A[A


19it [00:01, 11.44it/s][A[A[A


21it [00:01, 11.59it/s][A[A[A


23it [00:02, 11.68it/s][A[A[A


25it [00:02, 11.70it/s][A[A[A


27it [00:02, 11.66it/s][A[A[A


29it [00:02, 11.50it/s][A[A[A


31it [00:02, 11.47it/s][A[A[A


33it [00:02, 11.46it/s][A[A[A


35it [00:03, 11.49it/s][A[A[A


37it [00:03, 11.41it/s][A[A[A


39it [00:03, 11.38it/s][A[A[A


41it [00:03, 11.48it/s][A[A[A


43it [00:03, 11.15it/s][A[A[A


45it [00:03, 10.86it/s][A[A[A


47it [00:04, 10.85it/s][A[A[A


49it [00:04, 10.90it/s][A[A[A


51it [00:04, 11.10it/s][A[A[A


53it [00:04, 11.07it/s][A[A[A


55it [00:04, 11.11it/s][A[A[A


57it [00:05, 11.26it/s][A[A[A


59it [00:05, 11.37it/

Average training loss at batch  100 : 17.081





105it [00:09, 11.32it/s][A[A[A


107it [00:09, 11.09it/s][A[A[A


109it [00:09, 11.30it/s][A[A[A


111it [00:09, 11.46it/s][A[A[A


113it [00:10, 11.53it/s][A[A[A


115it [00:10, 11.43it/s][A[A[A


117it [00:10, 11.13it/s][A[A[A


119it [00:10, 10.94it/s][A[A[A


121it [00:10, 11.15it/s][A[A[A


123it [00:10, 11.35it/s][A[A[A


125it [00:11, 11.32it/s][A[A[A


127it [00:11, 11.00it/s][A[A[A


129it [00:11, 11.06it/s][A[A[A


131it [00:11, 11.26it/s][A[A[A


133it [00:11, 11.43it/s][A[A[A


135it [00:11, 11.14it/s][A[A[A


137it [00:12, 11.39it/s][A[A[A


139it [00:12, 11.03it/s][A[A[A


141it [00:12, 10.68it/s][A[A[A


143it [00:12, 10.37it/s][A[A[A


145it [00:12, 10.47it/s][A[A[A


147it [00:13, 10.89it/s][A[A[A


149it [00:13, 10.84it/s][A[A[A


151it [00:13, 10.74it/s][A[A[A


153it [00:13, 10.56it/s][A[A[A


155it [00:13, 10.93it/s][A[A[A


157it [00:14, 11.23it/s][A[A[A


159it [00:14, 11.35it/s]

Average training loss at batch  200 : 9.037





205it [00:18, 11.96it/s][A[A[A


207it [00:18, 11.94it/s][A[A[A


209it [00:18, 11.98it/s][A[A[A


211it [00:18, 12.04it/s][A[A[A


213it [00:18, 11.98it/s][A[A[A


215it [00:18, 12.02it/s][A[A[A


217it [00:19, 12.06it/s][A[A[A


219it [00:19, 12.01it/s][A[A[A


221it [00:19, 12.00it/s][A[A[A


223it [00:19, 12.01it/s][A[A[A


225it [00:19, 11.97it/s][A[A[A


227it [00:19, 11.99it/s][A[A[A


229it [00:20, 12.04it/s][A[A[A


231it [00:20, 12.02it/s][A[A[A


233it [00:20, 12.02it/s][A[A[A


235it [00:20, 12.02it/s][A[A[A


237it [00:20, 11.75it/s][A[A[A


239it [00:20, 10.95it/s][A[A[A


241it [00:21, 10.56it/s][A[A[A


243it [00:21, 10.33it/s][A[A[A


245it [00:21, 10.09it/s][A[A[A


247it [00:21, 10.07it/s][A[A[A


249it [00:21,  9.99it/s][A[A[A


251it [00:22, 10.45it/s][A[A[A


253it [00:22, 10.77it/s][A[A[A


255it [00:22, 10.80it/s][A[A[A


257it [00:22, 10.30it/s][A[A[A


259it [00:22, 10.15it/s]

Average training loss at batch  300 : 6.494





305it [00:27, 10.87it/s][A[A[A


307it [00:27, 10.46it/s][A[A[A


309it [00:27, 10.37it/s][A[A[A


311it [00:27, 10.17it/s][A[A[A


313it [00:28, 10.18it/s][A[A[A


315it [00:28, 10.60it/s][A[A[A


317it [00:28, 10.96it/s][A[A[A


319it [00:28, 11.01it/s][A[A[A


321it [00:28, 10.61it/s][A[A[A


323it [00:28, 11.00it/s][A[A[A


325it [00:29, 11.33it/s][A[A[A


327it [00:29, 11.44it/s][A[A[A


329it [00:29, 11.54it/s][A[A[A


331it [00:29, 11.64it/s][A[A[A


333it [00:29, 11.67it/s][A[A[A


335it [00:29, 11.72it/s][A[A[A


337it [00:30, 11.77it/s][A[A[A


339it [00:30, 11.70it/s][A[A[A


341it [00:30, 11.76it/s][A[A[A


343it [00:30, 11.65it/s][A[A[A


345it [00:30, 11.71it/s][A[A[A


347it [00:30, 11.77it/s][A[A[A


349it [00:31, 11.75it/s][A[A[A


351it [00:31, 11.64it/s][A[A[A


353it [00:31, 11.69it/s][A[A[A


355it [00:31, 11.62it/s][A[A[A


357it [00:31, 11.49it/s][A[A[A


359it [00:32, 11.35it/s]

Average training loss at batch  400 : 3.110





405it [00:36, 11.79it/s][A[A[A


407it [00:36, 11.84it/s][A[A[A


409it [00:36, 11.78it/s][A[A[A


411it [00:36, 11.74it/s][A[A[A


413it [00:36, 11.78it/s][A[A[A


415it [00:36, 11.85it/s][A[A[A


417it [00:37, 11.74it/s][A[A[A


419it [00:37, 11.48it/s][A[A[A


421it [00:37, 11.35it/s][A[A[A


423it [00:37, 11.28it/s][A[A[A


425it [00:37, 11.32it/s][A[A[A


427it [00:37, 11.35it/s][A[A[A


429it [00:38, 11.42it/s][A[A[A


431it [00:38, 11.57it/s][A[A[A


433it [00:38, 11.67it/s][A[A[A


435it [00:38, 11.78it/s][A[A[A


437it [00:38, 11.88it/s][A[A[A


439it [00:38, 11.80it/s][A[A[A


441it [00:39, 11.75it/s][A[A[A


443it [00:39, 11.72it/s][A[A[A


445it [00:39, 11.61it/s][A[A[A


447it [00:39, 11.75it/s][A[A[A


449it [00:39, 11.66it/s][A[A[A


451it [00:39, 11.75it/s][A[A[A


453it [00:40, 11.84it/s][A[A[A


455it [00:40, 11.78it/s][A[A[A


457it [00:40, 11.84it/s][A[A[A


459it [00:40, 11.77it/s]

Average training loss at batch  500 : 3.033





505it [00:44, 11.74it/s][A[A[A


507it [00:44, 11.71it/s][A[A[A


509it [00:44, 11.59it/s][A[A[A


511it [00:45, 11.63it/s][A[A[A


513it [00:45, 11.72it/s][A[A[A


515it [00:45, 11.68it/s][A[A[A


517it [00:45, 11.47it/s][A[A[A


519it [00:45, 11.44it/s][A[A[A


521it [00:45, 11.56it/s][A[A[A


523it [00:46, 11.51it/s][A[A[A


525it [00:46, 11.56it/s][A[A[A


527it [00:46, 11.62it/s][A[A[A


529it [00:46, 11.78it/s][A[A[A


531it [00:46, 11.86it/s][A[A[A


533it [00:46, 11.78it/s][A[A[A


535it [00:47, 11.67it/s][A[A[A


537it [00:47, 11.58it/s][A[A[A


539it [00:47, 11.59it/s][A[A[A


541it [00:47, 11.55it/s][A[A[A


543it [00:47, 11.52it/s][A[A[A


545it [00:48, 11.57it/s][A[A[A


547it [00:48, 11.58it/s][A[A[A


549it [00:48, 11.68it/s][A[A[A


551it [00:48, 11.74it/s][A[A[A


553it [00:48, 11.72it/s][A[A[A


555it [00:48, 11.70it/s][A[A[A


557it [00:49, 11.72it/s][A[A[A


559it [00:49, 11.77it/s]