In [75]:
from tensorflow.keras.datasets import imdb
import numpy as np 
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.sequence import pad_sequences


In [71]:
abbreviations = {
    "$" : " dollar ",
    "€" : " euro ",
    "4ao" : "for adults only",
    "a.m" : "before midday",
    "a3" : "anytime anywhere anyplace",
    "aamof" : "as a matter of fact",
    "acct" : "account",
    "adih" : "another day in hell",
    "afaic" : "as far as i am concerned",
    "afaict" : "as far as i can tell",
    "afaik" : "as far as i know",
    "afair" : "as far as i remember",
    "afk" : "away from keyboard",
    "app" : "application",
    "approx" : "approximately",
    "apps" : "applications",
    "asap" : "as soon as possible",
    "asl" : "age, sex, location",
    "atk" : "at the keyboard",
    "ave." : "avenue",
    "aymm" : "are you my mother",
    "ayor" : "at your own risk", 
    "b&b" : "bed and breakfast",
    "b+b" : "bed and breakfast",
    "b.c" : "before christ",
    "b2b" : "business to business",
    "b2c" : "business to customer",
    "b4" : "before",
    "b4n" : "bye for now",
    "b@u" : "back at you",
    "bae" : "before anyone else",
    "bak" : "back at keyboard",
    "bbbg" : "bye bye be good",
    "bbc" : "british broadcasting corporation",
    "bbias" : "be back in a second",
    "bbl" : "be back later",
    "bbs" : "be back soon",
    "be4" : "before",
    "bfn" : "bye for now",
    "blvd" : "boulevard",
    "bout" : "about",
    "brb" : "be right back",
    "bros" : "brothers",
    "brt" : "be right there",
    "bsaaw" : "big smile and a wink",
    "btw" : "by the way",
    "bwl" : "bursting with laughter",
    "c/o" : "care of",
    "cet" : "central european time",
    "cf" : "compare",
    "cia" : "central intelligence agency",
    "csl" : "can not stop laughing",
    "cu" : "see you",
    "cul8r" : "see you later",
    "cv" : "curriculum vitae",
    "cwot" : "complete waste of time",
    "cya" : "see you",
    "cyt" : "see you tomorrow",
    "dae" : "does anyone else",
    "dbmib" : "do not bother me i am busy",
    "diy" : "do it yourself",
    "dm" : "direct message",
    "dwh" : "during work hours",
    "e123" : "easy as one two three",
    "eet" : "eastern european time",
    "eg" : "example",
    "embm" : "early morning business meeting",
    "encl" : "enclosed",
    "encl." : "enclosed",
    "etc" : "and so on",
    "faq" : "frequently asked questions",
    "fawc" : "for anyone who cares",
    "fb" : "facebook",
    "fc" : "fingers crossed",
    "fig" : "figure",
    "fimh" : "forever in my heart", 
    "ft." : "feet",
    "ft" : "featuring",
    "ftl" : "for the loss",
    "ftw" : "for the win",
    "fwiw" : "for what it is worth",
    "fyi" : "for your information",
    "g9" : "genius",
    "gahoy" : "get a hold of yourself",
    "gal" : "get a life",
    "gcse" : "general certificate of secondary education",
    "gfn" : "gone for now",
    "gg" : "good game",
    "gl" : "good luck",
    "glhf" : "good luck have fun",
    "gmt" : "greenwich mean time",
    "gmta" : "great minds think alike",
    "gn" : "good night",
    "g.o.a.t" : "greatest of all time",
    "goat" : "greatest of all time",
    "goi" : "get over it",
    "gps" : "global positioning system",
    "gr8" : "great",
    "gratz" : "congratulations",
    "gyal" : "girl",
    "h&c" : "hot and cold",
    "hp" : "horsepower",
    "hr" : "hour",
    "hrh" : "his royal highness",
    "ht" : "height",
    "ibrb" : "i will be right back",
    "ic" : "i see",
    "icq" : "i seek you",
    "icymi" : "in case you missed it",
    "idc" : "i do not care",
    "idgadf" : "i do not give a damn fuck",
    "idgaf" : "i do not give a fuck",
    "idk" : "i do not know",
    "ie" : "that is",
    "i.e" : "that is",
    "ifyp" : "i feel your pain",
    "IG" : "instagram",
    "iirc" : "if i remember correctly",
    "ilu" : "i love you",
    "ily" : "i love you",
    "imho" : "in my humble opinion",
    "imo" : "in my opinion",
    "imu" : "i miss you",
    "iow" : "in other words",
    "irl" : "in real life",
    "j4f" : "just for fun",
    "jic" : "just in case",
    "jk" : "just kidding",
    "jsyk" : "just so you know",
    "l8r" : "later",
    "lb" : "pound",
    "lbs" : "pounds",
    "ldr" : "long distance relationship",
    "lmao" : "laugh my ass off",
    "lmfao" : "laugh my fucking ass off",
    "lol" : "laughing out loud",
    "ltd" : "limited",
    "ltns" : "long time no see",
    "m8" : "mate",
    "mf" : "motherfucker",
    "mfs" : "motherfuckers",
    "mfw" : "my face when",
    "mofo" : "motherfucker",
    "mph" : "miles per hour",
    "mr" : "mister",
    "mrw" : "my reaction when",
    "ms" : "miss",
    "mte" : "my thoughts exactly",
    "nagi" : "not a good idea",
    "nbc" : "national broadcasting company",
    "nbd" : "not big deal",
    "nfs" : "not for sale",
    "ngl" : "not going to lie",
    "nhs" : "national health service",
    "nrn" : "no reply necessary",
    "nsfl" : "not safe for life",
    "nsfw" : "not safe for work",
    "nth" : "nice to have",
    "nvr" : "never",
    "nyc" : "new york city",
    "oc" : "original content",
    "og" : "original",
    "ohp" : "overhead projector",
    "oic" : "oh i see",
    "omdb" : "over my dead body",
    "omg" : "oh my god",
    "omw" : "on my way",
    "p.a" : "per annum",
    "p.m" : "after midday",
    "pm" : "prime minister",
    "poc" : "people of color",
    "pov" : "point of view",
    "pp" : "pages",
    "ppl" : "people",
    "prw" : "parents are watching",
    "ps" : "postscript",
    "pt" : "point",
    "ptb" : "please text back",
    "pto" : "please turn over",
    "qpsa" : "what happens", #"que pasa",
    "ratchet" : "rude",
    "rbtl" : "read between the lines",
    "rlrt" : "real life retweet", 
    "rofl" : "rolling on the floor laughing",
    "roflol" : "rolling on the floor laughing out loud",
    "rotflmao" : "rolling on the floor laughing my ass off",
    "rt" : "retweet",
    "ruok" : "are you ok",
    "sfw" : "safe for work",
    "sk8" : "skate",
    "smh" : "shake my head",
    "sq" : "square",
    "srsly" : "seriously", 
    "ssdd" : "same stuff different day",
    "tbh" : "to be honest",
    "tbs" : "tablespooful",
    "tbsp" : "tablespooful",
    "tfw" : "that feeling when",
    "thks" : "thank you",
    "tho" : "though",
    "thx" : "thank you",
    "tia" : "thanks in advance",
    "til" : "today i learned",
    "tl;dr" : "too long i did not read",
    "tldr" : "too long i did not read",
    "tmb" : "tweet me back",
    "tntl" : "trying not to laugh",
    "ttyl" : "talk to you later",
    "u" : "you",
    "u2" : "you too",
    "u4e" : "yours for ever",
    "utc" : "coordinated universal time",
    "w/" : "with",
    "w/o" : "without",
    "w8" : "wait",
    "wassup" : "what is up",
    "wb" : "welcome back",
    "wtf" : "what the fuck",
    "wtg" : "way to go",
    "wtpa" : "where the party at",
    "wuf" : "where are you from",
    "wuzup" : "what is up",
    "wywh" : "wish you were here",
    "yd" : "yard",
    "ygtr" : "you got that right",
    "ynk" : "you never know",
    "zzz" : "sleeping bored and tired"
}

In [72]:
base_csv = './data/IMDB Dataset.csv'
df = pd.read_csv(base_csv)
import pandas as pd
from collections import Counter
import re
import nltk
from nltk.corpus import stopwords

# Download NLTK stop words (if not already downloaded)
nltk.download('stopwords')
# Preprocess the text to extract the vocabulary
def preprocess_text(text):
    text = text.lower()
    text = abbreviations[text.lower()] if text.lower() in abbreviations.keys() else text
    text = re.sub(r'<.*?>', '', text)
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    text = text.strip()
    stop_words = set(stopwords.words('english'))  
    text_tokens = text.split()  
    filtered_text = [word for word in text_tokens if word not in stop_words]
    text = ' '.join(filtered_text)
    return text

# Tokenize and build the vocabulary
def build_vocab(dataframe, column_name):
    word_list = []
    for review in dataframe[column_name]:
        processed_review = preprocess_text(review)
        word_list.extend(processed_review.split())
    # Count word frequencies
    word_counts = Counter(word_list)
    return word_counts


vocab = build_vocab(df, 'review')
print(f"Vocabulary size: {len(vocab)}")

[nltk_data] Downloading package stopwords to /Users/jack/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Vocabulary size: 221464


In [78]:
from tensorflow.keras.preprocessing.sequence import pad_sequences
# Create a word-to-index dictionary
word_to_index = {word: i+1 for i, word in enumerate(vocab.keys())}  # Start indexing from 1
# Convert reviews to sequences of integers
def text_to_sequence(text, word_to_index):
    processed_text = preprocess_text(text)
    return [word_to_index[word] for word in processed_text.split() if word in word_to_index]

# Tokenize the reviews
df['review_sequences'] = df['review'].apply(lambda x: text_to_sequence(x, word_to_index))
df['label_convert'] = df['sentiment'].apply(lambda x: 1 if x =='positive' else 0)

Word to index mapping: [('one', 1), ('reviewers', 2), ('mentioned', 3), ('watching', 4), ('1', 5), ('oz', 6), ('episode', 7), ('youll', 8), ('hooked', 9), ('right', 10)]


In [79]:
# Train-test split with shuffle
X_train, y_train, X_test, y_test = train_test_split(df['review_sequences'], df['label_convert'], test_size=0.2, random_state=42, shuffle=True)
# Pad the sequences to a fixed length
max_length = 100  # Maximum sequence length
x_train_pad = pad_sequences(X_train, maxlen=max_length, padding='post', truncating='post')
y_train_pad = pad_sequences(y_train, maxlen=max_length, padding='post', truncating='post')

print(f"Padded sequences shape: {x_train_pad.shape}")
print("Padded Array:", x_train_pad)
print("Padded Array:", y_train_pad)

Padded sequences shape: (40000, 100)
Padded Array: [[  626  1251  3129 ...   214 27947 10648]
 [  336   158   295 ...   919  3196  2333]
 [ 4438   332   420 ...     0     0     0]
 ...
 [ 2207     1  1167 ...     0     0     0]
 [  146   878  1069 ...     0     0     0]
 [ 3266   295  1204 ...     0     0     0]]
Padded Array: [[   186    607  84411 ...    239   1456 106497]
 [   251   2249     88 ...   2120   3037   4395]
 [   306   4404    380 ...      0      0      0]
 ...
 [   747  63135    565 ...     81   2623   7873]
 [    79    573     23 ...      0      0      0]
 [   103   4795    259 ...      0      0      0]]


In [80]:
x_train_pad

array([[  626,  1251,  3129, ...,   214, 27947, 10648],
       [  336,   158,   295, ...,   919,  3196,  2333],
       [ 4438,   332,   420, ...,     0,     0,     0],
       ...,
       [ 2207,     1,  1167, ...,     0,     0,     0],
       [  146,   878,  1069, ...,     0,     0,     0],
       [ 3266,   295,  1204, ...,     0,     0,     0]], dtype=int32)

In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# ------------------------------
# 1. Define Vocabulary and Utility Functions
# ------------------------------

# Define a simple vocabulary mapping.
vocab = {
    "i": 0,
    "love": 1,
    "hate": 2,
    "this": 3,
    "movie": 4,
    "film": 5,
    "is": 6,
    "amazing": 7,
    "awful": 8
}

def text_to_indices(text, vocab):
    """
    Converts a text string to a list of indices based on the given vocabulary.
    """
    return [vocab[word] for word in text.split() if word in vocab]


# ------------------------------
# 2. Define the Supervised HMM Classifier Model
# ------------------------------

class SupervisedHMMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_states, num_classes):
        """
        vocab_size: the size of the vocabulary.
        embed_dim: dimensionality of word embeddings.
        num_states: number of hidden states in the HMM.
        num_classes: number of classes to predict.
        """
        super(SupervisedHMMClassifier, self).__init__()
        # Embedding layer to convert word indices into embeddings.
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # HMM parameters:
        # Initial state distribution (pi) for the hidden states.
        self.pi = nn.Parameter(torch.rand(num_states))
        # Transition matrix A between states.
        self.A = nn.Parameter(torch.rand(num_states, num_states))
        # For each state, define a linear layer that computes a scalar emission score
        # from the word embedding. We exponentiate these scores to ensure positive emissions.
        self.emission_layers = nn.ModuleList(
            [nn.Linear(embed_dim, 1) for _ in range(num_states)]
        )
        
        # A simple classifier: maps the final latent state (vector of size num_states)
        # to class logits.
        self.classifier = nn.Linear(num_states, num_classes)
    
    def forward(self, x, embeddings=None, return_embeddings=False):
        """
        x: tensor of word indices with shape (T,), where T is the sentence length.
        embeddings: (optional) precomputed embeddings for x. If not provided, they are computed.
        return_embeddings: if True, also return intermediate values for saliency analysis.
        
        Returns:
          logits: class scores (before softmax).
          final_alpha: final latent state representation from the HMM (normalized).
          (optionally) embeddings, the full list of alphas (for each time step), and emissions.
        """
        if embeddings is None:
            embeddings = self.embedding(x)  # shape: (T, embed_dim)
        
        T = embeddings.size(0)  # Sentence length
        num_states = self.pi.size(0)
        
        # Compute emission probabilities for each state and time step.
        emission_list = []
        for state in range(num_states):
            out = self.emission_layers[state](embeddings).squeeze(-1)  # shape: (T,)
            emission_list.append(torch.exp(out))  # Ensure positive emissions.
        # Stack into shape (T, num_states)
        emissions = torch.stack(emission_list, dim=1)
        
        # --- Forward Algorithm ---
        alpha_list = []
        # Initialization (t=0)
        alpha_t = self.pi * emissions[0]  # shape: (num_states,)
        alpha_list.append(alpha_t)
        
        # Recursion for t = 1, …, T-1:
        for t in range(1, T):
            alpha_prev = alpha_list[-1]
            # For each state j:
            # alpha[t, j] = emissions[t, j] * sum_i (alpha[t-1, i] * A[i, j])
            alpha_t = emissions[t] * torch.matmul(alpha_prev, self.A)
            alpha_list.append(alpha_t)
        
        # Stack along time steps: shape (T, num_states)
        alpha = torch.stack(alpha_list, dim=0)
        # Normalize final alpha
        final_alpha = alpha[-1] / (alpha[-1].sum() + 1e-8)
        
        # --- Classification ---
        logits = self.classifier(final_alpha)  # shape: (num_classes,)
        
        if return_embeddings:
            return logits, final_alpha, embeddings, alpha, emissions
        else:
            return logits, final_alpha


# ------------------------------
# 3. Define a Saliency Computation Function
# ------------------------------

def compute_saliency(text, model):
    """
    Computes word-level saliency for a given text using gradient norms.
    
    Returns:
      saliency: a tensor of shape (T,) with a saliency value for each word.
      logits: the class logits computed by the model.
      predicted_class: the index of the predicted class.
    """
    indices = text_to_indices(text, vocab)
    x = torch.tensor(indices, dtype=torch.long)
    
    model.zero_grad()
    
    # Get embeddings and enable gradient tracking.
    embeddings = model.embedding(x)  # shape: (T, embed_dim)
    embeddings.requires_grad_(True)
    embeddings.retain_grad()

    # Forward pass (with intermediate values for saliency analysis)
    logits, final_alpha, used_embeddings, alpha, emissions = model(x, embeddings=embeddings, return_embeddings=True)
    
    # Select predicted class as the one with maximum logit.
    predicted_class = torch.argmax(logits)
    
    # Use the logit for the predicted class as the target.
    target_logit = logits[predicted_class]
    
    # Backpropagate to compute gradients with respect to embeddings.
    target_logit.backward()
    
    # Compute the L2 norm of the gradients for each word.
    saliency = embeddings.grad.norm(dim=1)
    return saliency, logits, predicted_class.item()


# ------------------------------
# 4. Create Fake Data and Train the Model
# ------------------------------

# Define a small fake training dataset.
# Each training instance is a tuple (text, label)
# For this example, label 1 will denote a positive sentiment and label 0 a negative sentiment.
train_data = [
    ("i love this movie", 1),
    ("i hate this movie", 0),
    ("this film is amazing", 1),
    ("this movie is awful", 0),
]

# Hyperparameters.
vocab_size = len(vocab)
embed_dim = 100
num_states = 2
num_classes = 2
num_epochs = 10
learning_rate = 0.01

# Instantiate the model and move it to device.
model = SupervisedHMMClassifier(vocab_size, embed_dim, num_states, num_classes)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

# Training loop.
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for text, label in train_data:
        optimizer.zero_grad()  # Reset gradients for each example.
        
        # Convert text to indices tensor.
        indices = text_to_indices(text, vocab)
        x = torch.tensor(indices, dtype=torch.long)
        target = torch.tensor([label], dtype=torch.long)
        
        # Forward pass.
        logits, final_alpha = model(x)
        # logits shape: (num_classes,) => unsqueeze to (1, num_classes)
        loss = loss_fn(logits.unsqueeze(0), target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_data):.4f}")

# ------------------------------
# 5. Evaluate and Compute Saliency for a Sample Input
# ------------------------------

model.eval()
sample_text = "i love this movie"
saliency, logits, pred_class = compute_saliency(sample_text, model)

print("\n--- Evaluation on Sample Input ---")
print(f"Text: '{sample_text}'")
print(f"Predicted class: {pred_class}")
print(f"Logits: {logits.detach().numpy()}")
print("Saliency for each word (higher value means more influence):")
for word, sal in zip(sample_text.split(), saliency):
    print(f"  {word}: {sal.item():.4f}")

Epoch 1/10, Loss: 0.7840
Epoch 2/10, Loss: 0.6433
Epoch 3/10, Loss: 0.6079
Epoch 4/10, Loss: 0.6003
Epoch 5/10, Loss: 0.5962
Epoch 6/10, Loss: 0.5924
Epoch 7/10, Loss: 0.5886
Epoch 8/10, Loss: 0.5848
Epoch 9/10, Loss: 0.5811
Epoch 10/10, Loss: 0.5774

--- Evaluation on Sample Input ---
Text: 'i love this movie'
Predicted class: 1
Logits: [-0.32629627  0.03500685]
Saliency for each word (higher value means more influence):
  i: 0.0000
  love: 0.0000
  this: 0.0001
  movie: 0.0132
