In [None]:
import sys
!{sys.executable} -m pip install torch transformers datasets nltk jupyter ipywidgets

In [None]:
import pandas as pd
import os
import seaborn as sns
from matplotlib import pyplot as plt

from tqdm import tqdm
from sklearn.model_selection import train_test_split

# To import the Transformer Models
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import BertTokenizer, BertModel
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

from datasets import Dataset

import re
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim

### Dataset Preprocessing

In [None]:
# The dataset does not contain class labels, so we need to explicitly provide it
data_path = '../data/ag-news-classification-dataset'
train_df=pd.read_csv(os.path.join(data_path,'train.csv'),names=['label','Title','Description'])
val_df=pd.read_csv(os.path.join(data_path,'test.csv'),names=['label','Title','Description'])

In [None]:
# concatenating the 'title' and 'description' column
train_df['text']=(train_df['Title']+ " " + train_df['Description'])
train_df.drop(columns=['Title','Description'],axis=1,inplace=True)
train_df.head()

In [None]:
# concatenating the 'title' and 'description' column
val_df['text']=(val_df['Title']+val_df['Description'])
val_df.drop(columns=['Title','Description'],axis=1,inplace=True)
val_df.head()

In [None]:
def remove_punctuations(text):
    text=re.sub(r'[\\-]',' ',text)
    text=re.sub(r'[,.$#?;:\'(){}!|0-9]',' ',text)
    return text

# the apply method applies a function along an axis of dataframe
train_df['text']=train_df['text'].apply(remove_punctuations)
train_df.head()

In [None]:
val_df['text']=val_df['text'].apply(remove_punctuations)
val_df.head()

In [None]:
english_stopwords = stopwords.words('english')

def remove_stopwords(text):
    clean_text=[]
    for word in text.split(' '):
        if word not in english_stopwords:
            clean_text.append(word)
    return ' '.join(clean_text)

# remove stopwords
train_df['text']=train_df['text'].apply(remove_stopwords)

# the class label in dataset contains labels as 1,2,3,4 but the model needs 0,1,2,3, so we subtract 1 from all
train_df['label']=train_df['label'].apply(lambda x:x-1)

# remove stopwords
val_df['text']=val_df['text'].apply(remove_stopwords)

# the class label in dataset contains labels as 1,2,3,4 but the model needs 0,1,2,3, so we subtract 1 from all
val_df['label']=val_df['label'].apply(lambda x:x-1)


In [None]:
train_df,test_df=train_test_split(train_df[['text','label']],train_size=.3,shuffle=True, random_state=0)
train_df.reset_index(inplace=True)
test_df.reset_index(inplace=True)

In [None]:
# training set has 36000 samples and testing set has 10000 samples for the purpose of a fast training loop
test_df = test_df[:10000]
train_df.shape,test_df.shape

### Load BERT models

In [None]:
raw_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
raw_model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
raw_model.eval()

In [None]:
tokenizer = AutoTokenizer.from_pretrained("wesleyacheng/news-topic-classification-with-bert")
model = AutoModelForSequenceClassification.from_pretrained("wesleyacheng/news-topic-classification-with-bert")

### Specify Custom Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.encoding = None

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.loc[idx]
        encoding = self.tokenizer.encode_plus(
            row['text'],
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        self.encoding = encoding
        
        input_ids = encoding['input_ids'].flatten()
            
        labels = row['label']
        return input_ids, torch.tensor(labels)
    
# Assuming train_df is your training dataframe and tokenizer is defined
dataset = CustomDataset(train_df, tokenizer, max_length=512)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

### Train InterpretCC Feature Gating Model

In [None]:
def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> torch.Tensor:
    """
    Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
    The discretization converts the values greater than `threshold` to 1 and the rest to 0.
    The code is adapted from the official PyTorch implementation of gumbel_softmax:
    https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax

    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized,
            but will be differentiated as if it is the soft sample in autograd
     threshold: threshold for the discretization,
                values greater than this will be set to 1 and the rest to 0

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Sigmoid distribution.
      If ``hard=True``, the returned samples are discretized according to `threshold`, otherwise they will
      be probability distributions.

    """
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0, 1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits, tau)
    y_soft = gumbels.sigmoid()

    if hard:
        # Straight through.
        indices = (y_soft > threshold).nonzero(as_tuple=True)
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
        y_hard[indices[0], indices[1]] = 1.0
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

In [None]:
# define hyperparameters

layer_size = 30
input_embedding = 768
learning_rate = 2e-5
num_epochs = 1
thres = 0.7


# define discriminator layers

discriminator = nn.Sequential(
    nn.Linear(input_embedding, layer_size),
    nn.Linear(layer_size, 1),
)

# compose feature gating model

interpret_model = nn.Sequential(
    discriminator,
    model
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(interpret_model[0].parameters(), lr = learning_rate)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
interpret_model.to(device)
raw_model.to(device)

In [None]:
# training loop

correct = 0
total = 0
print('initialized model training')
for epoch in range(num_epochs):
    interpret_model.train()
    for batch in tqdm(data_loader):
        input_ids, labels = batch
        input_ids, labels = input_ids.to(device), labels.to(device)

        optimizer.zero_grad()

        # Pass token_type_ids to the model
        outputs = raw_model(input_ids)
        embeddings = outputs.last_hidden_state
        output = interpret_model[0](embeddings)
        g_mask = gumbel_sigmoid(output, tau=1, hard=True, threshold=thres).squeeze()

        predictions = interpret_model[1](input_ids, attention_mask=g_mask)
        y_pred = torch.argmax(predictions.logits, 1)
        correct += torch.sum(y_pred == labels)
        total += len(labels)

        loss = criterion(predictions.logits, labels)
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}, Accuracy: {correct/total}")

In [None]:
torch.save(model, 'interpretcc_text_sigmoid.pt')

### Evaluate InterpretCC Feature Gating

In [None]:
model = torch.load('interpretcc_text_sigmoid.pt')

In [None]:
interpret_model.eval()
interpret_model.to(device)

test_preds = []
correct = 0
total = 0
with torch.no_grad():
    count = 0
    for i in test_df.index:
        count += 1
        if count % 100 == 0:
            print('test sample ', str(count))
        row = test_df.loc[i]
        input_ids = tokenizer.encode(row['text'], add_special_tokens=True)
        input_ids = torch.tensor([input_ids])
        labels = torch.Tensor([row['label']]).type(torch.LongTensor)
        input_ids, labels = input_ids.to(device), labels.to(device)
        
        embeddings = raw_model(input_ids).last_hidden_state
        output = interpret_model[0](embeddings)
        g_mask = gumbel_sigmoid(output, tau=1, hard=True, threshold=thres).squeeze()

        predictions = interpret_model[1](input_ids, attention_mask=g_mask)
        test_preds.append(predictions)
        predicted_labels = torch.argmax(predictions['logits'])
        correct += (predicted_labels == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Accuracy: {accuracy * 100}%")