# Prerequisites

## AttentionLSTM
- Proposed framework used by the authors Jain and Wallace and implemented in the SEAT paper as well

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.distributions import normal

m = normal.Normal(0, 1e-3)
def masked_softmax(attn_odds, masks) :
    attentions = torch.softmax(F.relu(attn_odds.squeeze()), dim=-1)

    # apply mask and renormalize attention scores (weights)
    masked = attn_odds * masks
    _sums = masked.sum(-1).unsqueeze(-1)  # sums per row

    attn_odds = masked.div(_sums)
    return attn_odds
def perturbed_masked_softmax(attn_odds, masks, delta) :
    attentions = torch.softmax(F.relu(attn_odds.squeeze()), dim=-1)
    attentions = attentions + delta
    attentions = torch.clamp(attentions, min=0, max=1)

    # apply mask and renormalize attention scores (weights)
    masked = attn_odds * masks

    _sums = masked.sum(-1).unsqueeze(-1)  # sums per row

    attn_odds = masked.div(_sums)
    return attn_odds
class TanhAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn1 = nn.Linear(hidden_size, hidden_size // 2)
        self.attn2 = nn.Linear(hidden_size // 2, 1, bias=False)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    def forward(self, hidden, lengths):
        #input_seq = (B, L), hidden : (B, L, H), masks : (B, L)
        max_len = hidden.shape[1]
        attn1 = nn.Tanh()(self.attn1(hidden))
        attn2 = self.attn2(attn1).squeeze(-1)
        masks = torch.ones(attn2.size(), requires_grad=False).to(self.device)
        for i, l in enumerate(lengths):  # skip the first sentence
            if l < max_len:
                masks[i, l:] = 0
        attn = masked_softmax(attn2, masks)
        # apply attention weights
        weighted = torch.mul(hidden, attn.unsqueeze(-1).expand_as(hidden))

        # get the final fixed vector representations of the sentences
        representations = weighted.sum(1).squeeze()

        return representations, attn


class AttentionLSTM(nn.Module):
    def __init__(
        self,
        vocab_size,
        emb_dim,
        hidden_size,
        num_classes,
        dropout = 0.4,
        lstm_layer = 2

    ):
        super(AttentionLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.dropout = nn.Dropout(p=dropout)

        self.lstm = nn.LSTM(input_size = emb_dim, hidden_size = hidden_size, bidirectional = True)
        self.attention = TanhAttention(hidden_size = hidden_size*2)
        self.fc1 = nn.Sequential(nn.Linear(hidden_size*lstm_layer, hidden_size*lstm_layer),
                                 nn.BatchNorm1d(hidden_size*lstm_layer),
                                 nn.ReLU())
        self.fc2 = nn.Linear(hidden_size*lstm_layer, num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, x_len):
        x = self.embedding(x)
        x = self.dropout(x)
        x = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=False)
        out1, (h_n, c_n) = self.lstm(x)
        x, lengths = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
        x, attn = self.attention(x, lengths) # skip connect


        y = self.fc1(self.dropout(x))
        y = self.fc2(self.dropout(y))
        y = self.sigmoid(y.squeeze())
        return y

    def atten_forward(self, x, x_len):
        x = self.embedding(x)
        x = self.dropout(x)
        x = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=False)
        out1, (h_n, c_n) = self.lstm(x)
        x, lengths = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
        x, _ = self.attention(x, lengths) # skip connect
        y = self.fc1(self.dropout(x))
        y = self.fc2(self.dropout(y))
        y = self.sigmoid(y.squeeze())
        return x, y

    def perturb(self, x, x_len, device = 'cpu'):
      x = self.embedding(x)
      idx = torch.randint(low = 0, high = x.shape[1], size = (x.shape[0],))
      s = m.sample((x.shape[0], 1, x.shape[2]))
      x[:,idx,:] = x[:,idx,:] + s.to(device)
      x = self.dropout(x)
      x = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=False)
      out1, (h_n, c_n) = self.lstm(x)
      x, lengths = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
      x, attn = self.attention(x, lengths) # skip connect
      y = self.fc1(self.dropout(x))
      y = self.fc2(self.dropout(y))
      y = self.sigmoid(y.squeeze())
      return x, y



## Objective function
- Objective function used to train the SEAT

In [None]:
import torch
import torch.nn as nn

def pgd_attack(model, text, label, seq_length, alpha, seat_attention, loss_fn, step_size, eps, batch_size):
    max_seq = max(seq_length)
    seat = TanhAttention(hidden_size = num_hiddens*2).to(device)
    seat.load_state_dict(seat_attention.state_dict())
    delta = torch.zeros_like(model.attention.attn2.weight).to(device)
    init_attention = model.attention
    model.attention = seat
    init_weight = seat.attn2.weight.clone().to(device)
    for i in range(step_size):
        model.attention.attn2.weight = nn.Parameter(init_weight)
        output = model(text, seq_length)
        # Apply attention + delta on a specific layer
        model.attention.attn2.weight = nn.Parameter(init_weight + delta)
        output_d = model(text, seq_length)
        loss = loss_fn(output, output_d)
        loss.backward()
        delta.data = (delta + alpha*model.attention.attn2.weight.grad.data.sign()).clamp(-eps, eps)
        model.zero_grad()

    model.attention = init_attention
    return delta


def stability_loss(model, text, seq_length, seat_attention, delta, loss_fn):
    init_attention = model.attention
    seat = TanhAttention(hidden_size = num_hiddens*2).to(device)
    seat.load_state_dict(seat_attention.state_dict())
    model.attention = seat
    init_weight = seat.attn2.weight.clone().to(device)
    # Apply seat attention on a specific layer
    output = model(text, seq_length)
    # Apply seat attention + delta on a specific layer
    model.attention.attn2.weight = nn.Parameter(init_weight + delta)
    output_d = model(text, seq_length)
    model.attention.att_weights = init_attention
    return loss_fn(output, output_d)

def similarity_loss(model, text, seq_length, seat_attention, loss_fn):
    init_attention = model.attention
    output = model(text, seq_length)
    # Apply SEAT on a specific layer
    model.attention = seat_attention
    output_s = model(text, seq_length)
    model.attention = init_attention
    return loss_fn(output, output_s)

def topk_loss(model, text, seq_length, seat_attention, k=7):
    init_attention = model.attention
    criterion = nn.L1Loss()
    attn_base, _ = model.atten_forward(text, seq_length)
    top_k_attention_vectors, _ = torch.topk(attn_base, k = k)
    model.attention = seat_attention
    attn_seat, _ = model.atten_forward(text, seq_length)
    top_k_seat_vectors, _ = torch.topk(attn_seat, k = k)
    model.attention = init_attention
    return_loss = (1/(2*k))*(criterion(top_k_seat_vectors, top_k_attention_vectors) + criterion(top_k_attention_vectors, top_k_seat_vectors))
    return return_loss/text.shape[0]

## Dataset
We are using the IMDB Reviews Dataset

In [None]:
import torch
from torch.utils.data import Dataset
from torchtext.data.utils import get_tokenizer
from nltk.corpus import stopwords
import nltk
import re
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
# Define the tokenizer
tokenizer = get_tokenizer('basic_english')
def stringprocess(text):
    text = re.sub(r"what's", "what is ", text)
    text = re.sub(r"\'s", " is", text)
    text = re.sub(r"\'ve", " have ", text)
    text = re.sub(r"can't", "cannot ", text)
    text = re.sub(r"n't", " not ", text)
    text = re.sub(r"i'm", "i am ", text)
    text = re.sub(r"\'re", " are ", text)
    text = re.sub(r"\'d", " would ", text)
    text = re.sub(r"\'ll", " will ", text)
    text = re.sub(r"\'scuse", " excuse ", text)
    text = re.sub('\W', ' ', text)
    text = re.sub('\s+', ' ', text)
    text = re.sub(r"\#", "", text)
    text = re.sub(r"http\S+","URL", text)
    text = re.sub(r"@", "", text)
    text = re.sub(r"[^A-Za-z0-9()!?\'\`\"]", " ", text)
    text = re.sub("\s{2,}", " ", text)
    text = text.strip(' ')
    text = text.lower()

    return text

def tokenprocess(text):
    text_tokens = tokenizer(text)
    # Filter tokens based on their frequency
    filtered_tokens = [token for token in text_tokens if token not in stop_words]
    return filtered_tokens
class IMDBDataset(Dataset):
    def __init__(self, df, tokenizer, vocab, max_length=500):
        self.data = df['review']
        self.targets = df['sentiment']
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.vocab_dict = {token: index for index, token in enumerate(vocab)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # Get the data and target for the given index
        data_point = self.data.iloc[index]
        data_point = stringprocess(data_point)
        word_tokens = tokenprocess(data_point)
        target = self.targets.iloc[index]

        # Truncate the data point to the specified max length
        truncated_data = word_tokens[:self.max_length]
        data_ids = [self.vocab_dict[word] for word in truncated_data if self.vocab_dict.get(word) is not None]

        return torch.tensor(data_ids), target

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [None]:
from nltk.corpus import stopwords
import nltk
import pandas as pd
import numpy as np
import re
import string
from string import digits
from collections import Counter
from torchtext.data.utils import get_tokenizer
from sklearn.model_selection import train_test_split
from torchtext.vocab import GloVe
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import time
import torch
import matplotlib.pyplot as plt

stop_words = set(stopwords.words('english'))
# Define the tokenizer
tokenizer = get_tokenizer('basic_english')
def stringprocess(text):
    text = re.sub(r"what's", "what is ", text)
    text = re.sub(r"\'s", " is", text)
    text = re.sub(r"\'ve", " have ", text)
    text = re.sub(r"can't", "cannot ", text)
    text = re.sub(r"n't", " not ", text)
    text = re.sub(r"i'm", "i am ", text)
    text = re.sub(r"\'re", " are ", text)
    text = re.sub(r"\'d", " would ", text)
    text = re.sub(r"\'ll", " will ", text)
    text = re.sub(r"\'scuse", " excuse ", text)
    text = re.sub('\W', ' ', text)
    text = re.sub('\s+', ' ', text)
    text = re.sub(r"\#", "", text)
    text = re.sub(r"http\S+","URL", text)
    text = re.sub(r"@", "", text)
    text = re.sub(r"[^A-Za-z0-9()!?\'\`\"]", " ", text)
    text = re.sub("\s{2,}", " ", text)
    text = text.strip(' ')
    text = text.lower()

    return text

def tokenprocess(text):
    text_tokens = tokenizer(text)
    # Filter tokens based on their frequency
    filtered_tokens = [token for token in text_tokens if token not in stop_words]
    return filtered_tokens

def collate_fn(batch):
    # Sort the batch in descending order of input sequence lengths
    batch = sorted(batch, key=lambda x: len(x[0]), reverse=True)

    # Separate inputs and labels
    inputs, labels = zip(*batch)

    # Get the lengths of each input sequence
    input_lengths = [len(x) for x in inputs]

    # Pad the input sequences to the length of the longest sequence
    padded_inputs = pad_sequence(inputs, batch_first=True)

    return padded_inputs, torch.tensor(labels, dtype=torch.float32), input_lengths

def load_dataset(data_folder = 'dataset/IMDB Dataset.csv', batch_size = 32):
    df = pd.read_csv(data_folder)
    # Create a mapping dictionary
    label_mapping = {'positive': 1, 'negative': 0}

    # Convert labels using the mapping dictionary
    df['sentiment'] = df['sentiment'].map(label_mapping)

    # Split the data into train and test sets
    train_df, test_df = train_test_split(df, test_size=0.5, random_state=1234)


    # Example usage
    print("Train set size:", len(train_df))
    print("Test set size:", len(test_df))
    stop_words = set(stopwords.words('english'))


    X = df["review"]

    X = X.apply(stringprocess)
    word_tokens = list(X.apply(tokenprocess))

    word_tokens_flat = [item for sublist in word_tokens for item in sublist]

    # Collect unique tokens from the dataset
    vocab = set()
    for data_point in word_tokens:
        vocab.update(data_point)

    # Step 1: Determine word frequencies
    word_frequency = {}
    for word in word_tokens_flat:
        if word in word_frequency:
            word_frequency[word] += 1
        else:
            word_frequency[word] = 1

    # Step 2: Define threshold frequency
    threshold = 4

    # Step 3: Create filtered list
    vocab = [word for word in vocab if word_frequency[word] >= threshold]

    # Convert the set of unique tokens to a list
    vocab = list(vocab)
    vocab = ['<pad>'] + vocab

    print(len(vocab))

    # Example usage: Print the vocabulary
    print(vocab[:50])

    # Count the number of tokens per data point
    token_counts = []
    for data_point in word_tokens:
        token_count = len(data_point)
        token_counts.append(token_count)


    # # Load GloVe embeddings
    # # Load a subset of GloVe embeddings
    glove = GloVe(name='6B', dim=300)

    # # Create a matrix to store GloVe embeddings
    embedding_matrix = np.zeros((len(vocab), 300))


    # # Fill the embedding matrix
    for i, token in enumerate(vocab):
        embedding_matrix[i] = glove[token]

    print("---------Saved pretrained embedding---------")
    np.save('embeddings.npy', embedding_matrix)


    train_dataset = IMDBDataset(train_df, tokenizer, vocab)
    test_dataset = IMDBDataset(test_df, tokenizer, vocab)


    # Create a DataLoader for batching and shuffling
    batch_size = 32
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
    return train_dataloader, test_dataloader



# Training

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
train_loader, test_loader = load_dataset('/content/drive/MyDrive/SEAT/IMDB Dataset.csv')

Train set size: 25000
Test set size: 25000
44127
['<pad>', 'wrinkled', 'split', 'chill', 'sticking', 'geico', 'bladerunner', 'cunningly', 'aristocrats', 'jordan', 'attach', 'torpedoed', 'sneer', 'millennium', '1988', 'eli', 'cub', 'folds', 'austerity', 'increasing', 'cisco', 'amusing', 'transparent', 'grisham', 'dogfight', 'coppers', 'emptied', 'lazarou', 'customs', 'bathing', 'voiced', 'burly', 'heflin', 'em', 'inopportune', 'incorporate', 'chhappan', 'pointer', 'bossy', 'upping', 'dissipate', 'newcomer', 'bribing', 'inbred', 'lezlie', 'realises', 'bjarne', 'romane', 'crumpet', 'rubbish']
---------Saved pretrained embedding---------


## Initializing the model
Hyperparams are based on the paper

In [None]:
embed_size, num_hiddens, num_layers, device = 300, 128, 1, torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = AttentionLSTM(
        vocab_size = len(train_loader.dataset.vocab_dict),
        emb_dim = embed_size,
        hidden_size = num_hiddens,
        num_classes = 1,
        dropout = 0.4,
)

net.to(device)

def init_weights(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)
    if type(module) == nn.LSTM:
        for param in module._flat_weights_names:
            if "weight" in param:
                nn.init.xavier_uniform_(module._parameters[param])
net.apply(init_weights)



AttentionLSTM(
  (embedding): Embedding(44127, 300)
  (dropout): Dropout(p=0.4, inplace=False)
  (lstm): LSTM(300, 128, bidirectional=True)
  (attention): TanhAttention(
    (attn1): Linear(in_features=256, out_features=128, bias=True)
    (attn2): Linear(in_features=128, out_features=1, bias=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (fc2): Linear(in_features=256, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [None]:
net.load_state_dict(torch.load('/content/drive/MyDrive/SEAT/models/imdb_bilstm_tanh_attention_glove_300d.pt', map_location = device))

<All keys matched successfully>

In [None]:
net = net.to(device)

## SEAT
Initializing the proposed SEAT Attention

In [None]:
seat_w = TanhAttention(hidden_size = num_hiddens*2).to(device)

In [None]:
seat_w

TanhAttention(
  (attn1): Linear(in_features=256, out_features=128, bias=True)
  (attn2): Linear(in_features=128, out_features=1, bias=False)
)

In [None]:
import torch.optim as optim
optimizer = optim.Adam(seat_w.parameters(), lr = 0.01)

In [None]:
alpha = 1e-4
loss_fn = nn.BCELoss()
step_size = 20
eps = 0.1

In [None]:
def seat_objective_fn(text, label, seq_length, iters = 100):
    #delta_o = torch.randn(seat_w.shape).to(device)
    delta = pgd_attack(net, text, label, seq_length, alpha, seat_w, loss_fn, step_size, eps, 32)
    st_loss = stability_loss(net, text, seq_length, seat_w, delta, loss_fn)
    sim_loss = similarity_loss(net, text, seq_length, seat_w, loss_fn)
    tk_loss = topk_loss(net, text, seq_length, seat_w, k=7)
    loss = st_loss + lambda1*sim_loss + lambda2*tk_loss
    return loss

In [None]:
from tqdm.auto import tqdm

## Metrics
- As in the paper we are using JSD and TVD for metrics

In [None]:
class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)).log_softmax(-1), q.view(-1, q.size(-1)).log_softmax(-1)
        m = (0.5 * (p + q))
        return 0.5 * (self.kl(m, p) + self.kl(m, q))

In [None]:
def total_variation_distance_from_logits(logits_p, logits_q):
    """
    Calculate the Total Variation Distance between two probability distributions.

    Args:
        logits_p (torch.Tensor): The logits (raw network outputs) for the first distribution.
        logits_q (torch.Tensor): The logits (raw network outputs) for the second distribution.

    Returns:
        torch.Tensor: The Total Variation Distance between the probability distributions.
    """
    # Convert logits to probabilities using the softmax function
    p = F.softmax(logits_p, dim=-1)
    q = F.softmax(logits_q, dim=-1)

    # Calculate the Total Variation Distance
    cdf_p = torch.cumsum(p, dim=-1)
    cdf_q = torch.cumsum(q, dim=-1)
    return 0.5 * torch.sum(torch.abs(cdf_p - cdf_q))

## Evaluate loop

In [None]:
jsd = JSD()

In [None]:
def evaluate(net, seat_w):
  with tqdm(test_loader, unit="batch") as tepoch:
    jsd_score_perturb = 0
    jsd_score_seed = 0
    tvd_score_perturb = 0
    tvd_score_seed = 0
    tepoch.set_description(f"SEAT Evaluation")
    with torch.no_grad():
      for idx, (text, label, seq_length) in enumerate(tepoch):
        text = text.to(device)
        label = label.to(device)
        attn_vanilla, output_vanilla = net.atten_forward(text, seq_length)
        init_attention = net.attention
        net.attention = seat_w
        attn_perturbed, output_perturbed = net.perturb(text, seq_length, device = device)
        attn_seat, output_seat = net.atten_forward(text, seq_length)

        jsd_score_perturb += jsd(attn_seat, attn_perturbed).item()
        jsd_score_seed += jsd(attn_seat, attn_vanilla).item()
        tvd_score_perturb += total_variation_distance_from_logits(output_seat, output_perturbed).item()
        tvd_score_seed += total_variation_distance_from_logits(output_seat, output_vanilla).item()
        net.attention = init_attention


        #print(seat_w.state_dict())

        #print(seat_w.attn1.weight)
        tepoch.set_postfix(
            jsd_on_word_perturb = jsd_score_perturb/(idx+1),
            jsd_base_vs_seat = jsd_score_seed/(idx+1),
            tvd_on_word_perturb = tvd_score_perturb/(idx+1),
            tvd_base_vs_seat = tvd_score_seed/(idx+1)
        )

In [None]:
evaluate(net, seat_w)

  0%|          | 0/782 [00:00<?, ?batch/s]

## Train Loop

In [None]:
for epoch in range(20):
  lambda1 = 1
  lambda2 = 1000
  running_loss = 0
  with tqdm(test_loader, unit="batch") as tepoch:
    tepoch.set_description(f"SEAT Training")
    for idx, (text, label, seq_length) in enumerate(tepoch):
      optimizer.zero_grad()
      text = text.to(device)
      label = label.to(device)
      #print(seat_w.attn1.weight)
      loss = seat_objective_fn(text, label, seq_length)
      loss.backward()

      running_loss += loss.item()

      optimizer.step()

      #print(seat_w.state_dict())

      #print(seat_w.attn1.weight)
      tepoch.set_postfix(loss = running_loss/(idx+1))
  torch.save(seat_w.state_dict(), 'seat_epoch_{}.pth'.format(epoch+1))
  evaluate(net, seat_w)

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

  0%|          | 0/782 [00:00<?, ?batch/s]

RuntimeError: ignored

In [None]:
seat_w.state_dict()

RecursionError: ignored

In [None]:
lambda1 = 1
lambda2 = 1000
lr = 1e-4
for i in range(100):
    delta_o = torch.randn(seat_w.shape)
    delta = pgd_attack(net, text, label, seq_length, delta_o, alpha, seat_w, loss_fn, step_size, eps)
    st_loss = stability_loss(net, text, seq_length, seat_w, nn.Parameter(seat_w + delta), loss_fn)
    sim_loss = similarity_loss(net, text, seq_length, seat_w, loss_fn)
    tk_loss = topk_loss(net, text, seq_length, seat_w, k=7)
    optimizer.zero_grad()
    loss = st_loss + lambda1*sim_loss + lambda2*tk_loss
    loss.backward()
    '''
    st_grad = torch.autograd.grad(st_loss, seat_w,
                                       retain_graph=False, create_graph=False)[0]
    sim_grad = torch.autograd.grad(sim_loss, seat_w, retain_graph=False, create_graph=False)[0]
    tk_loss.backward()
    seat_w = nn.Parameter(seat_w - lr*(sim_grad - (st_grad*lambda1) -(seat_w.grad*lambda2)))
    '''
    optimizer.step()
    print("Loss: {}".format(loss.item()))

AttributeError: ignored