In [1]:
!pip install torchinfo

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, GloVe, vocab
from tqdm import tqdm
from nltk.tokenize import sent_tokenize, word_tokenize
from collections import OrderedDict



# 1. Hierarchical Attention Network

## Data-Related Classes and Functions

In [2]:
class DataPreprocessorHcl():

    def __init__(self, num_classes, data_vocab):
        self.num_classes = num_classes
        self.vocab = data_vocab
        print("Vocab created: {} unique tokens".format(len(self.vocab)))
    
    @classmethod
    def from_pretrained_embeds(cls, num_classes, embed_path, embed_dim, sep=" ",  specials=['<unk>']):
        # start with all '0's for special tokens
        embeds = [np.asarray([0]*embed_dim, dtype=np.float32)]*len(specials)
        words = OrderedDict()
        with open(embed_path, encoding="utf-8") as f:
            for i, line in enumerate(f):
                #twitter 27B not used, ignore this if block
                if i == 38522 and 'twitter.27B.100d' in embed_path:
                    continue
                splitline = line.split()
                
                word = splitline[0]
                if word not in words:
                    words[word] = 0
                words[word]+=1
                embeds.append(np.asarray(splitline[1:], dtype=np.float32))
                
        embeds = torch.tensor(np.array(embeds))
        data_vocab = vocab(words, specials=specials)
        data_vocab.set_default_index(data_vocab['<unk>'])
        return cls(num_classes, data_vocab), embeds

    @classmethod
    def __yield_tokens(cls, df):
        for row in df.itertuples(index=False):
            yield word_tokenize(row.Text.lower())

    def get_vocab_size(self):
        return len(self.vocab)
    
    def preprocess_data(self, df, max_sent_len, max_num_sents, clean=True):
        '''
        Converts text into integers that index the vocab, and labels into the range [0,num_classes-1].
        Also calculates the number of sentences in each text and the length (number of tokens) of each sentence.
        
        Returns
            X: an array of size (N, max_sent_len, max_num_sents) containing the indices and padding
            ylens: a df where each row has the processed label, number of sentences and a list of sentence lengths
        '''
        if clean:
            X = df['Text'].apply(lambda t: [self.vocab(word_tokenize(s.lower())) for s in sent_tokenize(t.replace("'",""))])
        else:
            X = df['Text'].apply(lambda t: [self.vocab(word_tokenize(s.lower())) for s in sent_tokenize(t)])
        num_sentences = X.apply(lambda sentences : min(max_num_sents, len(sentences)))
        num_sentences.name = 'Num_Sentences'
        num_tokens = X.apply(lambda sentences : list(map(lambda s: min(max_sent_len, len(s)), sentences))[:max_num_sents])
        num_tokens = num_tokens.apply(lambda num_ls : num_ls + [0 for _ in range(max_num_sents-len(num_ls))]) #padding
        num_tokens.name = 'Num_Tokens'
        
        X_padded = np.zeros((len(X), max_num_sents, max_sent_len), dtype='int32')
        for i, sentences in X.items():
            for j, sent in enumerate(sentences):
                if j >= max_num_sents:
                    break
                k = min(max_sent_len, len(sent))
                X_padded[i,j,:k] = sent[:k]
                
        y = df['Label'].apply(lambda l: l-1)
        return X_padded, pd.concat([y, num_sentences, num_tokens], axis=1)

class WrapperDatasetHcl(Dataset):
    '''
    Wrapper for use with dataloader
    '''
    def __init__(self, X, y_and_lens):
        self.X = X
        self.y = y_and_lens['Label']
        self.num_sents = y_and_lens['Num_Sentences']
        self.num_tokens = y_and_lens['Num_Tokens']
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y.iloc[idx], self.num_sents.iloc[idx],  self.num_tokens.iloc[idx]

def collate_fnHcl(batch):
    X = []
    y = []
    num_sent = []
    sent_len = []
    for row in batch:
        X.append(row[0])
        y.append(row[1])
        num_sent.append(row[2])
        sent_len.append(row[3])
    X = torch.tensor(X, dtype=torch.long)
    y = torch.tensor(y, dtype=torch.long)
    num_sent = torch.tensor(num_sent, dtype=torch.long)
    sent_len = torch.tensor(sent_len, dtype=torch.long)
    return X, y, num_sent, sent_len
                 


## Trainer Class

In [52]:
class Trainer():
    def __init__(self, model, train_loader, val_loader, num_epochs, lr, weight_decay = 0,
               lr_anneal_factor=None, lr_anneal_patience=None,
               loss_weights=None,
               save_loss_acc_plots=True):
        #Data
        self.train_loader = train_loader
        self.val_loader = val_loader

        #Model
        self.model=model.to(DEVICE)

        #Training
        self.num_epochs = num_epochs
        self.loss_fn = nn.CrossEntropyLoss()
        if loss_weights is not None:
            self.loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(loss_weights, dtype=torch.float32, device=DEVICE))
        self.optimizer = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.to_anneal_lr = lr_anneal_factor and lr_anneal_patience
        if self.to_anneal_lr:
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                             factor=lr_anneal_factor,
                                             patience=lr_anneal_patience)

    def train(self, save_folder_root_path, file_name_root,
            plot_loss_acc=True, verbose=True):
        #account for pytorch versions < 2.2
        if hasattr(self.scheduler, 'verbose'):
            self.scheduler.verbose=verbose
            
        #Set up logging
        self.__set_up_logging(save_folder_root_path, file_name_root, plot_loss_acc)
        
        best_val_loss = float('inf')
        for i in range(1, self.num_epochs+1):
            self.model.train()
            train_loss = []
            preds = []
            truths = []
            for X, y, num_sent, sent_len in tqdm(self.train_loader, disable=not verbose):
                #Move to correct device
                X = X.to(DEVICE)
                y = y.to(DEVICE)

                #Forward pass
                outputs = self.model(X, num_sent, sent_len, return_attn_weights=False)
                loss = self.loss_fn(outputs, y)

                #Backprop
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                #Logging purposes
                train_loss.append(loss.item())
                preds.append(torch.argmax(outputs, dim=-1).cpu())
                truths.append(y.cpu())

            #Validation
            self.model.eval()
            val_loss, val_acc, val_f1 = self.validate()

            #Logging
            train_loss, train_acc, train_f1 = self.__log(train_loss, preds, truths, val_loss, val_acc, val_f1)
            tqdm.write("Epoch {} Complete:\n Train: loss={}, acc={}, F1={}\n Val  : loss={}, acc={}, F1={}\n".format(
                        i, train_loss, train_acc, train_f1, val_loss, val_acc, val_f1))
            
            #Save model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(),
                           self.model_folder + file_name_root + '_model.pt')
            
            #LR Annealing
            if self.to_anneal_lr:
                #account for pytorch versions >= 2.2
                if not hasattr(self.scheduler, 'verbose') and i > 1:
                    last_lr = self.scheduler.get_last_lr()
                    
                self.scheduler.step(val_loss)
                
                #account for pytorch versions >= 2.2
                if not hasattr(self.scheduler, 'verbose') and i > 1 and last_lr != self.scheduler.get_last_lr():
                    print("LR Reduced from {} to {} for next epoch onwards".format(last_lr, self.scheduler.get_last_lr()))
                    
        #Save plots
        if self.plot_loss_acc:
            self.plot_metrics()

        return self.model_folder + file_name_root + '_model.pt', self.plots_folder + file_name_root + '_plot.png'

    def validate(self):
        with torch.no_grad():
            losses = []
            preds = []
            truths = []
            for X, y, num_sent, sent_len in self.val_loader:
                #Move to correct device
                X = X.to(DEVICE)
                y = y.to(DEVICE)

                #Forward pass
                outputs = self.model(X, num_sent, sent_len, return_attn_weights=False)
                loss = self.loss_fn(outputs, y)

                #Logging
                losses.append(loss.item())
                preds.append(torch.argmax(outputs, dim=-1).cpu())
                truths.append(y.cpu())
        preds = torch.cat(preds)
        truths = torch.cat(truths)
        return sum(losses)/len(losses), accuracy_score(truths, preds), f1_score(truths, preds, average='macro')

    def plot_metrics(self):
        fig, axs = plt.subplots(1, 3, figsize=(15,5))
        axs[0].plot(range(1, self.num_epochs + 1), self.train_metrics['loss'], color='b', label='Train')
        axs[0].plot(range(1, self.num_epochs + 1), self.val_metrics['loss'], color='r', label='Validation')
        axs[0].set_ylabel('Loss')
        axs[0].set_xlabel('Epochs')
        axs[0].set_ylim(bottom=0)
        axs[0].set_xticks(range(0, self.num_epochs + 1, 2))
        axs[0].grid(visible=True, which='major', axis='both')

        axs[1].plot(range(1, self.num_epochs + 1), self.train_metrics['accuracy'], color='b', label='Train')
        axs[1].plot(range(1, self.num_epochs + 1), self.val_metrics['accuracy'], color='r', label='Validation')
        axs[1].set_ylabel('Accuracy')
        axs[1].set_xlabel('Epochs')
        axs[1].set_ylim(bottom=0)
        axs[1].set_xticks(range(0, self.num_epochs + 1, 2))
        axs[1].grid(visible=True, which='major', axis='both')

        axs[2].plot(range(1, self.num_epochs + 1), self.train_metrics['f1'], color='b', label='Train')
        axs[2].plot(range(1, self.num_epochs + 1), self.val_metrics['f1'], color='r', label='Validation')
        axs[2].set_ylabel('F1 (Macro-averaged)')
        axs[2].set_xlabel('Epochs')
        axs[2].set_ylim(bottom=0)
        axs[2].set_xticks(range(0, self.num_epochs + 1, 2))
        axs[2].grid(visible=True, which='major', axis='both')

        fig.legend(*axs[2].get_legend_handles_labels(), loc='upper center')
        fig.savefig(self.plots_folder + self.file_name_root + '_plot.png')
        plt.show()
        plt.close()

    def __set_up_logging(self, save_folder_root_path, file_name_root, plot_loss_acc):
        self.model_folder = save_folder_root_path+'model/'
        self.plots_folder = save_folder_root_path+'plots/'
        self.file_name_root = file_name_root
        self.plot_loss_acc = plot_loss_acc
        if not os.path.exists(self.model_folder):
            os.makedirs(self.model_folder)
        if self.plot_loss_acc and not os.path.exists(self.plots_folder):
            os.makedirs(self.plots_folder)
        self.train_metrics = {'loss':[], 'accuracy':[], 'f1':[]}
        self.val_metrics = {'loss':[], 'accuracy':[], 'f1':[]}
        return
    
    def __log(self, train_loss, train_preds, train_truths, val_loss, val_acc, val_f1):
        train_loss = sum(train_loss)/len(train_loss)
        preds = torch.cat(train_preds)
        truths = torch.cat(train_truths)
        train_acc = accuracy_score(truths, preds)
        train_f1 = f1_score(truths, preds, average='macro')

        self.train_metrics['loss'].append(train_loss)
        self.train_metrics['accuracy'].append(train_acc)
        self.train_metrics['f1'].append(train_f1)
        self.val_metrics['loss'].append(val_loss)
        self.val_metrics['accuracy'].append(val_acc)
        self.val_metrics['f1'].append(val_f1)

        return train_loss, train_acc, train_f1

## Model Architecture

In [44]:
class AttentionUnit(nn.Module):
    def __init__(self, input_dim,, num_outputs=1, attn_dropout=0.0):
        super(AttentionUnit, self).__init__()
        self.hidden = nn.Linear(input_dim, input_dim)
        self.query = nn.Linear(input_dim, num_outputs, bias=False)
        
    def forward(self, encoder_output, padding_positions=None, return_weights=False):
        # [B,L,H]-->[B,L,H]
        hidden_rep = F.tanh(self.hidden(encoder_output))
        
        # [B,L,H]-->[B,L,1]
        similarity = self.query(hidden_rep)
        if padding_positions is not None:
            similarity = similarity.masked_fill(padding_positions, -float('inf'))
        attention_weights = F.softmax(similarity, dim=1)
        
        #Return weighted sum [B,L,1], [B,L,H]-->[B,H]
        if return_weights:
            return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1), attention_weights
        return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1)

class BiLSTMHeAttFCNNClassifier(nn.Module):
    '''
    Classifier that uses heirarchical attention to encode a document and 
    a Fully-Connected Neural Network(FCNN) as a decoder.

    '''
    def __init__(self, vocab_len, embed_dim, hidden_dim, num_lstm_layers, num_classes, attn_dropout=0.0, pretrained_embeddings=None, freeze_embeds=False):
        super(BiLSTMHeAttFCNNClassifier, self).__init__()
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=freeze_embeds)
        else:
            self.embedding = nn.Embedding(num_embeddings=vocab_len, embedding_dim=embed_dim)
        
        self.word_encoder = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        self.word_attn = AttentionUnit(2*hidden_dim)
        
        self.sent_encoder = nn.LSTM(input_size=2*hidden_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        self.sent_attn = AttentionUnit(2*hidden_dim)
        
        self.decoder = nn.Linear(2*hidden_dim, num_classes)


    def forward(self, X_batch, num_sents, sent_lens, return_attn_weights=False):
        max_sent_len = X_batch.shape[2]
        max_num_sent = X_batch.shape[1]
        
        # Use word embeddings to form sentence embeddings
        word_attn_weights = []
        docs = []
        for doc, n, lens in zip(X_batch, num_sents, sent_lens):
            words_batch = doc[:n]
            embeddings = self.embedding(words_batch)
            output, (_, _) = self.word_encoder(embeddings)
            padding_positions = self.__get_padding_masks(lens[:n], max_sent_len).to(output.device)
            sent_embeddings = self.word_attn(output, padding_positions=padding_positions, return_weights=return_attn_weights)
            if return_attn_weights:
                word_attn_weights.append(sent_embeddings[1])
                sent_embeddings = sent_embeddings[0]
            sent_embeddings = self.__repad_sentence_embeddings(sent_embeddings, max_num_sent)
            docs.append(sent_embeddings)
        
        # Use sentence embeddings to form document embedding
        sent_embeddings_batch = torch.stack(docs) 
        output, (_, _) = self.sent_encoder(sent_embeddings_batch)
        padding_positions = self.__get_padding_masks(num_sents, max_num_sent).to(output.device)
        doc_embeddings = self.word_attn(output, padding_positions=padding_positions, return_weights=return_attn_weights)
        # Pass document embedding through output layer
        if return_attn_weights:
            return self.decoder(doc_embeddings[0]), word_attn_weights, doc_embeddings[1]
        else:
            return self.decoder(doc_embeddings)
        
    def __repad_sentence_embeddings(self, sents, max_num_sent):
        return torch.cat([sents,
                          torch.zeros((max_num_sent-sents.shape[0], 
                                       sents.shape[1]), device=sents.device)],dim=0)
    
    def __get_padding_masks(self, lengths, max_len):
        '''
        Returns a mask (shape BxLx1) that indicates the position of pad tokens as '1's
        '''
        return torch.tensor([[False]*i + [True]*(max_len-i) for i in lengths]).unsqueeze(2)

## Loading & Preprocessing the data

The cell below runs the preprocessing and saves the preprocessed items to the specified `SAVED_FOLDER_PATH`. As the preprocessing takes quite long, users are advised to skip this cell and use the already preprocessed inputs. Note:

- `X` and `X_test` are numpy arrays of size `[NUM_DOCS, MAX_NUM_SENTS, MAX_SENT_LEN]`. Each entry is the corresponding token's index in the preprocessor's vocabulary. The (i,j,k)-th entry corresponds to the k-th token of the j-th sentence of the i-th document. 
- `ylens` and `ylens_test` are dataframes with `NUM_DOCS` rows and 3 columns: `Label`, `Num_Sentences` and `Num_Tokens`

In [5]:
###### Uncomment this cell to run data preprocessing for HAN ############
# SAVE_FOLDER_PATH = './HAN_prepro_data/'
# FULL_TRAIN_DATA_PATH = '/kaggle/input/lun-glove/fulltrain.csv'
# FULL_TEST_DATA_PATH = '/kaggle/input/lun-glove/balancedtest.csv'
# GLOVE_TEXT_FILE_PATH = '/kaggle/input/lun-glove/glove.6B.100d.txt'

# MAX_SENT_LEN = 30
# MAX_NUM_SENTS = 30
# NUM_CLASSES = 4
# EMBED_DIM = 100

# glovepp, embeds = DataPreprocessorHcl.from_pretrained_embeds(NUM_CLASSES, GLOVE_TEXT_FILE_PATH, EMBED_DIM)
# train_df = pd.read_csv(FULL_TRAIN_DATA_PATH, header=None, names=['Label', 'Text'])
# test_df = pd.read_csv(FULL_TEST_DATA_PATH, header=None, names=['Label', 'Text'])
# X, ylens = glovepp.preprocess_data(train_df, MAX_SENT_LEN, MAX_NUM_SENTS)
# X_test, ylens_test = glovepp.preprocess_data(test_df, MAX_SENT_LEN, MAX_NUM_SENTS)

# VOCAB_LEN = len(glovepp.vocab)

# np.save(SAVE_FOLDER_PATH+'X_train_prep.npy',X)
# np.save(SAVE_FOLDER_PATH+'X_test_prep.npy',X_test)
# ylens.to_csv(SAVE_FOLDER_PATH+'ylens_train_prep.csv', index=False)
# ylens_test.to_csv(SAVE_FOLDER_PATH+'ylens_test_prep.csv', index=False)
# np.save(SAVE_FOLDER_PATH+'glove_embs.npy',embeds)

In [3]:
MAX_SENT_LEN = 30
MAX_NUM_SENTS = 30
NUM_CLASSES = 4
EMBED_DIM = 100
VOCAB_SIZE = 400001 #hardcoded for convenience; see prev cell for how it was obtained

X = np.load('./HAN_prepro_data/X_train_prep.npy')
ylens = pd.read_csv('./HAN_prepro_data/ylens_train_prep.csv')
X_test = np.load('./HAN_prepro_data/X_test_prep.npy')
ylens_test = pd.read_csv('./HAN_prepro_data/ylens_test_prep.csv')
embeds = torch.tensor(np.load('./glove_embs.npy'))

import ast
ylens['Num_Tokens'] = ylens['Num_Tokens'].apply(ast.literal_eval)
ylens_test['Num_Tokens'] = ylens_test['Num_Tokens'].apply(ast.literal_eval)

X_train, X_val, ylens_train, ylens_val = train_test_split(X, ylens, test_size=0.2, random_state=42)

In [4]:
def check_for_bugs(ylens, num_classes, max_num_sent, max_sent_len):
    if (ylens['Label'] < num_classes).all():
        print("Num classes correct")
    else:
        assert False
    if (ylens['Num_Sentences'] <= max_num_sent).all():
        print("Num sentences correct")
    else:
        assert False
    if (ylens['Num_Tokens'].apply(lambda ls : all([le <= max_sent_len for le in ls]))).all():
        print("Num tokenss correct")
    else:
        assert False
    return
check_for_bugs(ylens, NUM_CLASSES, MAX_NUM_SENTS, MAX_SENT_LEN)
check_for_bugs(ylens_test, NUM_CLASSES, MAX_NUM_SENTS, MAX_SENT_LEN)

Num classes correct
Num sentences correct
Num tokenss correct
Num classes correct
Num sentences correct
Num tokenss correct


In [54]:
TRAIN_BATCH_SIZE = 256
VALID_BATCH_SIZE = 512
train_loader = DataLoader(WrapperDatasetHcl(X_train, ylens_train),
                          batch_size=TRAIN_BATCH_SIZE,
                          collate_fn=collate_fnHcl,
                          shuffle=True)

val_loader = DataLoader(WrapperDatasetHcl(X_val, ylens_val),
                          batch_size=VALID_BATCH_SIZE,
                          collate_fn=collate_fnHcl,
                          shuffle=False)

test_loader = DataLoader(WrapperDatasetHcl(X_test, ylens_test),
                          batch_size=VALID_BATCH_SIZE,
                          collate_fn=collate_fnHcl,
                          shuffle=False)

## Training

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
HIDDEN_DIM = 100
NUM_LSTM_LAYERS = 1

NUM_EPOCHS = 10
LEARNING_RATE = 5e-04
WEIGHT_DECAY = 5e-06
LR_ANNEAL_FACTOR = 0.5
LR_ANNEAL_PATIENCE = 2

RANDOM_SEED = 42

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAVE_FOLDER_PATH="./outputs/"
FILES_NAME_FORMAT="bestHAN_msl{}_mns{}_ba{}_emb{}hid{}lay{}cla{}_ep{}lr{}wd{}_af{}ap{}"

records = {"path": [], "precision":[], "recall": [], "f1": [], "acc":[]}

torch.manual_seed(RANDOM_SEED)
model = BiLSTMHeAttFCNNClassifier(VOCAB_SIZE,
                              EMBED_DIM,
                              HIDDEN_DIM,
                              NUM_LSTM_LAYERS,
                              NUM_CLASSES,
                              pretrained_embeddings = embeds.to(torch.float32))
trainer = Trainer(model, train_loader, val_loader,
                  NUM_EPOCHS, LEARNING_RATE,
                  weight_decay=WEIGHT_DECAY,
                  lr_anneal_factor=LR_ANNEAL_FACTOR,
                  lr_anneal_patience=LR_ANNEAL_PATIENCE
                  save_loss_acc_plots=True)
model_path, _ = trainer.train(
    SAVE_FOLDER_PATH,
    FILES_NAME_FORMAT.format(MAX_SENT_LEN,MAX_NUM_SENTS, TRAIN_BATCH_SIZE, EMBED_DIM,
                             HIDDEN_DIM, NUM_LSTM_LAYERS, NUM_CLASSES,
                             NUM_EPOCHS, LEARNING_RATE, WEIGHT_DECAY,
                             LR_ANNEAL_FACTOR, LR_ANNEAL_PATIENCE),
    verbose=True)

## Evaluation

In [55]:
from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
MODEL_PATH='./outputs/model/bestHAN_msl30_mns30_ba256_emb100hid100lay1cla4_ep10lr0.0005wd5e-06_af0.5_ap2_model'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BiLSTMHeAttFCNNClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_LSTM_LAYERS,
                                  NUM_CLASSES, pretrained_embeddings = embeds.to(torch.float32))
model.to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))

records = {'split':[], 'acc':[],'f1':[],'precision':[], 'recall':[]}
for i, loader in enumerate([train_loader, val_loader, test_loader]):
    
        records['split'].append('train' if i == 0 else 'val' if i==1 else 'test')
        
        preds = []
        truths = []
        model.eval()
        with torch.no_grad():
            for X, y, num_sent, sent_len in loader:
                #Move to correct device
                X = X.to(DEVICE)

                #Forward pass
                outputs = model(X, num_sent, sent_len)

                #Logging
                preds.append(torch.argmax(outputs, dim=-1).cpu())
                truths.append(y)
        preds = torch.cat(preds)
        truths = torch.cat(truths)
        records['acc'].append(accuracy_score(truths, preds))
        records['f1'].append(f1_score(truths, preds, average='macro'))
        records['precision'].append(precision_score(truths, preds, average='macro'))
        records['recall'].append(recall_score(truths, preds, average='macro'))
        if i == 0:
            print("TRAINING DATA")
        elif i == 1:
            print("VALIDATION DATA")
        elif i == 2:
            print("TEST DATA")
        print(classification_report(truths, preds))

  X = torch.tensor(X, dtype=torch.long)


TEST DATA
              precision    recall  f1-score   support

           0       0.93      0.78      0.85       750
           1       0.77      0.54      0.64       750
           2       0.64      0.79      0.71       750
           3       0.75      0.92      0.83       750

    accuracy                           0.76      3000
   macro avg       0.77      0.76      0.76      3000
weighted avg       0.77      0.76      0.76      3000



# 2. Flat Attention Network

## Data-Related Classes and Functions; Trainer Class

In [2]:
class DataPreprocessorFlat():

    def __init__(self, num_classes, data_vocab):
        self.num_classes = num_classes
        self.vocab = data_vocab
        print("Vocab created: {} unique tokens".format(len(self.vocab)))
        
    @classmethod
    def from_train_df(cls, num_classes, train_df, specials=['<unk>']):
        data_vocab = build_vocab_from_iterator(cls.__yield_tokens(train_df), specials=specials)
        data_vocab.set_default_index(data_vocab['<unk>'])
        return cls(num_classes, data_vocab)
    
    @classmethod
    def from_pretrained_embeds(cls, num_classes, embed_path, embed_dim, sep=" ",  specials=['<unk>']):
        # start with all '0's for special tokens
        embeds = [np.asarray([0]*embed_dim, dtype=np.float32)]*len(specials)
        words = OrderedDict()
        with open(embed_path, encoding="utf-8") as f:
            for i, line in enumerate(f):
                if i == 38522 and 'twitter.27B.100d' in embed_path:
                    continue
                splitline = line.split()
                
                word = splitline[0]
                words[word] = 1
                
                embeds.append(np.asarray(splitline[1:], dtype=np.float32))
                
        embeds = torch.tensor(np.array(embeds))
        data_vocab = vocab(words, min_freq=1, specials=specials)
        data_vocab.set_default_index(data_vocab['<unk>'])
        return cls(num_classes, data_vocab), embeds

    @classmethod
    def __yield_tokens(cls, df):
        for row in df.itertuples(index=False):
            yield word_tokenize(row.Text.replace("'","").lower())

    def get_vocab_size(self):
        return len(self.vocab)
    
    def preprocess_data_flat(self, df, clean=True):
        '''
        Converts text into integers that index the vocab,
        and converts labels into the range [0,num_classes-1]
        '''
        if clean:
            X = df['Text'].apply(lambda t: [self.vocab(word_tokenize(s.lower())) for s in sent_tokenize(t.replace("'",""))])
            X = X.apply(lambda ls_of_ls: [token for ls in ls_of_ls for token in ls])
        else:
            X = df['Text'].apply(lambda t: self.vocab(word_tokenize(t)))

        if self.num_classes == 4:
            y = df['Label'].apply(lambda l: l-1)
        else: #num_classes == 2
            # assume test.xlsx
            y = df['Label']
        return X, y
    
class WrapperDatasetFlat(Dataset):
    '''
    Wrapper for use with dataloader
    '''
    def __init__(self, X, y):
        self.X = X
        self.y = y
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X.iloc[idx], self.y.iloc[idx]
                 
def make_flat_collate_function(model_max_len):
    '''
    Returns a function that pads and truncates sequences
    in each batch up to min(batch's max length, model_max_len).
    '''
    def collate(batch):
        '''
        Returns the batch of padded sequences
        and corresponding lengths and labels
        '''
        batch_max_len = max([len(tids) for tids, _ in batch])
        output_len = min(model_max_len, batch_max_len)

        X_padded = torch.zeros((len(batch), output_len), dtype=torch.long)
        lengths = torch.empty((len(batch)), dtype=torch.long)
        labels = torch.empty((len(batch)), dtype=torch.long)
        for i, (tids, label) in enumerate(batch):
            if len(tids) > output_len:
                # sequence longer than output_len --> truncate
                X_padded[i, :] = torch.tensor(tids[:output_len], dtype=torch.long)
                lengths[i] = output_len
            else:
                # sequence shorter than output_len --> pad
                X_padded[i, :len(tids)] = torch.tensor(tids, dtype=torch.long)
                lengths[i] = len(tids)
            labels[i] = label
        return X_padded, lengths, labels
                 
    return collate

class FlatTrainer():
    def __init__(self, model, train_loader, val_loader, num_epochs, lr, weight_decay = 0,
               lr_anneal_factor=None, lr_anneal_patience=None,
               loss_weights=None,
               save_loss_acc_plots=True):
        #Data
        self.train_loader = train_loader
        self.val_loader = val_loader

        #Model
        self.model=model.to(DEVICE)

        #Training
        self.num_epochs = num_epochs
        self.loss_fn = nn.CrossEntropyLoss()
        if loss_weights is not None:
            self.loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(loss_weights, dtype=torch.float32, device=DEVICE))
        self.optimizer = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.to_anneal_lr = lr_anneal_factor and lr_anneal_patience
        if self.to_anneal_lr:
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                             factor=lr_anneal_factor,
                                             patience=lr_anneal_patience)

    def train(self, save_folder_root_path, file_name_root,
            plot_loss_acc=True, verbose=True):
        #account for pytorch versions < 2.2
        if hasattr(self.scheduler, 'verbose'):
            self.scheduler.verbose=verbose
            
        #Set up logging
        self.model_folder = save_folder_root_path+'model/'
        self.plots_folder = save_folder_root_path+'plots/'
        self.file_name_root = file_name_root
        self.plot_loss_acc = plot_loss_acc
        if not os.path.exists(self.model_folder):
            os.makedirs(self.model_folder)
        if self.plot_loss_acc and not os.path.exists(self.plots_folder):
            os.makedirs(self.plots_folder)
        self.train_metrics = {'loss':[], 'accuracy':[], 'f1':[]}
        self.val_metrics = {'loss':[], 'accuracy':[], 'f1':[]}
        best_val_loss = float('inf')
        for i in range(1, self.num_epochs+1):
            self.model.train()
            train_loss = []
            preds = []
            truths = []
            for X, lengths, y in tqdm(self.train_loader, disable = not verbose):
                #Move to correct device
                X = X.to(DEVICE)
                y = y.to(DEVICE)

                #Forward pass
                outputs = self.model(X, lengths)
                if type(outputs)==tuple:
                    logits = outputs[0]
                else:
                    logits = outputs
                loss = self.loss_fn(logits, y)

                #Backprop
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                #Logging
                train_loss.append(loss.item())
                preds.append(torch.argmax(logits, dim=-1).cpu())
                truths.append(y.cpu())

            #Validation
            self.model.eval()
            val_loss, val_acc, val_f1 = self.validate()

            #Logging
            train_loss = sum(train_loss)/len(train_loss)
            preds = torch.cat(preds)
            truths = torch.cat(truths)
            train_acc = accuracy_score(truths, preds)
            train_f1 = f1_score(truths, preds, average='macro')

            self.train_metrics['loss'].append(train_loss)
            self.train_metrics['accuracy'].append(train_acc)
            self.train_metrics['f1'].append(train_f1)
            self.val_metrics['loss'].append(val_loss)
            self.val_metrics['accuracy'].append(val_acc)
            self.val_metrics['f1'].append(val_f1)
            
            if verbose:
                tqdm.write("Epoch {} Complete:\n Train: loss={}, acc={}, F1={}\n Val  : loss={}, acc={}, F1={}\n".format(
                            i, train_loss, train_acc, train_f1, val_loss, val_acc, val_f1))
            #Save model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(),
                           self.model_folder + file_name_root + '_model.pt')
            
            ##LR Annealing
            if self.to_anneal_lr:
                #account for pytorch versions >= 2.2
                if not hasattr(self.scheduler, 'verbose') and i > 1:
                    last_lr = self.scheduler.get_last_lr()
                    
                self.scheduler.step(val_loss)
                
                #account for pytorch versions >= 2.2
                if not hasattr(self.scheduler, 'verbose') and i > 1 and last_lr != self.scheduler.get_last_lr():
                    print("LR Reduced from {} to {} for next epoch onwards".format(last_lr, self.scheduler.get_last_lr()))
                    
        #Save plots
        if self.plot_loss_acc:
            self.plot_metrics(self.train_metrics, self.val_metrics)

        return self.model_folder + file_name_root + '_model.pt', self.model_folder + file_name_root + 'plot.pt'

    def validate(self, test_loader=None):
        with torch.no_grad():
            losses = []
            preds = []
            truths = []
            val_loader = self.val_loader if test_loader is None else test_loader
            for X, lengths, y in val_loader:
                #Move to correct device
                X = X.to(DEVICE)
                y = y.to(DEVICE)

                #Forward pass
                outputs = self.model(X, lengths)
                
                if type(outputs)==tuple:
                    logits = outputs[0]
                else:
                    logits = outputs
                loss = self.loss_fn(logits, y)

                #Logging
                losses.append(loss.item())
                preds.append(torch.argmax(logits, dim=-1).cpu())
                truths.append(y.cpu())
        preds = torch.cat(preds)
        truths = torch.cat(truths)
        return sum(losses)/len(losses), accuracy_score(truths, preds), f1_score(truths, preds, average='macro')

    def plot_metrics(self, train_metrics, val_metrics):
        fig, axs = plt.subplots(1, 3, figsize=(15,5))
        axs[0].plot(range(1, self.num_epochs + 1), train_metrics['loss'], color='b', label='Train')
        axs[0].plot(range(1, self.num_epochs + 1), val_metrics['loss'], color='r', label='Validation')
        axs[0].set_ylabel('Loss')
        axs[0].set_xlabel('Epochs')
        axs[0].set_ylim(bottom=0)
        axs[0].set_xticks(range(0, self.num_epochs + 1, 2))
        axs[0].grid(visible=True, which='major', axis='both')

        axs[1].plot(range(1, self.num_epochs + 1), train_metrics['accuracy'], color='b', label='Train')
        axs[1].plot(range(1, self.num_epochs + 1), val_metrics['accuracy'], color='r', label='Validation')
        axs[1].set_ylabel('Accuracy')
        axs[1].set_xlabel('Epochs')
        axs[1].set_ylim(bottom=0)
        axs[1].set_xticks(range(0, self.num_epochs + 1, 2))
        axs[1].grid(visible=True, which='major', axis='both')

        axs[2].plot(range(1, self.num_epochs + 1), train_metrics['f1'], color='b', label='Train')
        axs[2].plot(range(1, self.num_epochs + 1), val_metrics['f1'], color='r', label='Validation')
        axs[2].set_ylabel('F1 (Macro-averaged)')
        axs[2].set_xlabel('Epochs')
        axs[2].set_ylim(bottom=0)
        axs[2].set_xticks(range(0, self.num_epochs + 1, 2))
        axs[2].grid(visible=True, which='major', axis='both')

        fig.legend(*axs[2].get_legend_handles_labels(), loc='upper center')
        fig.savefig(self.plots_folder + self.file_name_root + '_plot.png')
        plt.show()
        plt.close()


## Model Architecture

In [38]:
class AttentionUnit(nn.Module):
    def __init__(self, input_dim, hidden_dim=None, num_outputs=1, attn_dropout=0.0):
        super(AttentionUnit, self).__init__()
        if hidden_dim is None:
            hidden_dim = input_dim
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.query = nn.Linear(hidden_dim, num_outputs, bias=False)
    def forward(self, encoder_output, padding_positions=None, return_weights=False):
        # [B,L,H]-->[B,L,H]
        hidden_rep = F.tanh(self.hidden(encoder_output))
        # [B,L,H]-->[B,L,1]
        similarity = self.query(hidden_rep)
        if padding_positions is not None:
            similarity = similarity.masked_fill(padding_positions, -float('inf'))
        attention_weights = F.softmax(similarity, dim=1)
        #Return weighted sum [B,L,1], [B,L,H]-->[B,H]
        if return_weights:
            return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1), attention_weights
        return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1)

class LSTMFlatAttentionFCNNClassifier(torch.nn.Module):
    '''
    Classifier that uses an LSTM as an encoder followed by an attention block
    and a Fully-Connected Neural Network(FCNN) as a decoder.
    '''
    def __init__(self, vocab_len, embed_dim, hidden_dim, num_lstm_layers, num_classes, attn_dropout=0.0, pretrained_embeddings=None, freeze_embeds=False):
        super(LSTMFlatAttentionFCNNClassifier, self).__init__()
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=freeze_embeds)
        else:
            self.embedding = nn.Embedding(num_embeddings=vocab_len, embedding_dim=embed_dim)

        self.encoder = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        self.attn = AttentionUnit(2*hidden_dim)
        self.decoder = nn.Linear(2*hidden_dim, num_classes)

    def forward(self, X_batch, lengths, return_attn_weights=False):
        embeddings = self.embedding(X_batch)

        embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, lengths.cpu(), enforce_sorted=False, batch_first=True)
        output, (_, _) = self.encoder(embeddings)
        output, _ = nn.utils.rnn.pad_packed_sequence(output,batch_first=True)

        padding_positions = self.__get_padding_masks(lengths).to(output.device)
        doc_embeddings = self.attn(output,padding_positions=padding_positions,return_weights=return_attn_weights)
        
        if return_attn_weights:
            return self.decoder(doc_embeddings[0]), doc_embeddings[1]
        else:
            return self.decoder(doc_embeddings)
    
    def __get_padding_masks(self, lengths):
        '''
        Returns a mask (shape BxLx1) that indicates the position of pad tokens
        '''
        max_len = lengths.max()
        return torch.tensor([[False]*i + [True]*(max_len-i) for i in lengths]).unsqueeze(2)

## Loading and Preprocessing the data

In [None]:
# train_df = pd.read_csv('/kaggle/input/lun-glove/fulltrain.csv', header=None, names=['Label','Text'])
# test_df = pd.read_excel('/kaggle/input/lun-glove/balancedtest.csv', header=None, names=['Label','Text'])

# dp, embeds = DataPreprocessorFlat.from_pretrained_embeds(4, '/kaggle/input/lun-glove/glove.6B.100d.txt', 100, sep=" ",  specials=['<unk>'])
# X, y = dp.preprocess_data_flat(train_df, clean=True)
# X_test, y_test = dp.preprocess_data_flat(test_df, clean=True)
# X.to_parquet('./FAN_prepro_data/X_train_prep_flat.parquet')
# y.to_parquet('./FAN_prepro_data/y_train_prep_flat.parquet')
# X_test.to_parquet('./FAN_prepro_data/X_test_prep_flat.parquet')
# y_test.to_parquet('./FAN_prepro_data/y_test_prep_flat.parquet')

# VOCAB_LEN = len(dp.vocab)

In [39]:
X = pd.read_parquet('./FAN_prepro_data/X_train_prep_flat.parquet')['Text']
y = pd.read_parquet('./FAN_prepro_data/y_train_prep_flat.parquet')['Label']
X_test = pd.read_parquet('./FAN_prepro_data/X_test_prep_flat.parquet')['Text']
y_test = pd.read_parquet('./FAN_prepro_data/y_test_prep_flat.parquet')['Label']
embeds = torch.tensor(np.load('./glove_embs.npy'))
VOCAB_SIZE = 400001

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

In [40]:
NUM_CLASSES = 4
TRAIN_BATCH_SIZE = 256
VALID_BATCH_SIZE = 512
MODEL_MAX_LEN = 500
collate_flat = make_flat_collate_function(MODEL_MAX_LEN)
train_loader = DataLoader(WrapperDatasetFlat(X_train, y_train),
                          batch_size=TRAIN_BATCH_SIZE,
                          collate_fn=collate_flat,
                          shuffle=True)

val_loader = DataLoader(WrapperDatasetFlat(X_val, y_val),
                          batch_size=VALID_BATCH_SIZE,
                          collate_fn=collate_flat,
                          shuffle=False)

test_loader = DataLoader(WrapperDatasetFlat(X_test, y_test),
                          batch_size=VALID_BATCH_SIZE,
                          collate_fn=collate_flat,
                          shuffle=False)

## Training

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
EMBED_DIM = 100
HIDDEN_DIM = 100
NUM_LSTM_LAYERS = 1

NUM_EPOCHS = 10
LEARNING_RATE = 5e-04
WEIGHT_DECAY = 5e-06
LR_ANNEAL_FACTOR = 0.5
LR_ANNEAL_PATIENCE = 2

RANDOM_SEED = 42

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAVE_FOLDER_PATH="./outputs/"
FILES_NAME_FORMAT="bestFAN_ml{}_ba{}_emb{}hid{}lay{}cla{}_ep{}lr{}wd{}_af{}ap{}"
records = {"path": [], "precision":[], "recall": [], "f1": [], "acc":[]}


torch.manual_seed(RANDOM_SEED)
model = LSTMFlatAttentionFCNNClassifier(VOCAB_SIZE,
                              EMBED_DIM,
                              HIDDEN_DIM,
                              NUM_LSTM_LAYERS,
                              NUM_CLASSES,
                              pretrained_embeddings = embeds.to(torch.float32))
trainer = FlatTrainer(model, train_loader, val_loader,
                  NUM_EPOCHS, LEARNING_RATE,
                  weight_decay=WEIGHT_DECAY,
                  lr_anneal_factor=LR_ANNEAL_FACTOR,
                  lr_anneal_patience=LR_ANNEAL_PATIENCE,
                  loss_weights=LOSS_WEIGHTS,
                  save_loss_acc_plots=True)
model_path, _ = trainer.train(SAVE_FOLDER_PATH,
              FILES_NAME_FORMAT.format(MODEL_MAX_LEN, TRAIN_BATCH_SIZE, EMBED_DIM,
                                       HIDDEN_DIM, NUM_LSTM_LAYERS, NUM_CLASSES,
                                      NUM_EPOCHS, LEARNING_RATE, WEIGHT_DECAY,
                                      LR_ANNEAL_FACTOR, LR_ANNEAL_PATIENCE),
              verbose=True)

## Evaluation

In [43]:
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
base="./outputs/model/bestFAN_ml500_ba256_emb100hid100lay1cla4_ep10lr0.0005wd5e-06_af0.5ap2_model.pt"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
records = {'split':[], 'acc':[],'f1':[],'precision':[], 'recall':[]}
for i, loader in enumerate([train_loader, val_loader, test_loader]):
        model = LSTMFlatAttentionFCNNClassifier(VOCAB_SIZE, 100, 100, 1, 4,
                                                pretrained_embeddings = embeds.to(torch.float32))
        model.to(DEVICE)
        model.load_state_dict(torch.load(base, map_location=DEVICE))
        preds = []
        truths = []
        model.eval()
        with torch.no_grad():
            for X, lengths, y in loader:
                #Move to correct device
                X = X.to(DEVICE)

                #Forward pass
                outputs = model(X, lengths)

                #Logging
                preds.append(torch.argmax(outputs, dim=-1).cpu())
                truths.append(y)
        preds = torch.cat(preds)
        truths = torch.cat(truths)
        records['split'].append('train' if i==0 else 'val' if i==1 else 'test')
        records['acc'].append(accuracy_score(truths, preds))
        records['f1'].append(f1_score(truths, preds, average='macro'))
        records['precision'].append(precision_score(truths, preds, average='macro'))
        records['recall'].append(recall_score(truths, preds, average='macro'))
        if i == 0:
            print("TRAINING DATA")
        elif i == 1:
            print("VALIDATION DATA")
        elif i == 2:
            print("TEST DATA")
        print(classification_report(truths, preds))

TEST DATA
              precision    recall  f1-score   support

           0       0.92      0.73      0.81       750
           1       0.78      0.46      0.58       750
           2       0.58      0.72      0.64       750
           3       0.67      0.93      0.78       750

    accuracy                           0.71      3000
   macro avg       0.74      0.71      0.70      3000
weighted avg       0.74      0.71      0.70      3000

