In [None]:
import sys
!{sys.executable} -m pip install torch transformers datasets nltk jupyter ipywidgets sentence-transformers wandb

In [None]:
import sys
import copy
import pickle

import torch
import numpy as np
from scipy.spatial.distance import cdist

from sentence_transformers import SentenceTransformer

import torch
from torch import nn
from transformers import BertTokenizer, BertModel

import pandas as pd
import os
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

# To import the Transformer Models
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

# to convert to Dataset datatype - the transformers library does not work well with pandas
from datasets import Dataset

import re
import nltk
# nltk.download('stopwords')
from nltk.corpus import stopwords

import torch
import torch.nn as nn
import torch.optim as optim


import random
random.seed(0)
import wandb

from torch.utils.data import DataLoader, Dataset

### Dataset Preprocessing

In [None]:
# The dataset does not contain class labels, so we need to explicitly provide it
data_path = '../data/ag-news-classification-dataset'
train_df=pd.read_csv(os.path.join(data_path,'train.csv'),names=['label','Title','Description'])
val_df=pd.read_csv(os.path.join(data_path,'test.csv'),names=['label','Title','Description'])

In [None]:
# concatenating the 'title' and 'description' column
train_df['text']=(train_df['Title']+train_df['Description'])
train_df.drop(columns=['Title','Description'],axis=1,inplace=True)
train_df.head()

val_df['text']=(val_df['Title']+val_df['Description'])
val_df.drop(columns=['Title','Description'],axis=1,inplace=True)
val_df.head()

In [None]:
def remove_punctuations(text):
    text=re.sub(r'[\\//-]',' ',text)
    text=re.sub(r'[,.$#?;:\'(){}!|0-9]',' ',text)
    return text

# the apply method applies a function along an axis of dataframe
train_df['text']=train_df['text'].apply(remove_punctuations)
train_df.head()

In [None]:
val_df['text']=val_df['text'].apply(remove_punctuations)
val_df.head()

In [None]:
english_stopwords = stopwords.words('english')

def remove_stopwords(text):
    clean_text=[]
    for word in text.split(' '):
        if word not in english_stopwords:
            clean_text.append(word)
    return ' '.join(clean_text)

# remove stopwords
train_df['text']=train_df['text'].apply(remove_stopwords)

# the class label in dataset contains labels as 1,2,3,4 but the model needs 0,1,2,3, so we subtract 1 from all
train_df['label']=train_df['label'].apply(lambda x:x-1)

# remove stopwords
val_df['text']=val_df['text'].apply(remove_stopwords)

# the class label in dataset contains labels as 1,2,3,4 but the model needs 0,1,2,3, so we subtract 1 from all
val_df['label']=val_df['label'].apply(lambda x:x-1)


In [None]:
train_df,test_df=train_test_split(train_df[['text','label']],train_size=.3,shuffle=True, random_state=0)
train_df.reset_index(inplace=True)
test_df.reset_index(inplace=True)

In [None]:
# training set has 36000 samples and testing set has 10000 samples for the purpose of a fast training loop
test_df = test_df[:10000]
train_df.shape,test_df.shape

### Router: Classify words into Dewey Decimal Code Categories

In [None]:
# Obtain Dewey Decimal Code subcategorization

with open('../data/ddc_subcategories.pkl', 'rb') as f:
    ddc_subcategories = pickle.load(f)

In [None]:
# Load a pre-trained Sentence-BERT model
sent_model = SentenceTransformer('sentence-transformers/bert-base-nli-mean-tokens')

def get_sent_embedding(text):
    embedding = sent_model.encode(text)
    return embedding

def get_word_embeddings_from_sent(sent):
    words = sent.split(' ')
    return list(map(get_sent_embedding, words))
    
# Flatten the list of subcategories and maintain a map to their main category
flattened_subcategories = []
category_map = {} 
for main_cat, subcats in ddc_subcategories.items():
    for subcat in subcats:
        flattened_subcategories.append(subcat)
        category_map[subcat] = main_cat

# Compute embeddings for each subcategory
subcategory_embeddings = np.array([get_sent_embedding(subcat) for subcat in flattened_subcategories])

In [None]:
# Example word sanity check
word = "school"
word_embedding = get_sent_embedding(word)
print("word:", word)

# Calculate distances
distances = cdist([word_embedding], subcategory_embeddings, metric='cosine').squeeze()

# Find the closest category
closest_idx = np.argmin(distances)
closest_category = flattened_subcategories[closest_idx]
closest_main_category = int(category_map[closest_category])/100
print(closest_idx, closest_category, closest_main_category)

sorted(zip(distances, flattened_subcategories))[:3]

### Load BERT Models

In [None]:
tokenizer = AutoTokenizer.from_pretrained("wesleyacheng/news-topic-classification-with-bert")
model = AutoModelForSequenceClassification.from_pretrained("wesleyacheng/news-topic-classification-with-bert")

### Specify Custom Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.encoding = None

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.loc[idx]

        # routing
        word_embeddings = get_word_embeddings_from_sent(row['text'])

        # Calculate distances
        distances = [cdist([w], subcategory_embeddings, metric='cosine').squeeze() for w in word_embeddings]
        idxs = [np.argmin(d) for d in distances]
        closest_categories = [flattened_subcategories[closest_idx] for closest_idx in idxs]
        closest_main_categories = [int(int(category_map[closest_category])/100) for closest_category in closest_categories]
        closest_main_categories = list(np.asarray(closest_main_categories) + 1)
        routing = [0] + closest_main_categories + [0]*(self.max_length - 1 - len(closest_main_categories))

        input_ids = tokenizer(row['text'], padding="max_length", truncation=True, return_tensors="pt", max_length=512)
            
        label = row['label']
        routing = np.array(routing)
        routing = routing.reshape(routing.shape[0], 1)
        return row['text'], input_ids, routing, torch.tensor(label)


### Training the InterpretCC Gated Routing Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t = 0.5
tau = 1
b = 8
l = 16
num_of_subnetworks = 10
num_classes = 4
num_epochs = 5
lr = 1e-5
scheduler_flag = False
ss = 10
gamma = 0.5
train_num = 36000
test_num = 3000
sentence_embedding_size = 768
max_word_length = 512
mul_thres = 0.1

best_val_accuracy = 0
val_count = 0

experiment = "interpretcc_gated_routing"

config = {"t": t, "tau": tau, "batch": b, "layer": l, "architecture": experiment, 
          "num_of_subnetworks": num_of_subnetworks, "num_classes": num_classes, 
          "num_epochs": num_epochs, "lr": lr, "scheduler_flag": scheduler_flag, 
          "ss": ss, "gamma": gamma, "train_num": train_num, "test_num": test_num, 
          "sentence_embedding_size": sentence_embedding_size, 
          "max_word_length": max_word_length}

user = "vinitra"
project = "interpretcc"
display_name = experiment

In [None]:
def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> torch.Tensor:
    """
    Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
    The discretization converts the values greater than `threshold` to 1 and the rest to 0.
    The code is adapted from the official PyTorch implementation of gumbel_softmax:
    https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax

    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized,
            but will be differentiated as if it is the soft sample in autograd
     threshold: threshold for the discretization,
                values greater than this will be set to 1 and the rest to 0

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Sigmoid distribution.
      If ``hard=True``, the returned samples are discretized according to `threshold`, otherwise they will
      be probability distributions.

    """
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0, 1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits, tau)
    y_soft = gumbels.sigmoid()

    if hard:
        # Straight through.
        indices = (y_soft > threshold).nonzero(as_tuple=True)
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
        y_hard[indices[0], indices[1]] = 1.0
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret


In [None]:
with wandb.init(entity=user, project=project, name=display_name, config=config, mode="online") as run:

    # Initializations of parameters
    model.to(device)

    best_val_accuracy = 0
    val_count = 0

    dataset = CustomDataset(train_df[:train_num], tokenizer, max_length=max_word_length)
    data_loader = DataLoader(dataset, batch_size=b, shuffle=True)

    test_dataset = CustomDataset(test_df[:test_num], tokenizer, max_length=max_word_length)
    test_loader = DataLoader(test_dataset, batch_size=b, shuffle=True)

    # Define discriminator network
    discriminator = nn.Sequential(
        nn.Linear(sentence_embedding_size + max_word_length, l),
        nn.Linear(l, num_of_subnetworks), # predict for each subnetwork at once
        nn.Softmax(dim=1)
    )


    # 10 copies of subnetworks
    subnetworks = [copy.deepcopy(model) for i in np.arange(num_of_subnetworks)]

    # assemble discriminators and subnetworks in module lists
    interpret_models = nn.ModuleList([discriminator] + subnetworks)
    # interpret_models = torch.load('26Jan_BestVal_36000Train')
    [sub.to(device) for sub in interpret_models]

    # Initialize optimizers
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(interpret_models.parameters(), lr=lr)

    # Define a learning rate scheduler
    if scheduler_flag:
        scheduler = StepLR(optimizer, step_size=ss, gamma=gamma)

    tracking = {'text': [], 'input_ids': [], 'routing': [], 'adaptive': [], 'attn_mask': [], 'labels': []}

    print('initialized model training')
    for epoch in range(num_epochs):
        # track correct and total for accuracy
        correct = {k:0 for k in np.arange(num_of_subnetworks + 1)}
        total = {k:0 for k in np.arange(num_of_subnetworks + 1)}
        interpret_models.train()

        # for each batch
        for batch in tqdm(data_loader):

            raw_text, tokenizer_output_raw, routing, labels = batch
            # send to GPU
            labels = labels.type(torch.LongTensor).to(device)
            routing, input_ids = routing.to(device), tokenizer_output_raw['input_ids'].to(device)
            original_attn_mask = tokenizer_output_raw['attention_mask'].to(device)
            tracking['text'].append(raw_text)
            tracking['input_ids'].append(input_ids)
            tracking['routing'].append(routing.squeeze())
            tracking['labels'].append(labels)
            b = len(labels)

            optimizer.zero_grad()

            # for all subnetworks, mask each subnetwork's assignment
            # batch x max_word_length x classes
            route_mask = nn.functional.one_hot(routing.squeeze(), num_classes=num_of_subnetworks+1)[:,:,1:]

            # create gumbel mask
            embeddings = torch.Tensor(sent_model.encode(raw_text)).to(device)
            concat_inputs_routing = torch.cat((routing.squeeze(), embeddings), 1)
            output = interpret_models[0](concat_inputs_routing)
            adaptive_mask_repeat = output.repeat_interleave(max_word_length,dim=0).reshape(b,max_word_length,num_of_subnetworks)
            tracking['adaptive'].append(output)


            # AND both masks together
            attn_mask = torch.mul(adaptive_mask_repeat, route_mask) > mul_thres
            attn_mask = attn_mask.type(torch.IntTensor).to(device)
            tracking['attn_mask'].append(attn_mask)


            subnet_predictions = []
            for subnet in np.arange(num_of_subnetworks):
                subnet_attn = attn_mask[:,:, subnet]
                if sum(sum(subnet_attn)) != 0:
                    predictions = interpret_models[subnet+1](input_ids.squeeze(), attention_mask=subnet_attn)
                    subnet_predictions.append(predictions.logits)

                    y_pred = torch.argmax(predictions.logits, 1)
                    correct[subnet] += torch.sum(y_pred == labels)
                    total[subnet] += len(labels)
                else:
                    subnet_predictions.append(torch.zeros((b, num_classes)).to(device))

            g_weighting = output.t().repeat_interleave(b).reshape(num_of_subnetworks, b, b)[:,:,:num_classes]
            subnet_predictions = torch.stack(subnet_predictions, dim=0)
            weighted_predictions = torch.mul(g_weighting, subnet_predictions)
            weighted_predictions = torch.sum(weighted_predictions, dim=0)

            loss = criterion(weighted_predictions, labels)
            loss.backward()
            optimizer.step()


            y_pred = torch.argmax(weighted_predictions, 1)
            correct[num_of_subnetworks] += torch.sum(y_pred == labels)
            total[num_of_subnetworks] += len(labels)

            if scheduler_flag:
                scheduler.step()

        print(f"Epoch {epoch + 1}, Loss: {loss.item()}, Train Accuracy: {correct[num_of_subnetworks]/total[num_of_subnetworks]}")
        [print("Subnet: ", subnet, "| Train Accuracy:", correct[subnet]/(total[subnet]+0.00001), "| Total:", total[subnet]) for subnet in np.arange(num_of_subnetworks)]
        run.log({"Epoch": epoch + 1, "Subnet": -1, "Loss": loss.item(), "Accuracy": correct[num_of_subnetworks]/total[num_of_subnetworks]})
        [run.log({"Epoch": epoch + 1, "Subnet": subnet, "Train Accuracy": correct[subnet]/(total[subnet]+0.00001), "Train Total": total[subnet]}) for subnet in np.arange(num_of_subnetworks)]
        
        torch.save(interpret_models, experiment)

        with open('correct_' + experiment + '.pickle', 'wb') as file:
            pickle.dump(correct, file)

        with open('total_' + experiment + '.pickle', 'wb') as file:
            pickle.dump(total, file)

        with open('tracking_' + experiment + '.pickle', 'wb') as file:
            pickle.dump(tracking, file)

        interpret_models.eval()
        test_dataset = CustomDataset(test_df[:test_num], tokenizer, max_length=max_word_length)
        test_loader = DataLoader(test_dataset, batch_size=b, shuffle=True)

        with torch.no_grad():    
            # track correct and total for accuracy
            correct_test = {k:0 for k in np.arange(num_of_subnetworks + 1)}
            total_test = {k:0 for k in np.arange(num_of_subnetworks + 1)}
            tracking_test = {'text': [], 'input_ids': [], 'routing': [], 'adaptive': [], 'attn_mask': [], 'labels': []}

            # for each batch
            for batch in tqdm(test_loader):

                raw_text, tokenizer_output_raw, routing, labels = batch
                # send to GPU
                labels = labels.type(torch.LongTensor).to(device)
                routing, input_ids = routing.to(device), tokenizer_output_raw['input_ids'].to(device)
                original_attn_mask = tokenizer_output_raw['attention_mask'].to(device)
                tracking_test['text'].append(raw_text)
                tracking_test['input_ids'].append(input_ids)
                tracking_test['routing'].append(routing.squeeze())
                tracking_test['labels'].append(labels)
                b = len(labels)

                optimizer.zero_grad()

                # for all subnetworks, mask each subnetwork's assignment
                # batch x max_word_length x classes
                route_mask = nn.functional.one_hot(routing.squeeze(), num_classes=num_of_subnetworks+1)[:,:,1:]

                # create gumbel mask
                embeddings = torch.Tensor(sent_model.encode(raw_text)).to(device)
                concat_inputs_routing = torch.cat((routing.squeeze(), embeddings), 1)
                output = interpret_models[0](concat_inputs_routing)
                gumbel_mask = torch.Tensor(gumbel_sigmoid(output, tau=1, hard=True, threshold=t).squeeze()).to(device)
                adaptive_mask_repeat = gumbel_mask.repeat_interleave(max_word_length,dim=0).reshape(b,max_word_length,num_of_subnetworks)
                tracking['adaptive'].append(gumbel_mask)


                # AND both masks together
                attn_mask = torch.mul(adaptive_mask_repeat, route_mask) > mul_thres
                attn_mask = attn_mask.type(torch.IntTensor).to(device)
                tracking['attn_mask'].append(attn_mask)


                subnet_predictions = []
                for subnet in np.arange(num_of_subnetworks):
                    subnet_attn = attn_mask[:,:, subnet]
                    if sum(sum(subnet_attn)) != 0:
                        predictions = interpret_models[subnet+1](input_ids.squeeze(), attention_mask=subnet_attn)
                        subnet_predictions.append(predictions.logits)

                        y_pred = torch.argmax(predictions.logits, 1)
                        correct_test[subnet] += torch.sum(y_pred == labels)
                        total_test[subnet] += len(labels)

                    else:
                        subnet_predictions.append(torch.zeros((b, num_classes)).to(device))

                g_weighting = gumbel_mask.t().repeat_interleave(b).reshape(num_of_subnetworks, b, b)[:,:,:num_classes]
                subnet_predictions = torch.stack(subnet_predictions, dim=0)
                weighted_predictions = torch.mul(g_weighting, subnet_predictions)
                weighted_predictions = torch.sum(weighted_predictions, dim=0)

                y_pred = torch.argmax(weighted_predictions, 1)
                correct_test[num_of_subnetworks] += torch.sum(y_pred == labels)
                total_test[num_of_subnetworks] += len(labels)

            val_accuracy = correct_test[num_of_subnetworks]/(total_test[num_of_subnetworks])
            if val_accuracy > best_val_accuracy:
                best_model = interpret_models
                torch.save(best_model, experiment)
                best_val_accuracy = val_accuracy
            else:
                val_count += 1

            print("Val Accuracy: ", correct_test[num_of_subnetworks]/(total_test[num_of_subnetworks]))    
            [print("Subnet: ", subnet, "| Val Accuracy:", correct_test[subnet]/(total_test[subnet]+0.00001), "| Val Total:", total_test[subnet]) for subnet in np.arange(num_of_subnetworks)]
            run.log({"Epoch": epoch + 1, "Subnet": -1, "Loss": loss.item(), "Val Accuracy": correct[num_of_subnetworks]/total[num_of_subnetworks]})
            [run.log({"Epoch": epoch + 1, "Subnet": subnet, "Val Accuracy": correct[subnet]/(total[subnet]+0.00001), "Val Total": total[subnet]}) for subnet in np.arange(num_of_subnetworks)]
            
            if val_count > 1:
                torch.save(best_model, experiment)
                break

torch.save(best_model, experiment)

with open('correct_' + experiment + '.pickle', 'wb') as file:
    pickle.dump(correct, file)

with open('total_' + experiment + '.pickle', 'wb') as file:
    pickle.dump(total, file)

with open('tracking_' + experiment + '.pickle', 'wb') as file:
    pickle.dump(tracking, file)

### Evaluate InterpretCC Gated Routing Model

In [None]:
new_interpret = torch.load(experiment)

In [None]:
new_interpret.eval()
new_interpret.to(device)

test_dataset = CustomDataset(test_df[test_num:test_num*2].reset_index(), tokenizer, max_length=max_word_length)
test_loader = DataLoader(test_dataset, batch_size=b, shuffle=True)

with torch.no_grad():    
    # track correct and total for accuracy
    correct_test = {k:0 for k in np.arange(num_of_subnetworks + 1)}
    total_test = {k:0 for k in np.arange(num_of_subnetworks + 1)}
    tracking_test = {'text': [], 'input_ids': [], 'routing': [], 'adaptive': [], 'attn_mask': [], 'labels': []}

    # for each batch
    for batch in tqdm(test_loader):

        raw_text, tokenizer_output_raw, routing, labels = batch
        # send to GPU
        labels = labels.type(torch.LongTensor).to(device)
        routing, input_ids = routing.to(device), tokenizer_output_raw['input_ids'].to(device)
        original_attn_mask = tokenizer_output_raw['attention_mask'].to(device)
        tracking_test['text'].append(raw_text)
        tracking_test['input_ids'].append(input_ids)
        tracking_test['routing'].append(routing.squeeze())
        tracking_test['labels'].append(labels)
        b = len(labels)

        # for all subnetworks, mask each subnetwork's assignment
        # batch x max_word_length x classes
        route_mask = nn.functional.one_hot(routing.squeeze(), num_classes=num_of_subnetworks+1)[:,:,1:]

        # create gumbel mask
        embeddings = torch.Tensor(sent_model.encode(raw_text)).to(device)
        concat_inputs_routing = torch.cat((routing.squeeze(), embeddings), 1)
        output = new_interpret[0](concat_inputs_routing)
        gumbel_mask = torch.Tensor(gumbel_sigmoid(output, tau=1, hard=True, threshold=t).squeeze()).to(device)
        adaptive_mask_repeat = gumbel_mask.repeat_interleave(max_word_length,dim=0).reshape(b,max_word_length,num_of_subnetworks)
        tracking_test['adaptive'].append(gumbel_mask)


        # AND both masks together
        attn_mask = torch.mul(adaptive_mask_repeat, route_mask) > mul_thres
        attn_mask = attn_mask.type(torch.IntTensor).to(device)
        tracking_test['attn_mask'].append(attn_mask)


        subnet_predictions = []
        for subnet in np.arange(num_of_subnetworks):
            subnet_attn = attn_mask[:,:, subnet]
            if sum(sum(subnet_attn)) != 0:
                predictions = new_interpret[subnet+1](input_ids.squeeze(), attention_mask=subnet_attn)
                subnet_predictions.append(predictions.logits)

                y_pred = torch.argmax(predictions.logits, 1)
                correct_test[subnet] += torch.sum(y_pred == labels)
                total_test[subnet] += len(labels)

            else:
                subnet_predictions.append(torch.zeros((b, num_classes)).to(device))

        g_weighting = output.t().repeat_interleave(b).reshape(num_of_subnetworks, b, b)[:,:,:num_classes]
        subnet_predictions = torch.stack(subnet_predictions, dim=0)
        weighted_predictions = torch.mul(g_weighting, subnet_predictions)
        weighted_predictions = torch.sum(weighted_predictions, dim=0)

        y_pred = torch.argmax(weighted_predictions, 1)
        correct_test[num_of_subnetworks] += torch.sum(y_pred == labels)
        total_test[num_of_subnetworks] += len(labels)
    
    print("Test Accuracy: ", correct_test[num_of_subnetworks]/(total_test[num_of_subnetworks]))    
    [print("Subnet: ", subnet, "| Test Accuracy:", correct_test[subnet]/(total_test[subnet]+0.00001), "| Test Total:", total_test[subnet]) for subnet in np.arange(num_of_subnetworks)]


with open('correct_test_' + experiment + '.pickle', 'wb') as file:
    pickle.dump(correct_test, file)
    
with open('total_test_' + experiment + '.pickle', 'wb') as file:
    pickle.dump(total_test, file)
    
with open('tracking_test_' + experiment + '.pickle', 'wb') as file:
    pickle.dump(tracking_test, file)