In [None]:
import random
import numpy as np
import torch
import re
import glob
import io
import os
import time
import json
import itertools
import logging
import gensim
import csv
import warnings
import sys
warnings.filterwarnings("ignore")

from tqdm import tqdm
from collections import Counter
from statistics import mean 
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import pandas as pd 
import mpld3
from PIL import Image, ImageDraw, ImageFont

from transformers import BertTokenizer, BertModel

from torchtext import data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
import torch
torch.cuda.set_device(1)

%matplotlib inline
mpld3.enable_notebook()

SEED = 77
BATCH_SIZE = 16 # batch size

OUTPUT_DIM = 5
EMBEDDING_DIM = 100
N_EPOCHS = 100
TRAIN_RATIO = 0.8
POS_WEIGHT = torch.tensor([1, 7, 8, 4, 9])

MICRO = 'micro'
MACRO = 'macro'
FINE_TUNE_EMBEDDING = True  # fine-tune word embeddings?

SENTENCE_LIMIT = 100
WORD_LIMIT = 1000
MIN_WORD_COUNT = 5
DATA_FOLDER ='.\\HAN'
WORKER = 0
MAX_VOCAB_SIZE = 25000

WORD_RNN_SIZE = 250 
SENTENCE_RNN_SIZE = 200
WORD_RNN_LAYERS = 2 
SENTENCE_RNN_LAYERS = 2
WORD_ATTENTION_SIZE = 250
SENTENCE_ATTENTION_SIZE = 200
DROPOUT = 0.5 

ORIGINAL_TRAIN = r'..\Datasets\NTUH\train_preprocessing.txt'
TEST = r'..\Datasets\NTUH\test_preprocessing.txt'
TRAIN = r'..\Datasets\NTUH\train_preprocessing.txt'
VALID = r'..\Datasets\NTUH\validate_preprocessing.txt'

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Required Functions and Definitions

In [None]:
# Tokenizers
def sentence_segment(sentence):
    sentences = re.split(r'\s*<sep>(?:\s*<sep>)*\s*', sentence)
    filtered_sentence = list(filter(lambda sent: '<unk>' not in sent, sentences))
    return [sent for sent in filtered_sentence]
    return tokens

sent_tokenizer = sentence_segment
word_tokenizer = str.split

def preprocess(text):
    if isinstance(text, float):
        return ''

    return text.lower()

def read_all_text_from_corpus(tsv_file, sentence_limit, word_limit):
    docs = []
    labels = []
    word_counter = Counter()
    with io.open(tsv_file, 'r', encoding="utf-8") as file:
        for i, line in enumerate(tqdm(file)):
            sentences = list()
            pid, bh_text, ep_text, major_d, sc, bp, minor_d, de = preprocess(line).strip().split('\t')
            all_text = "%s <sep> %s" % (bh_text, ep_text)
            sentences.extend([s for s in sent_tokenizer(all_text)])
            
            words = list()
            for s in sentences[:sentence_limit]:
                w = word_tokenizer(s)[:word_limit]
                # If sentence is empty (due to ?)
                if len(w) == 0:
                    continue
                words.append(w)
                word_counter.update(w)
            # If all sentences were empty
            if len(words) == 0:
                continue

            labels.append([float(major_d), float(sc), float(bp), float(minor_d), float(de)]) 
            docs.append(words)

    return docs, labels, word_counter

def create_input_files(train, test, output_folder, sentence_limit, word_limit, read_tsv, 
                       min_word_count=5, save_word2vec_data = False, valid = None):
    # Read training data
    print(f'\nReading and preprocessing training data {train}...\n')
    train_docs, train_labels, word_counter = read_tsv(train, sentence_limit, word_limit)
    # Save text data for word2vec
    if save_word2vec_data:
        torch.save(train_docs, os.path.join(output_folder, 'word2vec_data.pth.tar'))
        print('\nText data for word2vec saved to %s.\n' % os.path.abspath(output_folder))

    # Create word map
    word_map = dict()
    word_map['<pad>'] = 0
    for word, count in word_counter.items():
        if count >= min_word_count:
            word_map[word] = len(word_map)
    word_map['<unk>'] = len(word_map)
    print('\nDiscarding words with counts less than %d, the size of the vocabulary is %d.\n' % (
        min_word_count, len(word_map)))

    with open(os.path.join(output_folder, 'word_map.json'), 'w') as j:
        json.dump(word_map, j)
    print('Word map saved to %s.\n' % os.path.abspath(output_folder))

    # Encode and pad
    print('Encoding and padding training data...\n')
    encoded_train_docs = list(map(lambda doc: list(
        map(lambda s: list(map(lambda w: word_map.get(w, word_map['<unk>']), s)) + [0] * (word_limit - len(s)),
            doc)) + [[0] * word_limit] * (sentence_limit - len(doc)), train_docs))
    sentences_per_train_document = list(map(lambda doc: len(doc), train_docs))
    words_per_train_sentence = list(
        map(lambda doc: list(map(lambda s: len(s), doc)) + [0] * (sentence_limit - len(doc)), train_docs))

    # Save
    print('Saving...\n')
    output_content = {'docs': encoded_train_docs,
                'labels': train_labels,
                'sentences_per_document': sentences_per_train_document,
                'words_per_sentence': words_per_train_sentence}
    torch.save(output_content,
               os.path.join(output_folder, 'TRAIN_data.pth.tar'))
    print('Encoded, padded training data saved to %s.\n' % os.path.abspath(output_folder))

    del train_docs, encoded_train_docs, train_labels, sentences_per_train_document, words_per_train_sentence

    print(f'Reading and preprocessing test data {test}...\n')
    test_docs, test_labels, _ = read_tsv(test, sentence_limit, word_limit)

    # Encode and pad
    print('\nEncoding and padding test data...\n')
    encoded_test_docs = list(map(lambda doc: list(
        map(lambda s: list(map(lambda w: word_map.get(w, word_map['<unk>']), s)) + [0] * (word_limit - len(s)),
            doc)) + [[0] * word_limit] * (sentence_limit - len(doc)), test_docs))
    sentences_per_test_document = list(map(lambda doc: len(doc), test_docs))
    words_per_test_sentence = list(
        map(lambda doc: list(map(lambda s: len(s), doc)) + [0] * (sentence_limit - len(doc)), test_docs))

    # Save
    print('Saving...\n')
    torch.save({'docs': encoded_test_docs,
                'labels': test_labels,
                'sentences_per_document': sentences_per_test_document,
                'words_per_sentence': words_per_test_sentence},
               os.path.join(output_folder, 'TEST_data.pth.tar'))
    print('Encoded, padded test data saved to %s.\n' % os.path.abspath(output_folder))

    del test_docs, encoded_test_docs, test_labels, sentences_per_test_document, words_per_test_sentence
    if valid:
        print(f'Reading and preprocessing validation data {valid}...\n')
        valid_docs, valid_labels, _ = read_tsv(valid, sentence_limit, word_limit)
        # Encode and pad
        print('\nEncoding and padding validation data...\n')
        encoded_valid_docs = list(map(lambda doc: list(
            map(lambda s: list(map(lambda w: word_map.get(w, word_map['<unk>']), s)) + [0] * (word_limit - len(s)),
            doc)) + [[0] * word_limit] * (sentence_limit - len(doc)), valid_docs))
        sentences_per_valid_document = list(map(lambda doc: len(doc), valid_docs))
        words_per_valid_sentence = list(
            map(lambda doc: list(map(lambda s: len(s), doc)) + [0] * (sentence_limit - len(doc)), valid_docs))
        # Save
        print('Saving...\n')
        torch.save({'docs': encoded_valid_docs,
                'labels': valid_labels,
                'sentences_per_document': sentences_per_valid_document,
                'words_per_sentence': words_per_valid_sentence},
                   os.path.join(output_folder, 'VALID_data.pth.tar'))
        print('Encoded, padded valid data saved to %s.\n' % os.path.abspath(output_folder))
    print('All done!\n')
    
    
class NTUH_HANDataset(Dataset):
    diagnosis_types = ['major_depressive', 'schizophrenia', 'biploar', 'minor_depressive', 'dementia']

    def __init__(self, data_folder, split):
        split = split.upper()
        self.split = split

        # Load data
        self.data = torch.load(os.path.join(data_folder, split + '_data.pth.tar'))

    def __getitem__(self, i):
        return torch.LongTensor(self.data['docs'][i]), \
               torch.LongTensor([self.data['sentences_per_document'][i]]), \
               torch.LongTensor(self.data['words_per_sentence'][i]), \
               torch.FloatTensor(self.data['labels'][i])

    def __len__(self):
        return len(self.data['labels'])    
    
class NTUHDataset(data.Dataset):
    #urls = ['Datasets\\NTUH\\corpus.txt']
    name = 'ntuh'
    dirname = 'ntuh'
    diagnosis_types = ['major_depressive', 'schizophrenia', 'biploar', 'minor_depressive', 'dementia']
    
    @staticmethod
    def sort_key(ex):
        return len(ex.all_text) # TODO add ep_text?

    def __init__(self, path, id_field, bh_text_field, ep_text_field, all_text_field,
                 major_label_field, sch_label_field, bipolar_label_field, minor_label_field, dementia_label_field,
                 **kwargs):
        fields = [('patient_id', id_field), 
                  ('bh_text', bh_text_field),
                  ('ep_text', ep_text_field),
                  ('all_text', all_text_field),
                  ('major_depressive', major_label_field),
                  ('schizophrenia', sch_label_field),
                  ('biploar', bipolar_label_field),
                  ('minor_depressive', minor_label_field),
                  ('dementia', dementia_label_field)]
        examples = []
        
        for fname in glob.iglob(path + '.txt'):
            with io.open(fname, 'r', encoding="utf-8") as f:
                for line in f:
                    pid, bh_text, ep_text, major_d, sc, bp, minor_d, de = line.strip().split('\t')
                    all_text = "%s <sep> %s" % (bh_text, ep_text)
                    examples.append(data.Example.fromlist([pid, bh_text, ep_text, all_text, major_d, sc, bp, minor_d, de], 
                                                          fields))
        super(NTUHDataset, self).__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, id_field,
               bh_text_field, ep_text_field, all_text_field,
               major_label_field, sch_label_field, bipolar_label_field, minor_label_field, dementia_label_field,
               root='..\\Datasets\\NTUH',
               train='train_preprocessing', test='test_preprocessing', **kwargs):
        return super(NTUHDataset, cls).splits(
            path = root, root=root, id_field=id_field,
            bh_text_field = bh_text_field, ep_text_field = ep_text_field, all_text_field = all_text_field, 
            major_label_field = major_label_field, sch_label_field = sch_label_field, 
            bipolar_label_field = bipolar_label_field, minor_label_field = minor_label_field, 
            dementia_label_field = dementia_label_field,
            train=train, validation=None, test=test, **kwargs)
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def create_input_files_for_glove(glove_stoi, train, test, output_folder, sentence_limit, word_limit, read_tsv, 
                       min_word_count=5, valid = None):
    # Read training data
    print(f'\nReading and preprocessing training data {train}...\n')
    train_docs, train_labels, word_counter = read_tsv(train, sentence_limit, word_limit)
    
    # Create word map
    word_map = dict(glove_stoi)
    print('\nThe size of the vocabulary is %d.\n' % (len(word_map)))

    with open(os.path.join(output_folder, 'glove_word_map.json'), 'w') as j:
        json.dump(word_map, j)
    print('Word map saved to %s.\n' % os.path.abspath(output_folder))

    # Encode and pad
    print('Encoding and padding training data...\n')
    encoded_train_docs = list(map(lambda doc: list(
        map(lambda s: list(map(lambda w: word_map.get(w, UNK_IDX), s)) + [0] * (word_limit - len(s)),
            doc)) + [[0] * word_limit] * (sentence_limit - len(doc)), train_docs))
    sentences_per_train_document = list(map(lambda doc: len(doc), train_docs))
    words_per_train_sentence = list(
        map(lambda doc: list(map(lambda s: len(s), doc)) + [0] * (sentence_limit - len(doc)), train_docs))

    # Save
    print('Saving...\n')
    # Because of the large data, saving as a JSON can be very slow
    output_content = {'docs': encoded_train_docs,
                'labels': train_labels,
                'sentences_per_document': sentences_per_train_document,
                'words_per_sentence': words_per_train_sentence}
    torch.save(output_content,
               os.path.join(output_folder, 'GLOVE_TRAIN_data.pth.tar'))
    print('Encoded, padded training data (GLOVE_TRAIN_data.pth.tar) saved to %s.\n' % os.path.abspath(output_folder))

    del train_docs, encoded_train_docs, train_labels, sentences_per_train_document, words_per_train_sentence

    # Read test data
    print(f'Reading and preprocessing test data {test}...\n')
    test_docs, test_labels, _ = read_tsv(test, sentence_limit, word_limit)

    # Encode and pad
    print('\nEncoding and padding test data...\n')
    encoded_test_docs = list(map(lambda doc: list(
        map(lambda s: list(map(lambda w: word_map.get(w, UNK_IDX), s)) + [0] * (word_limit - len(s)),
            doc)) + [[0] * word_limit] * (sentence_limit - len(doc)), test_docs))
    sentences_per_test_document = list(map(lambda doc: len(doc), test_docs))
    words_per_test_sentence = list(
        map(lambda doc: list(map(lambda s: len(s), doc)) + [0] * (sentence_limit - len(doc)), test_docs))

    # Save
    print('Saving...\n')
    torch.save({'docs': encoded_test_docs,
                'labels': test_labels,
                'sentences_per_document': sentences_per_test_document,
                'words_per_sentence': words_per_test_sentence},
               os.path.join(output_folder, 'GLOVE_TEST_data.pth.tar'))
    print('Encoded, padded test data (GLOVE_TEST_data.pth.tar) saved to %s.\n' % os.path.abspath(output_folder))

    del test_docs, encoded_test_docs, test_labels, sentences_per_test_document, words_per_test_sentence
    if valid:
        print(f'Reading and preprocessing validation data {valid}...\n')
        valid_docs, valid_labels, _ = read_tsv(valid, sentence_limit, word_limit)
        # Encode and pad
        print('\nEncoding and padding validation data...\n')
        encoded_valid_docs = list(map(lambda doc: list(
            map(lambda s: list(map(lambda w: word_map.get(w, UNK_IDX), s)) + [0] * (word_limit - len(s)),
            doc)) + [[0] * word_limit] * (sentence_limit - len(doc)), valid_docs))
        sentences_per_valid_document = list(map(lambda doc: len(doc), valid_docs))
        words_per_valid_sentence = list(
            map(lambda doc: list(map(lambda s: len(s), doc)) + [0] * (sentence_limit - len(doc)), valid_docs))
        # Save
        print('Saving...\n')
        torch.save({'docs': encoded_valid_docs,
                'labels': valid_labels,
                'sentences_per_document': sentences_per_valid_document,
                'words_per_sentence': words_per_valid_sentence},
                   os.path.join(output_folder, 'GLOVE_VALID_data.pth.tar'))
        print('Encoded, padded valid data (GLOVE_VALID_data.pth.tar) saved to %s.\n' % os.path.abspath(output_folder))
    print('All done!\n')
    
def load_glove_embeddings():
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    
    ID_TEXT = data.Field(batch_first = True, lower = True)
    BH_TEXT = data.Field(batch_first = True, lower = True)
    EP_TEXT = data.Field(batch_first = True, lower = True)
    ALL_TEXT = data.Field(batch_first = True, lower = True)

    MAJ_LABEL = data.LabelField(dtype = torch.float)
    SCH_LABEL = data.LabelField(dtype = torch.float)
    BIP_LABEL = data.LabelField(dtype = torch.float)
    MIN_LABEL = data.LabelField(dtype = torch.float)
    DEM_LABEL = data.LabelField(dtype = torch.float)
    full_train_data, test_data = NTUHDataset.splits(ID_TEXT, BH_TEXT, EP_TEXT, ALL_TEXT, 
                                               MAJ_LABEL, SCH_LABEL, BIP_LABEL, MIN_LABEL, DEM_LABEL)
    ALL_TEXT.build_vocab(full_train_data, max_size = MAX_VOCAB_SIZE, vectors = "glove.6B.300d", 
                         unk_init = torch.Tensor.normal_)

    print("\nEmbedding length is %d.\n" % len(ALL_TEXT.vocab.vectors))

    # Create tensor to hold embeddings for words that are in-corpus
    embeddings = ALL_TEXT.vocab.vectors

    return embeddings, ALL_TEXT.vocab.vectors.shape[1], ALL_TEXT.vocab, ALL_TEXT.pad_token, ALL_TEXT.unk_token

def analysis_plotter(fig, ax, train, valid, title, param_dict1, param_dict2):
    out = ax.plot(train, **param_dict1)
    out = ax.plot(valid, **param_dict2)
    ax.title.set_text(title)
    ax.legend()
    pv = float('inf')
    x = []
    y = []
    for k, v in enumerate(valid):
        if v > pv:
            x.append(k)
            y.append(v)
        pv = v
    scatter = ax.scatter(x, y)
    labels = []
    for x, y in zip(x,y):
        labels.append(f'{x}: {y}')
    tooltip = mpld3.plugins.PointLabelTooltip(scatter, labels=labels)
    mpld3.plugins.connect(fig, tooltip)    

In [None]:
checkpoint = None  # path to model checkpoint, None if none

def init_embedding(input_embedding):
    """
    Initialize embedding tensor with values from the uniform distribution.
    :param input_embedding: embedding tensor
    """
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    bias = np.sqrt(3.0 / input_embedding.size(1))
    nn.init.uniform_(input_embedding, -bias, bias)

with open(os.path.join(DATA_FOLDER, 'word_map.json'), 'r') as j:
    word_map = json.load(j)

In [None]:
names =['ID', 'BH Text', 'EP Text', 'Major Depressive', 'Schizophrenia', 'Biploar', 'Minor Depressive', 'Dementia']

label_map = {k: v for v, k in enumerate(names[3:])}
rev_label_map = {v: k for k, v in label_map.items()}
n_classes = len(label_map)

print(label_map)
print(rev_label_map)

## Model Definitions

In [None]:
class HierarchialAttentionNetwork(nn.Module):
    def __init__(self, n_classes, vocab_size, emb_size, word_rnn_size, sentence_rnn_size, word_rnn_layers,
                 sentence_rnn_layers, word_att_size, sentence_att_size, dropout=0.5):
        super(HierarchialAttentionNetwork, self).__init__()

        self.sentence_attention = SentenceAttention(vocab_size, emb_size, word_rnn_size, sentence_rnn_size,
                                                    word_rnn_layers, sentence_rnn_layers, word_att_size,
                                                    sentence_att_size, dropout)
        self.fc = nn.Linear(2 * sentence_rnn_size, n_classes)

        self.dropout = nn.Dropout(dropout)

    def forward(self, documents, sentences_per_document, words_per_sentence):
        document_embeddings, word_alphas, sentence_alphas = self.sentence_attention(documents, sentences_per_document,
                                                                                    words_per_sentence)  
        scores = self.fc(self.dropout(document_embeddings))  # (n_documents, n_classes)

        return scores, word_alphas, sentence_alphas


class SentenceAttention(nn.Module):
    def __init__(self, vocab_size, emb_size, word_rnn_size, sentence_rnn_size, word_rnn_layers, sentence_rnn_layers,
                 word_att_size, sentence_att_size, dropout):
        super(SentenceAttention, self).__init__()

        self.word_attention = WordAttention(vocab_size, emb_size, word_rnn_size, word_rnn_layers, word_att_size,
                                            dropout)

        self.sentence_rnn = nn.GRU(2 * word_rnn_size, sentence_rnn_size, num_layers=sentence_rnn_layers,
                                   bidirectional=True, dropout=dropout, batch_first=True)

        self.sentence_attention = nn.Linear(2 * sentence_rnn_size, sentence_att_size)

        self.sentence_context_vector = nn.Linear(sentence_att_size, 1,
                                                 bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, documents, sentences_per_document, words_per_sentence):
        packed_sentences = pack_padded_sequence(documents,
                                                lengths=sentences_per_document.tolist(), 
                                                batch_first=True,
                                                enforce_sorted=False)  
        packed_words_per_sentence = pack_padded_sequence(words_per_sentence,
                                                         lengths=sentences_per_document.tolist(),
                                                         batch_first=True,
                                                         enforce_sorted=False)  
        sentences, word_alphas = self.word_attention(packed_sentences.data,
                                                     packed_words_per_sentence.data)
        sentences = self.dropout(sentences)
        
        packed_sentences, _ = self.sentence_rnn(PackedSequence(data=sentences,
                                                               batch_sizes=packed_sentences.batch_sizes,
                                                               sorted_indices=packed_sentences.sorted_indices,
                                                               unsorted_indices=packed_sentences.unsorted_indices)) 

        att_s = self.sentence_attention(packed_sentences.data)
        att_s = torch.tanh(att_s)
        att_s = self.sentence_context_vector(att_s).squeeze(1) 

        max_value = att_s.max() 
        att_s = torch.exp(att_s - max_value) 

        att_s, _ = pad_packed_sequence(PackedSequence(data=att_s,
                                                      batch_sizes=packed_sentences.batch_sizes,
                                                      sorted_indices=packed_sentences.sorted_indices,
                                                      unsorted_indices=packed_sentences.unsorted_indices),
                                       batch_first=True)

        sentence_alphas = att_s / torch.sum(att_s, dim=1, keepdim=True) 

        documents, _ = pad_packed_sequence(packed_sentences,
                                           batch_first=True) 

        documents = documents * sentence_alphas.unsqueeze(2)
        documents = documents.sum(dim=1) 

        word_alphas, _ = pad_packed_sequence(PackedSequence(data=word_alphas,
                                                            batch_sizes=packed_sentences.batch_sizes,
                                                            sorted_indices=packed_sentences.sorted_indices,
                                                            unsorted_indices=packed_sentences.unsorted_indices),
                                             batch_first=True)

        return documents, word_alphas, sentence_alphas

class WordAttention(nn.Module):
    def __init__(self, vocab_size, emb_size, word_rnn_size, word_rnn_layers, word_att_size, dropout):
        super(WordAttention, self).__init__()

        self.embeddings = nn.Embedding(vocab_size, emb_size)

        self.word_rnn = nn.GRU(emb_size, word_rnn_size, num_layers=word_rnn_layers, bidirectional=True,
                               dropout=dropout, batch_first=True)

        self.word_attention = nn.Linear(2 * word_rnn_size, word_att_size)

        self.word_context_vector = nn.Linear(word_att_size, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def init_embeddings(self, embeddings):
        self.embeddings.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=False):
        for p in self.embeddings.parameters():
            p.requires_grad = fine_tune

    def forward(self, sentences, words_per_sentence):
        sentences = self.dropout(self.embeddings(sentences)) 
        
        packed_words = pack_padded_sequence(sentences,
                                            lengths=words_per_sentence.tolist(),
                                            batch_first=True,
                                            enforce_sorted=False) 
        
        packed_words, _ = self.word_rnn(
            packed_words)
        
        att_w = self.word_attention(packed_words.data) 
        att_w = torch.tanh(att_w)
        att_w = self.word_context_vector(att_w).squeeze(1)  # (n_words)

        max_value = att_w.max()
        att_w = torch.exp(att_w - max_value)
        
        att_w, _ = pad_packed_sequence(PackedSequence(data=att_w,
                                                      batch_sizes=packed_words.batch_sizes,
                                                      sorted_indices=packed_words.sorted_indices,
                                                      unsorted_indices=packed_words.unsorted_indices),
                                       batch_first=True) 
        
        word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True)
        
        
        sentences, _ = pad_packed_sequence(packed_words,
                                           batch_first=True)

        sentences = sentences * word_alphas.unsqueeze(2) 
        sentences = sentences.sum(dim=1)
        
        return sentences, word_alphas

In [None]:
def f_measure(predictions, labels):
    diagnoses = {}
    MICRO = 'micro'
    MACRO = 'macro'
    predicts = []
    diagnoses[MICRO] = {}
    
    rounded_preds = torch.round(torch.sigmoid(predictions))
    predicts.extend(rounded_preds.data.tolist())
    
    for index, value in enumerate(rounded_preds):
        for did, dvalue in enumerate(rounded_preds[index]):
            v = dvalue.item()                    
            if v == 1:
                if dvalue == labels[index, did]:
                    if did not in diagnoses:
                        diagnoses[did] = {}                                
                    diagnoses[did]['tp'] = diagnoses[did].get('tp', 0) + 1
                    diagnoses[MICRO]['tp'] = diagnoses[MICRO].get('tp', 0) + 1
                else:
                    if did not in diagnoses:
                        diagnoses[did] = {}
                    diagnoses[did]['fp'] = diagnoses[did].get('fp', 0) + 1
                    diagnoses[MICRO]['fp'] = diagnoses[MICRO].get('fp', 0) + 1
            elif v == 0:
                if 1 == labels[index, did].item():
                    if did not in diagnoses:
                        diagnoses[did] = {}
                    diagnoses[did]['fn'] = diagnoses[did].get('fn', 0) + 1
                    diagnoses[MICRO]['fn'] = diagnoses[MICRO].get('fn', 0) + 1
    diagnoses[MACRO] = {}
    for d in diagnoses:        
        if d is MACRO:
            continue
        try:
            diagnoses[d]['p']=diagnoses[d].get('tp', 0)/(diagnoses[d].get('tp', 0)+diagnoses[d].get('fp', 0))            
        except:            
            diagnoses[d]['p']=0.0
        if d is not MICRO:
                diagnoses[MACRO]['p']=diagnoses[MACRO].get('p', 0.0)+diagnoses[d]['p']                
            
        try:
            diagnoses[d]['r']=diagnoses[d].get('tp', 0)/(diagnoses[d].get('tp', 0)+diagnoses[d].get('fn', 0))            
        except:
            diagnoses[d]['r']=0.0
        if d is not MICRO:
            diagnoses[MACRO]['r']=diagnoses[MACRO].get('r', 0.0)+diagnoses[d]['r']
        
        try:
            diagnoses[d]['f']=2/(1/diagnoses[d]['p']+1/diagnoses[d]['r'])            
        except:
            diagnoses[d]['f']=0.0
        if d is not MICRO:
                diagnoses[MACRO]['f']=diagnoses[MACRO].get('f', 0.0)+diagnoses[d]['f']

    diagnoses[MACRO]['f']=diagnoses[MACRO]['f']/float(len(diagnoses)-2)
    diagnoses[MACRO]['p']=diagnoses[MACRO]['p']/float(len(diagnoses)-2)
    diagnoses[MACRO]['r']=diagnoses[MACRO]['r']/float(len(diagnoses)-2)
    return diagnoses, predicts


In [None]:
class AverageMeter(object):
    
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

def update_fscores(new, overall):
    MICRO = 'micro'
    MACRO = 'macro'
    
    for k in new:
        if k not in overall:
            overall[k] = {}
        overall[k]['tp'] = overall[k].get('tp', 0) + new[k].get('tp', 0)
        overall[k]['fp'] = overall[k].get('fp', 0) + new[k].get('fp', 0)
        overall[k]['fn'] = overall[k].get('fn', 0) + new[k].get('fn', 0)
        overall[MICRO]['tp'] = overall[MICRO].get('tp', 0) + new[k].get('tp', 0)
        overall[MICRO]['fp'] = overall[MICRO].get('fp', 0) + new[k].get('fp', 0)
        overall[MICRO]['fn'] = overall[MICRO].get('fn', 0) + new[k].get('fn', 0)
        
    overall[MACRO] = {}
    for d in overall:        
        if d is MACRO:
            continue
        try:
            overall[d]['p']=overall[d].get('tp', 0)/(overall[d].get('tp', 0)+overall[d].get('fp', 0))            
        except:            
            overall[d]['p']=0.0
        if d is not MICRO:
            overall[MACRO]['p']=overall[MACRO].get('p', 0.0)+overall[d]['p']                
            
        try:
            overall[d]['r']=overall[d].get('tp', 0)/(overall[d].get('tp', 0)+overall[d].get('fn', 0))            
        except:
            overall[d]['r']=0.0
        if d is not MICRO:
            overall[MACRO]['r']=overall[MACRO].get('r', 0.0)+overall[d]['r']
        
        try:
            overall[d]['f']=2/(1/overall[d]['p']+1/overall[d]['r'])            
        except:
            overall[d]['f']=0.0
        if d is not MICRO:
                overall[MACRO]['f']=overall[MACRO].get('f', 0.0)+overall[d]['f']

    overall[MACRO]['f']=overall[MACRO]['f']/float(len(overall)-2)
    overall[MACRO]['p']=overall[MACRO]['p']/float(len(overall)-2)
    overall[MACRO]['r']=overall[MACRO]['r']/float(len(overall)-2)
    return overall
    
def train(train_loader, model, criterion, optimizer, grad_clip): #, epoch, total_epoch, print_freq = 100):
    model.train()

    losses = AverageMeter() 
    accs = {}

    # Batches
    for i, (documents, sentences_per_document, words_per_sentence, labels) in enumerate(train_loader):
        # Back prop.
        optimizer.zero_grad()        

        documents = documents.to(device) 
        
        sentences_per_document = sentences_per_document.squeeze(1).to(device) 
        
        words_per_sentence = words_per_sentence.to(device)  
        
        labels = labels.to(device)  
        
        scores, word_alphas, sentence_alphas = model(documents, sentences_per_document,
                                                     words_per_sentence)  
        loss = criterion(scores, labels) 

        loss.backward()

        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)

        optimizer.step()

        fscores, _ = f_measure(scores, labels)            
        losses.update(loss.item(), labels.size(0))
        accs = update_fscores(fscores, accs)
        
    return losses.avg, accs['micro']["f"]        

def save_checkpoint(epoch, model, optimizer, word_map, data_folder, model_name):
    state = {'epoch': epoch,
             'model': model,
             'optimizer': optimizer,
             'word_map': word_map}
    filename = os.path.join(data_folder, model_name)#'checkpoint_han.pth.tar'
    torch.save(state, filename)
    
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def train_epoch(start_epoch, epochs, data_loader, model, criterion, optimizer, word_map, model_name, grad_clip, valid_iterator = None, 
                interval = 50, early_stop = False, period = 20, gap = 0.005, threshold = 0.5):
    best_valid_loss = float('inf')
    best_valid_fscore = 0
    train_losses = []
    valid_losses = []
    train_accs = []
    valid_accs = []
    observed_time = 0
    
    # Epochs
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        # One epoch's training
        #try:
        train_loss, train_acc = train(train_loader=data_loader, model=model, 
                                      criterion=criterion, optimizer=optimizer, grad_clip=grad_clip)#, epoch=epoch, 
              #total_epoch = epochs)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_iterator:
            valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)#, model_type)
            valid_losses.append(valid_loss)
            valid_accs.append(valid_acc)
        else:
            valid_loss = 0 

        if (epoch + 1) % interval == 0:
            print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
            print(f'\tTrain Loss: {train_loss:.3f} | Train micro-F-score: {train_acc*100:.2f}%')
            if valid_iterator:
                print(f'\t Val. Loss: {valid_loss:.3f} |  Val. micro-F-score: {valid_acc*100:.2f}%')
        elif epoch == epochs - 1:
            print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
            print(f'\tTrain Loss: {train_loss:.3f} | Train micro-F-score: {train_acc*100:.2f}%')
            if valid_iterator:
                print(f'\t Val. Loss: {valid_loss:.3f} |  Val. micro-F-score: {valid_acc*100:.2f}%')

        # Save checkpoint
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            save_checkpoint(epoch, model, optimizer, word_map, DATA_FOLDER, model_name + '_loss.pt')

        if early_stop and best_valid_fscore > threshold and best_valid_fscore - valid_acc > gap:
            observed_time += 1
            print(f'\rBest validation F-measure: {best_valid_fscore:.3f}/Current F-measure: {valid_acc:.3f} [Times: {observed_time}/{period}]')  
            if observed_time >= period:
                print(f'Early stop at epoch {epoch+1:02}.')
                break                        
        if valid_acc > best_valid_fscore:
            best_valid_fscore = valid_acc
            save_checkpoint(epoch, model, optimizer, word_map, DATA_FOLDER, model_name + '_fscore.pt')
            observed_time = 0
        #except RuntimeError as e:
        #    print(f'Runtime error at epoch {epoch+1:02}: {str(e)}')
        #    save_checkpoint(epoch, model, optimizer, word_map, DATA_FOLDER, model_name + '_except.pt')
        #    return train_losses, valid_losses, train_accs, valid_accs
            
    # Decay learning rate every epoch
    #adjust_learning_rate(optimizer, 0.1)

    #return model
    return train_losses, valid_losses, train_accs, valid_accs   


def evaluate(model, iterator, criterion):#, model_type):    
    epoch_loss = 0
    epoch_fscore = 0
    model.eval()
    
    with torch.no_grad():    
        for i, (documents, sentences_per_document, words_per_sentence, labels) in enumerate(iterator):
            documents = documents.to(device)
            sentences_per_document = sentences_per_document.squeeze(1).to(device) 
            words_per_sentence = words_per_sentence.to(device)
            labels = labels.to(device)  # (batch_size)
            
            # Forward prop.
            scores, word_alphas, sentence_alphas = model(documents, sentences_per_document,
                    words_per_sentence) 
            
            loss = criterion(scores, labels) 
            
            fscores, _ = f_measure(scores, labels)                    
        
            epoch_loss += loss.item()
            epoch_fscore += fscores['micro']["f"]
        
    return epoch_loss / len(iterator), epoch_fscore / len(iterator)

def test(model, iterator, model_type, model_path = None):    
    epoch_loss = 0
    epoch_acc = 0
    if model_path:
        model.load_state_dict(torch.load(model_path))

    model.eval()
    
    diagnoses = {}
    predicts = []
    MICRO = 'micro'
    MACRO = 'macro'
    diagnoses[MICRO] = {}
    with torch.no_grad():     
        for i, (documents, sentences_per_document, words_per_sentence, labels) in enumerate(
            tqdm(iterator, desc='Evaluating')):
            
            documents = documents.to(device)  # (batch_size, sentence_limit, word_limit)
            sentences_per_document = sentences_per_document.squeeze(1).to(device)  # (batch_size)
            words_per_sentence = words_per_sentence.to(device)  # (batch_size, sentence_limit)
            labels = labels.to(device)  # (batch_size)
            
            # Forward prop.
            scores, word_alphas, sentence_alphas = model(documents, sentences_per_document,
                    words_per_sentence)  # (n_documents, n_classes), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), (n_documents, max_doc_len_in_batch)
            
            rounded_preds = torch.round(torch.sigmoid(scores))
            predicts.extend(rounded_preds.data.tolist())
            
            for index, value in enumerate(rounded_preds):
                for did, dvalue in enumerate(rounded_preds[index]):
                    v = dvalue.item()                    
                    if v == 1:
                        if dvalue == labels[index, did]:
                            if did not in diagnoses:
                                diagnoses[did] = {}                                
                            diagnoses[did]['tp'] = diagnoses[did].get('tp', 0) + 1
                            diagnoses[MICRO]['tp'] = diagnoses[MICRO].get('tp', 0) + 1 
                        else:
                            if did not in diagnoses:
                                diagnoses[did] = {}
                            diagnoses[did]['fp'] = diagnoses[did].get('fp', 0) + 1
                            diagnoses[MICRO]['fp'] = diagnoses[MICRO].get('fp', 0) + 1
                    elif v == 0:
                        if 1 == labels[index, did].item():
                            if did not in diagnoses:
                                diagnoses[did] = {}
                            diagnoses[did]['fn'] = diagnoses[did].get('fn', 0) + 1
                            diagnoses[MICRO]['fn'] = diagnoses[MICRO].get('fn', 0) + 1
                        else:
                            if did not in diagnoses:
                                diagnoses[did] = {}
                            diagnoses[did]['tn'] = diagnoses[did].get('tn', 0) + 1
                            diagnoses[MICRO]['tn'] = diagnoses[MICRO].get('tn', 0) + 1
    diagnoses[MACRO] = {}
    for d in diagnoses:        
        if d is MACRO:
            continue
        try:
            diagnoses[d]['p']=diagnoses[d].get('tp', 0)/(diagnoses[d].get('tp', 0)+diagnoses[d].get('fp', 0))
            if d is not MICRO:
                diagnoses[MACRO]['p']=diagnoses[MACRO].get('p', 0.0)+diagnoses[d]['p']                
        except:            
            diagnoses[d]['p']=0.0
            
        try:
            diagnoses[d]['r']=diagnoses[d].get('tp', 0)/(diagnoses[d].get('tp', 0)+diagnoses[d].get('fn', 0))
            if d is not MICRO:
                diagnoses[MACRO]['r']=diagnoses[MACRO].get('r', 0.0)+diagnoses[d]['r']
        except:
            diagnoses[d]['r']=0.0
        
        try:
            diagnoses[d]['f']=2/(1/diagnoses[d]['p']+1/diagnoses[d]['r'])
            if d is not MICRO:
                diagnoses[MACRO]['f']=diagnoses[MACRO].get('f', 0.0)+diagnoses[d]['f']
        except:
            diagnoses[d]['f']=0.0
    diagnoses[MACRO]['f']=diagnoses[MACRO]['f']/float(len(diagnoses)-2)
    diagnoses[MACRO]['p']=diagnoses[MACRO]['p']/float(len(diagnoses)-2)
    diagnoses[MACRO]['r']=diagnoses[MACRO]['r']/float(len(diagnoses)-2)
    return diagnoses, predicts

def initialize_embeddings(embedding_dim, word_map):
    print("\nEmbedding length is %d.\n" % embedding_dim)

    # Create tensor to hold embeddings
    embeddings = torch.FloatTensor(len(word_map), embedding_dim)
    init_embedding(embeddings)

    print("Done.\n Embedding vocabulary: %d.\n" % len(word_map))

    return embeddings, embedding_dim

# Dataset Analysis
## Training Set

In [None]:
ntuhdataset = pd.read_csv('../Datasets/NTUH/train_preprocessing.txt', sep ='\t', names = names)
bh_text = ntuhdataset['BH Text'].apply(lambda x: len(re.sub(r'\s+', ' ', re.sub(r'<[^>]+>', '', str(x))).split()))
plt.figure(figsize=(5,5))
avg_bh_text = mean(bh_text)
plt.title(f'BH Token Length Distribution: Average legnth: {avg_bh_text}')
plt.hist(bh_text, bins = 50)
plt.show()
ep_text = ntuhdataset['EP Text'].apply(lambda x: len(re.sub(r'\s+', ' ', re.sub(r'<[^>]+>', '', str(x))).split()))
avg_ep_text = mean(ep_text)
plt.figure(figsize=(5,5))
plt.title(f'EP Token Length Distribution: Average length: {avg_ep_text}')
plt.hist(ep_text, bins = 50)
plt.show()

In [None]:
bh_text = ntuhdataset['BH Text'].apply(lambda x: len(re.sub(r'\s+', ' ', str(x)).split('<sep>')))
plt.figure(figsize=(5,5))
avg_bh_text = mean(bh_text)
plt.title(f'BH Sentence Number Distribution: Average legnth: {avg_bh_text}')
plt.hist(bh_text, bins = 50)
plt.show()
ep_text = ntuhdataset['EP Text'].apply(lambda x: len(re.sub(r'\s+', ' ', str(x)).split('<sep>')))
avg_ep_text = mean(ep_text)
plt.figure(figsize=(5,5))
plt.title(f'EP Sentence Number Distribution: Average length: {avg_ep_text}')
plt.hist(ep_text, bins = 50)
plt.show()

# Preprocessing

Create validation set

In [None]:
class NTUHDataset(data.Dataset):
    #urls = ['Datasets\\NTUH\\corpus.txt']
    name = 'ntuh'
    dirname = 'ntuh'
    diagnosis_types = ['major_depressive', 'schizophrenia', 'biploar', 'minor_depressive', 'dementia']
    
    @staticmethod
    def sort_key(ex):
        return len(ex.all_text) # TODO add ep_text?

    def __init__(self, path, id_field, bh_text_field, ep_text_field, all_text_field,
                 major_label_field, sch_label_field, bipolar_label_field, minor_label_field, dementia_label_field,
                 **kwargs):
        fields = [('patient_id', id_field), 
                  ('bh_text', bh_text_field),
                  ('ep_text', ep_text_field),
                  ('all_text', all_text_field),
                  ('major_depressive', major_label_field),
                  ('schizophrenia', sch_label_field),
                  ('biploar', bipolar_label_field),
                  ('minor_depressive', minor_label_field),
                  ('dementia', dementia_label_field)]
        examples = []
        
        for fname in glob.iglob(path + '.txt'):
            with io.open(fname, 'r', encoding="utf-8") as f:
                for line in f:
                    pid, bh_text, ep_text, major_d, sc, bp, minor_d, de = line.strip().split('\t')
                    all_text = "%s <sep> %s" % (bh_text, ep_text)
                    examples.append(data.Example.fromlist([pid, bh_text, ep_text, all_text, major_d, sc, bp, minor_d, de], 
                                                          fields))
        super(NTUHDataset, self).__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, id_field,
               bh_text_field, ep_text_field, all_text_field,
               major_label_field, sch_label_field, bipolar_label_field, minor_label_field, dementia_label_field,
               root='..\\Datasets\\NTUH',
               train='train_preprocessing', test='test_preprocessing', **kwargs):
        """Create dataset objects for splits of the NTUH dataset.
        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            root: Root dataset storage directory. Default is '.data'.
            train: The directory that contains the training examples
            test: The directory that contains the test examples
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        return super(NTUHDataset, cls).splits(
            path = root, root=root, id_field=id_field,
            bh_text_field = bh_text_field, ep_text_field = ep_text_field, all_text_field = all_text_field, 
            major_label_field = major_label_field, sch_label_field = sch_label_field, 
            bipolar_label_field = bipolar_label_field, minor_label_field = minor_label_field, 
            dementia_label_field = dementia_label_field,
            train=train, validation=None, test=test, **kwargs)

In [None]:
ID_TEXT = data.Field(batch_first = True)
BH_TEXT = data.Field(batch_first = True)
EP_TEXT = data.Field(batch_first = True)
ALL_TEXT = data.Field(batch_first = True)

MAJ_LABEL = data.LabelField(dtype = torch.float)
SCH_LABEL = data.LabelField(dtype = torch.float)
BIP_LABEL = data.LabelField(dtype = torch.float)
MIN_LABEL = data.LabelField(dtype = torch.float)
DEM_LABEL = data.LabelField(dtype = torch.float)
full_train_data, test_data = NTUHDataset.splits(ID_TEXT, BH_TEXT, EP_TEXT, ALL_TEXT, 
                                           MAJ_LABEL, SCH_LABEL, BIP_LABEL, MIN_LABEL, DEM_LABEL)
train_data, valid_data = full_train_data.split(random_state = random.seed(SEED), split_ratio = TRAIN_RATIO)

## Dump Split Training and Validation Data

In [None]:
with open(TRAIN, 'wt', encoding="utf-8") as out_file:
    for example in train_data.examples:
        out_file.write('{example.patient_id[0]}\t{0}\t{1}\t{example.major_depressive}\t{example.schizophrenia}\t{example.biploar}\t{example.minor_depressive}\t{example.dementia}\n'
                            .format(' '.join(example.bh_text), ' '.join(example.ep_text), example=example))
        
with open(VALID, 'wt', encoding="utf-8") as out_file:
    for example in valid_data.examples:
        out_file.write('{example.patient_id[0]}\t{0}\t{1}\t{example.major_depressive}\t{example.schizophrenia}\t{example.biploar}\t{example.minor_depressive}\t{example.dementia}\n'
                            .format(' '.join(example.bh_text), ' '.join(example.ep_text), example=example)) 

In [None]:
ntuhdataset = pd.read_csv(TRAIN, sep ='\t', names = names)
ntuhdataset

In [None]:
ntuhdataset = pd.read_csv(VALID, sep ='\t', names = names)
ntuhdataset

## Create Input Files

In [None]:
create_input_files(TRAIN, TEST, DATA_FOLDER, SENTENCE_LIMIT, WORD_LIMIT, read_all_text_from_corpus,
                       MIN_WORD_COUNT, True, VALID)

# Randmly Initialized Word Embedding

In [None]:
embeddings, emb_size = initialize_embeddings(EMBEDDING_DIM, word_map)

In [None]:
print(embeddings.shape)
print(embeddings)

In [None]:
print(f'<pad>: {word_map["<pad>"]}\n<unk>: {word_map["<unk>"]}\n')

In [None]:
embeddings

In [None]:
embeddings[word_map["<pad>"]] = torch.zeros(EMBEDDING_DIM)
embeddings[word_map["<unk>"]] = torch.zeros(EMBEDDING_DIM)

In [None]:
embeddings

# Dataset and DataLoader

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# DataLoaders
train_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'train'), batch_size=BATCH_SIZE, shuffle=False,
                                               num_workers=WORKER, pin_memory=True)
len(train_loader)

# Initialize Our Model

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

start_epoch = 0  # start at this epoch
grad_clip = None

# DataLoaders
train_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'train'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)

valid_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'valid'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)

In [None]:
model = HierarchialAttentionNetwork(n_classes=n_classes,
                                            vocab_size=len(word_map),
                                            emb_size=emb_size,
                                            word_rnn_size=WORD_RNN_SIZE,
                                            sentence_rnn_size=SENTENCE_RNN_SIZE,
                                            word_rnn_layers=WORD_RNN_LAYERS,
                                            sentence_rnn_layers=SENTENCE_RNN_LAYERS,
                                            word_att_size=WORD_ATTENTION_SIZE,
                                            sentence_att_size=SENTENCE_ATTENTION_SIZE,
                                            dropout=DROPOUT)
model

In [None]:
model.sentence_attention.word_attention.init_embeddings(
            embeddings)  
model.sentence_attention.word_attention.fine_tune_embeddings(FINE_TUNE_EMBEDDING) 

# Train Our Model

In [None]:
optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()))

cudnn.benchmark = True
criterion = nn.BCEWithLogitsLoss()

model = model.to(device)
criterion = criterion.to(device)

In [None]:
#train_epoch(start_epoch, N_EPOCHS, train_loader, model, criterion, optimizer, word_map)
train_losses, valid_losses, train_accs, valid_accs = \
    train_epoch(start_epoch, N_EPOCHS, train_loader, model, criterion, optimizer, word_map, 
                'han_rand', grad_clip, valid_loader, early_stop = True, period = 30)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, figsize=(15,10))
analysis_plotter(fig, ax1, train_losses, valid_losses, 'Training/Validation Loss', {'label': 'Training Loss'}, {'label': 'Validation Loss'})
analysis_plotter(fig, ax2, train_accs, valid_accs, 'Training/Validation Micro-F-Measure', {'label': 'Training F-Measure'}, {'label': 'Validation F-Measure'})

In [None]:
model.eval()

test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)


test_f_scores, predicts = test(model, test_loader, criterion, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

In [None]:
checkpoint = os.path.join(DATA_FOLDER, 'han_rand_fscore.pt')

# Load model
checkpoint = torch.load(checkpoint)
model = checkpoint['model']

# Load test data
test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

# Deal with Imbalance

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

embeddings, emb_size = initialize_embeddings(EMBEDDING_DIM, word_map)
embeddings[word_map["<pad>"]] = torch.zeros(EMBEDDING_DIM)
embeddings[word_map["<unk>"]] = torch.zeros(EMBEDDING_DIM)

model = HierarchialAttentionNetwork(n_classes=n_classes,
                                            vocab_size=len(word_map),
                                            emb_size=emb_size,
                                            word_rnn_size=WORD_RNN_SIZE,
                                            sentence_rnn_size=SENTENCE_RNN_SIZE,
                                            word_rnn_layers=WORD_RNN_LAYERS,
                                            sentence_rnn_layers=SENTENCE_RNN_LAYERS,
                                            word_att_size=WORD_ATTENTION_SIZE,
                                            sentence_att_size=SENTENCE_ATTENTION_SIZE,
                                            dropout=DROPOUT)

model.sentence_attention.word_attention.init_embeddings(
            embeddings)  
model.sentence_attention.word_attention.fine_tune_embeddings(FINE_TUNE_EMBEDDING) 

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# DataLoaders
train_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'train'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)
valid_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'valid'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)

In [None]:
optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.BCEWithLogitsLoss(pos_weight=POS_WEIGHT)
model = model.to(device)
criterion = criterion.to(device)

train_losses, valid_losses, train_accs, valid_accs = \
    train_epoch(start_epoch, N_EPOCHS, train_loader, model, criterion, optimizer, word_map, 
                'han_rand', grad_clip, valid_loader, early_stop = True, period = 30)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, figsize=(15,10))
analysis_plotter(fig, ax1, train_losses, valid_losses, 'Training/Validation Loss', {'label': 'Training Loss'}, {'label': 'Validation Loss'})
analysis_plotter(fig, ax2, train_accs, valid_accs, 'Training/Validation Micro-F-Measure', {'label': 'Training F-Measure'}, {'label': 'Validation F-Measure'})

In [None]:
# Load test data
test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

In [None]:
checkpoint = os.path.join(DATA_FOLDER, 'han_rand_fscore.pt')

# Load model
checkpoint = torch.load(checkpoint)
model = checkpoint['model']

# Load test data
test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

# Word2vec

In [None]:
def train_word2vec_model(data_folder, embedding_dim, algorithm='skipgram'):
    sg = 1 if algorithm is 'skipgram' else 0

    sentences = torch.load(os.path.join(data_folder, 'word2vec_data.pth.tar'))
    sentences = list(itertools.chain.from_iterable(sentences))

    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

    model = gensim.models.word2vec.Word2Vec(sentences=sentences, size=embedding_dim, workers=8, window=10, min_count=5,
                                            sg=sg, iter = 50, seed = SEED)
    model.init_sims(True)
    model.wv.save(os.path.join(data_folder, f'word2vec_{algorithm}_model')) 
    model.wv.save_word2vec_format(os.path.join(data_folder, f'word2vec_{algorithm}_model.bin'))

train_word2vec_model(data_folder=DATA_FOLDER, embedding_dim= EMBEDDING_DIM)

In [None]:
word2vec_file = os.path.join(DATA_FOLDER, 'word2vec_skipgram_model')  # path to pre-trained word2vec embeddings

def load_word2vec_embeddings(word2vec_file, word_map):
    w2v = gensim.models.KeyedVectors.load(word2vec_file, mmap='r')

    print("\nEmbedding length is %d.\n" % w2v.vector_size)
    embeddings = torch.FloatTensor(len(word_map), w2v.vector_size)
    init_embedding(embeddings)

    print("Loading embeddings...")
    for word in word_map:
        if word in w2v.vocab:
            embeddings[word_map[word]] = torch.FloatTensor(w2v[word])

    print("Done.\n Embedding vocabulary: %d.\n" % len(word_map))

    return embeddings, w2v.vector_size

with open(os.path.join(DATA_FOLDER, 'word_map.json'), 'r') as j:
    word_map = json.load(j)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

embeddings, emb_size = load_word2vec_embeddings(word2vec_file, word_map)
embeddings[word_map["<pad>"]] = torch.zeros(EMBEDDING_DIM)
embeddings[word_map["<unk>"]] = torch.zeros(EMBEDDING_DIM)

# Initialize and Train Our Model

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

train_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'train'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)
valid_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'valid'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)

model = HierarchialAttentionNetwork(n_classes=n_classes,
                                            vocab_size=len(word_map),
                                            emb_size=emb_size,
                                            word_rnn_size=WORD_RNN_SIZE,
                                            sentence_rnn_size=SENTENCE_RNN_SIZE,
                                            word_rnn_layers=WORD_RNN_LAYERS,
                                            sentence_rnn_layers=SENTENCE_RNN_LAYERS,
                                            word_att_size=WORD_ATTENTION_SIZE,
                                            sentence_att_size=SENTENCE_ATTENTION_SIZE,
                                            dropout=DROPOUT)

model.sentence_attention.word_attention.init_embeddings(embeddings) 
model.sentence_attention.word_attention.fine_tune_embeddings(FINE_TUNE_EMBEDDING) 

In [None]:
optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.BCEWithLogitsLoss(pos_weight=POS_WEIGHT)
model = model.to(device)
criterion = criterion.to(device)

train_losses, valid_losses, train_accs, valid_accs = \
    train_epoch(start_epoch, N_EPOCHS, train_loader, model, criterion, optimizer, word_map, 
                'han_w2v', grad_clip, valid_loader, early_stop = True, period = 30)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, figsize=(15,10))
analysis_plotter(fig, ax1, train_losses, valid_losses, 'Training/Validation Loss', {'label': 'Training Loss'}, {'label': 'Validation Loss'})
analysis_plotter(fig, ax2, train_accs, valid_accs, 'Training/Validation Micro-F-Measure', {'label': 'Training F-Measure'}, {'label': 'Validation F-Measure'})

In [None]:
# Load test data
test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

In [None]:
checkpoint = os.path.join(DATA_FOLDER, 'han_w2v_fscore.pt')

# Load model
checkpoint = torch.load(checkpoint)
model = checkpoint['model']

# Load test data
test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

In [None]:
test_f_scores, predicts = test(model, test_loader, criterion, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

## Glove

In [None]:
embeddings, emb_size, vocab, PAD, GLOVE_UNK_TOKEN = load_glove_embeddings()

GLOVE_INPUT_DIM = len(vocab)
GLOVE_EMBEDDING_DIM = emb_size
GLOVE_PAD_IDX = vocab[PAD]
GLOVE_UNK_IDX = vocab[GLOVE_UNK_TOKEN]

print("Input dimension: %s\nUnknown word (%s) index: %s\nPadding index: %s\nEmbedding dimension: %s" % 
      (GLOVE_INPUT_DIM, GLOVE_UNK_TOKEN, GLOVE_UNK_IDX, GLOVE_PAD_IDX, GLOVE_EMBEDDING_DIM))

In [None]:
create_input_files_for_glove(vocab.stoi, TRAIN, TEST, DATA_FOLDER, SENTENCE_LIMIT, WORD_LIMIT, 
                             read_all_text_from_corpus, MIN_WORD_COUNT, VALID)

In [None]:
with open(os.path.join(DATA_FOLDER, 'glove_word_map.json'), 'r') as j:
    word_map = json.load(j)

print(embeddings)
embeddings[GLOVE_PAD_IDX] = torch.zeros(GLOVE_EMBEDDING_DIM)
embeddings[GLOVE_UNK_IDX] = torch.zeros(GLOVE_EMBEDDING_DIM)
print(embeddings)

In [None]:
BATCH_SIZE = 8 # batch size

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
# DataLoaders
train_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'GLOVE_TRAIN'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)
valid_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'GLOVE_VALID'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)

model = HierarchialAttentionNetwork(n_classes=n_classes,
                                            vocab_size=len(word_map),
                                            emb_size=emb_size,
                                            word_rnn_size=WORD_RNN_SIZE,
                                            sentence_rnn_size=SENTENCE_RNN_SIZE,
                                            word_rnn_layers=WORD_RNN_LAYERS,
                                            sentence_rnn_layers=SENTENCE_RNN_LAYERS,
                                            word_att_size=WORD_ATTENTION_SIZE,
                                            sentence_att_size=SENTENCE_ATTENTION_SIZE,
                                            dropout=DROPOUT)

model.sentence_attention.word_attention.init_embeddings(embeddings)  
model.sentence_attention.word_attention.fine_tune_embeddings(FINE_TUNE_EMBEDDING)
model

In [None]:
print(model.sentence_attention.word_attention.embeddings.weight)
print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.BCEWithLogitsLoss(pos_weight=POS_WEIGHT)
model = model.to(device)
criterion = criterion.to(device)
start_epoch = 0
grad_clip = None
train_losses, valid_losses, train_accs, valid_accs = \
    train_epoch(start_epoch, N_EPOCHS, train_loader, model, criterion, optimizer, word_map, 
                'glove', grad_clip, valid_loader, early_stop=True, period = 30)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, figsize=(15,10))
analysis_plotter(fig, ax1, train_losses, valid_losses, 'Training/Validation Loss', {'label': 'Training Loss'}, {'label': 'Validation Loss'})
analysis_plotter(fig, ax2, train_accs, valid_accs, 'Training/Validation Micro-F-Measure', {'label': 'Training F-Measure'}, {'label': 'Validation F-Measure'})                

# Load test data
test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'glove_test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')

In [None]:
checkpoint = os.path.join(DATA_FOLDER, 'glove_fscore.pt')

# Load model
checkpoint = torch.load(checkpoint)
model = checkpoint['model']

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')        

# BERT

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
DROPOUT = 0  # dropout

BATCH_SIZE = 2

BERT_MODEL = 'bert-base-uncased'

bert_tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case = True)
bert = BertModel.from_pretrained(BERT_MODEL, output_hidden_states = True)

BERT_BOS_TOKEN = bert_tokenizer.cls_token
BERT_EOS_TOKEN = bert_tokenizer.sep_token
BERT_PAD = bert_tokenizer.pad_token
BERT_UNK = bert_tokenizer.unk_token

BERT_BOS_IDX = bert_tokenizer.convert_tokens_to_ids(BERT_BOS_TOKEN)
BERT_EOS_IDX = bert_tokenizer.convert_tokens_to_ids(BERT_EOS_TOKEN)
BERT_PAD_IDX = bert_tokenizer.convert_tokens_to_ids(BERT_PAD)
BERT_UNK_IDX = bert_tokenizer.convert_tokens_to_ids(BERT_UNK)

print(f'{BERT_BOS_TOKEN}: {BERT_BOS_IDX}, {BERT_EOS_TOKEN}:{BERT_EOS_IDX}, {BERT_PAD}:{BERT_PAD_IDX}, \
{BERT_UNK}:{BERT_UNK_IDX}') #, bos_token, eos_token')

BERT_WORD_LIMIT = bert_tokenizer.max_model_input_sizes[BERT_MODEL]
print(BERT_WORD_LIMIT)

In [None]:
def create_input_files_for_bert(bert_tokenizer, train, test, output_folder, sentence_limit, word_limit, read_tsv, 
                       valid = None):
    print(f'\nReading and preprocessing training data {train}...\n')
    train_docs, train_labels, word_counter = read_tsv(bert_tokenizer, train, sentence_limit, word_limit)
    print(len(train_docs))
    
    word_map = dict()
    for word, count in word_counter.items():
        word_map[word] = bert_tokenizer.convert_tokens_to_ids(word)    
    
    print('Encoding and padding training data...\n')
    encoded_train_docs = list(map(lambda doc: list(
        map(lambda s: [BERT_BOS_IDX] + list(map(lambda w: word_map.get(w, BERT_UNK_IDX), s)) + [BERT_EOS_IDX]
            + [BERT_PAD_IDX] * (word_limit - len(s) - 2), doc)) + [[BERT_PAD_IDX] * word_limit] * (sentence_limit - len(doc)), 
                                  train_docs))
    
    sentences_per_train_document = list(map(lambda doc: len(doc), train_docs))
    words_per_train_sentence = list(
        map(lambda doc: list(map(lambda s: len(s)+2, doc)) + [0] * (sentence_limit - len(doc)), train_docs))
    
    # Save
    print('Saving...\n')
    output_content = {'docs': encoded_train_docs,
                'labels': train_labels,
                'sentences_per_document': sentences_per_train_document,
                'words_per_sentence': words_per_train_sentence}
    torch.save(output_content,
               os.path.join(output_folder, 'BERT_TRAIN_data.pth.tar'))
    print('Encoded, padded training data (BERT_TRAIN_data.pth.tar) saved to %s.\n' % os.path.abspath(output_folder))

    del train_docs, encoded_train_docs, train_labels, sentences_per_train_document, words_per_train_sentence

    # Read test data
    print(f'Reading and preprocessing test data {test}...\n')
    test_docs, test_labels, word_counter = read_tsv(bert_tokenizer, test, sentence_limit, word_limit)
    print(f'Updating word map based on test data {test}...\n')
    for word, count in word_counter.items():
        word_map[word] = bert_tokenizer.convert_tokens_to_ids(word)    
    
    # Encode and pad
    print('\nEncoding and padding test data...\n')
    encoded_test_docs = list(map(lambda doc: list(
        map(lambda s: [BERT_BOS_IDX] + list(map(lambda w: word_map.get(w, BERT_UNK_IDX), s)) + [BERT_EOS_IDX]
            + [BERT_PAD_IDX] * (word_limit - len(s) - 2), doc)) + [[BERT_PAD_IDX] * word_limit] * (sentence_limit - len(doc)), 
                                 test_docs))
    sentences_per_test_document = list(map(lambda doc: len(doc), test_docs))
    
    words_per_test_sentence = list(
        map(lambda doc: list(map(lambda s: len(s)+2, doc)) + [0] * (sentence_limit - len(doc)), test_docs))

    # Save
    print('Saving...\n')
    torch.save({'docs': encoded_test_docs,
                'labels': test_labels,
                'sentences_per_document': sentences_per_test_document,
                'words_per_sentence': words_per_test_sentence},
               os.path.join(output_folder, 'BERT_TEST_data.pth.tar'))
    print('Encoded, padded test data (BERT_TEST_data.pth.tar) saved to %s.\n' % os.path.abspath(output_folder))

    del test_docs, encoded_test_docs, test_labels, sentences_per_test_document, words_per_test_sentence
    if valid:
        print(f'Reading and preprocessing validation data {valid}...\n')
        valid_docs, valid_labels, word_counter = read_tsv(bert_tokenizer, valid, sentence_limit, word_limit)
        print(f'Updating word map based on validation data {valid}...\n')
        for word, count in word_counter.items():
            word_map[word] = bert_tokenizer.convert_tokens_to_ids(word)    
    
        # Encode and pad
        print('\nEncoding and padding validation data...\n')
        encoded_valid_docs = list(map(lambda doc: list(
        map(lambda s: [BERT_BOS_IDX] + list(map(lambda w: word_map.get(w, BERT_UNK_IDX), s)) + [BERT_EOS_IDX]
            + [BERT_PAD_IDX] * (word_limit - len(s) - 2), doc)) + [[BERT_PAD_IDX] * word_limit] * (sentence_limit - len(doc)), 
                                 valid_docs))
        sentences_per_valid_document = list(map(lambda doc: len(doc), valid_docs))
        words_per_valid_sentence = list(
            map(lambda doc: list(map(lambda s: len(s)+2, doc)) + [0] * (sentence_limit - len(doc)), valid_docs))
        # Save
        print('Saving...\n')
        torch.save({'docs': encoded_valid_docs,
                'labels': valid_labels,
                'sentences_per_document': sentences_per_valid_document,
                'words_per_sentence': words_per_valid_sentence},
                   os.path.join(output_folder, 'BERT_VALID_data.pth.tar'))
        print('Encoded, padded valid data (BERT_VALID_data.pth.tar) saved to %s.\n' % os.path.abspath(output_folder))
    
    print('\nThe size of the vocabulary is %d.\n' % (len(word_map)))
    with open(os.path.join(output_folder, 'BERT_word_map.json'), 'w') as j:
        json.dump(word_map, j)
    print('Word map saved to %s.\n' % os.path.abspath(output_folder))

    print('All done!\n')

In [None]:
def read_all_text_from_corpus_with_bert(word_tokenizer, tsv_file, sentence_limit, word_limit):
    docs = []
    labels = []
    word_counter = Counter()
    with io.open(tsv_file, 'r', encoding="utf-8") as file:
        for i, line in enumerate(tqdm(file)):
            sentences = list()
            pid, bh_text, ep_text, major_d, sc, bp, minor_d, de = preprocess(line).strip().split('\t')
            all_text = "%s <sep> %s" % (bh_text, ep_text)
            sentences.extend([s for s in sent_tokenizer(all_text)])
            
            words = list()
            for s in sentences[:sentence_limit]:
                w = word_tokenizer.tokenize(s)[:word_limit-2]
                # If sentence is empty (due to ?)
                if len(w) == 0:
                    continue
                words.append(w)
                word_counter.update(w)
            # If all sentences were empty
            if len(words) == 0:
                continue
            
            labels.append([float(major_d), float(sc), float(bp), float(minor_d), float(de)]) 
            docs.append(words)
    return docs, labels, word_counter

In [None]:
checkpoint = None 

with open(os.path.join(DATA_FOLDER, 'bert_word_map.json'), 'r') as j:
    word_map = json.load(j)

han_bert_cache = torch.load('han_cache_train.pt')

In [None]:
class BERTHierarchialAttentionNetwork(nn.Module):
    def __init__(self, bert, n_classes, vocab_size, word_rnn_size, sentence_rnn_size, word_rnn_layers,
                 sentence_rnn_layers, word_att_size, sentence_att_size, han_bert_cache, dropout=0.5):
        super(BERTHierarchialAttentionNetwork, self).__init__()

        self.sentence_attention = BERTSentenceAttention(bert, vocab_size, word_rnn_size, sentence_rnn_size,
                                                    word_rnn_layers, sentence_rnn_layers, word_att_size,
                                                    sentence_att_size, dropout, han_bert_cache)

        self.fc = nn.Linear(2 * sentence_rnn_size, n_classes)
        self.han_bert_cache = han_bert_cache

        self.dropout = nn.Dropout(dropout)

    def forward(self, documents, sentences_per_document, words_per_sentence):
        document_embeddings, word_alphas, sentence_alphas = self.sentence_attention(documents, sentences_per_document,
                                                                                    words_per_sentence)  
        scores = self.fc(self.dropout(document_embeddings))  # (n_documents, n_classes)

        return scores, word_alphas, sentence_alphas


class BERTSentenceAttention(nn.Module):
    def __init__(self, bert, vocab_size, word_rnn_size, sentence_rnn_size, word_rnn_layers, sentence_rnn_layers,
                 word_att_size, sentence_att_size, dropout, han_bert_cache):
        super(BERTSentenceAttention, self).__init__()

        self.word_attention = BERTWordAttention(bert, vocab_size, word_rnn_size, word_rnn_layers, word_att_size,
                                            dropout, han_bert_cache)

        self.sentence_rnn = nn.GRU(2 * word_rnn_size, sentence_rnn_size, num_layers=sentence_rnn_layers,
                                   bidirectional=True, dropout=dropout, batch_first=True)

        self.sentence_attention = nn.Linear(2 * sentence_rnn_size, sentence_att_size)

        self.sentence_context_vector = nn.Linear(sentence_att_size, 1,
                                                 bias=False) 
        self.dropout = nn.Dropout(dropout)

    def forward(self, documents, sentences_per_document, words_per_sentence):
        packed_sentences = pack_padded_sequence(documents,
                                                lengths=sentences_per_document.tolist(), 
                                                batch_first=True,
                                                enforce_sorted=False)  
        packed_words_per_sentence = pack_padded_sequence(words_per_sentence,
                                                         lengths=sentences_per_document.tolist(),
                                                         batch_first=True,
                                                         enforce_sorted=False)  
        sentences, word_alphas = self.word_attention(packed_sentences.data,
                                                     packed_words_per_sentence.data)  
        sentences = self.dropout(sentences)
        
        packed_sentences, _ = self.sentence_rnn(PackedSequence(data=sentences,
                                                               batch_sizes=packed_sentences.batch_sizes,
                                                               sorted_indices=packed_sentences.sorted_indices,
                                                               unsorted_indices=packed_sentences.unsorted_indices))

        att_s = self.sentence_attention(packed_sentences.data) 
        att_s = torch.tanh(att_s) 
        att_s = self.sentence_context_vector(att_s).squeeze(1)

        max_value = att_s.max() 
        att_s = torch.exp(att_s - max_value)

        att_s, _ = pad_packed_sequence(PackedSequence(data=att_s,
                                                      batch_sizes=packed_sentences.batch_sizes,
                                                      sorted_indices=packed_sentences.sorted_indices,
                                                      unsorted_indices=packed_sentences.unsorted_indices),
                                       batch_first=True) 

        sentence_alphas = att_s / torch.sum(att_s, dim=1, keepdim=True)

        documents, _ = pad_packed_sequence(packed_sentences, batch_first=True)

        documents = documents * sentence_alphas.unsqueeze(2)
        documents = documents.sum(dim=1)

        word_alphas, _ = pad_packed_sequence(PackedSequence(data=word_alphas,
                                                            batch_sizes=packed_sentences.batch_sizes,
                                                            sorted_indices=packed_sentences.sorted_indices,
                                                            unsorted_indices=packed_sentences.unsorted_indices),
                                             batch_first=True)

        return documents, word_alphas, sentence_alphas


class BERTWordAttention(nn.Module):
    def __init__(self, bert, vocab_size, word_rnn_size, word_rnn_layers, word_att_size, dropout, han_bert_cache):
        super(BERTWordAttention, self).__init__()
        self.bert = bert
        self.han_bert_cache = han_bert_cache
        self.bert.eval()
        self.embedding_dim = self.bert.config.to_dict()['hidden_size']

        self.word_rnn = nn.GRU(self.embedding_dim, word_rnn_size, num_layers=word_rnn_layers, bidirectional=True,
                               dropout=dropout, batch_first=True)

        self.word_attention = nn.Linear(2 * word_rnn_size, word_att_size)

        self.word_context_vector = nn.Linear(word_att_size, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def create_attention_masks(self, ids):
        attention_masks = []
        for id in ids:
            id_mask = [float(i>0) for i in id]            
            attention_masks.append(id_mask)
        return torch.tensor(attention_masks).to(device)
            
    def embeddings(self, batch):
        # ID:102 is used to separate sentences
        # [batch size, sent len]
        batch_embeddings = []
        for sents in batch:
            key = ' '.join(str(x) for x in sents.data.tolist())
            if key in self.han_bert_cache:
                sent_embeddings = self.han_bert_cache[key].to(device)
            else:            
                print('Not Found')
                sep_idxes = (sents == BERT_EOS_IDX).nonzero().squeeze(1).data.tolist()
                seq_lengths = []
                sents_ids = []
                pv = -1
                for k, v in enumerate(sep_idxes):                
                    sent_embedding = [PAD_IDX]*WORD_LIMIT
                    if k == 0:
                        seq_lengths.append(v+1)
                        sent_embedding[:v+1] = sents[:v+1].data.tolist()
                    else:
                        seq_lengths.append(v-pv)
                        sent_embedding[:v-pv] = sents[pv+1:v+1].data.tolist()
                    sents_ids.append(sent_embedding)
                    pv = v
                attention_masks = self.create_attention_masks(sents_ids)
                sents_ids = torch.tensor(sents_ids).to(device)
                sent_embeddings = []
                with torch.no_grad():
                    last_hidden_state, _, hidden_states = self.bert(sents_ids, attention_masks)
                    token_embeddings = torch.stack(hidden_states[:-1], dim=0)
                    token_embeddings = token_embeddings.permute(1, 2, 0, 3)
                    for id, tks in enumerate(token_embeddings):
                        token_vecs = []
                        for i in range(seq_lengths[id]):
                            #cat_vec = torch.cat((tks[i][-1], tks[i][-2], tks[i][-3], tks[i][-4]), dim =0)
                            cat_vec = tks[i][-1] + tks[i][-2] + tks[i][-3] + tks[i][-4]
                            token_vecs.append(cat_vec)
                        token_vecs=torch.stack(token_vecs, 0)
                        sent_embeddings.append(token_vecs)
                    sent_embeddings = torch.cat(sent_embeddings, 0)                

                    if sent_embeddings.shape[0] != WORD_LIMIT:
                        sent_embeddings = torch.cat((sent_embeddings, \
                                torch.zeros(WORD_LIMIT - sent_embeddings.shape[0], self.embedding_dim).to(device)), 0)
                    # # sentences, # words, # layers, # features
                self.han_bert_cache[key] = sent_embeddings
            batch_embeddings.append(sent_embeddings)
        batch_embeddings = torch.stack(batch_embeddings, 0)
        return batch_embeddings

    def forward(self, sentences, words_per_sentence):
        sentences = self.dropout(self.embeddings(sentences))  
        packed_words = pack_padded_sequence(sentences,
                                            lengths=words_per_sentence.tolist(),
                                            batch_first=True,
                                            enforce_sorted=False)
        
        packed_words, _ = self.word_rnn(
            packed_words)
        
        att_w = self.word_attention(packed_words.data) 
        att_w = torch.tanh(att_w) 
        att_w = self.word_context_vector(att_w).squeeze(1) 

        max_value = att_w.max() 
        att_w = torch.exp(att_w - max_value)
        
        att_w, _ = pad_packed_sequence(PackedSequence(data=att_w,
                                                      batch_sizes=packed_words.batch_sizes,
                                                      sorted_indices=packed_words.sorted_indices,
                                                      unsorted_indices=packed_words.unsorted_indices),
                                       batch_first=True)
        
        word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True)  
        sentences, _ = pad_packed_sequence(packed_words,
                                           batch_first=True) 

        sentences = sentences * word_alphas.unsqueeze(2)  
        sentences = sentences.sum(dim=1) 
        
        return sentences, word_alphas

In [None]:
def bert_train_epoch(start_epoch, epochs, data_loader, model, criterion, optimizer, word_map, model_name, grad_clip, valid_iterator = None, 
                interval = 1, early_stop = False, period = 20, gap = 0.005, threshold = 0.5, best_valid_fscore = 0):
    best_valid_loss = float('inf')
    train_losses = []
    valid_losses = []
    train_accs = []
    valid_accs = []
    observed_time = 0
    
    # Epochs
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        # One epoch's training
        try:
            train_loss, train_acc = train(train_loader=data_loader, model=model, 
                                          criterion=criterion, optimizer=optimizer, grad_clip=grad_clip)#, epoch=epoch, 
                  #total_epoch = epochs)
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            end_time = time.time()
            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_iterator:
                valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)#, model_type)
                valid_losses.append(valid_loss)
                valid_accs.append(valid_acc)
            else:
                valid_loss = 0 

            if (epoch + 1) % interval == 0:
                print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
                print(f'\tTrain Loss: {train_loss:.3f} | Train micro-F-score: {train_acc*100:.2f}%')
                if valid_iterator:
                    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. micro-F-score: {valid_acc*100:.2f}%')
            elif epoch == epochs - 1:
                print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
                print(f'\tTrain Loss: {train_loss:.3f} | Train micro-F-score: {train_acc*100:.2f}%')
                if valid_iterator:
                    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. micro-F-score: {valid_acc*100:.2f}%')

            # Save checkpoint
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                cache = model.han_bert_cache
                model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = None
                save_checkpoint(epoch, model, optimizer, word_map, DATA_FOLDER, model_name + '_loss.pt')
                model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = cache
            if early_stop and best_valid_fscore > threshold and best_valid_fscore - valid_acc > gap:
                observed_time += 1
                print(f'\rBest validation F-measure: {best_valid_fscore:.3f}/Current F-measure: {valid_acc:.3f} [Times: {observed_time}/{period}]')  
                if observed_time >= period:
                    print(f'Early stop at epoch {epoch+1:02}.')
                    break                        
            if valid_acc > best_valid_fscore:
                best_valid_fscore = valid_acc
                cache = model.han_bert_cache
                model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = None
                save_checkpoint(epoch, model, optimizer, word_map, DATA_FOLDER, model_name + '_fscore.pt')            
                model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = cache
                observed_time = 0            
        except RuntimeError as e:
            print(f'Runtime error at epoch {epoch+1:02}: {str(e)}')
            model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = None
            save_checkpoint(epoch, model, optimizer, word_map, DATA_FOLDER, model_name + '_except.pt')
            return train_losses, valid_losses, train_accs, valid_accs
        
        model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = None
        save_checkpoint(epoch, model, optimizer, word_map, DATA_FOLDER, model_name + '_current.pt')
        model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = cache
    # Decay learning rate every epoch
    #adjust_learning_rate(optimizer, 0.1)

    #return model
    return train_losses, valid_losses, train_accs, valid_accs

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
# DataLoaders
train_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'BERT_TRAIN'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)
valid_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'BERT_VALID'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)

model = BERTHierarchialAttentionNetwork(bert=bert, n_classes=n_classes,
                                    vocab_size=len(word_map),
                                    word_rnn_size=WORD_RNN_SIZE,
                                    sentence_rnn_size=SENTENCE_RNN_SIZE,
                                    word_rnn_layers=WORD_RNN_LAYERS,
                                    sentence_rnn_layers=SENTENCE_RNN_LAYERS,
                                    word_att_size=WORD_ATTENTION_SIZE,
                                    sentence_att_size=SENTENCE_ATTENTION_SIZE,
                                    han_bert_cache = han_bert_cache, dropout=DROPOUT)
print(f'The model has {count_parameters(model):,} trainable parameters')
model

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

train_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'BERT_TRAIN'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)
valid_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'BERT_VALID'), batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=WORKER, pin_memory=True)

In [None]:
checkpoint = os.path.join(DATA_FOLDER, 'bert_current.pt')
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = han_bert_cache

In [None]:
N_EPOCHS = 500
BATCH_SIZE = 12
optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()))
criterion = nn.BCEWithLogitsLoss(pos_weight=POS_WEIGHT)
model = model.to(device)
criterion = criterion.to(device)
start_epoch = 113 #89
best_valid_fscore = 0.5258
period = 6
grad_clip = None
train_losses, valid_losses, train_accs, valid_accs = \
    bert_train_epoch(start_epoch, N_EPOCHS, train_loader, model, criterion, optimizer, word_map, 
                'bert', grad_clip, valid_loader, early_stop = True, period = period, best_valid_fscore = best_valid_fscore)

In [None]:
han_bert_cache_test = torch.load('han_cache_test.pt')
model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = han_bert_cache_test

In [None]:
test_loader = DataLoader(NTUH_HANDataset(DATA_FOLDER, 'bert_test'), batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=WORKER, pin_memory=True)

test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')


In [None]:
checkpoint = os.path.join(DATA_FOLDER, 'bert_fscore.pt')
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
model.han_bert_cache = model.sentence_attention.word_attention.han_bert_cache = han_bert_cache_test
test_f_scores, predicts = test(model, test_loader, 0)

for f in test_f_scores:
    if f is MICRO or f is MACRO:
        print(f'{f}-average:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')
    else:
        print(f'{NTUH_HANDataset.diagnosis_types[f]}:\n\tprecision: {test_f_scores[f]["p"]:0.3f}\n\trecall: {test_f_scores[f]["r"]:0.3f}\n\tf-score: {test_f_scores[f]["f"]:0.3f}\n')                
                                            

# Results Analysis

In [None]:
def classify(document):
    doc = list()

    sentences = list()    
    pid, bh_text, ep_text, major_d, sc, bp, minor_d, de = preprocess(document).strip().split('\t')
    all_text = "%s <sep> %s" % (bh_text, ep_text)
    sentences.extend([s for s in sent_tokenizer(all_text)])
    
    for s in sentences[:sentence_limit]:
        w = word_tokenizer(s)[:word_limit]
        if len(w) == 0:
            continue
        doc.append(w)

    sentences_in_doc = len(doc)
    sentences_in_doc = torch.LongTensor([sentences_in_doc]).to(device)  # (1)

    words_in_each_sentence = list(map(lambda s: len(s), doc))
    words_in_each_sentence = torch.LongTensor(words_in_each_sentence).unsqueeze(0).to(device)  # (1, n_sentences)

    encoded_doc = list(
        map(lambda s: list(map(lambda w: word_map.get(w, word_map['<unk>']), s)) + [0] * (word_limit - len(s)),
            doc)) + [[0] * word_limit] * (sentence_limit - len(doc))
    encoded_doc = torch.LongTensor(encoded_doc).unsqueeze(0).to(device)

    scores, word_alphas, sentence_alphas = model(encoded_doc, sentences_in_doc,
                                                 words_in_each_sentence)  # (1, n_classes), (1, n_sentences, max_sent_len_in_document), (1, n_sentences)
        
    word_alphas = word_alphas.squeeze(0)  # (n_sentences, max_sent_len_in_document)
    sentence_alphas = sentence_alphas.squeeze(0)  # (n_sentences)
    words_in_each_sentence = words_in_each_sentence.squeeze(0)  # (n_sentences)

    return doc, scores.squeeze(0), word_alphas, sentence_alphas, words_in_each_sentence


def visualize_attention(doc, scores, word_alphas, sentence_alphas, words_in_each_sentence):
    # Find best prediction
    rounded_preds = torch.round(torch.sigmoid(scores)).squeeze(0) # (n_classes)
    prediction = ''
    for i, score in enumerate(rounded_preds):
        if score == 1:
            prediction += '{category} ({score:.2f}\n) '.format(category=rev_label_map[i], score=scores.tolist()[i])
    
    alphas = (sentence_alphas.unsqueeze(1) * word_alphas * words_in_each_sentence.unsqueeze(
        1).float() / words_in_each_sentence.max().float())
    alphas = alphas.to('cpu')

    min_font_size = 15 
    max_font_size = 55 
    space_size = ImageFont.truetype("./calibril.ttf", max_font_size).getsize(' ') 
    line_spacing = 15 
    left_buffer = 300 
    top_buffer = 2 * min_font_size + 3 * line_spacing  
    image_width = left_buffer
    image_height = top_buffer + line_spacing  
    word_loc = [image_width, image_height] 
    rectangle_height = 0.75 * max_font_size 
    max_rectangle_width = 0.8 * left_buffer 
    rectangle_loc = [0.9 * left_buffer,
                     image_height + rectangle_height]
    word_viz_properties = list()
    sentence_viz_properties = list()
    for s, sentence in enumerate(doc):
        sentence_factor = sentence_alphas[s].item() / sentence_alphas.max().item()
        rectangle_saturation = str(int(sentence_factor * 100))
        rectangle_lightness = str(25 + 50 - int(sentence_factor * 50))
        rectangle_color = 'hsl(0,' + rectangle_saturation + '%,' + rectangle_lightness + '%)'
        rectangle_bounds = [rectangle_loc[0] - sentence_factor * max_rectangle_width,
                            rectangle_loc[1] - rectangle_height] + rectangle_loc

        sentence_viz_properties.append({'bounds': rectangle_bounds.copy(),
                                        'color': rectangle_color})

        for w, word in enumerate(sentence):
            word_factor = alphas[s, w].item() / alphas.max().item()

            word_saturation = str(int(word_factor * 100))
            word_lightness = str(25 + 50 - int(word_factor * 50))
            word_color = 'hsl(0,' + word_saturation + '%,' + word_lightness + '%)'

            word_font_size = int(min_font_size + word_factor * (max_font_size - min_font_size))
            word_font = ImageFont.truetype("./calibril.ttf", word_font_size)

            word_viz_properties.append({'loc': word_loc.copy(),
                                        'word': word,
                                        'font': word_font,
                                        'color': word_color})

            word_size = word_font.getsize(word)
            word_loc[0] += word_size[0] + space_size[0]
            image_width = max(image_width, word_loc[0])
        word_loc[0] = left_buffer
        word_loc[1] += max_font_size + line_spacing
        image_height = max(image_height, word_loc[1])
        rectangle_loc[1] += max_font_size + line_spacing

    img = Image.new('RGB', (image_width, image_height), (255, 255, 255))
    
    draw = ImageDraw.Draw(img)
    for viz in word_viz_properties:
        draw.text(xy=viz['loc'], text=viz['word'], fill=viz['color'], font=viz['font'])
    for viz in sentence_viz_properties:
        draw.rectangle(xy=viz['bounds'], fill=viz['color'])
    category_font = ImageFont.truetype("./calibril.ttf", min_font_size)
    draw.text(xy=[line_spacing, line_spacing], text='Detected Category:', fill='grey', font=category_font)
    draw.text(xy=[line_spacing, line_spacing + category_font.getsize('Detected Category:')[1] + line_spacing],
              text=prediction.upper(), fill='black',
              font=category_font)
    del draw
    
    fig = plt.figure(figsize = (15, 15)) # create a 5 x 5 figure 
    ax = fig.add_subplot(111)
    ax.imshow(np.asarray(img), interpolation='none')
    plt.show()

In [None]:
test_lines = []
with io.open(TEST, 'r', encoding="utf-8") as file:
    for i, line in enumerate(tqdm(file)):
        test_lines.append(line)
test_lines[3]

In [None]:
classify(test_lines[0])[1].max(dim=0)

In [None]:
#mpld3.enable_notebook()
visualize_attention(*classify(test_lines[4]))