In [1]:
!pip install torchinfo

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, GloVe, vocab
from tqdm import tqdm
from nltk.tokenize import sent_tokenize, word_tokenize
from collections import OrderedDict



# 1. HAN Analysis

This notebook corresponds to the discussion section of the report. First, define the model:

In [2]:
class AttentionUnit(nn.Module):
    def __init__(self, input_dim, hidden_dim=None, num_outputs=1, attn_dropout=0.0):
        super(AttentionUnit, self).__init__()
        if hidden_dim is None:
            hidden_dim = input_dim
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.query = nn.Linear(hidden_dim, num_outputs, bias=False)
        
    def forward(self, encoder_output, padding_positions=None, return_weights=False):
        # [B,L,I]-->[B,L,H]
        hidden_rep = F.tanh(self.hidden(encoder_output))
        
        # [B,L,H]-->[B,L,1]
        similarity = self.query(hidden_rep)
        if padding_positions is not None:
            similarity = similarity.masked_fill(padding_positions, -float('inf'))
        attention_weights = F.softmax(similarity, dim=1)
        
        #Return weighted sum [B,L,1], [B,L,H]-->[B,H]
        if return_weights:
            return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1), attention_weights
        return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1)

class BiLSTMHeAttFCNNClassifier(nn.Module):
    '''
    Classifier that uses heirarchical attention to encode a document and 
    a Fully-Connected Neural Network(FCNN) as a decoder.

    '''
    def __init__(self, vocab_len, embed_dim, hidden_dim, num_lstm_layers, num_classes, attn_dropout=0.0, pretrained_embeddings=None, freeze_embeds=False):
        super(BiLSTMHeAttFCNNClassifier, self).__init__()
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=freeze_embeds)
        else:
            self.embedding = nn.Embedding(num_embeddings=vocab_len, embedding_dim=embed_dim)
        
        self.word_encoder = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        self.word_attn = AttentionUnit(2*hidden_dim)
        
        self.sent_encoder = nn.LSTM(input_size=2*hidden_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        self.sent_attn = AttentionUnit(2*hidden_dim)
        
        self.decoder = nn.Linear(2*hidden_dim, num_classes)


    def forward(self, X_batch, num_sents, sent_lens, return_attn_weights=False):
        '''
        Returns logits if return_attn_weights is False,
        else returns (logits, word-level attention weights, sentence-level attention weights)
        '''
        max_sent_len = X_batch.shape[2]
        max_num_sent = X_batch.shape[1]
        
        # Use word embeddings to form sentence embeddings
        word_attn_weights = []
        docs = []
        for doc, n, lens in zip(X_batch, num_sents, sent_lens):
            words_batch = doc[:n]
            embeddings = self.embedding(words_batch)
            output, (_, _) = self.word_encoder(embeddings)
            padding_positions = self.__get_padding_masks(lens[:n], max_sent_len).to(output.device)
            sent_embeddings = self.word_attn(output, padding_positions=padding_positions, return_weights=return_attn_weights)
            if return_attn_weights:
                word_attn_weights.append(sent_embeddings[1])
                sent_embeddings = sent_embeddings[0]
            sent_embeddings = self.__repad_sentence_embeddings(sent_embeddings, max_num_sent)
            docs.append(sent_embeddings)
        
        # Use sentence embeddings to form document embedding
        sent_embeddings_batch = torch.stack(docs) 
        output, (_, _) = self.sent_encoder(sent_embeddings_batch)
        padding_positions = self.__get_padding_masks(num_sents, max_num_sent).to(output.device)
        doc_embeddings = self.word_attn(output, padding_positions=padding_positions, return_weights=return_attn_weights)
        # Pass document embedding through output layer
        if return_attn_weights:
            return self.decoder(doc_embeddings[0]), word_attn_weights, doc_embeddings[1]
        else:
            return self.decoder(doc_embeddings)
        
    def __repad_sentence_embeddings(self, sents, max_num_sent):
        return torch.cat([sents,
                          torch.zeros((max_num_sent-sents.shape[0], 
                                       sents.shape[1]), device=sents.device)],dim=0)
    
    def __get_padding_masks(self, lengths, max_len):
        '''
        Returns a mask (shape BxLx1) that indicates the position of pad tokens as '1's
        '''
        return torch.tensor([[False]*i + [True]*(max_len-i) for i in lengths]).unsqueeze(2)

## 1. Categorized evals

Set hyperparameters and load model

In [3]:
MAX_SENT_LEN = 30
MAX_NUM_SENTS = 30
NUM_CLASSES = 4
EMBED_DIM = 100
HIDDEN_DIM = 100
NUM_LSTM_LAYERS = 1

VOCAB_LEN = 400001 #harcoded for convenience; see below for how it was obtained
# glove, _ = DataPreprocessorHcl.from_pretrained_embeds(NUM_CLASSES,'/kaggle/input/lun-glove/glove.6B.100d.txt', EMBED_DIM)
# VOCAB_LEN = len(glove.vocab)

MODEL_PATH = './outputs/model/bestHAN_msl30_mns30_ba256_emb100hid100lay1cla4_ep10lr0.0005wd5e-06_af0.5_ap2_model.pt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = BiLSTMHeAttFCNNClassifier(VOCAB_LEN, EMBED_DIM, HIDDEN_DIM, NUM_LSTM_LAYERS, NUM_CLASSES)
model.to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))

<All keys matched successfully>

Load Preprocessed test data and sub-categorized test data

In [118]:
X_hier_test = np.load('./HAN_prepro_data/X_test_prep.npy')
ylens_test = pd.read_csv('./HAN_prepro_data/ylens_test_prep.csv')
import ast
ylens_test['Num_Tokens'] = ylens_test['Num_Tokens'].apply(ast.literal_eval)

In [22]:
test_df_categorised = pd.read_csv('/kaggle/input/lun-glove/balancedtestwithclass_new_cleaned.csv')

In [119]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
def categorised_eval(categories_df, X, ylens, model, device, category_list=[0,1,2,3,4,5], batch_size=128, return_preds=False):
    all_preds = []
    records = {'category':[], 'support':[], 'acc':[], 'f1':[], 'precision':[], 'recall':[]}
    for cat in category_list:
        idx = categories_df[categories_df['Category']==cat].index
        ylens_cat = ylens.loc[idx]
        X_cat = X[idx]
        
        model.to(device)
        preds=[]
        truths=[]
        
        for tokens, (_, (label, num_sent, sent_len)) in tqdm(zip(X_cat, ylens_cat[['Label','Num_Sentences','Num_Tokens']].iterrows())):
            
            X_in = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
            num_sent = torch.tensor(num_sent, dtype=torch.long).unsqueeze(0)
            sent_len = torch.tensor(sent_len, dtype=torch.long).unsqueeze(0)
            
            #Forward pass
            outputs = model(X_in, num_sent, sent_len, return_attn_weights=False)

            #Logging
            preds.append(torch.argmax(outputs, dim=-1).cpu().item())
            truths.append(label)
        
        records['category'].append(cat)
        records['support'].append(len(X_cat))
        records['acc'].append(accuracy_score(truths, preds))
        records['f1'].append(f1_score(truths, preds, average='macro'))
        records['precision'].append(precision_score(truths, preds, average='macro'))
        records['recall'].append(recall_score(truths, preds, average='macro'))
        all_preds.append((idx,preds,truths))
    return records if not return_preds else (records, all_preds)

In [120]:
results, preds = categorised_eval(test_df_categorised, X_hier_test, ylens_test, model, DEVICE, category_list=[0,1,2,3,4,5], batch_size=128, return_preds=True)

179it [00:02, 77.85it/s]
112it [00:01, 93.76it/s]
559it [00:07, 72.61it/s]
930it [00:11, 82.10it/s] 
1051it [00:11, 88.23it/s]
169it [00:01, 95.39it/s]


In [None]:
pd.DataFrame(results)

In [None]:
pd.DataFrame(results).to_csv('./outputs/HAN_categorized_eval_results.csv')

In [18]:
# health propaganda articles
prophealth = test_df_categorised[(test_df_categorised['Category']==2) & (test_df_categorised['Label'] == 3)]
prophealth

Unnamed: 0,Text,Label,Category,Length
1500,New research suggests that one of the more po...,3,2,728
1503,Infectious disease medicine and psychiatry ha...,3,2,618
1504,Water fluoridation may cause hypothyroidism a...,3,2,630
1507,"As controversial as it may sound, botox has n...",3,2,1537
1508,Having bad skin is something that millions of...,3,2,473
...,...,...,...,...
2242,There's a good chance that the vast majority ...,3,2,525
2245,Medicinal substances that can effectively hel...,3,2,594
2246,Tweet (NewsTarget) Viruses that cause winter ...,3,2,489
2247,Tweet (NewsTarget) The foundation of a health...,3,2,707


In [111]:
from sklearn.metrics import confusion_matrix, classification_report
print(classification_report(test_df_categorised[test_df_categorised['Category'] == 2]['Label'], test_df_categorised[test_df_categorised['Category'] == 2]['prediction']))

              precision    recall  f1-score   support

           1       0.48      0.62      0.54        16
           2       0.35      0.33      0.34        18
           3       0.96      0.78      0.86       493
           4       0.21      0.81      0.34        32

    accuracy                           0.76       559
   macro avg       0.50      0.64      0.52       559
weighted avg       0.89      0.76      0.81       559



In [145]:
prophealth[prophealth['Label'] != prophealth['prediction']]

Unnamed: 0,Text,Label,Category,Length,prediction
1500,New research suggests that one of the more po...,3,2,728,4
1512,Tweet Eli Lilly treated the American public '...,3,2,291,2
1526,"Tens of millions of Americans are obese, and ...",3,2,785,4
1535,Compared to land vegetables like broccoli and...,3,2,551,4
1540,Tweet (NewsTarget) In the third installment o...,3,2,84,1
...,...,...,...,...,...
2202,Winter is the season of the kidney according ...,3,2,515,4
2203,June is National Fresh Fruit and Vegetable Mo...,3,2,337,4
2208,As terrifying as a diagnosis of cancer can be...,3,2,587,4
2220,From happy childish squeals of delight during...,3,2,614,4


## 2. Visualizing attention

In [13]:
class AttnVizPreprocessorHcl():

    def __init__(self, data_vocab):
        self.vocab = data_vocab
        print("Vocab created: {} unique tokens".format(len(self.vocab)))
        
    @classmethod
    def from_pretrained_embeds(cls, embed_path, embed_dim, sep=" ",  specials=['<unk>']):
        # start with all '0's for special tokens
        embeds = [np.asarray([0]*embed_dim, dtype=np.float32)]*len(specials)
        words = OrderedDict()
        with open(embed_path, encoding="utf-8") as f:
            for i, line in enumerate(f):
                if i == 38522 and 'twitter.27B.100d' in embed_path:
                    continue
                splitline = line.split()
                
                word = splitline[0]
                if word not in words:
                    words[word] = 0
                words[word]+=1
                embeds.append(np.asarray(splitline[1:], dtype=np.float32))
                
        embeds = torch.tensor(np.array(embeds))
        data_vocab = vocab(words, specials=specials)
        data_vocab.set_default_index(data_vocab['<unk>'])
        return cls(data_vocab)

    def get_vocab_size(self):
        return len(self.vocab)
    
    def preprocess_single_row(self, row, max_sent_len, max_num_sents, preprocess_label=False):
        '''
        Converts text into integers that index the vocab,
        and converts labels into the range [0,num_classes-1]
        
        Return tokens by sentence (unpadded), idx by sentence (padded), label, num_sentences, num_tokens
        '''
        text = row['Text']
        label = row['Label']
        
        words = [word_tokenize(sent.lower()) for sent in sent_tokenize(text.replace("'",""))]
        token_idxs = [self.vocab(sent) for sent in words]
        num_sentences = min(max_num_sents, len(words))
        num_tokens = [min(max_sent_len, len(sent)) for sent in words][:max_num_sents]
        num_tokens = num_tokens + [0 for _ in range(max_sent_len-len(num_tokens))] #padding
        
        tokens_padded = np.zeros((1, max_num_sents, max_sent_len), dtype='int32')
        for j, sent in enumerate(token_idxs):
            if j >= max_num_sents:
                break
            k = min(max_sent_len, len(sent))
            tokens_padded[0,j,:k] = sent[:k]
                
        if preprocess_label:
            label -= 1
        return words, torch.tensor(tokens_padded, dtype=torch.long), label,\
                torch.tensor(num_sentences, dtype=torch.long).unsqueeze(0),\
                torch.tensor(num_tokens, dtype=torch.long).unsqueeze(0)

In [159]:
MAX_SENT_LEN = 30
MAX_NUM_SENTS = 30
NUM_CLASSES = 4
EMBED_DIM = 100
HIDDEN_DIM = 100
NUM_LSTM_LAYERS = 1

VOCAB_LEN = 400001 #harcoded for convenience; see below for how it was obtained
# glove, _ = DataPreprocessorHcl.from_pretrained_embeds(NUM_CLASSES,'/kaggle/input/lun-glove/glove.6B.100d.txt', EMBED_DIM)
# VOCAB_LEN = len(glove.vocab)

MODEL_PATH = '/kaggle/input/bilstmheattnewsclassifier/pytorch/best_performing/1/glounfro_clean_msl30_mns30_batch256_embed100hidden100layers1classes4_ep10lr0.0005wd5e-06_af0.5_ap2.pt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

modelhan = BiLSTMHeAttFCNNClassifier(VOCAB_LEN, EMBED_DIM, HIDDEN_DIM, NUM_LSTM_LAYERS, NUM_CLASSES)
modelhan.to(DEVICE)
modelhan.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
modelhan.eval()

BiLSTMHeAttFCNNClassifier(
  (embedding): Embedding(400001, 100)
  (word_encoder): LSTM(100, 100, batch_first=True, bidirectional=True)
  (word_attn): AttentionUnit(
    (hidden): Linear(in_features=200, out_features=200, bias=True)
    (query): Linear(in_features=200, out_features=1, bias=False)
  )
  (sent_encoder): LSTM(200, 100, batch_first=True, bidirectional=True)
  (sent_attn): AttentionUnit(
    (hidden): Linear(in_features=200, out_features=200, bias=True)
    (query): Linear(in_features=200, out_features=1, bias=False)
  )
  (decoder): Linear(in_features=200, out_features=4, bias=True)
)

In [14]:
EMBED_PATH = '/kaggle/input/lun-glove/glove.6B.100d.txt'
EMBED_DIM = 100
MAX_SENT_LEN = 30
MAX_NUM_SENT = 30
pp = AttnVizPreprocessorHcl.from_pretrained_embeds(EMBED_PATH, EMBED_DIM)

Vocab created: 400001 unique tokens


In [11]:
train_df = pd.read_csv('/kaggle/input/lun-glove/fulltrain.csv', header=None, names=['Label','Text'])

In [14]:
import re
HEDGE_REGEX = r"may|might|possib|probab|assum|likely|perhap|seem"

train_df['lengths'] = train_df['Text'].apply(lambda s: len(s.split()))
train_df['has_hedge'] = train_df['Text'].apply(lambda text : len(re.findall(HEDGE_REGEX, text)))
train_df['has_2p'] = train_df['Text'].apply(lambda text : len(re.findall("you",text)))
train_df['has_1ps'] = train_df['Text'].apply(lambda text : len(re.findall(r"\b(I|i)\b", text)))
train_df['has_nums'] = train_df['Text'].apply(lambda text : len(re.findall(r"[0-9]|millio|trillio|billio|dollar|\$|\%", text)))

train_df['has_bri_ire_afh_am_hondu'] = train_df['Text'].apply(lambda text : len(re.findall(r"brit|afgha|america|u.s|hondur|ire", text)))

In [17]:
train_df[train_df.has_hedge > 0]

Unnamed: 0,Label,Text,lengths,has_hedge,has_2p,has_1ps,has_nums,has_bri_ire_afh_am_hondu
0,1,"A little less than a decade ago, hockey fans w...",146,2,0,0,0,2
1,1,The writers of the HBO series The Sopranos too...,122,1,0,1,3,0
2,1,Despite claims from the TV news outlet to offe...,705,2,3,0,35,15
3,1,After receiving 'subpar' service and experienc...,705,2,3,15,18,2
5,1,"At a cafeteria-table press conference Monday, ...",93,1,0,1,7,0
...,...,...,...,...,...,...,...,...
48829,4,Opposition Democratic Progressive Party (DPP) ...,319,1,0,0,11,2
48833,4,Taiwan is likely to be affected by a dust stor...,147,1,0,0,7,4
48838,4,President Ma Ying-jeou said Wednesday that a s...,368,1,0,0,19,4
48842,4,A preliminary report on the cause of the April...,476,1,0,0,5,7


In [8]:
X_test = np.load('./HAN_prepro_data/X_test_prep.npy')
ylens_test = pd.read_csv('./HAN_prepro_data/ylens_test_prep.csv')
import ast
ylens_test['Num_Tokens'] = ylens_test['Num_Tokens'].apply(ast.literal_eval)

In [4]:
test_df_categorised = pd.read_csv('./HAN_prepro_data/balancedtestwithclass_new_cleaned.csv')

In [103]:
import re
HEDGE_REGEX = r"may|might|possib|probab|assum|likely|perhap|seem"

test_df_categorised['lengths'] = test_df_categorised['Text'].apply(lambda s: len(s.split()))
test_df_categorised['has_hedge'] = test_df_categorised['Text'].apply(lambda text : len(re.findall(HEDGE_REGEX, text)))
test_df_categorised['has_2p'] = test_df_categorised['Text'].apply(lambda text : len(re.findall("you",text)))
test_df_categorised['has_1ps'] = test_df_categorised['Text'].apply(lambda text : len(re.findall(r"\b(I|i)\b", text)))
test_df_categorised['has_nums'] = test_df_categorised['Text'].apply(lambda text : len(re.findall(r"[0-9]|millio|trillio|billio|dollar|\$|\%", text)))

test_df_categorised['has_bri_ire_afh_am_hondu'] = test_df_categorised['Text'].apply(lambda text : len(re.findall(r"brit|afgha|america|u.s|hondur|ire", text)))

In [214]:

import torch
import json
from IPython.display import Markdown, display

labelid2label = ['Satire', 'Hoax', 'Propaganda', 'Trusted']

class AttentionVisualizerHcl():
    def __init__(self, model, preprocessor):

        class Struct:
            def __init__(self, **entries):
                self.__dict__.update(entries)
        
        self.model = model
        self.preprocessor = preprocessor

    def visualize_attention(self, doc_and_label_row, max_sent_len, max_num_sents, device='cpu', preprocess_label=False, sents_only=False):
        # doc_and_label_row should contain a row of the df with columns ['Text'] and ['Label']
        
        words, X, y, num_sentences, num_tokens = self.preprocessor.preprocess_single_row(doc_and_label_row, max_sent_len, max_num_sents, preprocess_label)
        pred, word_weights, sent_weights = self.model(X, num_sentences, num_tokens, return_attn_weights=True)
        pred = torch.argmax(pred, dim = -1).cpu().item()
        # remove the padding elements (which have 0 weight)
        sent_weights = sent_weights.squeeze(0,2)[:num_sentences]
        word_weights = word_weights[0].squeeze(2).tolist()
        word_weights = [weights[:num] for num, weights in zip(num_tokens[0], word_weights)]
        
        display(Markdown('<p style="font-size:18px"> Ground Truth: '+ labelid2label[y] + '&emsp;&emsp;&emsp;Prediction: '+labelid2label[pred] +'</p>'))
        i = 0
        if sents_only:
            line = []
            for sent_weight in sent_weights:
                sent_weight = sent_weight.item()
                line.append('<span style="background-color:rgba(255,0,0,' +\
                        str(sent_weight) +\
                        ');font-size:16px;color:rgba(255,0,0,0);">' + '_____' + '</span>')
            display(Markdown(" ".join(line)))
            return words, word_weights, sent_weights
        max_weight_sent = max(map(max, word_weights))
        min_weight_sent = min(map(min, word_weights))
        
        for sent, word_weights_sent, sent_weight in zip(words, word_weights, sent_weights):
            i += 1
            sent_weight = sent_weight.item()
            line = [self.__make_sent(sent_weight)]
            line_length = 5
            for word, word_weight in zip(sent, word_weights_sent):
                line_length += len(word) + 1
                line.append(self.__make_word(word, self.__scale_weight(word_weight,
                                                                       max_weight_sent,
                                                                       min_weight_sent),
                                             sent_weight))
                
                if line_length > 60:
                    display(Markdown(" ".join(line)))
                    line = [self.__make_blank()]
                    line_length = 5
                    
            display(Markdown(" ".join(line)))
        return words, word_weights, sent_weights

    def __make_blank(self):
        return '<span style="color:rgba(255,255,255,0);font-size:16px">' + '_____' + '</span>'
    
    def __make_sent(self, sent_weight):
        return '<span style="background-color:rgba(255,0,0,' +\
                    str(sent_weight) +\
                    ');font-size:16px;color:rgba(255,0,0,0);">' + '_____' + '</span>'
    
    def __make_word(self, word, word_weight, sent_weight):
        return '<span style="background-color:rgba(0,0,255,' +\
                        str(word_weight*sent_weight) + ');font-size:16px;">' +\
                        word.replace('$', '\$').replace("'", "\'") + '</span>'
    
    def __scale_weight(self, orig_weight, max_weight, min_weight):
        out = (orig_weight-min_weight)/(max_weight-min_weight)
        if type(out) == torch.Tensor:
            out = out.item()
        return out



In [216]:
avhan =AttentionVisualizerHcl(modelhan, pp)

In [287]:
a,b,c = av.visualize_attention(test_df_categorised.loc[925], MAX_SENT_LEN, MAX_NUM_SENT, preprocess_label=True)

<p style="font-size:18px"> Ground Truth: Hoax&emsp;&emsp;&emsp;Prediction: Hoax</p>

<span style="background-color:rgba(255,0,0,0.38874557614326477);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.016056856978135723);font-size:16px;">5</span> <span style="background-color:rgba(0,0,255,0.020698486688090607);font-size:16px;">fast</span> <span style="background-color:rgba(0,0,255,0.02129833752830908);font-size:16px;">facts</span> <span style="background-color:rgba(0,0,255,0.011524553205517521);font-size:16px;">you</span> <span style="background-color:rgba(0,0,255,0.016486570148576907);font-size:16px;">probably</span> <span style="background-color:rgba(0,0,255,0.03097247692137667);font-size:16px;">didnt</span> <span style="background-color:rgba(0,0,255,0.016011315827400934);font-size:16px;">know</span> <span style="background-color:rgba(0,0,255,0.0171647435457063);font-size:16px;">about</span> <span style="background-color:rgba(0,0,255,0.05146342352472175);font-size:16px;">melania</span> <span style="background-color:rgba(0,0,255,0.041566688314819945);font-size:16px;">trump</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.04591509591271989);font-size:16px;">melania</span> <span style="background-color:rgba(0,0,255,0.028355713106230727);font-size:16px;">trump</span> <span style="background-color:rgba(0,0,255,0.012247979738759828);font-size:16px;">put</span> <span style="background-color:rgba(0,0,255,0.013013152133284144);font-size:16px;">herself</span> <span style="background-color:rgba(0,0,255,0.009197132015150808);font-size:16px;">in</span> <span style="background-color:rgba(0,0,255,0.010008817029084101);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.01834613074796995);font-size:16px;">spotlight</span> <span style="background-color:rgba(0,0,255,0.00870751180335951);font-size:16px;">for</span> <span style="background-color:rgba(0,0,255,0.008723928056131775);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.010647690366518863);font-size:16px;">first</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.008847746471155497);font-size:16px;">time</span> <span style="background-color:rgba(0,0,255,0.011365597006624983);font-size:16px;">in</span> <span style="background-color:rgba(0,0,255,0.016603707321182683);font-size:16px;">husband</span> <span style="background-color:rgba(0,0,255,0.03700650170729606);font-size:16px;">donalds</span> <span style="background-color:rgba(0,0,255,0.01621257881927857);font-size:16px;">campaign</span> <span style="background-color:rgba(0,0,255,0.014686968851251658);font-size:16px;">for</span> <span style="background-color:rgba(0,0,255,0.016837381248815574);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.0252438916419563);font-size:16px;">us</span> <span style="background-color:rgba(0,0,255,0.04378042908871803);font-size:16px;">presidency</span> <span style="background-color:rgba(0,0,255,0.024503234169420577);font-size:16px;">last</span>

<span style="background-color:rgba(255,0,0,0.22007574141025543);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.02273649267632144);font-size:16px;">no</span> <span style="background-color:rgba(0,0,255,0.03821373473479856);font-size:16px;">sooner</span> <span style="background-color:rgba(0,0,255,0.018649093130474206);font-size:16px;">than</span> <span style="background-color:rgba(0,0,255,0.014499200555871734);font-size:16px;">she</span> <span style="background-color:rgba(0,0,255,0.018259406043645718);font-size:16px;">left</span> <span style="background-color:rgba(0,0,255,0.018866790258637245);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.030778401239769125);font-size:16px;">stage</span> <span style="background-color:rgba(0,0,255,0.020357442717770652);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.024849845463984527);font-size:16px;">critics</span> <span style="background-color:rgba(0,0,255,0.015349968370589944);font-size:16px;">were</span> <span style="background-color:rgba(0,0,255,0.019294035377876557);font-size:16px;">looking</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.018778214049307275);font-size:16px;">for</span> <span style="background-color:rgba(0,0,255,0.026431328558105618);font-size:16px;">anything</span> <span style="background-color:rgba(0,0,255,0.023157477832222818);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.05776415248607784);font-size:16px;">bash</span> <span style="background-color:rgba(0,0,255,0.02682186673431031);font-size:16px;">her</span> <span style="background-color:rgba(0,0,255,0.030701174053866447);font-size:16px;">one</span> <span style="background-color:rgba(0,0,255,0.04361369235567004);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.3911786675453186);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.031957847680044546);font-size:16px;">but</span> <span style="background-color:rgba(0,0,255,0.03599790734979267);font-size:16px;">here</span> <span style="background-color:rgba(0,0,255,0.0296223557946747);font-size:16px;">are</span> <span style="background-color:rgba(0,0,255,0.023020838125524853);font-size:16px;">some</span> <span style="background-color:rgba(0,0,255,0.03905215102418634);font-size:16px;">facts</span> <span style="background-color:rgba(0,0,255,0.01846301448845027);font-size:16px;">that</span> <span style="background-color:rgba(0,0,255,0.02310931820121194);font-size:16px;">you</span> <span style="background-color:rgba(0,0,255,0.029764580963340698);font-size:16px;">may</span> <span style="background-color:rgba(0,0,255,0.028613988156066132);font-size:16px;">never</span> <span style="background-color:rgba(0,0,255,0.027366275796589727);font-size:16px;">have</span> <span style="background-color:rgba(0,0,255,0.040575180293827035);font-size:16px;">known</span> <span style="background-color:rgba(0,0,255,0.04628526382162988);font-size:16px;">about</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.1690829805142743);font-size:16px;">melania</span> <span style="background-color:rgba(0,0,255,0.082531055180701);font-size:16px;">.</span>

In [306]:
a,b,c = av.visualize_attention(test_df_categorised.loc[1890], MAX_SENT_LEN, MAX_NUM_SENT, preprocess_label=True)

<p style="font-size:18px"> Ground Truth: Propaganda&emsp;&emsp;&emsp;Prediction: Propaganda</p>

<span style="background-color:rgba(255,0,0,0.09338170289993286);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.02092750562633254);font-size:16px;">even</span> <span style="background-color:rgba(0,0,255,0.02368837651080304);font-size:16px;">obama</span> <span style="background-color:rgba(0,0,255,0.04979552181771765);font-size:16px;">doesnt</span> <span style="background-color:rgba(0,0,255,0.01527078041893253);font-size:16px;">have</span> <span style="background-color:rgba(0,0,255,0.01447282377451875);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.025939287446273517);font-size:16px;">buy</span> <span style="background-color:rgba(0,0,255,0.05893300377786621);font-size:16px;">obamacare</span> <span style="background-color:rgba(0,0,255,0.010631760941267998);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.013529857892208407);font-size:16px;">it</span> <span style="background-color:rgba(0,0,255,0.027426691691069198);font-size:16px;">turns</span> <span style="background-color:rgba(0,0,255,0.020298586214447824);font-size:16px;">out</span> <span style="background-color:rgba(0,0,255,0.02467000552769924);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.1411142796278);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.008329520109437348);font-size:16px;">and</span> <span style="background-color:rgba(0,0,255,0.00803010261123373);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.008891861238929802);font-size:16px;">very</span> <span style="background-color:rgba(0,0,255,0.013397972173326961);font-size:16px;">fact</span> <span style="background-color:rgba(0,0,255,0.008930756496887834);font-size:16px;">that</span> <span style="background-color:rgba(0,0,255,0.045641123647989196);font-size:16px;">obamacare</span> <span style="background-color:rgba(0,0,255,0.015450093763923579);font-size:16px;">forces</span> <span style="background-color:rgba(0,0,255,0.01288559984918485);font-size:16px;">citizens</span> <span style="background-color:rgba(0,0,255,0.007020857850592278);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.01278941126331217);font-size:16px;">purchase</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.007016186613236593);font-size:16px;">a</span> <span style="background-color:rgba(0,0,255,0.012094125990982666);font-size:16px;">private</span> <span style="background-color:rgba(0,0,255,0.011755189796761711);font-size:16px;">insurance</span> <span style="background-color:rgba(0,0,255,0.012629792927773888);font-size:16px;">product</span> <span style="background-color:rgba(0,0,255,0.00847571108654909);font-size:16px;">or</span> <span style="background-color:rgba(0,0,255,0.008087956403666339);font-size:16px;">be</span> <span style="background-color:rgba(0,0,255,0.016707451478635803);font-size:16px;">fined</span> <span style="background-color:rgba(0,0,255,0.008619499972911568);font-size:16px;">by</span> <span style="background-color:rgba(0,0,255,0.008260855159368027);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.023645554399136078);font-size:16px;">irs</span> <span style="background-color:rgba(0,0,255,0.008413722031568233);font-size:16px;">is</span> <span style="background-color:rgba(0,0,255,0.01737559795347016);font-size:16px;">blatantly</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.012869389060230909);font-size:16px;">unconstitutional</span> <span style="background-color:rgba(0,0,255,0.004814714523017811);font-size:16px;">and</span> <span style="background-color:rgba(0,0,255,0.005071731585981598);font-size:16px;">an</span> <span style="background-color:rgba(0,0,255,0.01329769450550302);font-size:16px;">outlandish</span> <span style="background-color:rgba(0,0,255,0.01374621585695854);font-size:16px;">interpretation</span> <span style="background-color:rgba(0,0,255,0.008014191996160195);font-size:16px;">of</span> <span style="background-color:rgba(0,0,255,0.009145837029699688);font-size:16px;">the</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.02424311453118411);font-size:16px;">commerce</span>

<span style="background-color:rgba(255,0,0,0.11197298020124435);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.086220334175198);font-size:16px;">obamacare</span> <span style="background-color:rgba(0,0,255,0.019757235127196304);font-size:16px;">is</span> <span style="background-color:rgba(0,0,255,0.016195168264801284);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.022681242540419908);font-size:16px;">biggest</span> <span style="background-color:rgba(0,0,255,0.020808622350178628);font-size:16px;">government</span> <span style="background-color:rgba(0,0,255,0.027173111154929926);font-size:16px;">boondoggle</span> <span style="background-color:rgba(0,0,255,0.013781338403809842);font-size:16px;">our</span> <span style="background-color:rgba(0,0,255,0.012855061952215807);font-size:16px;">nation</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.006376685247572061);font-size:16px;">has</span> <span style="background-color:rgba(0,0,255,0.006612188536016596);font-size:16px;">ever</span> <span style="background-color:rgba(0,0,255,0.005607169361417602);font-size:16px;">seen</span> <span style="background-color:rgba(0,0,255,0.003732506977548856);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.004618592759377537);font-size:16px;">and</span> <span style="background-color:rgba(0,0,255,0.004657084665605073);font-size:16px;">it</span> <span style="background-color:rgba(0,0,255,0.006399652680191761);font-size:16px;">is</span> <span style="background-color:rgba(0,0,255,0.013107637019526896);font-size:16px;">doomed</span> <span style="background-color:rgba(0,0,255,0.006729884707865038);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.011967809380681461);font-size:16px;">crash</span> <span style="background-color:rgba(0,0,255,0.008513768421821099);font-size:16px;">and</span> <span style="background-color:rgba(0,0,255,0.02102474271860885);font-size:16px;">burn</span> <span style="background-color:rgba(0,0,255,0.015803798825034496);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.10022211819887161);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.015915433330761394);font-size:16px;">nobody</span> <span style="background-color:rgba(0,0,255,0.008048285271199994);font-size:16px;">who</span> <span style="background-color:rgba(0,0,255,0.015511796820600278);font-size:16px;">understands</span> <span style="background-color:rgba(0,0,255,0.007280429003525972);font-size:16px;">it</span> <span style="background-color:rgba(0,0,255,0.009542483583794817);font-size:16px;">wants</span> <span style="background-color:rgba(0,0,255,0.006004604041091346);font-size:16px;">it</span> <span style="background-color:rgba(0,0,255,0.005313362335782852);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.007124336422278201);font-size:16px;">and</span> <span style="background-color:rgba(0,0,255,0.008431416460744925);font-size:16px;">even</span> <span style="background-color:rgba(0,0,255,0.011003799910034308);font-size:16px;">those</span> <span style="background-color:rgba(0,0,255,0.01061402545419669);font-size:16px;">who</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.03237498000839386);font-size:16px;">foolishly</span> <span style="background-color:rgba(0,0,255,0.012226791858974708);font-size:16px;">were</span> <span style="background-color:rgba(0,0,255,0.03336381362622351);font-size:16px;">mind-tricked</span> <span style="background-color:rgba(0,0,255,0.013896092660392632);font-size:16px;">into</span> <span style="background-color:rgba(0,0,255,0.017509849083397706);font-size:16px;">supporting</span> <span style="background-color:rgba(0,0,255,0.006759423797918685);font-size:16px;">it</span> <span style="background-color:rgba(0,0,255,0.006032398910607779);font-size:16px;">have</span> <span style="background-color:rgba(0,0,255,0.005710984761715393);font-size:16px;">no</span> <span style="background-color:rgba(0,0,255,0.013186921330115607);font-size:16px;">clue</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.005587469895084356);font-size:16px;">just</span> <span style="background-color:rgba(0,0,255,0.006932721884571039);font-size:16px;">how</span> <span style="background-color:rgba(0,0,255,0.008559021305629654);font-size:16px;">badly</span> <span style="background-color:rgba(0,0,255,0.006484713002205196);font-size:16px;">its</span> <span style="background-color:rgba(0,0,255,0.006661195327982396);font-size:16px;">going</span> <span style="background-color:rgba(0,0,255,0.006297318760686278);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.011073751746868053);font-size:16px;">hurt</span> <span style="background-color:rgba(0,0,255,0.008929161978893038);font-size:16px;">them</span> <span style="background-color:rgba(0,0,255,0.010202124074484863);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.10944917798042297);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.021775225553672284);font-size:16px;">every</span> <span style="background-color:rgba(0,0,255,0.045344738284564534);font-size:16px;">rational</span> <span style="background-color:rgba(0,0,255,0.019074477528823894);font-size:16px;">person</span> <span style="background-color:rgba(0,0,255,0.01908953057137037);font-size:16px;">wants</span> <span style="background-color:rgba(0,0,255,0.013248761499762013);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.01569730882365263);font-size:16px;">be</span> <span style="background-color:rgba(0,0,255,0.04961821005312744);font-size:16px;">exempted</span> <span style="background-color:rgba(0,0,255,0.02456105642158917);font-size:16px;">from</span> <span style="background-color:rgba(0,0,255,0.09696883588996431);font-size:16px;">obamacare</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.025452915300322017);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.09913370013237);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.014735476153195768);font-size:16px;">any</span> <span style="background-color:rgba(0,0,255,0.020452794701408747);font-size:16px;">lawmakers</span> <span style="background-color:rgba(0,0,255,0.010194783464121298);font-size:16px;">who</span> <span style="background-color:rgba(0,0,255,0.021042886960401314);font-size:16px;">continue</span> <span style="background-color:rgba(0,0,255,0.013112348863892167);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.03231812335239214);font-size:16px;">defend</span> <span style="background-color:rgba(0,0,255,0.04572311267363345);font-size:16px;">obamacare</span> <span style="background-color:rgba(0,0,255,0.009534331452688715);font-size:16px;">will</span> <span style="background-color:rgba(0,0,255,0.009883563851818126);font-size:16px;">likely</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.007562506032253274);font-size:16px;">find</span> <span style="background-color:rgba(0,0,255,0.008336097356041329);font-size:16px;">themselves</span> <span style="background-color:rgba(0,0,255,0.004892056882025155);font-size:16px;">out</span> <span style="background-color:rgba(0,0,255,0.005098033909291541);font-size:16px;">of</span> <span style="background-color:rgba(0,0,255,0.005233893393429653);font-size:16px;">a</span> <span style="background-color:rgba(0,0,255,0.0074683581584986955);font-size:16px;">job</span> <span style="background-color:rgba(0,0,255,0.006497470636094177);font-size:16px;">when</span> <span style="background-color:rgba(0,0,255,0.01097637590568201);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.03423641171798404);font-size:16px;">2014</span> <span style="background-color:rgba(0,0,255,0.014720929519373807);font-size:16px;">elections</span> <span style="background-color:rgba(0,0,255,0.00972228235982183);font-size:16px;">come</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.011211980468589892);font-size:16px;">around</span> <span style="background-color:rgba(0,0,255,0.011901230756664445);font-size:16px;">.</span>

In [299]:
short_trusted = (test_df_categorised['Label'] == 3) & (test_df_categorised['Text'].apply(lambda t: re.search(r'obama|trump|melania|biden', t.lower())))
test_df_categorised[short_trusted].sort_values('lengths').head(20)

Unnamed: 0,Text,Label,Category,lengths,has_hedge,has_2p,has_1ps,has_nums,has_bri_ire_afh_am_hondu,has_days
1890,Top lawmakers on Capitol Hill are negotiating...,3,4,321,1,0,1,4,2,0
2191,After months of collecting signatures to get ...,3,4,322,0,0,0,30,2,0
2143,The battle continues to rage between the indi...,3,4,380,0,0,1,10,4,0
1727,With preparations being made for the massive ...,3,4,423,0,0,0,8,1,0
1896,Thank you to all those who voted in our selfn...,3,4,510,5,4,2,57,5,0
2007,In what is quickly shaping up to be the bigge...,3,4,531,3,0,0,11,5,0
1587,"The Law of Attraction, made popular by 'The S...",3,2,534,0,7,7,20,7,12
1775,Tweet (NewsTarget) California governor Arnold...,3,4,565,0,0,2,20,3,0
1756,"If you're 'Ready for Hillary' in 2016, you mi...",3,4,574,4,6,0,18,4,0
1644,"With a sad twist of irony, corporate and gove...",3,3,576,0,1,0,8,3,0


In [222]:
a,b,c = avhan.visualize_attention(test_df.loc[561], MAX_SENT_LEN, MAX_NUM_SENT, preprocess_label=True)

<p style="font-size:18px"> Ground Truth: Satire&emsp;&emsp;&emsp;Prediction: Satire</p>

<span style="background-color:rgba(255,0,0,0.24200034141540527);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.12800093275906924);font-size:16px;">cnn</span> <span style="background-color:rgba(0,0,255,0.0686656732192339);font-size:16px;">apologized</span> <span style="background-color:rgba(0,0,255,0.023735232263345278);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.03139357901777049);font-size:16px;">its</span> <span style="background-color:rgba(0,0,255,0.04592112639376344);font-size:16px;">viewers</span> <span style="background-color:rgba(0,0,255,0.03297088506997794);font-size:16px;">today</span> <span style="background-color:rgba(0,0,255,0.02809632537479983);font-size:16px;">for</span> <span style="background-color:rgba(0,0,255,0.04498048909980942);font-size:16px;">briefly</span> <span style="background-color:rgba(0,0,255,0.05424646468441979);font-size:16px;">airing</span> <span style="background-color:rgba(0,0,255,0.011694496760459914);font-size:16px;">a</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.013522807299843266);font-size:16px;">story</span> <span style="background-color:rgba(0,0,255,0.006186737586989254);font-size:16px;">on</span> <span style="background-color:rgba(0,0,255,0.006035391070540488);font-size:16px;">sunday</span> <span style="background-color:rgba(0,0,255,0.000594714139850972);font-size:16px;">that</span> <span style="background-color:rgba(0,0,255,0.0);font-size:16px;">had</span> <span style="background-color:rgba(0,0,255,0.006083250916267387);font-size:16px;">nothing</span> <span style="background-color:rgba(0,0,255,0.0011575253685266321);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.004128781753417192);font-size:16px;">do</span> <span style="background-color:rgba(0,0,255,0.007408141100190395);font-size:16px;">with</span> <span style="background-color:rgba(0,0,255,0.011685815890008213);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.023447332433060464);font-size:16px;">missing</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.05884020775388262);font-size:16px;">malaysia</span> <span style="background-color:rgba(0,0,255,0.053390606686742885);font-size:16px;">airlines</span> <span style="background-color:rgba(0,0,255,0.0688980601869735);font-size:16px;">flight</span> <span style="background-color:rgba(0,0,255,0.04763412665826261);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.1594906896352768);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.020649687691945987);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.01806528332357854);font-size:16px;">story</span> <span style="background-color:rgba(0,0,255,0.008646701176333989);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.01616890932158344);font-size:16px;">which</span> <span style="background-color:rgba(0,0,255,0.029477051111571808);font-size:16px;">caused</span> <span style="background-color:rgba(0,0,255,0.019445070525178833);font-size:16px;">thousands</span> <span style="background-color:rgba(0,0,255,0.013894376726604674);font-size:16px;">of</span> <span style="background-color:rgba(0,0,255,0.022729155530379585);font-size:16px;">viewers</span> <span style="background-color:rgba(0,0,255,0.011627106412791006);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.027408742458556182);font-size:16px;">contact</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.012749780262664837);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.024212070591006957);font-size:16px;">network</span> <span style="background-color:rgba(0,0,255,0.011138855219376497);font-size:16px;">in</span> <span style="background-color:rgba(0,0,255,0.019359907712518405);font-size:16px;">anger</span> <span style="background-color:rgba(0,0,255,0.003234946891362784);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.0027125262895308574);font-size:16px;">had</span> <span style="background-color:rgba(0,0,255,0.004640105857272176);font-size:16px;">something</span> <span style="background-color:rgba(0,0,255,0.005116338466443487);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.007460621025571788);font-size:16px;">do</span> <span style="background-color:rgba(0,0,255,0.011782907475781665);font-size:16px;">with</span> <span style="background-color:rgba(0,0,255,0.03097433474280984);font-size:16px;">crimea</span> <span style="background-color:rgba(0,0,255,0.011159797107724308);font-size:16px;">,</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.03868795194976976);font-size:16px;">ukraine</span> <span style="background-color:rgba(0,0,255,0.012694811958309184);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.023550440378798432);font-size:16px;">and</span> <span style="background-color:rgba(0,0,255,0.04992943166600428);font-size:16px;">russia</span> <span style="background-color:rgba(0,0,255,0.03472092075275935);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.12299803644418716);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.010093119473519926);font-size:16px;">in</span> <span style="background-color:rgba(0,0,255,0.009448084967218757);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.012868296658751758);font-size:16px;">official</span> <span style="background-color:rgba(0,0,255,0.014594173455341478);font-size:16px;">apology</span> <span style="background-color:rgba(0,0,255,0.00643498179435829);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.02652403866875466);font-size:16px;">cnn</span> <span style="background-color:rgba(0,0,255,0.02058838454793205);font-size:16px;">chief</span> <span style="background-color:rgba(0,0,255,0.023513236783770788);font-size:16px;">jeff</span> <span style="background-color:rgba(0,0,255,0.02006761178976706);font-size:16px;">zucker</span> <span style="background-color:rgba(0,0,255,0.004801311342946055);font-size:16px;">wrote</span> <span style="background-color:rgba(0,0,255,9.357809542919335e-05);font-size:16px;">,</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.0006071479116512968);font-size:16px;">on</span> <span style="background-color:rgba(0,0,255,0.0014025631151565775);font-size:16px;">sunday</span> <span style="background-color:rgba(0,0,255,0.0006888246952133639);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.002859542331954672);font-size:16px;">we</span> <span style="background-color:rgba(0,0,255,0.007937976743716902);font-size:16px;">briefly</span> <span style="background-color:rgba(0,0,255,0.004674681473790687);font-size:16px;">cut</span> <span style="background-color:rgba(0,0,255,0.007356644247277597);font-size:16px;">away</span> <span style="background-color:rgba(0,0,255,0.00613151254378937);font-size:16px;">from</span> <span style="background-color:rgba(0,0,255,0.011810869359970423);font-size:16px;">our</span> <span style="background-color:rgba(0,0,255,0.02522546625264316);font-size:16px;">nonstop</span> <span style="background-color:rgba(0,0,255,0.015550009921839653);font-size:16px;">coverage</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.008925258203804763);font-size:16px;">of</span> <span style="background-color:rgba(0,0,255,0.016238086085855984);font-size:16px;">flight</span> <span style="background-color:rgba(0,0,255,0.022970723204320517);font-size:16px;">370</span> <span style="background-color:rgba(0,0,255,0.006366211604516508);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.010686967589795354);font-size:16px;">talk</span> <span style="background-color:rgba(0,0,255,0.007798670855183193);font-size:16px;">about</span> <span style="background-color:rgba(0,0,255,0.014694351880544517);font-size:16px;">something</span> <span style="background-color:rgba(0,0,255,0.03439082976146793);font-size:16px;">else</span>

<span style="background-color:rgba(255,0,0,0.20533621311187744);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.046498688702378346);font-size:16px;">were</span> <span style="background-color:rgba(0,0,255,0.03024411137347906);font-size:16px;">not</span> <span style="background-color:rgba(0,0,255,0.03886646819640896);font-size:16px;">going</span> <span style="background-color:rgba(0,0,255,0.036200423050911534);font-size:16px;">to</span> <span style="background-color:rgba(0,0,255,0.1218194811302222);font-size:16px;">sugarcoat</span> <span style="background-color:rgba(0,0,255,0.03433919140873926);font-size:16px;">it</span> <span style="background-color:rgba(0,0,255,0.05036888212525839);font-size:16px;">:</span> <span style="background-color:rgba(0,0,255,0.0552907778666426);font-size:16px;">we</span> <span style="background-color:rgba(0,0,255,0.20533621311187744);font-size:16px;">messed</span> <span style="background-color:rgba(0,0,255,0.11446427415928825);font-size:16px;">up</span> <span style="background-color:rgba(0,0,255,0.11636646054406043);font-size:16px;">.</span>

<span style="background-color:rgba(255,0,0,0.2701747417449951);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.18462358252074268);font-size:16px;">cnn</span> <span style="background-color:rgba(0,0,255,0.12192390469922851);font-size:16px;">regrets</span> <span style="background-color:rgba(0,0,255,0.04734815943184023);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.06829797851267416);font-size:16px;">error</span> <span style="background-color:rgba(0,0,255,0.032604771279978684);font-size:16px;">and</span> <span style="background-color:rgba(0,0,255,0.08450575802784707);font-size:16px;">promises</span> <span style="background-color:rgba(0,0,255,0.052300251325189916);font-size:16px;">our</span> <span style="background-color:rgba(0,0,255,0.050560102972480515);font-size:16px;">viewers</span> <span style="background-color:rgba(0,0,255,0.020283860226074278);font-size:16px;">that</span> <span style="background-color:rgba(0,0,255,0.03089278975774513);font-size:16px;">it</span> <span style="background-color:rgba(0,0,255,0.11843866558609938);font-size:16px;">wont</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.09777943071779445);font-size:16px;">happen</span> <span style="background-color:rgba(0,0,255,0.07655611543207946);font-size:16px;">again</span> <span style="background-color:rgba(0,0,255,0.07871297236991769);font-size:16px;">.</span>

In [190]:
short_trusted = (test_df_categorised['lengths'] < 100) & (test_df_categorised['Label'] == 4)
test_df_categorised[(test_df_categorised['has_nums'] > 0) & short_trusted].head(20)

Unnamed: 0,Text,Label,Category,lengths,has_hedge,has_2p,has_1ps,has_nums,has_bri_ire_afh_am_hondu,has_days
2268,British American Tobacco announced Tuesday tha...,4,3,55,0,0,0,9,3,1
2318,The plaintiff in the lawsuit that legalized ab...,4,4,98,2,0,0,6,2,0
2334,"Gold for current delivery closed at $1,107.80 ...",4,0,22,0,0,0,14,1,2
2399,Triple Olympic gold medalist Stephanie Rice sa...,4,5,91,0,1,0,4,0,1
2401,Singapore exchange to buy Australian bourse fo...,4,3,13,0,0,0,4,1,0
2407,West Indies beat England by five wickets under...,4,5,74,0,0,0,25,0,1
2413,Eurozone recovery falters in Q4 as economy gro...,4,0,15,0,0,0,4,0,0
2467,Coast Guard Adm. Thad Allen: cap now funneling...,4,3,20,0,0,0,9,0,0
2487,Spanish bank BBVA reported Wednesday its fourt...,4,3,99,0,0,0,45,0,1
2526,Results Thursday from the St. Petersburg Open ...,4,5,53,0,0,0,18,7,1


In [290]:
a,b,c = av.visualize_attention(test_df_categorised.loc[2467], MAX_SENT_LEN, MAX_NUM_SENT, preprocess_label=True)
a,b,c = av.visualize_attention(test_df_categorised.loc[2334], MAX_SENT_LEN, MAX_NUM_SENT, preprocess_label=True)

<p style="font-size:18px"> Ground Truth: Trusted&emsp;&emsp;&emsp;Prediction: Trusted</p>

<span style="background-color:rgba(255,0,0,1.0);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.04612308740615845);font-size:16px;">coast</span> <span style="background-color:rgba(0,0,255,0.04704555124044418);font-size:16px;">guard</span> <span style="background-color:rgba(0,0,255,0.0670442059636116);font-size:16px;">adm.</span> <span style="background-color:rgba(0,0,255,0.07110335677862167);font-size:16px;">thad</span> <span style="background-color:rgba(0,0,255,0.04284561425447464);font-size:16px;">allen</span> <span style="background-color:rgba(0,0,255,0.02684144489467144);font-size:16px;">:</span> <span style="background-color:rgba(0,0,255,0.04974472522735596);font-size:16px;">cap</span> <span style="background-color:rgba(0,0,255,0.03876177594065666);font-size:16px;">now</span> <span style="background-color:rgba(0,0,255,0.1004386618733406);font-size:16px;">funneling</span> <span style="background-color:rgba(0,0,255,0.07699908316135406);font-size:16px;">462,000</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.04530829191207886);font-size:16px;">gallons</span> <span style="background-color:rgba(0,0,255,0.034812189638614655);font-size:16px;">(</span> <span style="background-color:rgba(0,0,255,0.045579664409160614);font-size:16px;">1.7</span> <span style="background-color:rgba(0,0,255,0.026468364521861076);font-size:16px;">million</span> <span style="background-color:rgba(0,0,255,0.0351150780916214);font-size:16px;">liters</span> <span style="background-color:rgba(0,0,255,0.019862961024045944);font-size:16px;">)</span> <span style="background-color:rgba(0,0,255,0.01809554174542427);font-size:16px;">of</span> <span style="background-color:rgba(0,0,255,0.023657288402318954);font-size:16px;">oil</span> <span style="background-color:rgba(0,0,255,0.01755109801888466);font-size:16px;">a</span> <span style="background-color:rgba(0,0,255,0.01983649656176567);font-size:16px;">day</span> <span style="background-color:rgba(0,0,255,0.023630375042557716);font-size:16px;">from</span> <span style="background-color:rgba(0,0,255,0.04681131988763809);font-size:16px;">gulf</span> <span style="background-color:rgba(0,0,255,0.043448325246572495);font-size:16px;">spill</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.03287554532289505);font-size:16px;">.</span>

<p style="font-size:18px"> Ground Truth: Trusted&emsp;&emsp;&emsp;Prediction: Trusted</p>

<span style="background-color:rgba(255,0,0,1.0);font-size:16px;color:rgba(255,0,0,0);">_____</span> <span style="background-color:rgba(0,0,255,0.035253413021564484);font-size:16px;">gold</span> <span style="background-color:rgba(0,0,255,0.021404463797807693);font-size:16px;">for</span> <span style="background-color:rgba(0,0,255,0.028654640540480614);font-size:16px;">current</span> <span style="background-color:rgba(0,0,255,0.02937629260122776);font-size:16px;">delivery</span> <span style="background-color:rgba(0,0,255,0.03306358680129051);font-size:16px;">closed</span> <span style="background-color:rgba(0,0,255,0.0252033993601799);font-size:16px;">at</span> <span style="background-color:rgba(0,0,255,0.03220895305275917);font-size:16px;">\$</span> <span style="background-color:rgba(0,0,255,0.07705462723970413);font-size:16px;">1,107.80</span> <span style="background-color:rgba(0,0,255,0.04180104285478592);font-size:16px;">per</span> <span style="background-color:rgba(0,0,255,0.03937329351902008);font-size:16px;">troy</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.032713644206523895);font-size:16px;">ounce</span> <span style="background-color:rgba(0,0,255,0.02076738514006138);font-size:16px;">thursday</span> <span style="background-color:rgba(0,0,255,0.020397823303937912);font-size:16px;">on</span> <span style="background-color:rgba(0,0,255,0.021511899307370186);font-size:16px;">the</span> <span style="background-color:rgba(0,0,255,0.02474067360162735);font-size:16px;">new</span> <span style="background-color:rgba(0,0,255,0.030498450621962547);font-size:16px;">york</span> <span style="background-color:rgba(0,0,255,0.04653860628604889);font-size:16px;">mercantile</span> <span style="background-color:rgba(0,0,255,0.03625035658478737);font-size:16px;">exchange</span> <span style="background-color:rgba(0,0,255,0.019604403525590897);font-size:16px;">,</span> <span style="background-color:rgba(0,0,255,0.030437614768743515);font-size:16px;">up</span>

<span style="color:rgba(255,255,255,0);font-size:16px">_____</span> <span style="background-color:rgba(0,0,255,0.031183507293462753);font-size:16px;">from</span> <span style="background-color:rgba(0,0,255,0.04347136616706848);font-size:16px;">\$</span> <span style="background-color:rgba(0,0,255,0.09160143882036209);font-size:16px;">1,096.50</span> <span style="background-color:rgba(0,0,255,0.04823124781250954);font-size:16px;">late</span> <span style="background-color:rgba(0,0,255,0.092160165309906);font-size:16px;">wedensday</span> <span style="background-color:rgba(0,0,255,0.04649776220321655);font-size:16px;">.</span>

# 2. FAN Analysis

First, define the model:

In [3]:
class AttentionUnit(nn.Module):
    def __init__(self, input_dim, hidden_dim=None, num_outputs=1, attn_dropout=0.0):
        super(AttentionUnit, self).__init__()
        if hidden_dim is None:
            hidden_dim = input_dim
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.query = nn.Linear(hidden_dim, num_outputs, bias=False)
    def forward(self, encoder_output, padding_positions=None, return_weights=False):
        #Calculate u_{i} = tanh(Wh_{i}+b) [B,L,H]-->[B,L,H]
        hidden_rep = F.tanh(self.hidden(encoder_output))
        #Calculate a_{i} = softmax(u_{i}^Tc) with masking [B,L,H]-->[B,L,1]
        similarity = self.query(hidden_rep)
        if padding_positions is not None:
            similarity = similarity.masked_fill(padding_positions, -float('inf'))
        attention_weights = F.softmax(similarity, dim=1)
        #Return weighted sum [B,L,1], [B,L,H]-->[B,H]
        if return_weights:
            return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1), attention_weights
        return torch.bmm(attention_weights.transpose(1,2), hidden_rep).squeeze(1)

class LSTMFlatAttentionFCNNClassifier(torch.nn.Module):
    '''
    Classifier that uses an LSTM as an encoder followed by an attention block
    and a Fully-Connected Neural Network(FCNN) as a decoder.
    '''
    def __init__(self, vocab_len, embed_dim, hidden_dim, num_lstm_layers, num_classes, attn_dropout=0.0, pretrained_embeddings=None, freeze_embeds=False):
        super(LSTMFlatAttentionFCNNClassifier, self).__init__()
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=freeze_embeds)
        else:
            self.embedding = nn.Embedding(num_embeddings=vocab_len, embedding_dim=embed_dim)

        self.encoder = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        self.attn = AttentionUnit(2*hidden_dim)
        self.decoder = nn.Linear(2*hidden_dim, num_classes)

    def forward(self, X_batch, lengths, return_attn_weights=False):
        embeddings = self.embedding(X_batch)

        embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, lengths.cpu(), enforce_sorted=False, batch_first=True)
        output, (_, _) = self.encoder(embeddings)
        output, _ = nn.utils.rnn.pad_packed_sequence(output,batch_first=True)

        padding_positions = self.__get_padding_masks(lengths).to(output.device)
        doc_embeddings = self.attn(output,padding_positions=padding_positions,return_weights=return_attn_weights)
        
        if return_attn_weights:
            return self.decoder(doc_embeddings[0]), doc_embeddings[1]
        else:
            return self.decoder(doc_embeddings)
    
    def __get_padding_masks(self, lengths):
        '''
        Returns a mask (shape BxLx1) that indicates the position of pad tokens
        '''
        max_len = lengths.max()
        return torch.tensor([[False]*i + [True]*(max_len-i) for i in lengths]).unsqueeze(2)

## 1. Categorized evals

Set hyperparameters and load model

In [115]:
MODEL_MAX_LEN = 500
NUM_CLASSES = 4
EMBED_DIM = 100
HIDDEN_DIM = 100
NUM_LSTM_LAYERS = 1

VOCAB_LEN = 400001 #harcoded for convenience; see below for how it was obtained
# glove, _ = DataPreprocessorFlat.from_pretrained_embeds(NUM_CLASSES,'/kaggle/input/lun-glove/glove.6B.100d.txt', EMBED_DIM)
# VOCAB_LEN = len(glove.vocab)

MODEL_PATH = './outputs/model/bestFAN_ml500_ba256_emb100hid100lay1cla4_ep10lr0.0005wd5e-06_af0.5ap2_model.pt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

collate_fn = make_flat_collate_function(MODEL_MAX_LEN)
model = LSTMFlatAttentionFCNNClassifier(VOCAB_LEN, EMBED_DIM, HIDDEN_DIM, NUM_LSTM_LAYERS, NUM_CLASSES)
model.to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))

<All keys matched successfully>

Load Preprocessed test data and sub-categorized test data

In [119]:
X_test = pd.read_parquet('./FAN_prepro_data/X_test_prep_flat.parquet')['Text']
y_test = pd.read_parquet('./FAN_prepro_data/y_test_prep_flat.parquet')['Label']
embeds = torch.tensor(np.load('./glove_embs.npy'))

In [102]:
test_df_categorised = pd.read_csv('/kaggle/input/lun-glove/balancedtestwithclass_new_cleaned.csv')

In [100]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
def categorised_eval_flat(categories_df, X, ylens, model, device, category_list=[0,1,2,3,4,5]):
    records = {'category':[], 'support':[], 'acc':[], 'f1':[], 'precision':[], 'recall':[]}
    all_preds = []
    all_truths = []
    idxes = []
    for cat in category_list:
        idx = categories_df[categories_df['Category']==cat].index
        ylens_cat = ylens.loc[idx]
        X_cat = X.loc[idx]
        loader = DataLoader(WrapperDatasetFlat(X_cat, ylens_cat),
                          batch_size=128,
                          collate_fn=collate_fn,
                          shuffle=False)
        model.to(device)
        preds=[]
        truths=[]
        for X_batch, lengths, y_batch in tqdm(loader):
            #Move to correct device
            X_batch = X_batch.to(device)

            #Forward pass
            outputs = model(X_batch, lengths)
            if type(outputs)==tuple:
                logits = outputs[0]
            else:
                logits = outputs

            #Logging
            preds.append(torch.argmax(logits, dim=-1).cpu())
            truths.append(y_batch)
        preds = torch.cat(preds)
        truths = torch.cat(truths)
        records['category'].append(cat)
        records['support'].append(len(X_cat))
        records['acc'].append(accuracy_score(truths, preds))
        records['f1'].append(f1_score(truths, preds, average='macro'))
        records['precision'].append(precision_score(truths, preds, average='macro'))
        records['recall'].append(recall_score(truths, preds, average='macro'))
        all_preds.append(preds)
        all_truths.append(truths)
        idxes.append(idx)
    return records, all_preds, all_truths, idxes

In [105]:
results = categorised_eval_flat(test_df_categorised, X_test_flat, y_test_flat, model, DEVICE, category_list=[0,1,2,3,4,5])

100%|██████████| 2/2 [00:00<00:00,  2.10it/s]
100%|██████████| 1/1 [00:00<00:00,  1.76it/s]
100%|██████████| 5/5 [00:03<00:00,  1.26it/s]
100%|██████████| 8/8 [00:05<00:00,  1.37it/s]
100%|██████████| 9/9 [00:06<00:00,  1.33it/s]
100%|██████████| 2/2 [00:00<00:00,  2.15it/s]


In [None]:
pd.DataFrame(results)

In [None]:
pd.DataFrame(results).to_csv('./outputs/FAN_categorized_eval_results.csv', index=False)

In [144]:
print(classification_report(test_df_categorised[test_df_categorised['Category']==2]['Label'], test_df_categorised[test_df_categorised['Category']==2]['pred_fan']))

              precision    recall  f1-score   support

           1       0.41      0.69      0.51        16
           2       0.33      0.22      0.27        18
           3       0.96      0.70      0.81       493
           4       0.18      0.88      0.29        32

    accuracy                           0.70       559
   macro avg       0.47      0.62      0.47       559
weighted avg       0.88      0.70      0.75       559



## Visualize attention

In [23]:
class AttnVizPreprocessorFlat():

    def __init__(self, data_vocab):
        self.vocab = data_vocab
        print("Vocab created: {} unique tokens".format(len(self.vocab)))
        
    @classmethod
    def from_pretrained_embeds(cls, embed_path, embed_dim, sep=" ",  specials=['<unk>']):
        # start with all '0's for special tokens
        embeds = [np.asarray([0]*embed_dim, dtype=np.float32)]*len(specials)
        words = OrderedDict()
        with open(embed_path, encoding="utf-8") as f:
            for i, line in enumerate(f):
                if i == 38522 and 'twitter.27B.100d' in embed_path:
                    continue
                splitline = line.split()
                
                word = splitline[0]
                if word not in words:
                    words[word] = 0
                words[word]+=1
                embeds.append(np.asarray(splitline[1:], dtype=np.float32))
                
        embeds = torch.tensor(np.array(embeds))
        data_vocab = vocab(words, specials=specials)
        data_vocab.set_default_index(data_vocab['<unk>'])
        return cls(data_vocab)

    def get_vocab_size(self):
        return len(self.vocab)
    
    def preprocess_single_row(self, row, model_max_len, preprocess_label=False):
        '''
        Converts text into integers that index the vocab,
        and converts labels into the range [0,num_classes-1]
        
        Return tokens by sentence (unpadded), idx by sentence (padded), label, num_sentences, num_tokens
        '''
        text = row['Text']
        label = row['Label']
        
        words = [word_tokenize(sent.lower()) for sent in sent_tokenize(text.replace("'",""))]
        words = [word for sent in words for word in sent][:model_max_len] # flatten and truncate
        token_idxs = self.vocab(words)
        num_tokens = len(token_idxs)
        
        if preprocess_label:
            label -= 1
        return words, torch.tensor(token_idxs, dtype=torch.long), label,\
                torch.tensor(num_tokens, dtype=torch.long).unsqueeze(0)
                
    

In [144]:
MODEL_MAX_LEN = 500
NUM_CLASSES = 4
EMBED_DIM = 100
HIDDEN_DIM = 100
NUM_LSTM_LAYERS = 1

VOCAB_LEN = 400001 #harcoded for convenience; see below for how it was obtained
# glove, _ = DataPreprocessorFlat.from_pretrained_embeds(NUM_CLASSES,'/kaggle/input/lun-glove/glove.6B.100d.txt', EMBED_DIM)
# VOCAB_LEN = len(glove.vocab)

EMBED_PATH = '../glove.6B.100d.txt'
MODEL_PATH = './outputs/model/bestFAN_ml500_ba256_emb100hid100lay1cla4_ep10lr0.0005wd5e-06_af0.5ap2_model.pt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

collate_fn = make_flat_collate_function(MODEL_MAX_LEN)
model = LSTMFlatAttentionFCNNClassifier(VOCAB_LEN, EMBED_DIM, HIDDEN_DIM, NUM_LSTM_LAYERS, NUM_CLASSES)
model.to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
ppflat = AttnVizPreprocessorFlat.from_pretrained_embeds(EMBED_PATH, EMBED_DIM)

Vocab created: 400001 unique tokens


In [90]:

import torch
import json
from IPython.display import Markdown, display

labelid2label = ['Satire', 'Hoax', 'Propaganda', 'Trusted']

class AttentionVisualizerFlat():
    def __init__(self, model, preprocessor):
        
        self.model = model
        self.preprocessor = preprocessor

    def visualize_attention(self, doc_and_label_row, model_max_len, device='cpu', preprocess_label=False):
        # doc_and_label_row should contain a row of the df with columns ['Text'] and ['Label']
        
        words, X, y, num_tokens = self.preprocessor.preprocess_single_row(doc_and_label_row, model_max_len, preprocess_label)
        pred, word_weights = self.model(X.unsqueeze(0), num_tokens, return_attn_weights=True)
        pred = torch.argmax(pred, dim = -1).cpu().item()
        word_weights = word_weights.squeeze(0,2).cpu()
        
        max_weight = word_weights.max()
        min_weight = word_weights.min()
        line = []
        line_length = 0
        
        display(Markdown('<p style="font-size:18px"> Ground Truth: '+ labelid2label[y] + '&emsp;&emsp;&emsp;Prediction: '+labelid2label[pred] +'</p>'))
        
        for word, weight in zip(words, word_weights):
            line_length += len(word)
            line.append(self.__make_word(word, self.__scale_weight(weight, max_weight, min_weight)))
            if line_length > 60:
                display(Markdown(" ".join(line)))
                line = []
                line_length = 0
        if len(line) > 0:
            display(Markdown(" ".join(line)))
        return words, pred, word_weights
    
    def __make_word(self, word, word_weight):
        return '<span style="background-color:rgba(0,0,255,' +\
                        str(word_weight.item()) + ');font-size:16px;">' +\
                        word.replace('$', '\$').replace("'", "\'") + '</span>'

    def __scale_weight(self, orig_weight, max_weight, min_weight):
        return (orig_weight-min_weight)/(max_weight-min_weight) * 0.5


In [91]:
avflat = AttentionVisualizerFlat(model, ppflat)

In [7]:
test_df = pd.read_csv('/kaggle/input/lun-glove/balancedtest.csv', header=None, names=['Label', 'Text'])

In [97]:
X_test_flat = pd.read_parquet('./FAN_prepro_data/X_test_prep_flat.parquet')['Text']
y_test_flat = pd.read_parquet('./FAN_prepro_data/y_test_prep_flat.parquet')['Label']

In [147]:
def predic(te):
    w, tok, lab, pad = ppflat.preprocess_single_row({'Text':te, 'Label':1}, 500, False)
    pred = model(tok.unsqueeze(0), pad, return_attn_weights=False)
    return torch.argmax(pred, dim = -1).cpu().item()

In [148]:
test_df['pred_flat_2'] = -1
test_df['pred_flat_2'] = test_df['Text'].apply(predic)

In [113]:
test_df['pred_flat'] = -1
test_df['truths'] = -1
for p, t, i in zip(results[1],results[2],results[3]):
    test_df.loc[i, 'pred_flat'] = p.tolist()
    test_df.loc[i, 'truths'] = t.tolist()

In [154]:
test_df['length'] = test_df['Text'].apply(lambda s: len(s.split()))

In [258]:
test_df[(test_df['pred_hier']==test_df['pred_flat']) & (test_df['pred_hier']==test_df['truths'])  & (test_df['length'] < 100) & (test_df['pred_flat'] == 3) & (test_df['Text'].apply(lambda t: re.search(r'[0-9]', t) is not None))]

Unnamed: 0,Label,Text,pred_flat,truths,pred_hier,pred_flat_2,length
2268,4,British American Tobacco announced Tuesday tha...,3,3,3,3,55
2334,4,"Gold for current delivery closed at $1,107.80 ...",3,3,3,3,22
2399,4,Triple Olympic gold medalist Stephanie Rice sa...,3,3,3,3,91
2401,4,Singapore exchange to buy Australian bourse fo...,3,3,3,3,13
2407,4,West Indies beat England by five wickets under...,3,3,3,3,74
2413,4,Eurozone recovery falters in Q4 as economy gro...,3,3,3,3,15
2467,4,Coast Guard Adm. Thad Allen: cap now funneling...,3,3,3,3,20
2487,4,Spanish bank BBVA reported Wednesday its fourt...,3,3,3,3,99
2526,4,Results Thursday from the St. Petersburg Open ...,3,3,3,3,53
2542,4,"Brome Howard Inn 18281 Rosecroft Rd., St. Mary...",3,3,3,3,35
