In [1]:
## KAGGLE ONLY
from shutil import copyfile
copyfile(src="../input/scriptandpickle/generate_dataloaders.py", dst="../working/generate_dataloaders.py")
copyfile(src="../input/scriptssss/model.py", dst="../working/model.py")
copyfile(src="../input/newevaluation/evaluation.py", dst="../working/evaluation.py")

copyfile(src="../input/newfiles/train_dataloader_lstm.p", dst="../working/train_dataloader_lstm.p")
copyfile(src="../input/newfiles/val_dataloader_lstm.p", dst="../working/val_dataloader_lstm.p")
copyfile(src="../input/newfiles/dictionary_lstm.p", dst="../working/dictionary.p")
copyfile(src="../input/newfiles/train_unlabeled_dataloader_lstm.p", dst="../working/train_unlabelled_dataloader_lstm.p")
copyfile(src="../input/newfiles/train_labeled_dataloader_lstm.p", dst="../working/train_labelled_dataloader_lstm.p")

'../working/train_labelled_dataloader_lstm.p'

In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

#from datasets import get_mnist_dataset, get_data_loader
#from utils import *
#from models import *

import pickle as pkl
import os
import datetime as dt
import pandas as pd
import random

from generate_dataloaders import *

from tqdm import tqdm_notebook as tqdm

import evaluation
import importlib
importlib.reload(evaluation)

<module 'evaluation' from '/kaggle/working/evaluation.py'>

## Get Dataloaders

In [3]:
seed = 1029
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.enabled = False 
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def _init_fn(worker_id):
    np.random.seed(int(seed))

In [4]:
path = os.getcwd()
data_dir = path + '/'
#data_dir = path +'/data/' #Uncomment for local system

#### *Verify filenames are consistent*

In [5]:
train_loader_labelled = pkl.load(open(data_dir + 'train_labelled_dataloader_lstm.p','rb'))
train_loader_unlabelled = pkl.load(open(data_dir + 'train_unlabelled_dataloader_lstm.p','rb'))
val_loader = pkl.load(open(data_dir + 'val_dataloader_lstm.p','rb'))

In [6]:
review_dict = pkl.load(open(data_dir + 'dictionary.p','rb'))

In [7]:
#%conda install pytorch torchvision -c pytorch
## if torch.__version__ is not 1.3.1, run this cell then restart kernel

In [8]:
print(torch.__version__)

1.3.0


## PRE TRAINED WORD EMBEDDINGS 

In [9]:
def get_coefs(word, *arr):
    return word, np.asarray(arr, dtype='float16')

In [10]:
def load_embeddings(path):
    with open(path) as f:
        return dict(get_coefs(*line.strip().split(' ')) for line in tqdm(f))

In [11]:
def build_matrix(review_dict, embedding_index ,dim = 200):
#     embedding_index = load_embeddings(path)
    embedding_matrix = np.zeros((len(review_dict.tokens), dim))
    unknown_words = []
    
    for word, i in review_dict.ids.items():
        try:
            embedding_matrix[i] = embedding_index[word]
        except KeyError:
            unknown_words.append(word)
    return embedding_matrix, unknown_words

In [12]:
## LOCAL - 2nd line // KAGGLE -- 1st line
glove_twitter = '../input/glove-global-vectors-for-word-representation/glove.twitter.27B.200d.txt' #Change loc for local system
#glove_twitter = data_dir + 'glove.twitter.27B.200d.txt'

In [13]:
embedding_index = load_embeddings(glove_twitter)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [14]:
glove_embedding_index,unknown_words = build_matrix(review_dict, embedding_index)
del embedding_index

In [15]:
len(review_dict.tokens)

16256

In [16]:
len(unknown_words)

4428

In [17]:
# for word in unknown_words:
#     print(word)

In [18]:
review_dict.get_id('great')

34

## Neural Network LSTM Class

NOTE: Data loader is defined as:
- tuple: (tokens, flagged_index, problematic)

In [19]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

In [20]:
class LSTM_model(nn.Module):
    """
    LSTM classification model using pretrained glove embeddings
    """
    # NOTE: we can't use linear layer until we take weighted average, otherwise it will
    # remember certain positions incorrectly (ie, 4th word has bigger weights vs 7th word)
    def __init__(self, embedding_matrix, num_hidden_layers = 3, hidden_size = 100, num_classes = 2):
        super(LSTM_model, self).__init__()
        vocab_size = embedding_matrix.shape[0]
        embed_size = embedding_matrix.shape[1]
        self.embedding_matrix = embedding_matrix
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_hidden_layers = num_hidden_layers
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)    
        self.embed.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
        self.embed.weight.requires_grad = False
        
        self.lstm = nn.LSTM(self.embed_size,self.hidden_size,self.num_hidden_layers, batch_first=True,bidirectional= True,bias=True)
        
        self.projection = nn.Linear(2*self.hidden_size, self.num_classes, bias=True)

    
    def forward(self, tokens, flagged_index):
        batch_size, num_tokens = tokens.shape
        embedding = self.embed(tokens)
#         print(embedding.shape) # below assumes "batch_size x num_tokens x Emb_dim" (VERIFY)
        
        lstm_output = self.lstm(embedding)
        # lstm_output is a tuple containing lstm output and (hidden_state, lstm_cell). 
        # lstm_output[0] would be of shape "batch_size x num_tokens x hidden_size" (VERIFY)
        
        logits = self.projection(lstm_output[0])
        # logits would be of shape "batch_size x num_tokens x num_classes (2)" (VERIFY)
        
        batch_size, _, __ = logits.shape
        
        #selecting the logit at the flagged index
        relevant_logits = logits[list(range(batch_size)),flagged_index]
        # relevant_logits would be of shape "batch_size x num_classes (2)" (VERIFY)
        
        return relevant_logits

## First performing fully supervised learning using the labelled set to train new vector representations

In [21]:
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

model = LSTM_model(glove_embedding_index, num_hidden_layers = 3, hidden_size = 100, num_classes = 2).to(current_device)

In [22]:
criterion = nn.CrossEntropyLoss(reduction='sum')
#optimizer = torch.optim.Adam(model.parameters(), 0.01, amsgrad=True)

## Supervised model training

In [33]:
def train_supervised_model(model, criterion, train_loader_labelled, valid_loader, num_frozen_epochs=8, num_unfrozen_epochs=0, path_to_save=None, print_every = 1000):

    train_losses=[]
    val_losses=[]
    num_gpus = torch.cuda.device_count()
    if num_gpus > 0:
        current_device = 'cuda'
    else:
        current_device = 'cpu'
    
    num_first_epochs = num_frozen_epochs
    num_second_epochs = num_unfrozen_epochs
    
    empty_centroids = torch.tensor([])
    # freeze part    
    optimizer = torch.optim.Adam(model.parameters(), 0.01, amsgrad=True)
    
    for epoch in range(num_first_epochs):
        print('{} | Epoch {}'.format(dt.datetime.now(), epoch))
        model.train()
        total_epoch_loss = 0
        
        for i,(tokens_labelled, labels, flagged_indices_labelled) in tqdm(enumerate(train_loader_labelled)):
            
            tokens_labelled = tokens_labelled.to(current_device)
            flagged_indices_labelled = flagged_indices_labelled.to(current_device)
            labels = labels.to(current_device)

            # forward pass and compute loss
            logits = model(tokens_labelled,flagged_indices_labelled)
            
            loss = criterion(logits, labels)
        
            # run update step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            #Add loss to the epoch loss
            total_epoch_loss += loss.detach()

            if i % print_every == 0:
                losses = loss/len(tokens_labelled)
                print('Average training loss at batch ',i,': %.3f' % losses)
            
        total_epoch_loss /= len(train_loader_labelled.dataset)
        total_epoch_loss = total_epoch_loss.detach()
        train_losses.append(total_epoch_loss)
        print('Average training loss after epoch ',epoch,': %.3f' % total_epoch_loss)
        
        # calculate validation loss after every epoch
        total_validation_loss = 0
        for i, (tokens, labels, flagged_indices) in enumerate(valid_loader):
            model.eval()
            tokens = tokens.to(current_device)
            labels = labels.to(current_device)
            flagged_indices = flagged_indices.to(current_device)
            
            # forward pass and compute loss
            logits = model(tokens,flagged_indices)
            
            loss = criterion(logits, labels)
            
            #Add loss to the validation loss
            total_validation_loss += loss

        total_validation_loss /= len(valid_loader.dataset)
        val_losses.append(total_validation_loss)
        print('Average validation loss after epoch ',epoch,': %.3f' % total_validation_loss)
        print('Train result:')
        TP_cluster, FP_cluster=evaluation.main(model, empty_centroids, train_loader_labelled, criterion, data_dir, current_device)
        print()
        print('Validation result:')
        TP_cluster, FP_cluster=evaluation.main(model, empty_centroids, valid_loader, criterion, data_dir, current_device)
        
        if path_to_save == None:
            pass
        else:
            opts = {"embedding_matrix":model.embedding_matrix,\
                    "num_hidden_layers":model.num_hidden_layers,\
                    "hidden_size":model.hidden_size,\
                    "num_classes":model.num_classes}
            torch.save(model.state_dict(), path_to_save+ str(epoch)+'_model_dict_labelled.pt')
            torch.save(train_losses, path_to_save+ str(epoch)+'_train_losses_labelled')
            torch.save(val_losses, path_to_save+ str(epoch)+'_val_losses_labelled')
            torch.save(opts, path_to_save+str(epoch)+'_opts_labelled')

    # unfreeze part
    unfreeze_model(model)
    
    params_to_update = []
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
    

    optimizer = torch.optim.Adam(model.parameters(), 0.01, amsgrad=True)
    
    for epoch in range(num_second_epochs):
        print('{} | Epoch {}'.format(dt.datetime.now(), epoch))
        model.train()
        total_epoch_loss = 0

        for i,(tokens_labelled, labels, flagged_indices_labelled) in tqdm(enumerate(train_loader_labelled)):
            
            tokens_labelled = tokens_labelled.to(current_device)
            flagged_indices_labelled = flagged_indices_labelled.to(current_device)
            labels = labels.to(current_device)

            # forward pass and compute loss
            logits = model(tokens_labelled,flagged_indices_labelled)
            
            loss = criterion(logits, labels)
        
            # run update step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            #Add loss to the epoch loss
            total_epoch_loss += loss.detach()

            if i % print_every == 0:
                losses = loss/len(tokens_labelled)
                print('Average training loss at batch ',i,': %.3f' % losses)
            
        total_epoch_loss /= len(train_loader_labelled.dataset)
        total_epoch_loss = total_epoch_loss.detach()
        train_losses.append(total_epoch_loss)
        print('Average training loss after epoch ',epoch,': %.3f' % total_epoch_loss)
        
        # calculate validation loss after every epoch
        total_validation_loss = 0
        for i, (tokens, labels, flagged_indices) in enumerate(valid_loader):
            model.eval()
            tokens = tokens.to(current_device)
            labels = labels.to(current_device)
            flagged_indices = flagged_indices.to(current_device)
            
            # forward pass and compute loss
            logits = model(tokens,flagged_indices)
            
            loss = criterion(logits, labels)
            
            #Add loss to the validation loss
            total_validation_loss += loss

        total_validation_loss /= len(valid_loader.dataset)
        val_losses.append(total_validation_loss)
        print('Average validation loss after epoch ',epoch,': %.3f' % total_validation_loss)
        print('Train result:')
        TP_cluster, FP_cluster=evaluation.main(model, empty_centroids, train_loader_labelled, criterion, data_dir, current_device)
        print()
        print('Validation result:')
        TP_cluster, FP_cluster=evaluation.main(model, empty_centroids, valid_loader, criterion, data_dir, current_device)
        
        
        if path_to_save == None:
            pass
        else:
            opts = {"embedding_matrix":model.embedding_matrix,\
                    "num_hidden_layers":model.num_hidden_layers,\
                    "hidden_size":model.hidden_size,\
                    "num_classes":model.num_classes}
            torch.save(model.state_dict(), path_to_save+ str(epoch)+'_model_dict_labelled.pt')
            torch.save(train_losses, path_to_save+ str(epoch)+'_train_losses_labelled')
            torch.save(val_losses, path_to_save+ str(epoch)+'_val_losses_labelled')
            torch.save(opts, path_to_save+str(epoch)+'_opts_labelled')

    return model, train_losses, val_losses

In [34]:
path = os.getcwd()
model_folder = 'lstm_unfrozen_model/'
model_dir = path + '/' #+ model_folder

In [35]:
len(train_loader_labelled)

226

In [36]:
model, train_losses, val_losses = train_supervised_model(model, criterion, train_loader_labelled, val_loader, num_frozen_epochs=3, num_unfrozen_epochs=10, path_to_save=model_dir)


2019-12-07 21:24:20.166308 | Epoch 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.016

Average training loss after epoch  0 : 0.065
Average validation loss after epoch  0 : 0.462
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3778
TP_rate: 0.9499735309687666
FP_rate: 0.05002646903123346
FN_rate: 0.0069605568445475635
TN_rate: 0.9930394431554525


Accuracy: 0.9715064870621095
Precision: 0.9499735309687666
Recall: 0.9927261898879022
F1 score: 0.9708794352192728
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 478
TP_rate: 0.9121338912133892
FP_rate: 0.08786610878661087
FN_rate: 0.41025641025641024
TN_rate: 0.5897435897435898


Accuracy: 0.7509387404784895
Precision: 0.9121338912133892
Recall: 0.6897614797987993
F1 score: 0.7855130041114243
2019-12-07 21:24:34.786821 | Epoch 1


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.098

Average training loss after epoch  1 : 0.047
Average validation loss after epoch  1 : 0.567
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3662
TP_rate: 0.9814309120699072
FP_rate: 0.018569087930092845
FN_rate: 0.005331088664421998
TN_rate: 0.994668911335578


Accuracy: 0.9880499117027426
Precision: 0.9814309120699072
Recall: 0.9945973916096742
F1 score: 0.9879702870370578
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 455
TP_rate: 0.9384615384615385
FP_rate: 0.06153846153846154
FN_rate: 0.4032258064516129
TN_rate: 0.5967741935483871


Accuracy: 0.7676178660049628
Precision: 0.9384615384615385
Recall: 0.6994636582208249
F1 score: 0.8015259086574124
2019-12-07 21:24:49.649500 | Epoch 2


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.007

Average training loss after epoch  2 : 0.036
Average validation loss after epoch  2 : 0.535
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3563
TP_rate: 0.9935447656469267
FP_rate: 0.006455234353073253
FN_rate: 0.01992901992901993
TN_rate: 0.98007098007098


Accuracy: 0.9868078728589533
Precision: 0.9935447656469267
Recall: 0.9803359295399097
F1 score: 0.986896152077517
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 437
TP_rate: 0.9473684210526315
FP_rate: 0.05263157894736842
FN_rate: 0.475
TN_rate: 0.525


Accuracy: 0.7361842105263158
Precision: 0.9473684210526315
Recall: 0.666049953746531
F1 score: 0.7821835958718087
2019-12-07 21:25:04.536092 | Epoch 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.097

Average training loss after epoch  0 : 0.043
Average validation loss after epoch  0 : 0.656
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3597
TP_rate: 0.9958298582151793
FP_rate: 0.004170141784820684
FN_rate: 0.008542298153761367
TN_rate: 0.9914577018462386


Accuracy: 0.993643780030709
Precision: 0.9958298582151793
Recall: 0.991494887527902
F1 score: 0.9936576449147989
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 456
TP_rate: 0.9298245614035088
FP_rate: 0.07017543859649122
FN_rate: 0.45901639344262296
TN_rate: 0.5409836065573771


Accuracy: 0.735404083980443
Precision: 0.9298245614035088
Recall: 0.6694967902257196
F1 score: 0.7784733927281483
2019-12-07 21:25:19.371246 | Epoch 1


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.005

Average training loss after epoch  1 : 0.049
Average validation loss after epoch  1 : 0.482
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3600
TP_rate: 0.9930555555555556
FP_rate: 0.006944444444444444
FN_rate: 0.010479867622724766
TN_rate: 0.9895201323772752


Accuracy: 0.9912878439664154
Precision: 0.9930555555555556
Recall: 0.989557052615508
F1 score: 0.9913032173698589
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 453
TP_rate: 0.9359823399558499
FP_rate: 0.0640176600441501
FN_rate: 0.4375
TN_rate: 0.5625


Accuracy: 0.7492411699779249
Precision: 0.9359823399558499
Recall: 0.681466599698644
F1 score: 0.7886996454106844
2019-12-07 21:25:34.151085 | Epoch 2


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.098

Average training loss after epoch  2 : 0.036
Average validation loss after epoch  2 : 0.512
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3608
TP_rate: 0.9925166297117517
FP_rate: 0.007483370288248337
FN_rate: 0.008844665561083471
TN_rate: 0.9911553344389166


Accuracy: 0.9918359820753342
Precision: 0.9925166297117517
Recall: 0.9911673582723469
F1 score: 0.9918415351151748
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 456
TP_rate: 0.9385964912280702
FP_rate: 0.06140350877192982
FN_rate: 0.39344262295081966
TN_rate: 0.6065573770491803


Accuracy: 0.7725769341386253
Precision: 0.9385964912280702
Recall: 0.704631328943107
F1 score: 0.8049577603749152
2019-12-07 21:25:48.958454 | Epoch 3


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.085

Average training loss after epoch  3 : 0.026
Average validation loss after epoch  3 : 0.727
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3620
TP_rate: 0.9903314917127072
FP_rate: 0.009668508287292817
FN_rate: 0.00776483638380477
TN_rate: 0.9922351636161952


Accuracy: 0.9912833276644513
Precision: 0.9903314917127072
Recall: 0.992220353722157
F1 score: 0.991275022917586
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 464
TP_rate: 0.9267241379310345
FP_rate: 0.07327586206896551
FN_rate: 0.41509433962264153
TN_rate: 0.5849056603773585


Accuracy: 0.7558148991541964
Precision: 0.9267241379310345
Recall: 0.6906479180556396
F1 score: 0.7914568501475951
2019-12-07 21:26:03.695562 | Epoch 4


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Average training loss at batch  0 : 0.002

Average training loss after epoch  4 : 0.024
Average validation loss after epoch  4 : 0.731
Train result:
Total examples in val loader: 7226
Assigned to cluster 1: 3621
TP_rate: 0.9928196630764982
FP_rate: 0.007180336923501795
FN_rate: 0.0049930651872399446
TN_rate: 0.99500693481276


Accuracy: 0.9939132989446291
Precision: 0.9928196630764982
Recall: 0.9949959896824244
F1 score: 0.9939066350221317
Validation result:
Total examples in val loader: 517
Assigned to cluster 1: 465
TP_rate: 0.9204301075268817
FP_rate: 0.07956989247311828
FN_rate: 0.46153846153846156
TN_rate: 0.5384615384615384


Accuracy: 0.72944582299421
Precision: 0.9204301075268817
Recall: 0.6660282499401483
F1 score: 0.772831446628238


### Clustering Stuff

In [None]:
class KMeansCriterion(nn.Module):
    
    def __init__(self, lmbda):
        super().__init__()
        self.lmbda = lmbda
    
    def forward(self, embeddings, centroids, labelled = False,  cluster_assignments = None):
        if labelled:
            num_reviews = len(cluster_assignments)
            distances = torch.sum((embeddings[:, None, :] - centroids)**2, 2)
            cluster_distances = distances[list(range(num_reviews)),cluster_assignments]
            loss = self.lmbda * cluster_distances.sum()
        else:
            distances = torch.sum((embeddings[:, None, :] - centroids)**2, 2)
            cluster_distances, cluster_assignments = distances.min(1)
            loss = self.lmbda * cluster_distances.sum()
        return loss, cluster_assignments

In [None]:
def centroid_init(k, d, dataloader, model, current_device):
    ## Here we ideally don't want to do randomized/zero initialization
    centroid_sums = torch.zeros(k, d).to(current_device)
    centroid_counts = torch.zeros(k).to(current_device)
    for (tokens, labels, flagged_indices) in dataloader:
        # cluster_assignments = torch.LongTensor(tokens.size(0)).random_(k)
        cluster_assignments = labels.to(current_device)
        
        model.eval()
        sentence_embed = model(tokens.to(current_device),flagged_indices.to(current_device))
    
        update_clusters(centroid_sums.detach(), centroid_counts.detach(),
                        cluster_assignments.detach(), sentence_embed.to(current_device).detach())
    
    centroid_means = centroid_sums / centroid_counts[:, None].to(current_device)
    return centroid_means.clone()

def update_clusters(centroid_sums, centroid_counts,
                    cluster_assignments, embeddings):
    k = centroid_sums.size(0)

    centroid_sums.index_add_(0, cluster_assignments, embeddings)
    bin_counts = torch.bincount(cluster_assignments,minlength=k).type(torch.FloatTensor).to(current_device)
    centroid_counts.add_(bin_counts)

## Dataloader stuff

In [None]:
def loadLabelledBatch(train_loader_labelled_iter, train_loader_labelled):
    try:
        tokens, labels, flagged_indices = next(train_loader_labelled_iter)
    except StopIteration:
        train_loader_labelled_iter = iter(train_loader_labelled)
        tokens, labels, flagged_indices = next(train_loader_labelled_iter)

    return tokens, labels, flagged_indices, train_loader_labelled_iter


def loadUnlabelledBatch(train_loader_unlabelled_iter, train_loader_unlabelled):
    try:
        tokens, labels, flagged_indices = next(train_loader_unlabelled_iter)
    except StopIteration:
        train_loader_unlabelled_iter = iter(train_loader_unlabelled)
        tokens, labels, flagged_indices = next(train_loader_unlabelled_iter)

    return tokens, labels, flagged_indices, train_loader_unlabelled_iter

### Training Function

In [None]:
def train_clusters(model, centroids, criterion, train_loader_labelled, train_loader_unlabelled, valid_loader, num_epochs=15, num_batches = 1000, path_to_save=None, print_every = 1000):

    train_loader_labelled_iter = iter(train_loader_labelled)
    train_loader_unlabelled_iter = iter(train_loader_unlabelled)

    train_losses=[]
    val_losses=[]
    num_gpus = torch.cuda.device_count()
    if num_gpus > 0:
        current_device = 'cuda'
    else:
        current_device = 'cpu'
    
    optimizer = torch.optim.Adam(model.parameters(), 0.01, amsgrad=True)
    
    for epoch in range(num_epochs):
        print('{} | Epoch {}'.format(dt.datetime.now(), epoch))
        model.eval() # we're only clustering, not training model
        k, d = centroids.size()
        centroid_sums = torch.zeros_like(centroids).to(current_device)
        centroid_counts = torch.zeros(k).to(current_device)
        total_epoch_loss = 0
        
        for i in tqdm(range(int(num_batches))):
            tokens_labelled, labels, flagged_indices_labelled, train_loader_labelled_iter = loadLabelledBatch(train_loader_labelled_iter, train_loader_labelled)
            tokens_unlabelled, _, flagged_indices_unlabelled, train_loader_unlabelled_iter = loadUnlabelledBatch(train_loader_unlabelled_iter, train_loader_unlabelled)

            tokens_labelled = tokens_labelled.to(current_device)
            labels = labels.to(current_device)
            flagged_indices_labelled = flagged_indices_labelled.to(current_device)
            
            tokens_unlabelled = tokens_unlabelled.to(current_device)
            flagged_indices_unlabelled = flagged_indices_unlabelled.to(current_device)

            # forward pass and compute loss
            sentence_embed_labelled = model(tokens_labelled,flagged_indices_labelled)
            sentence_embed_unlabelled = model(tokens_unlabelled,flagged_indices_unlabelled)
            
            cluster_loss_unlabelled, cluster_assignments_unlabelled = criterion(sentence_embed_unlabelled, centroids.detach())
            cluster_loss_labelled, cluster_assignments_labelled = criterion(sentence_embed_labelled, centroids.detach(), labelled = True, cluster_assignments = labels)
    
            total_batch_loss = cluster_loss_labelled.data + cluster_loss_unlabelled.data
            
#             #Add loss to the epoch loss
            total_epoch_loss += total_batch_loss.data

#             # store centroid sums and counts in memory for later centering
            update_clusters(centroid_sums.detach(), centroid_counts.detach(),
                            cluster_assignments_labelled.detach(), sentence_embed_labelled.detach())
    
            update_clusters(centroid_sums.detach(), centroid_counts.detach(),
                            cluster_assignments_unlabelled.detach(), sentence_embed_unlabelled.detach())

            if i % print_every == 0:
                losses = total_batch_loss/(len(tokens_labelled)+ len(tokens_unlabelled))
                print('Average training loss at batch ',i,': %.3f' % losses)
            
        total_epoch_loss /= (len(train_loader_labelled.dataset)+len(train_loader_unlabelled.dataset))
        train_losses.append(total_epoch_loss)
        print('Average training loss after epoch ',epoch,': %.3f' % total_epoch_loss)
        
        # update centroids based on assignments from autoencoders
        centroids = centroid_sums / (centroid_counts[:, None] + 1).to(current_device)
        
        # calculate validation loss after every epoch
        total_validation_loss = 0
        for i, (tokens, labels, flagged_indices) in enumerate(valid_loader):
            model.eval()
            tokens = tokens.to(current_device)
            labels = labels.to(current_device)
            flagged_indices = flagged_indices.to(current_device)
            
            # forward pass and compute loss
            sentence_embed = model(tokens,flagged_indices)
            cluster_loss, cluster_assignments = criterion(sentence_embed, centroids)
            
            #Add loss to the validation loss
            total_validation_loss += cluster_loss.data

        total_validation_loss /= len(valid_loader.dataset)
        val_losses.append(total_validation_loss)
        print('Average validation loss after epoch ',epoch,': %.3f' % total_validation_loss)
        
        if path_to_save == None:
            pass
        else:
            opts = {"embedding_matrix":model.embedding_matrix,\
                    "num_hidden_layers":model.num_hidden_layers,\
                    "hidden_size":model.hidden_size,\
                    "num_classes":model.num_classes}
            torch.save(model.state_dict(), path_to_save+'model_dict_unlabelled.pt')
            torch.save(centroids, path_to_save+'centroids_unlabelled')
            torch.save(train_losses, path_to_save+'train_losses_unlabelled')
            torch.save(val_losses, path_to_save+'val_losses_unlabelled')
            torch.save(opts, path_to_save+'opts_unlabelled')
        
    return model, centroids, train_losses, val_losses

In [None]:
unsupervised_model = LSTM_model(glove_embedding_index, num_hidden_layers = 3, hidden_size = 100, num_classes = 2).to(current_device)

unsupervised_model.projection = nn.Identity()

In [None]:
centroids = centroid_init(2, 2*unsupervised_model.hidden_size, train_loader_labelled, unsupervised_model, current_device)
criterion = KMeansCriterion(1).to(current_device)
#optimizer = torch.optim.Adam(model.parameters(), 0.01, amsgrad=True)

In [None]:
centroids.shape

In [None]:
path = os.getcwd()
model_folder = 'lstm_unfrozen_model/'
model_dir = path + '/'#'/models/' + model_folder

In [None]:
num_batches = int(len(train_loader_unlabelled.dataset)/train_loader_unlabelled.batch_size)+1
num_batches

In [None]:
lstm_model, lstm_centroids, lstm_train_losses, lstm_val_losses = train_clusters(unsupervised_model, centroids, criterion, train_loader_labelled,train_loader_unlabelled, val_loader, num_epochs=3, num_batches=num_batches, path_to_save=model_dir)


In [None]:
torch.save(lstm_centroids, model_dir+'centroids_unlabelled')

In [None]:
# #Only needed for Kaggle

# from IPython.display import FileLink, FileLinks 
# FileLinks('.') #lists all downloadable files on server

# Evaluate Model

## Supervised Evaluation

In [None]:
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

In [None]:
## This cell will change for each model
model_folder = 'lstm_model/'

criterion = nn.CrossEntropyLoss(reduction='sum')
criterion = criterion.to(current_device)

path = os.getcwd()
model_dir = ''#path #+ '/models/' + model_folder

opts = torch.load(model_dir+'opts_labelled')
model = LSTM_model(opts['embedding_matrix']) #change here depending on model
model.load_state_dict(torch.load(model_dir+'model_dict_labelled.pt',map_location=lambda storage, loc: storage))
model = model.to(current_device)

In [None]:
empty_centroids = torch.tensor([])

TP_cluster, FP_cluster=evaluation.main(model, empty_centroids, train_loader_labelled, criterion, data_dir, current_device)

## Unsupervised Evaluation

In [None]:
## This cell will change for each model
model_folder = 'lstm_model/'

criterion = KMeansCriterion(1)
criterion = criterion.to(current_device)

path = os.getcwd()
model_dir = ''#path + '/models/' + model_folder

opts = torch.load(model_dir+'opts_unlabelled')
model = LSTM_model(opts['embedding_matrix']) #change here depending on model
model.projection = nn.Identity()
model.load_state_dict(torch.load(model_dir+'model_dict_unlabelled.pt',map_location=lambda storage, loc: storage))
model = model.to(current_device)
centroids = torch.load(model_dir+'centroids_unlabelled',map_location=lambda storage, loc: storage)

In [None]:
TP_cluster, FP_cluster=evaluation.main(model, centroids, val_loader, criterion, data_dir, current_device)

In [None]:
#TP_cluster

In [None]:
#FP_cluster

In [None]:
#FP_cluster[FP_cluster.original == 1]