# Using Embeddings for Document Classification

## Colab configurations

In [193]:
import os
colab = False # Change to True if using Colab
if colab:
    from google.colab import drive
    drive.mount('/content/drive')

    gdrivedir = '/content/drive/My Drive/Colab Notebooks/deepnlpa3/'
    os.chdir(gdrivedir)
    !pwd

### Download dataset (Skip if not required)

In [194]:
!pip install requests
import requests

def progress_bar(some_iter):
    try:
        from tqdm import tqdm
        return tqdm(some_iter)
    except ModuleNotFoundError:
        return some_iter

def download_file_from_google_drive(id, destination):
    print("Trying to fetch {}".format(destination))

    def get_confirm_token(response):
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                return value

        return None

    def save_response_content(response, destination):
        CHUNK_SIZE = 32768

        with open(destination, "wb") as f:
            for chunk in progress_bar(response.iter_content(CHUNK_SIZE)):
                if chunk: # filter out keep-alive new chunks
                    f.write(chunk)

    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)

if not colab:
    gdrivedir = ''
    
datadir = 'data/ag_news/'
!mkdir 'data/'
!mkdir 'data/ag_news/'

filename = gdrivedir + datadir + 'news_with_splits.csv'
if not os.path.exists(filename):
    download_file_from_google_drive('1Z4fOgvrNhcn6pYlOxrEuxrPNxT-bLh7T', filename)

!pwd

mkdir: cannot create directory ‘data/’: File exists
mkdir: cannot create directory ‘data/ag_news/’: File exists
/home/karyl/Virtual/Python/deepnlpa3/src


## Requirements

In [195]:
!pip install sentencepiece
!pip install pandas
!pip install tqdm



## Imports

In [196]:
import os
from argparse import Namespace
from collections import Counter, defaultdict # defaultdict NEW
import json
import re
import string

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm #from tqdm import tqdm_notebook

### NEW ####
import sentencepiece as spm
import tempfile
import sys
import copy
import random

## Data Vectorization classes

### The Vocabulary

In [197]:
class Vocabulary(object):
    """Class to process text and extract vocabulary for mapping"""

    def __init__(self, token_to_idx=None):
        """
        Args:
            token_to_idx (dict): a pre-existing map of tokens to indices
        """

        if token_to_idx is None:
            token_to_idx = {}
        self._token_to_idx = token_to_idx

        self._idx_to_token = {idx: token 
                              for token, idx in self._token_to_idx.items()}
        
    def to_serializable(self):
        """ returns a dictionary that can be serialized """
        return {'token_to_idx': self._token_to_idx}

    @classmethod
    def from_serializable(cls, contents):
        """ instantiates the Vocabulary from a serialized dictionary """
        return cls(**contents)

    def add_token(self, token):
        """Update mapping dicts based on the token.

        Args:
            token (str): the item to add into the Vocabulary
        Returns:
            index (int): the integer corresponding to the token
        """
        if token in self._token_to_idx:
            index = self._token_to_idx[token]
        else:
            index = len(self._token_to_idx)
            self._token_to_idx[token] = index
            self._idx_to_token[index] = token
        return index
            
    def add_many(self, tokens):
        """Add a list of tokens into the Vocabulary
        
        Args:
            tokens (list): a list of string tokens
        Returns:
            indices (list): a list of indices corresponding to the tokens
        """
        return [self.add_token(token) for token in tokens]

    def lookup_token(self, token):
        """Retrieve the index associated with the token 
        
        Args:
            token (str): the token to look up 
        Returns:
            index (int): the index corresponding to the token
        """
        return self._token_to_idx[token]

    def lookup_index(self, index):
        """Return the token associated with the index
        
        Args: 
            index (int): the index to look up
        Returns:
            token (str): the token corresponding to the index
        Raises:
            KeyError: if the index is not in the Vocabulary
        """
        if index not in self._idx_to_token:
            raise KeyError("the index (%d) is not in the Vocabulary" % index)
        return self._idx_to_token[index]

    def __str__(self):
        return "<Vocabulary(size=%d)>" % len(self)

    def __len__(self):
        return len(self._token_to_idx)

In [198]:
class WordSequenceVocabulary(Vocabulary):
    def __init__(self, token_to_idx=None, unk_token="<UNK>",
                 mask_token="<MASK>", begin_seq_token="<BEGIN>",
                 end_seq_token="<END>"):

        super(WordSequenceVocabulary, self).__init__(token_to_idx)

        self._mask_token = mask_token
        self._unk_token = unk_token
        self._begin_seq_token = begin_seq_token
        self._end_seq_token = end_seq_token

        self.mask_index = self.add_token(self._mask_token)
        self.unk_index = self.add_token(self._unk_token)
        self.begin_seq_index = self.add_token(self._begin_seq_token)
        self.end_seq_index = self.add_token(self._end_seq_token)

    def to_serializable(self):
        contents = super(WordSequenceVocabulary, self).to_serializable()
        contents.update({'unk_token': self._unk_token,
                         'mask_token': self._mask_token,
                         'begin_seq_token': self._begin_seq_token,
                         'end_seq_token': self._end_seq_token})
        return contents

    def lookup_token(self, token):
        """Retrieve the index associated with the token 
          or the UNK index if token isn't present.
        
        Args:
            token (str): the token to look up 
        Returns:
            index (int): the index corresponding to the token
        Notes:
            `unk_index` needs to be >=0 (having been added into the Vocabulary) 
              for the UNK functionality 
        """
        if self.unk_index >= 0:
            return self._token_to_idx.get(token, self.unk_index)
        else:
            return self._token_to_idx[token]

In [199]:
class CharacterSequenceVocabulary(Vocabulary):
    def __init__(self, token_to_idx=None, pad_token="<PAD>"):

        super(CharacterSequenceVocabulary, self).__init__(token_to_idx)

        self._pad_token = pad_token

        self.pad_index = self.add_token(self._pad_token)

    def to_serializable(self):
        contents = super(CharacterSequenceVocabulary, self).to_serializable()
        contents.update({'pad_token': self._pad_token})
        return contents

    def lookup_token(self, token):
        """Retrieve the index associated with the token 
          or the UNK index if token isn't present.
        
        Args:
            token (str): the token to look up 
        Returns:
            index (int): the index corresponding to the token
        Notes:
            `unk_index` needs to be >=0 (having been added into the Vocabulary) 
              for the UNK functionality 
        """
        return self._token_to_idx[token]

In [200]:
class SubwordSequenceVocabulary(Vocabulary):
    def __init__(self, token_to_idx=None, 
                 pad_token="<PAD>", unk_token="<UNK>",
                 begin_seq_token="<BEGIN>",
                 end_seq_token="<END>"):
        
        super(SubwordSequenceVocabulary, self).__init__(token_to_idx)
        self.token_freq = {}
        self.bpe_codes = []
        
        self._pad_token = pad_token
        self._unk_token = unk_token
        self._begin_seq_token = begin_seq_token
        self._end_seq_token = end_seq_token

        self.pad_index = self.add_token(self._pad_token)
        self.unk_index = self.add_token(self._unk_token)
        self.begin_seq_index = self.add_token(self._begin_seq_token)
        self.end_seq_index = self.add_token(self._end_seq_token)
        
    def to_serializable(self):
        contents = super(SubwordSequenceVocabulary, self).to_serializable()
        contents.update({'token_freq': self.token_freq})
        contents.update({'bpe_codes': self.bpe_codes})
        contents.update({'pad_token': self._pad_token, 
                         'unk_token': self._unk_token,
                         'begin_seq_token': self._begin_seq_token,
                         'end_seq_token': self._end_seq_token})
        return contents
    
    def add_bpe_codes_list(self, bpe_code_list):
        self.bpe_codes.extend(bpe_code_list)
    
    def add_bpe_token(self, token, frequency):
        """Update mapping dicts based on the token.

        Args:
            token (str): the item to add into the Vocabulary
        Returns:
            index (int): the integer corresponding to the token
        """
        if token in self._token_to_idx:
            index = self._token_to_idx[token]
        else:
            index = len(self._token_to_idx)
            self._token_to_idx[token] = index
            self._idx_to_token[index] = token
            
        self.token_freq[token] = frequency
        return index

    def lookup_token(self, token):
        """Retrieve the index associated with the token 
          or the UNK index if token isn't present.
        
        Args:
            token (str): the token to look up 
        Returns:
            index (int): the index corresponding to the token
        Notes:
            `unk_index` needs to be >=0 (having been added into the Vocabulary) 
              for the UNK functionality 
        """
        return self._token_to_idx[token]

In [201]:
class SentenceSequenceVocabulary(Vocabulary):
    def __init__(self, token_to_idx=None, pad_token="<pad>", bos_token="<s>",
                 eos_token="</s>", unk_token="<unk>"):

        super(SentenceSequenceVocabulary, self).__init__(token_to_idx)
        self._pad_token = pad_token
        self._bos_token = bos_token
        self._eos_token = eos_token
        self._unk_token = unk_token

        self.pad_index = self.add_token(self._pad_token) # 0
        self.bos_index = self.add_token(self._bos_token) # 1
        self.eos_index = self.add_token(self._eos_token) # 2
        self.unk_index = self.add_token(self._unk_token) # 3

        self.sp_segmenter = None

    def to_serializable(self):
        contents = super(SentenceSequenceVocabulary, self).to_serializable()
        contents.update({'pad_token': self._pad_token , 'bos_token': self._bos_token,
                         'eos_token' : self._eos_token, 'unk_token': self._unk_token })
        return contents

    def load_vocab_file(self, vocab_file):
        with open(vocab_file, encoding='utf-8') as f:
            vo = [doc.strip().split("\t") for doc in f]

        for i, w in enumerate(vo):  # w[0]: token name, w[1]: token score
            self.add_token(w[0])    # add_token will skip duplicates

    def load_model_file(self, model_file):
        self.sp_segmenter = spm.SentencePieceProcessor()
        self.sp_segmenter.load(model_file)
        self.sp_segmenter.SetEncodeExtraOptions('bos:eos') # auto append bos eos tokens

    def lookup_token(self, token):
        """Retrieve the index associated with the token 
          or the UNK index if token isn't present.
        
        Args:
            token (str): the token to look up 
        Returns:
            index (int): the index corresponding to the token
        Notes:
            `unk_index` needs to be >=0 (having been added into the Vocabulary) 
              for the UNK functionality 
        """
        return self._token_to_idx[token]

### BPE Trainer and Segmenter

In [202]:
def get_vocabulary(fobj, is_dict=False):
    """Read text and return dictionary that encodes vocabulary
    """
    vocab = Counter()
    for i, line in enumerate(fobj):
        if is_dict:
            try:
                word, count = line.strip('\r\n ').split(' ')
            except:
                print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
                sys.exit(1)
            vocab[word] += int(count)
        else:
            for word in line.strip('\r\n ').split(' '):
                if word:
                    vocab[word] += 1
    return vocab

def update_pair_statistics(pair, changed, stats, indices):
    """Minimally update the indices and frequency of symbol pairs
    if we merge a pair of symbols, only pairs that overlap with occurrences
    of this pair are affected, and need to be updated.
    """
    stats[pair] = 0
    indices[pair] = defaultdict(int)
    first, second = pair
    new_pair = first+second
    for j, word, old_word, freq in changed:

        # find all instances of pair, and update frequency/indices around it
        i = 0
        while True:
            # find first symbol
            try:
                i = old_word.index(first, i)
            except ValueError:
                break
            # if first symbol is followed by second symbol, we've found an occurrence of pair (old_word[i:i+2])
            if i < len(old_word)-1 and old_word[i+1] == second:
                # assuming a symbol sequence "A B C", if "B C" is merged, reduce the frequency of "A B"
                if i:
                    prev = old_word[i-1:i+1]
                    stats[prev] -= freq
                    indices[prev][j] -= 1
                if i < len(old_word)-2:
                    # assuming a symbol sequence "A B C B", if "B C" is merged, reduce the frequency of "C B".
                    # however, skip this if the sequence is A B C B C, because the frequency of "C B" will be reduced by the previous code block
                    if old_word[i+2] != first or i >= len(old_word)-3 or old_word[i+3] != second:
                        nex = old_word[i+1:i+3]
                        stats[nex] -= freq
                        indices[nex][j] -= 1
                i += 2
            else:
                i += 1

        i = 0
        while True:
            try:
                # find new pair
                i = word.index(new_pair, i)
            except ValueError:
                break
            # assuming a symbol sequence "A BC D", if "B C" is merged, increase the frequency of "A BC"
            if i:
                prev = word[i-1:i+1]
                stats[prev] += freq
                indices[prev][j] += 1
            # assuming a symbol sequence "A BC B", if "B C" is merged, increase the frequency of "BC B"
            # however, if the sequence is A BC BC, skip this step because the count of "BC BC" will be incremented by the previous code block
            if i < len(word)-1 and word[i+1] != new_pair:
                nex = word[i:i+2]
                stats[nex] += freq
                indices[nex][j] += 1
            i += 1


def get_pair_statistics(vocab):
    """Count frequency of all symbol pairs, and create index"""

    # data structure of pair frequencies
    stats = defaultdict(int)

    #index from pairs to words
    indices = defaultdict(lambda: defaultdict(int))

    for i, (word, freq) in enumerate(vocab):
        prev_char = word[0]
        for char in word[1:]:
            stats[prev_char, char] += freq
            indices[prev_char, char][i] += 1
            prev_char = char

    return stats, indices


def replace_pair(pair, vocab, indices):
    """Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'"""
    first, second = pair
    pair_str = ''.join(pair)
    pair_str = pair_str.replace('\\','\\\\')
    changes = []
    pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
    if sys.version_info < (3, 0):
        iterator = indices[pair].iteritems()
    else:
        iterator = indices[pair].items()
    for j, freq in iterator:
        if freq < 1:
            continue
        word, freq = vocab[j]
        new_word = ' '.join(word)
        new_word = pattern.sub(pair_str, new_word)
        new_word = tuple(new_word.split(' '))

        vocab[j] = (new_word, freq)
        changes.append((j, new_word, word, freq))

    return changes

def prune_stats(stats, big_stats, threshold):
    """Prune statistics dict for efficiency of max()
    The frequency of a symbol pair never increases, so pruning is generally safe
    (until we the most frequent pair is less frequent than a pair we previously pruned)
    big_stats keeps full statistics for when we need to access pruned items
    """
    for item,freq in list(stats.items()):
        if freq < threshold:
            del stats[item]
            if freq < 0:
                big_stats[item] += freq
            else:
                big_stats[item] = freq


def learn_bpe(data, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
    """Learn num_symbols BPE operations from vocabulary, and write to list of tuples.
    """
    vocab = get_vocabulary(data, is_dict) # count words in text

    vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
    
    sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)

    stats, indices = get_pair_statistics(sorted_vocab)

    big_stats = copy.deepcopy(stats)

    if total_symbols:
        uniq_char_internal = set()
        uniq_char_final = set()
        for word in vocab:
            for char in word[:-1]:
                uniq_char_internal.add(char)
            uniq_char_final.add(word[-1])
        sys.stderr.write('Number of word-internal characters: {0}\n'.format(len(uniq_char_internal)))
        sys.stderr.write('Number of word-final characters: {0}\n'.format(len(uniq_char_final)))
        sys.stderr.write('Reducing number of merge operations by {0}\n'.format(len(uniq_char_internal) + len(uniq_char_final)))
        num_symbols -= len(uniq_char_internal) + len(uniq_char_final)
        sys.stderr.write('Number of symbols left: {0}\n'.format(num_symbols))

    bpe_codes = []
    # threshold is inspired by Zipfian assumption, but should only affect speed
    threshold = max(stats.values()) / 10
    for i in range(num_symbols):
        if stats:
            most_frequent = max(stats, key=lambda x: (stats[x], x))

        # we probably missed the best pair because of pruning; go back to full statistics
        if not stats or (i and stats[most_frequent] < threshold):
            prune_stats(stats, big_stats, threshold)
            stats = copy.deepcopy(big_stats)
            most_frequent = max(stats, key=lambda x: (stats[x], x))
            # threshold is inspired by Zipfian assumption, but should only affect speed
            threshold = stats[most_frequent] * i/(i+10000.0)
            prune_stats(stats, big_stats, threshold)

        if stats[most_frequent] < min_frequency:
            sys.stderr.write('no pair has frequency >= {0}. Stopping\n'.format(min_frequency))
            break

        if verbose:
            sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent]))

        bpe_codes.append((most_frequent[0], most_frequent[1]))
        changes = replace_pair(most_frequent, sorted_vocab, indices)
        update_pair_statistics(most_frequent, changes, stats, indices)
        stats[most_frequent] = 0
        if not i % 100:
            prune_stats(stats, big_stats, threshold)
            
    if verbose:
        print(bpe_codes)
    
    return bpe_codes

class BPE(object):

    def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None):
        self.version = (0, 2) #Hardcode
        self.bpe_codes = codes
        
        # some hacking to deal with duplicates (only consider first instance)
        self.bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(self.bpe_codes)))])
        
        self.bpe_codes_reverse = dict([(pair[0] + pair[1], pair) for pair,i in self.bpe_codes.items()])

        self.separator = separator

        self.vocab = vocab

        self.glossaries = glossaries if glossaries else []

        self.glossaries_regex = re.compile('^({})$'.format('|'.join(glossaries))) if glossaries else None

        self.cache = {}

    def process_line(self, line, dropout=0):
        """segment line, dealing with leading and trailing whitespace"""

        out = ""

        leading_whitespace = len(line)-len(line.lstrip('\r\n '))
        if leading_whitespace:
            out += line[:leading_whitespace]

        out += self.segment(line, dropout)

        trailing_whitespace = len(line)-len(line.rstrip('\r\n '))
        if trailing_whitespace and trailing_whitespace != len(line):
            out += line[-trailing_whitespace:]
        
        # print(out)
        return out

    def segment(self, sentence, dropout=0):
        """segment single sentence (whitespace-tokenized string) with BPE encoding"""
        segments = self.segment_tokens(sentence.strip('\r\n ').split(' '), dropout)
        return ' '.join(segments)

    def segment_tokens(self, tokens, dropout=0):
        """segment a sequence of tokens with BPE encoding"""
        output = []
        for word in tokens:
            # eliminate double spaces
            if not word:
                continue
            new_word = [out for segment in self._isolate_glossaries(word)
                        for out in encode(segment,
                                          self.bpe_codes,
                                          self.bpe_codes_reverse,
                                          self.vocab,
                                          self.separator,
                                          self.version,
                                          self.cache,
                                          self.glossaries_regex,
                                          dropout)]

            for item in new_word[:-1]:
                output.append(item + self.separator)
            output.append(new_word[-1])

        return output

    def _isolate_glossaries(self, word):
        word_segments = [word]
        for gloss in self.glossaries:
            word_segments = [out_segments for segment in word_segments
                                 for out_segments in isolate_glossary(segment, gloss)]
        return word_segments

def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, dropout=0):
    """Encode word based on list of BPE merge operations, which are applied consecutively
    """

    if not dropout and orig in cache:
        return cache[orig]

    if glossaries_regex and glossaries_regex.match(orig):
        cache[orig] = (orig,)
        return (orig,)

    if len(orig) == 1:
        return orig

    if version == (0, 1):
        word = list(orig) + ['</w>']
    elif version == (0, 2): # more consistent handling of word-final segments
        word = list(orig[:-1]) + [orig[-1] + '</w>']
    else:
        raise NotImplementedError

    while len(word) > 1:

        # get list of symbol pairs; optionally apply dropout
        pairs = [(bpe_codes[pair],i,pair) for (i,pair) in enumerate(zip(word, word[1:])) if (not dropout or random.random() > dropout) and pair in bpe_codes]

        if not pairs:
            break

        #get first merge operation in list of BPE codes
        bigram = min(pairs)[2]

        # find start position of all pairs that we want to merge
        positions = [i for (rank,i,pair) in pairs if pair == bigram]

        i = 0
        new_word = []
        bigram = ''.join(bigram)
        for j in positions:
            # merges are invalid if they start before current position. This can happen if there are overlapping pairs: (x x x -> xx x)
            if j < i:
                continue
            new_word.extend(word[i:j]) # all symbols before merged pair
            new_word.append(bigram) # merged pair
            i = j+2 # continue after merged pair
        new_word.extend(word[i:]) # add all symbols until end of word
        word = new_word

    # don't print end-of-word symbols
    if word[-1] == '</w>':
        word = word[:-1]
    elif word[-1].endswith('</w>'):
        word[-1] = word[-1][:-4]

    word = tuple(word)
    if vocab:
        word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator)

    cache[orig] = word
    return word

def recursive_split(segment, bpe_codes, vocab, separator, final=False):
    """Recursively split segment into smaller units (by reversing BPE merges)
    until all units are either in-vocabulary, or cannot be split futher."""

    try:
        if final:
            left, right = bpe_codes[segment + '</w>']
            right = right[:-4]
        else:
            left, right = bpe_codes[segment]
    except:
        #sys.stderr.write('cannot split {0} further.\n'.format(segment))
        yield segment
        return

    if left + separator in vocab:
        yield left
    else:
        for item in recursive_split(left, bpe_codes, vocab, separator, False):
            yield item

    if (final and right in vocab) or (not final and right + separator in vocab):
        yield right
    else:
        for item in recursive_split(right, bpe_codes, vocab, separator, final):
            yield item

def check_vocab_and_split(orig, bpe_codes, vocab, separator):
    """Check for each segment in word if it is in-vocabulary,
    and segment OOV segments into smaller units by reversing the BPE merge operations"""

    out = []

    for segment in orig[:-1]:
        if segment + separator in vocab:
            out.append(segment)
        else:
            #sys.stderr.write('OOV: {0}\n'.format(segment))
            for item in recursive_split(segment, bpe_codes, vocab, separator, False):
                out.append(item)

    segment = orig[-1]
    if segment in vocab:
        out.append(segment)
    else:
        #sys.stderr.write('OOV: {0}\n'.format(segment))
        for item in recursive_split(segment, bpe_codes, vocab, separator, True):
            out.append(item)

    return out


def read_vocabulary(vocab_file, threshold):
    """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold.
    """

    vocabulary = set()

    for line in vocab_file:
        word, freq = line.strip('\r\n ').split(' ')
        freq = int(freq)
        if threshold == None or freq >= threshold:
            vocabulary.add(word)

    return vocabulary

def isolate_glossary(word, glossary):
    """
    Isolate a glossary present inside a word.

    Returns a list of subwords. In which all 'glossary' glossaries are isolated 

    For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is:
        ['1934', 'USA', 'B', 'USA']
    """
    # regex equivalent of (if word == glossary or glossary not in word)
    if re.match('^'+glossary+'$', word) or not re.search(glossary, word):
        return [word]
    else:
        segments = re.split(r'({})'.format(glossary), word)
        segments, ending = segments[:-1], segments[-1]
        segments = list(filter(None, segments)) # Remove empty strings in regex group.
        return segments + [ending.strip('\r\n ')] if ending != '' else segments
    
### CALL THIS FUNCTION TO TRAIN AND CREATE VOCAB ###             
def learn_joint_bpe_and_vocab(data, symbols, min_frequency, total_symbols, dropout=0, separator="@@", verbose=False):
    # get combined vocabulary of all input texts
    full_vocab = get_vocabulary(data)
    
    vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]

    # learn BPE on combined vocabulary
    bpe_codes = learn_bpe(vocab_list, symbols, min_frequency, verbose, is_dict=True, total_symbols=total_symbols)

    bpe = BPE(bpe_codes, separator=separator)
    
    # apply BPE to each training corpus and get vocabulary
    segments = []
    
    for line in data:
        line_segment = bpe.segment(line, dropout=dropout).strip()
        segments.append(line_segment)
              
    vocab = get_vocabulary(segments)
    
    subword_vocab = SubwordSequenceVocabulary()
    subword_vocab.add_bpe_codes_list(bpe_codes)
    
    for key, freq in sorted(vocab.items(), key=lambda x: x[1], reverse=True):
        subword_vocab.add_bpe_token(key, freq)

    return subword_vocab

### The Vectorizer

In [203]:
class NewsVectorizer(object):
    """ The Vectorizer which coordinates the Vocabularies and puts them to use"""    
    def __init__(self, title_vocab, category_vocab, mode): # , title_vocab
        self.mode = mode
        self.title_vocab = title_vocab
        self.category_vocab = category_vocab
        if "bpe" in mode:
            self.bpe = BPE(self.title_vocab.bpe_codes, vocab = self.title_vocab.token_freq)
                
    def vectorize(self, title, max_seq_length, max_word_length, max_sent_length):
        """
        Args:
            word (str): a word
            vector_length (int): an argument for forcing the length of index vector
        Returns:
            the vetorized title (numpy.array)
        """
        if self.mode == "word":
            indices = [self.title_vocab.begin_seq_index]
            indices.extend(self.title_vocab.lookup_token(token) 
                          for token in title.split(" "))
            indices.append(self.title_vocab.end_seq_index)

            vector_length = max_seq_length
            if vector_length < 0: 
                vector_length = len(indices)

            out_vectors = np.zeros(vector_length, dtype=np.int64)
            out_vectors[:len(indices)] = indices
            out_vectors[len(indices):] = self.title_vocab.mask_index

        elif self.mode == "char":
            words = title.split(" ")
            if len(words) > max_seq_length:
                words = words[:max_seq_length]
                
            out_vectors = []
            for word in words:
                word_indices = [self.title_vocab.lookup_token(token) for token in word]
                if len(word_indices) > max_word_length:
                    word_indices = word_indices[:max_word_length]

                out_vector = np.zeros(max_word_length, dtype=np.int64)
                out_vector[:len(word_indices)] = word_indices
                if len(word_indices) < max_word_length:
                    out_vector[len(word_indices):] = self.title_vocab.pad_index
                
                out_vectors.append(out_vector)
                
            if len(words) < max_seq_length:
                null_word_emb = np.array([self.title_vocab.pad_index] * max_word_length, dtype=np.int64)
                for _ in range(max_seq_length - len(words)):
                    out_vectors.append(null_word_emb)

            out_vectors = np.array(out_vectors, dtype=np.int64)

        elif self.mode == "bpe-char":
            words = title.strip().split() # segment title into words
            out_vectors = []  
            for titleword in words:
                encoded = self.bpe.process_line(titleword.strip())
                subwords = encoded.strip().split()
                word_indices = [self.title_vocab.lookup_token(token) for token in subwords]
                if len(word_indices) > max_word_length:
                    word_indices = word_indices[:max_word_length]
                    
                out_vector = np.zeros(max_word_length, dtype=np.int64)
                out_vector[:len(word_indices)] = word_indices
                if len(word_indices) < max_word_length:
                    out_vector[len(word_indices):] = self.title_vocab.pad_index
                
                out_vectors.append(out_vector) # append each subword as a rep of each word

            if len(words) < max_seq_length:
                null_word_emb = np.array([self.title_vocab.pad_index] * max_word_length, dtype=np.int64)
                for _ in range(max_seq_length - len(words)):
                    out_vectors.append(null_word_emb)

            out_vectors = np.array(out_vectors, dtype=np.int64)

        elif self.mode == "bpe-word":
            indices = [self.title_vocab.begin_seq_index]
            encoded = self.bpe.process_line(title.strip())
            indices.extend(self.title_vocab.lookup_token(token) 
                          for token in encoded.strip().split())
            indices.append(self.title_vocab.end_seq_index)

            vector_length = max_sent_length
            out_vectors = np.zeros(vector_length, dtype=np.int64)
            out_vectors[:len(indices)] = indices
            out_vectors[len(indices):] = self.title_vocab.pad_index

        elif self.mode == "sent":
            # words = self.title_vocab.sp_segmenter.encode_as_pieces(title) # for debugging
            # spm already configured to auto add bos and eos to ids
            # encodes entire sentence into tokens
            indices = self.title_vocab.sp_segmenter.encode_as_ids(title) 

            vector_length = max_sent_length
            out_vectors = np.zeros(vector_length, dtype=np.int64)
            out_vectors[:len(indices)] = indices
            out_vectors[len(indices):] = self.title_vocab.pad_index
          
        return out_vectors

    @classmethod
    def from_dataframe(cls, news_df, mode, vocab_size, cutoff=25, delete_files=False):
        """Instantiate the vectorizer from the dataset dataframe
        
        Args:
            news_df (pandas.DataFrame): the target dataset
            cutoff (int): frequency threshold for including in Vocabulary 
        Returns:
            an instance of the NewsVectorizer
        """
        category_vocab = Vocabulary()        
        for category in sorted(set(news_df.category)):
            category_vocab.add_token(category)

        if mode == "word":
            word_counts = Counter()
            for title in news_df.title:
                for token in title.split(" "):
                    if token not in string.punctuation:
                        word_counts[token] += 1
            
            title_vocab = WordSequenceVocabulary()
            for word, word_count in word_counts.items():
                if word_count >= cutoff:        
                    title_vocab.add_token(word)

        elif mode == "char":
            title_vocab = CharacterSequenceVocabulary()
            for title in news_df.title:
                for token in title.split(" "):
                    title_vocab.add_many(list(token))

        elif "bpe" in mode:
            title_vocab = SubwordSequenceVocabulary()
            total_symbols = False
            separator = "@@"
            min_frequency = 0
            title_vocab = learn_joint_bpe_and_vocab(news_df.title, vocab_size, 
                                                    min_frequency, total_symbols, 
                                                    dropout=0, separator=separator, verbose=False)
            for title in news_df.title:
                for token in title.split(" "):
                    title_vocab.add_many(list(token))
            
#             title_vocab.compress_dict(vocab_size)

        elif mode == "sent":
            title_vocab = SentenceSequenceVocabulary()
            news_df_title = news_df.title

            # create a temporary text file from dataframe for spm input 
            tmp = tempfile.NamedTemporaryFile(delete=False) 
            tmp.close()
            with open(tmp.name, 'w') as tmpout:
                for key, value in news_df_title.iteritems():
                    tmpout.write(value + '\n')

            with open(tmp.name, 'r') as tmpin:
                if not delete_files:
                    prefix = "spm/" # create a dir to store
                    handle_dirs(prefix)
                    handle_dirs(prefix)
                
                prefix = prefix + str(vocab_size) + "_train_news_spm"
                model_name = prefix + ".model"
                vocab_name = prefix + ".vocab"

                character_coverage = 1.0  # to reduce character set 
                model_type ="bpe"     # choose from unigram (default), bpe, char, or word

                templates= "--input={} --pad_id={} --bos_id={} --eos_id={} --unk_id={} \
                            --model_prefix={} --vocab_size={} \
                            --character_coverage={} --model_type={}"

                cmd = templates.format(tmp.name,
                                title_vocab.pad_index,
                                title_vocab.bos_index,
                                title_vocab.eos_index,
                                title_vocab.unk_index,
                                prefix, vocab_size,
                                character_coverage, model_type)
                
                spm.SentencePieceTrainer.Train(cmd) # run the trainer on trainset
                
            os.remove(tmp.name) # delete temp txt file after training
            
            title_vocab.load_model_file(model_name) # load the model into spm
            title_vocab.load_vocab_file(vocab_name) # load the vocab file for saving

            if delete_files:
                os.remove(model_name)
                os.remove(vocab_name)
     
        return cls(title_vocab, category_vocab, mode) # title_char_vocab, 

    @classmethod
    def from_serializable(cls, contents, mode):
        if mode == "word":
            title_vocab = \
               WordSequenceVocabulary.from_serializable(contents['title_vocab'])
        elif mode == "char":
            title_vocab = \
              CharacterSequenceVocabulary.from_serializable(contents['title_vocab'])
        elif "bpe" in mode:
            title_vocab = \
              SubwordSequenceVocabulary.from_serializable(contents['title_vocab'])      
        elif mode == "sent":
            title_vocab = \
              SentenceSequenceVocabulary.from_serializable(contents['title_vocab'])

        category_vocab =  \
            Vocabulary.from_serializable(contents['category_vocab'])

        return cls(title_vocab=title_vocab, category_vocab=category_vocab) # title_vocab=title_vocab, 

    def to_serializable(self):
        return {'title_vocab': self.title_vocab.to_serializable(),
                'category_vocab': self.category_vocab.to_serializable()} 
    # 'title_vocab': self.title_vocab.to_serializable(),

### The Dataset

In [204]:
class NewsDataset(Dataset):
    def __init__(self, news_df, vectorizer):
        """
        Args:
            news_df (pandas.DataFrame): the dataset
            vectorizer (NewsVectorizer): vectorizer instatiated from dataset
        """
        self.news_df = news_df
        self._vectorizer = vectorizer

        # +1 if only using begin_seq, +2 if using both begin and end seq tokens
        measure_len = lambda context: len(context.split(" "))
        self._max_seq_length = max(map(measure_len, news_df.title)) + 2
        
        self._max_word_length = 0
        self._max_sent_length = 0
        for title in news_df.title:
            if len(title) > self._max_sent_length:
                self._max_sent_length = len(title)
            for token in title.split(" "):
                if len(token) > self._max_word_length:
                    self._max_word_length = len(token)
                    
        self.train_df = self.news_df[self.news_df.split=='train']
        self.train_size = len(self.train_df)

        self.val_df = self.news_df[self.news_df.split=='val']
        self.validation_size = len(self.val_df)

        self.test_df = self.news_df[self.news_df.split=='test']
        self.test_size = len(self.test_df)

        self._lookup_dict = {'train': (self.train_df, self.train_size),
                             'val': (self.val_df, self.validation_size),
                             'test': (self.test_df, self.test_size)}

        self.set_split('train')

        # Class weights
        class_counts = news_df.category.value_counts().to_dict()
        def sort_key(item):
            return self._vectorizer.category_vocab.lookup_token(item[0])
        sorted_counts = sorted(class_counts.items(), key=sort_key)
        frequencies = [count for _, count in sorted_counts]
        self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)
        
        
    @classmethod
    def load_dataset_and_make_vectorizer(cls, news_csv, mode, vocab_size):
        """Load dataset and make a new vectorizer from scratch
        
        Args:
            surname_csv (str): location of the dataset
        Returns:
            an instance of SurnameDataset
        """
        news_df = pd.read_csv(news_csv)
        train_news_df = news_df[news_df.split=='train']
        return cls(news_df, NewsVectorizer.from_dataframe(train_news_df, mode = mode,
                                                          vocab_size = vocab_size))

    @classmethod
    def load_dataset_and_load_vectorizer(cls, news_csv, vectorizer_filepath):
        """Load dataset and the corresponding vectorizer. 
        Used in the case in the vectorizer has been cached for re-use
        
        Args:
            surname_csv (str): location of the dataset
            vectorizer_filepath (str): location of the saved vectorizer
        Returns:
            an instance of SurnameDataset
        """
        news_df = pd.read_csv(news_csv)
        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
        return cls(news_csv, vectorizer)

    @staticmethod
    def load_vectorizer_only(vectorizer_filepath):
        """a static method for loading the vectorizer from file
        
        Args:
            vectorizer_filepath (str): the location of the serialized vectorizer
        Returns:
            an instance of SurnameVectorizer
        """
        with open(vectorizer_filepath) as fp:
            return NameVectorizer.from_serializable(json.load(fp))

    def save_vectorizer(self, vectorizer_filepath):
        """saves the vectorizer to disk using json
        
        Args:
            vectorizer_filepath (str): the location to save the vectorizer
        """
        with open(vectorizer_filepath, "w") as fp:
            json.dump(self._vectorizer.to_serializable(), fp)

    def get_vectorizer(self):
        """ returns the vectorizer """
        return self._vectorizer

    def set_split(self, split="train"):
        """ selects the splits in the dataset using a column in the dataframe """
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]

    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """the primary entry point method for PyTorch datasets
        
        Args:
            index (int): the index to the data point 
        Returns:
            a dictionary holding the data point's features (x_data) and label (y_target)
        """
        row = self._target_df.iloc[index]

        title_vector = \
            self._vectorizer.vectorize(row.title, self._max_seq_length, 
                                       self._max_word_length, self._max_sent_length)

        category_index = \
            self._vectorizer.category_vocab.lookup_token(row.category)

        return {'x_data': title_vector,
                'y_target': category_index}

    def get_num_batches(self, batch_size):
        """Given a batch size, return the number of batches in the dataset
        
        Args:
            batch_size (int)
        Returns:
            number of batches in the dataset
        """
        return len(self) // batch_size

def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu"): 
    """
    A generator function which wraps the PyTorch DataLoader. It will 
      ensure each tensor is on the write device location.
    """
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

## The Model: NewsClassifier

In [205]:
class NewsClassifier(nn.Module):
    def __init__(self, model_mode, char_embedding_size, word_embedding_size, 
                 char_num_embeddings, word_num_channels, 
                 char_kernel_size, hidden_dim, num_classes, dropout_p, 
                 char_pretrained_embeddings=None, padding_idx=0):
        """
        Args:
            embedding_size (int): size of the embedding vectors
            num_embeddings (int): number of embedding vectors
            filter_width (int): width of the convolutional kernels
            num_channels (int): number of convolutional kernels per layer
            hidden_dim (int): the size of the hidden dimension
            num_classes (int): the number of classes in classification
            dropout_p (float): a dropout parameter 
            pretrained_embeddings (numpy.array): previously trained word embeddings
                default is None. If provided, 
            padding_idx (int): an index representing a null position
        """
        super(NewsClassifier, self).__init__()
        print(("model_mode={}, char_embedding_size={}, word_embedding_size={}, char_num_embeddings={}, word_num_channels={}, " \
              + "char_kernel_size={}, hidden_dim={}, num_classes={}" \
              + "").format(model_mode, char_embedding_size, word_embedding_size, char_num_embeddings, word_num_channels, 
                 char_kernel_size, hidden_dim, num_classes))
        self.model_mode = model_mode

        if "word" in self.model_mode or self.model_mode == "sent":
            if char_pretrained_embeddings is None: # token_emb 
                self.char_emb = nn.Embedding(embedding_dim=word_embedding_size,
                                        num_embeddings=char_num_embeddings,
                                        padding_idx=padding_idx)        
            else:
                char_pretrained_embeddings = torch.from_numpy(char_pretrained_embeddings).float()
                self.char_emb = nn.Embedding(embedding_dim=word_embedding_size,
                                        num_embeddings=char_num_embeddings,
                                        padding_idx=padding_idx,
                                        _weight=char_pretrained_embeddings)
        elif "char" in self.model_mode: 
            if char_pretrained_embeddings is None: # char_emb
                self.char_emb = nn.Embedding(embedding_dim=char_embedding_size,
                                            num_embeddings=char_num_embeddings,
                                            padding_idx=padding_idx)        
            else:
                char_pretrained_embeddings = torch.from_numpy(char_pretrained_embeddings).float()
                self.char_emb = nn.Embedding(embedding_dim=char_embedding_size,
                                            num_embeddings=char_num_embeddings,
                                            padding_idx=padding_idx,
                                            _weight=char_pretrained_embeddings)
        
            self.char_convnet = nn.Sequential(
                nn.Conv1d(in_channels=char_embedding_size, out_channels=word_embedding_size, kernel_size=char_kernel_size),
                nn.ReLU()
            )

        self.word_convnet = nn.Sequential(
            nn.Conv1d(in_channels=word_embedding_size, 
                   out_channels=word_num_channels, kernel_size=3),
            nn.ELU(),
            nn.Conv1d(in_channels=word_num_channels, out_channels=word_num_channels, 
                   kernel_size=3, stride=2),
            nn.ELU(),
            nn.Conv1d(in_channels=word_num_channels, out_channels=word_num_channels, 
                   kernel_size=3, stride=2),
            nn.ELU(),
            nn.Conv1d(in_channels=word_num_channels, out_channels=word_num_channels, 
                   kernel_size=3),
            nn.ELU()
        )
        
        self._dropout_p = dropout_p
        self.fc1 = nn.Linear(word_num_channels, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x_in, apply_softmax=False):
        """The forward pass of the classifier
        
        Args:
            x_in (torch.Tensor): an input data tensor. 
                x_in.shape should be (batch, dataset._max_seq_length)
            apply_softmax (bool): a flag for the softmax activation
                should be false if used with the Cross Entropy losses
        Returns:
            the resulting tensor. tensor.shape should be (batch, num_classes)
        """

        # embed and permute so features are channels
        # x_in: (batch_size, max_seq_size, max_word_size)
        # x_emb: (batch_size, max_seq_size, max_word_size, char_embedding_size)
        x_emb = self.char_emb(x_in)

        if "char" in self.model_mode: # char or bpe-char
            batch_size = x_emb.size(dim=0)
            max_seq_size = x_emb.size(dim=1)
            max_word_size = x_emb.size(dim=2)
            char_embedding_size = x_emb.size(dim=3)
            # x_reshaped: (batch_size * max_seq_size, char_embedding_size, max_word_size)
            x_reshaped = x_emb.view(batch_size * max_seq_size, max_word_size, char_embedding_size).permute(0, 2, 1)

            # x_conv: (batch_size * max_seq_size, word_embedding_size, max_word_size - char_kernel_size + 1)
            x_conv = self.char_convnet(x_reshaped)
            # x_conv_out: (batch_size * max_seq_size, word_embedding_size)
            word_embedding_size = x_conv.size(dim=1)
            remaining_size = x_conv.size(dim=2)
            x_conv_out = F.max_pool1d(x_conv, remaining_size).squeeze(dim=2)
            x_embedding = x_conv_out.view(batch_size, max_seq_size, word_embedding_size)
        elif "word" in self.model_mode or self.model_mode == "sent":
            x_embedding = x_emb
        
        features = self.word_convnet(x_embedding.permute(0, 2, 1))

        # average and remove the extra dimension
        remaining_size = features.size(dim=2)
        features = F.avg_pool1d(features, remaining_size).squeeze(dim=2)
        features = F.dropout(features, p=self._dropout_p)
        
        # mlp classifier
        intermediate_vector = F.relu(F.dropout(self.fc1(features), p=self._dropout_p))
        prediction_vector = self.fc2(intermediate_vector)

        if apply_softmax:
            prediction_vector = F.softmax(prediction_vector, dim=1)

        return prediction_vector

## Training Routine

### Helper functions

In [206]:
def make_train_state(args):
    return {'stop_early': False,
            'early_stopping_step': 0,
            'early_stopping_best_val': 1e8,
            'learning_rate': args.learning_rate,
            'epoch_index': 0,
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'test_loss': -1,
            'test_acc': -1,
            'model_filename': args.model_state_file}

def update_train_state(args, model, train_state):
    """Handle the training state updates.

    Components:
     - Early Stopping: Prevent overfitting.
     - Model Checkpoint: Model is saved if the model is better

    :param args: main arguments
    :param model: model to train
    :param train_state: a dictionary representing the training state values
    :returns:
        a new train_state
    """

    # Save one model at least
    if train_state['epoch_index'] == 0:
        torch.save(model.state_dict(), train_state['model_filename'])
        train_state['stop_early'] = False

    # Save model if performance improved
    elif train_state['epoch_index'] >= 1:
        loss_tm1, loss_t = train_state['val_loss'][-2:]

        # If loss worsened
        if loss_t >= train_state['early_stopping_best_val']:
            # Update step
            train_state['early_stopping_step'] += 1
        # Loss decreased
        else:
            # Save the best model
            if loss_t < train_state['early_stopping_best_val']:
                torch.save(model.state_dict(), train_state['model_filename'])

            # Reset early stopping step
            train_state['early_stopping_step'] = 0

        # Stop early ?
        train_state['stop_early'] = \
            train_state['early_stopping_step'] >= args.early_stopping_criteria

    return train_state

def compute_accuracy(y_pred, y_target):
    _, y_pred_indices = y_pred.max(dim=1)
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100

#### general utilities

In [207]:
def set_seed_everywhere(seed, cuda):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)

def handle_dirs(dirpath):
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)
        
def load_glove_from_file(glove_filepath):
    """
    Load the GloVe embeddings 
    
    Args:
        glove_filepath (str): path to the glove embeddings file 
    Returns:
        word_to_index (dict), embeddings (numpy.ndarary)
    """

    word_to_index = {}
    embeddings = []
    with open(glove_filepath, "r", encoding='utf8') as fp:
        for index, line in enumerate(fp):
            line = line.split(" ") # each line: word num1 num2 ...
            word_to_index[line[0]] = index # word = line[0] 
            embedding_i = np.array([float(val) for val in line[1:]])
            embeddings.append(embedding_i)
    return word_to_index, np.stack(embeddings)

def make_embedding_matrix(glove_filepath, words):
    """
    Create embedding matrix for a specific set of words.
    
    Args:
        glove_filepath (str): file path to the glove embeddigns
        words (list): list of words in the dataset
    """
    word_to_idx, glove_embeddings = load_glove_from_file(glove_filepath)
    embedding_size = glove_embeddings.shape[1]
    
    final_embeddings = np.zeros((len(words), embedding_size))

    for i, word in enumerate(words):
        if word in word_to_idx:
            final_embeddings[i, :] = glove_embeddings[word_to_idx[word]]
        else:
            embedding_i = torch.ones(1, embedding_size)
            torch.nn.init.xavier_uniform_(embedding_i)
            final_embeddings[i, :] = embedding_i

    return final_embeddings

### Settings and some prep work

In [208]:
from argparse import Namespace

In [209]:
args = Namespace(
    # Data and Path hyper parameters
    news_csv="data/ag_news/news_with_splits.csv",
    vectorizer_file="vectorizer.json",
    model_state_file="model.pth",
    save_dir="model_storage/ch5/document_classification",
    model_mode = "sent", # choose from word, char, bpe-char, bpe-word, sent
    # Model hyper parameters
    glove_filepath='data/glove/glove.6B.100d.txt', 
    use_glove=False,
    word_embedding_size=100, 
    char_embedding_size=50,
    char_kernel_size=5,
    hidden_dim=100, 
    word_num_channels=100,
    # Training hyper parameter
    seed=1337, 
    learning_rate=0.001, 
    weight_decay=1e-5, # Newly added to regularize variance
    dropout_p=0.2, #0.1
    batch_size= 128,
    num_epochs=100, 
    early_stopping_criteria=5, 
    vocab_size = 10000, # 1000, 3000, 10000
    # Runtime option
    cuda=True, 
    catch_keyboard_interrupt=True, 
    reload_from_files=False,
    expand_filepaths_to_save_dir=True
) 

if args.expand_filepaths_to_save_dir:
    args.vectorizer_file = args.model_mode + "_" + args.vectorizer_file
    args.model_state_file = args.model_mode + "_" + args.model_state_file
    if "bpe" in args.model_mode or args.model_mode == "sent":
        args.vectorizer_file = str(args.vocab_size) + "_" + args.vectorizer_file
        args.model_state_file = str(args.vocab_size) + "_" + args.model_state_file

    args.vectorizer_file = os.path.join(args.save_dir,
                                        args.vectorizer_file)

    args.model_state_file = os.path.join(args.save_dir,
                                         args.model_state_file)
    
    print("Expanded filepaths: ")
    print("\t{}".format(args.vectorizer_file))
    print("\t{}".format(args.model_state_file))
    
# Check CUDA
if not torch.cuda.is_available():
    args.cuda = False
    
args.device = torch.device("cuda" if args.cuda else "cpu")
print("Using CUDA: {}".format(args.cuda))

# Set seed for reproducibility
set_seed_everywhere(args.seed, args.cuda)

# handle dirs
handle_dirs(args.save_dir)

Expanded filepaths: 
	model_storage/ch5/document_classification/10000_sent_vectorizer.json
	model_storage/ch5/document_classification/10000_sent_model.pth
Using CUDA: True


### Initializations

In [210]:
args.use_glove = False

In [211]:
if args.reload_from_files:
    # training from a checkpoint
    dataset = NewsDataset.load_dataset_and_load_vectorizer(args.news_csv,
                                                           args.vectorizer_file)
else:
    # create dataset and vectorizer
    dataset = NewsDataset.load_dataset_and_make_vectorizer(args.news_csv, 
                                                           mode = args.model_mode, 
                                                           vocab_size = args.vocab_size)
    dataset.save_vectorizer(args.vectorizer_file)
vectorizer = dataset.get_vectorizer()
print("Title vocabulary size:", len(vectorizer.title_vocab))
print("Max sentence length:", dataset._max_sent_length)
print("Max sequence length:", dataset._max_seq_length)
print("Max word length:", dataset._max_word_length)

# Use GloVe or randomly initialized embeddings
if args.use_glove:
    words = vectorizer.title_vocab._token_to_idx.keys()
    embeddings = make_embedding_matrix(glove_filepath=args.glove_filepath, 
                                       words=words)
    print("Using pre-trained embeddings")
else:
    print("Not using pre-trained embeddings")
    embeddings = None

classifier = NewsClassifier(model_mode = args.model_mode,
                            char_embedding_size=args.char_embedding_size, 
                            word_embedding_size=args.word_embedding_size,
                            char_num_embeddings=len(vectorizer.title_vocab),
                            word_num_channels=args.word_num_channels,
                            char_kernel_size=args.char_kernel_size,
                            hidden_dim=args.hidden_dim, 
                            num_classes=len(vectorizer.category_vocab), 
                            dropout_p=args.dropout_p,
                            char_pretrained_embeddings=embeddings,
                            padding_idx=0)
# print(classifier)

Title vocabulary size: 10000
Max sentence length: 115
Max sequence length: 21
Max word length: 35
Not using pre-trained embeddings
model_mode=sent, char_embedding_size=50, word_embedding_size=100, char_num_embeddings=10000, word_num_channels=100, char_kernel_size=5, hidden_dim=100, num_classes=4


### Training loop

In [212]:
classifier = classifier.to(args.device)
dataset.class_weights = dataset.class_weights.to(args.device)
    
loss_func = nn.CrossEntropyLoss(dataset.class_weights)
optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                           mode='min', factor=0.5,
                                           patience=1, verbose=True) #turned on verbose

train_state = make_train_state(args)

epoch_bar = tqdm(desc='training routine', 
                      total=args.num_epochs,
                      position=0)

dataset.set_split('train')
train_bar = tqdm(desc='split=train',
                      total=dataset.get_num_batches(args.batch_size), 
                      position=1, 
                      leave=True)
dataset.set_split('val')
val_bar = tqdm(desc='split=val',
                    total=dataset.get_num_batches(args.batch_size), 
                    position=1, 
                    leave=True)

try:
    for epoch_index in range(args.num_epochs):
        train_state['epoch_index'] = epoch_index

        # Iterate over training dataset

        # setup: batch generator, set loss and acc to 0, set train mode on

        dataset.set_split('train')
        batch_generator = generate_batches(dataset, 
                                           batch_size=args.batch_size, 
                                           device=args.device)
        running_loss = 0.0
        running_acc = 0.0
        classifier.train()

        for batch_index, batch_dict in enumerate(batch_generator):
            # the training routine is these 5 steps:

            # --------------------------------------
            # step 1. zero the gradients
            optimizer.zero_grad()

            # step 2. compute the output
            y_pred = classifier(batch_dict['x_data'])

            # step 3. compute the loss
            loss = loss_func(y_pred, batch_dict['y_target'])
            loss_t = loss.item()
            running_loss += (loss_t - running_loss) / (batch_index + 1)

            # step 4. use loss to produce gradients
            loss.backward()

            # step 5. use optimizer to take gradient step
            optimizer.step()
            # -----------------------------------------
            # compute the accuracy
            acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
            running_acc += (acc_t - running_acc) / (batch_index + 1)

            # update bar
            train_bar.set_postfix(loss=running_loss, acc=running_acc, 
                                  epoch=epoch_index)
            train_bar.update()

        train_state['train_loss'].append(running_loss)
        train_state['train_acc'].append(running_acc)

        # Iterate over val dataset

        # setup: batch generator, set loss and acc to 0; set eval mode on
        dataset.set_split('val')
        batch_generator = generate_batches(dataset, 
                                           batch_size=args.batch_size, 
                                           device=args.device)
        running_loss = 0.
        running_acc = 0.
        classifier.eval()

        for batch_index, batch_dict in enumerate(batch_generator):

            # compute the output
            y_pred =  classifier(batch_dict['x_data'])

            # step 3. compute the loss
            loss = loss_func(y_pred, batch_dict['y_target'])
            loss_t = loss.item()
            running_loss += (loss_t - running_loss) / (batch_index + 1)

            # compute the accuracy
            acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
            running_acc += (acc_t - running_acc) / (batch_index + 1)
            val_bar.set_postfix(loss=running_loss, acc=running_acc, 
                            epoch=epoch_index)
            val_bar.update()

        train_state['val_loss'].append(running_loss)
        train_state['val_acc'].append(running_acc)

        train_state = update_train_state(args=args, model=classifier,
                                         train_state=train_state)

        scheduler.step(train_state['val_loss'][-1])

        if train_state['stop_early']:
            break

        train_bar.n = 0
        val_bar.n = 0
        epoch_bar.update()
except KeyboardInterrupt:
    print("Exiting loop")


HBox(children=(FloatProgress(value=0.0, description='training routine', style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='split=train', max=656.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='split=val', max=140.0, style=ProgressStyle(description_wi…

Epoch     7: reducing learning rate of group 0 to 5.0000e-04.
Epoch     9: reducing learning rate of group 0 to 2.5000e-04.
Epoch    11: reducing learning rate of group 0 to 1.2500e-04.
Epoch    13: reducing learning rate of group 0 to 6.2500e-05.
Epoch    15: reducing learning rate of group 0 to 3.1250e-05.
Epoch    17: reducing learning rate of group 0 to 1.5625e-05.
Epoch    19: reducing learning rate of group 0 to 7.8125e-06.
Epoch    21: reducing learning rate of group 0 to 3.9063e-06.
Epoch    23: reducing learning rate of group 0 to 1.9531e-06.
Epoch    25: reducing learning rate of group 0 to 9.7656e-07.
Epoch    27: reducing learning rate of group 0 to 4.8828e-07.
Epoch    29: reducing learning rate of group 0 to 2.4414e-07.
Epoch    31: reducing learning rate of group 0 to 1.2207e-07.
Epoch    33: reducing learning rate of group 0 to 6.1035e-08.
Epoch    35: reducing learning rate of group 0 to 3.0518e-08.
Epoch    37: reducing learning rate of group 0 to 1.5259e-08.


In [215]:
# compute the loss & accuracy on the test set using the best available model

classifier.load_state_dict(torch.load(train_state['model_filename']))
        
classifier = classifier.to(args.device)   
dataset.class_weights = dataset.class_weights.to(args.device)
loss_func = nn.CrossEntropyLoss(dataset.class_weights)

dataset.set_split('test') 
batch_generator = generate_batches(dataset, 
                                   batch_size=args.batch_size, 
                                   device=args.device)
running_loss = 0.
running_acc = 0.
classifier.eval()

for batch_index, batch_dict in enumerate(batch_generator):
    # compute the output
    y_pred =  classifier(batch_dict['x_data'])
    
    # compute the loss
    loss = loss_func(y_pred, batch_dict['y_target'])
    loss_t = loss.item()
    running_loss += (loss_t - running_loss) / (batch_index + 1)

    # compute the accuracy
    acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
    running_acc += (acc_t - running_acc) / (batch_index + 1)

train_state['test_loss'] = running_loss
train_state['test_acc'] = running_acc


In [216]:
print("Test loss: {};".format(train_state['test_loss']))
print("Test Accuracy: {}".format(train_state['test_acc']))

Test loss: 0.8477242233497758;
Test Accuracy: 82.52232142857146


### Inference

In [None]:
# Preprocess the reviews
def preprocess_text(text):
    text = ' '.join(word.lower() for word in text.split(" "))
    text = re.sub(r"([.,!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
    return text

In [None]:
def predict_category(title, classifier, vectorizer, max_seq_length, max_word_length, max_sent_length):
    """Predict a News category for a new title
    
    Args:
        title (str): a raw title string
        classifier (NewsClassifier): an instance of the trained classifier
        vectorizer (NewsVectorizer): the corresponding vectorizer
        max_length (int): the max sequence length
            Note: CNNs are sensitive to the input data tensor size. 
                  This ensures to keep it the same size as the training data
    """
    title = preprocess_text(title)
    vectorized_title = \
        torch.tensor(vectorizer.vectorize(title, max_seq_length, max_word_length, max_sent_length))
    result = classifier(vectorized_title.unsqueeze(0), apply_softmax=True)
    probability_values, indices = result.max(dim=1)
    predicted_category = vectorizer.category_vocab.lookup_index(indices.item())

    return {'category': predicted_category, 
            'probability': probability_values.item()}

In [None]:
def get_samples():
    samples = {}
    for cat in dataset.val_df.category.unique():
        samples[cat] = dataset.val_df.title[dataset.val_df.category==cat].tolist()[:5]
    return samples

val_samples = get_samples()

In [None]:
#title = input("Enter a news title to classify: ")
classifier = classifier.to("cpu")

for truth, sample_group in val_samples.items():
    print(f"True Category: {truth}")
    print("="*30)
    for sample in sample_group:
        prediction = predict_category(sample, classifier, 
                                      vectorizer, dataset._max_seq_length + 1, 
                                      dataset._max_word_length + 1, 
                                      dataset._max_sent_length + 1)
        print("Prediction: {} (p={:0.2f})".format(prediction['category'],
                                                  prediction['probability']))
        print("\t + Sample: {}".format(sample))
    print("-"*30 + "\n")