In [75]:
# Here we explore feature importance modules provided by captum using toy datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn.functional as F
from sklearn.datasets import make_classification
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", font_scale=0.8, rc=custom_params)
%config InlineBackend.figure_format='retina'

x, y = make_classification(
    n_samples=1000,
    n_features=10,
    n_informative=3,
    n_redundant=0,
    n_repeated=0,
    n_classes=4,
    n_clusters_per_class=1,
    weights=None,
    flip_y=0.01,
    class_sep=1.0,
    hypercube=True,
    shift=0.0,
    scale=1.0,
    shuffle=False,
    random_state=0,
)

In [76]:
# x with only informative features
# x = x[:, :3]

In [77]:
feature_names = [f"x{i}" for i in range(1, x.shape[1] + 1)]
df = pd.DataFrame(x, columns=feature_names)
df["y"] = y

In [78]:
import torch
import torch.nn as nn
import torch.optim as optim

In [87]:
batch_size = 50
num_epochs = 2000
learning_rate = 0.001
size_hidden1 = 100
size_hidden2 = 50
size_hidden3 = 10
size_hidden4 = 1

In [88]:
torch.manual_seed(1234)
np.random.seed(1234)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=0)

In [89]:
x_train = torch.tensor(x_train).float()
y_train = torch.tensor(y_train).view(-1, 1).long()

x_test = torch.tensor(x_test).float()
y_test = torch.tensor(y_test).view(-1, 1).long()

datasets = torch.utils.data.TensorDataset(x_train, y_train)
train_iter = torch.utils.data.DataLoader(datasets, batch_size=50, shuffle=True)

In [90]:
class Model(nn.Module):
    def __init__(self, input_size=10, n_classes=4):
        super().__init__()
        self.lin1 = nn.Linear(input_size, size_hidden1)
        self.lin2 = nn.Linear(size_hidden1, size_hidden2)
        self.lin3 = nn.Linear(size_hidden2, size_hidden3)
        self.lin4 = nn.Linear(size_hidden3, n_classes)
        self.gelu = nn.GELU()

    def forward(self, input):
        x = self.gelu(self.lin1(input))
        x = self.gelu(self.lin2(x))
        x = self.gelu(self.lin3(x))
        logits = self.gelu(self.lin4(x))
        return logits


model = Model(input_size=x.shape[1], n_classes=np.unique(y).size)
model.train()

Model(
  (lin1): Linear(in_features=10, out_features=100, bias=True)
  (lin2): Linear(in_features=100, out_features=50, bias=True)
  (lin3): Linear(in_features=50, out_features=10, bias=True)
  (lin4): Linear(in_features=10, out_features=4, bias=True)
  (gelu): GELU(approximate='none')
)

In [91]:
def eval_stats(model_inp, x_test, y_test):
    model_inp.eval()
    with torch.no_grad():
        outputs = model_inp(x_test)
        _, predicted = torch.max(outputs.data, 1)

        total = y_test.size(0)
        correct = (predicted == y_test.view(-1)).sum().item()
        print("Accuracy of the network on the test set: %d %%" % (100 * correct / total))


def train(model_inp, num_epochs=num_epochs):
    optimizer = torch.optim.Adam(model_inp.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for inputs, labels in train_iter:
            outputs = model_inp(inputs)
            loss = criterion(outputs, labels.view(-1))
            optimizer.zero_grad()
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        if epoch % 20 == 0:
            eval_stats(model_inp, x_test, y_test)
            print(
                "Epoch [%d]/[%d] running accumulative loss across all batches: %.3f"
                % (epoch + 1, num_epochs, running_loss)
            )
        running_loss = 0.0

In [30]:
criterion = nn.CrossEntropyLoss()

In [31]:
train(model)

Accuracy of the network on the test set: 30 %
Epoch [1]/[2000] running accumulative loss across all batches: 19.445
Accuracy of the network on the test set: 72 %
Epoch [21]/[2000] running accumulative loss across all batches: 16.205
Accuracy of the network on the test set: 83 %
Epoch [41]/[2000] running accumulative loss across all batches: 7.993
Accuracy of the network on the test set: 84 %
Epoch [61]/[2000] running accumulative loss across all batches: 5.368
Accuracy of the network on the test set: 85 %
Epoch [81]/[2000] running accumulative loss across all batches: 4.386
Accuracy of the network on the test set: 86 %
Epoch [101]/[2000] running accumulative loss across all batches: 3.845
Accuracy of the network on the test set: 88 %
Epoch [121]/[2000] running accumulative loss across all batches: 3.475
Accuracy of the network on the test set: 87 %
Epoch [141]/[2000] running accumulative loss across all batches: 3.194
Accuracy of the network on the test set: 88 %
Epoch [161]/[2000] run

In [None]:
class ModelConcrete(torch.nn.Module):
    def __init__(self, n_mask, input_size=10, n_classes=4):
        super().__init__()
        self.n_mask = n_mask
        self.num_features = input_size
        self.num_classes = n_classes
        self.concrete = nn.Parameter(torch.randn(self.n_mask, self.num_features))
        self.lin1 = nn.Linear(input_size, size_hidden1)
        self.lin2 = nn.Linear(size_hidden1, size_hidden2)
        self.lin3 = nn.Linear(size_hidden2, size_hidden3)
        self.lin4 = nn.Linear(size_hidden3, n_classes)
        self.gelu = nn.GELU()

    def forward(self, x, temp, hard_):
        mask = F.gumbel_softmax(self.concrete, tau=temp, hard=hard_)
        mask = torch.sum(mask, axis=0)
        mask = torch.clamp(mask, min=0, max=1)
        x = mask * x
        x = self.gelu(self.lin1(x))
        x = self.gelu(self.lin2(x))
        x = self.gelu(self.lin3(x))
        logits = self.gelu(self.lin4(x))
        return logits

    def softmax(self):
        return F.softmax(self.concrete, dim=1)


modelConcrete = ModelConcrete(n_mask=3, input_size=x.shape[1], n_classes=np.unique(y).size)
modelConcrete.train()

ModelConcrete(
  (lin1): Linear(in_features=10, out_features=100, bias=True)
  (lin2): Linear(in_features=100, out_features=50, bias=True)
  (lin3): Linear(in_features=50, out_features=10, bias=True)
  (lin4): Linear(in_features=10, out_features=4, bias=True)
  (gelu): GELU(approximate='none')
)

In [93]:
def exp_decay_temp_schedule(epoch, total_epoch):
    start_temp = 10
    end_temp = 0.01
    temp = start_temp * (end_temp / start_temp) ** (epoch / total_epoch)
    return temp

Consider adding LR scheduler

In [94]:
def eval_statsConcrete(model_inp, x_test, y_test):
    model_inp.eval()
    with torch.no_grad():
        outputs = model_inp(x_test, 0.01, True)
        _, predicted = torch.max(outputs.data, 1)

        total = y_test.size(0)
        correct = (predicted == y_test.view(-1)).sum().item()
        print("Accuracy of the network on the test set: %d %%" % (100 * correct / total))


def trainConcrete(model_inp, num_epochs=num_epochs):
    optimizer = torch.optim.Adam(model_inp.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for inputs, labels in train_iter:
            outputs = model_inp(inputs, exp_decay_temp_schedule(epoch, num_epochs), False)
            loss = criterion(outputs, labels.view(-1))
            optimizer.zero_grad()
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        if epoch % 20 == 0:
            eval_statsConcrete(model_inp, x_test, y_test)
            print(
                "Epoch [%d]/[%d] running accumulative loss across all batches: %.3f"
                % (epoch + 1, num_epochs, running_loss)
            )
        running_loss = 0.0

In [95]:
trainConcrete(modelConcrete)

Accuracy of the network on the test set: 20 %
Epoch [1]/[2000] running accumulative loss across all batches: 19.385
Accuracy of the network on the test set: 42 %
Epoch [21]/[2000] running accumulative loss across all batches: 4.039
Accuracy of the network on the test set: 37 %
Epoch [41]/[2000] running accumulative loss across all batches: 2.945
Accuracy of the network on the test set: 24 %
Epoch [61]/[2000] running accumulative loss across all batches: 2.635
Accuracy of the network on the test set: 46 %
Epoch [81]/[2000] running accumulative loss across all batches: 2.315
Accuracy of the network on the test set: 36 %
Epoch [101]/[2000] running accumulative loss across all batches: 2.104
Accuracy of the network on the test set: 24 %
Epoch [121]/[2000] running accumulative loss across all batches: 2.086
Accuracy of the network on the test set: 38 %
Epoch [141]/[2000] running accumulative loss across all batches: 1.756
Accuracy of the network on the test set: 24 %
Epoch [161]/[2000] runn

In [97]:
mask = F.gumbel_softmax(modelConcrete.concrete, tau=0.01, hard=True)
mask = torch.sum(mask, axis=0)
mask = torch.clamp(mask, min=0, max=1)
mask

tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<ClampBackward1>)