# MNIST

In [None]:
import sage
import torch
import pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as dsets
from sklearn.metrics import log_loss

In [None]:
# Load train set
train = dsets.MNIST('../data', train=True, download=True)
imgs = train.data.reshape(-1, 784) / 255.0
labels = train.targets

# Shuffle and split into train and val
inds = torch.randperm(len(train))
imgs = imgs[inds]
labels = labels[inds]
val, Y_val = imgs[:6000], labels[:6000]
train, Y_train = imgs[6000:], labels[6000:]

# Load test set
test = dsets.MNIST('../data', train=False, download=True)
test, Y_test = test.data.reshape(-1, 784) / 255.0, test.targets

# Move test data to numpy
test_np = test.cpu().data.numpy()
Y_test_np = Y_test.cpu().data.numpy()

In [None]:
def train_model(train, Y_train, val, Y_val):
    # Create model
    device = torch.device('cuda', 1)
    model = nn.Sequential(
        nn.Linear(train.shape[1], 256),
        nn.ELU(),
        nn.Linear(256, 256),
        nn.ELU(),
        nn.Linear(256, 10)).to(device)

    # Training parameters
    lr = 1e-3
    mbsize = 64
    max_nepochs = 250
    loss_fn = nn.CrossEntropyLoss()
    lookback = 5
    verbose = False

    # Move to GPU
    train = train.to(device)
    val = val.to(device)
    # test = test.to(device)
    Y_train = Y_train.to(device)
    Y_val = Y_val.to(device)
    # Y_test = Y_test.to(device)

    # Data loader
    train_set = TensorDataset(train, Y_train)
    train_loader = DataLoader(train_set, batch_size=mbsize, shuffle=True)

    # Setup
    optimizer = optim.Adam(model.parameters(), lr=lr)
    min_criterion = np.inf
    min_epoch = 0

    # Train
    for epoch in range(max_nepochs):
        for x, y in train_loader:
            # Move to device.
            x = x.to(device=device)
            y = y.to(device=device)

            # Take gradient step.
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            model.zero_grad()

        # Check progress.
        with torch.no_grad():
            # Calculate validation loss.
            val_loss = loss_fn(model(val), Y_val).item()
            if verbose:
                print('{}Epoch = {}{}'.format('-' * 10, epoch + 1, '-' * 10))
                print('Val loss = {:.4f}'.format(val_loss))

            # Check convergence criterion.
            if val_loss < min_criterion:
                min_criterion = val_loss
                min_epoch = epoch
                best_model = deepcopy(model)
            elif (epoch - min_epoch) == lookback:
                if verbose:
                    print('Stopping early')
                break

    # Keep best model
    model = best_model
    return model


In [None]:
sage_values = sage.load('results/mnist_sage_01.pkl')

In [None]:
with open('results/mnist mean_importance.pkl', 'rb') as f:
    mean_imp = pickle.load(f)

In [None]:
permutation = []
for i in range(512):
    filename = 'results/mnist permutation_test {}.pkl'.format(i)
    with open(filename, 'rb') as f:
        permutation.append(pickle.load(f)['scores'])
permutation = np.array(permutation).mean(axis=0)

In [None]:
with open('results/mnist feature_ablation.pkl', 'rb') as f:
    ablation = pickle.load(f)

In [None]:
with open('results/mnist univariate.pkl', 'rb') as f:
    univariate = pickle.load(f)

# Train models

In [None]:
importance = (sage_values.values, permutation, mean_imp, ablation, univariate)
names = ('SAGE', 'Permutation Test', 'Mean Importance', 'Feature Ablation', 'Univariate')
mnist_results = {name: {'values': imp} for (imp, name) in zip(importance, names)}

In [None]:
device = torch.device('cuda', 1)
num_features = list(range(5, 55, 5))

for name in mnist_results.keys():
    values = mnist_results[name]['values']
    order = np.argsort(values)[::-1]

    loss_list = []
    acc_list = []
    for num in num_features:
        # Subsample data
        inds = order[:num]
        inds = np.array([i in inds for i in range(784)])
        train_small = train[:, inds]
        val_small = val[:, inds]
        test_small = test[:, inds]
        
        # Train model
        model = train_model(train_small, Y_train, val_small, Y_val)
        preds = model(test_small.to(device)).softmax(dim=1).cpu().data.numpy()
        loss = log_loss(Y_test_np, preds)
        acc = np.mean(np.argmax(preds, axis=1) == Y_test_np)
        loss_list.append(loss)
        acc_list.append(acc)
        print('Done with {} {} (loss = {:.4f}, acc = {:.4f})'.format(name, num, loss, acc))
    
    mnist_results[name]['selection'] = loss_list
    mnist_results[name]['accuracy'] = acc_list

In [None]:
device = torch.device('cuda', 1)
num_features = list(range(5, 55, 5))

for name in mnist_results.keys():
    values = mnist_results[name]['values']
    order = np.argsort(values)[::-1]

    loss_list = []
    acc_list = []
    for num in num_features:
        # Subsample data
        inds = order[-num:]
        inds = np.array([i in inds for i in range(784)])
        train_small = train[:, inds]
        val_small = val[:, inds]
        test_small = test[:, inds]
        
        # Train model
        model = train_model(train_small, Y_train, val_small, Y_val)
        preds = model(test_small.to(device)).softmax(dim=1).cpu().data.numpy()
        loss = log_loss(Y_test_np, preds)
        acc = np.mean(np.argmax(preds, axis=1) == Y_test_np)
        loss_list.append(loss)
        acc_list.append(acc)
        print('Done with {} {} (loss = {:.4f}, acc = {:.4f})'.format(name, num, loss, acc))
    
    mnist_results[name]['inv_selection'] = loss_list
    mnist_results[name]['inv_accuracy'] = acc_list

# Plot

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))

names = ('Permutation Test', 'Mean Importance', 'Feature Ablation', 'Univariate', 'SAGE')
colors = ('tab:blue', 'tab:gray', 'tab:green', 'tab:olive', 'tab:pink')

# Selection
ax = axarr[0]
plt.sca(ax)
for name, color in zip(names, colors):
    values = mnist_results[name]['accuracy']
    plt.plot(num_features, values, color=color, label=name,
             marker='o', linestyle='--')
plt.ylabel('Accuracy', fontsize=18)
plt.xlabel('# Features', fontsize=18)
plt.tick_params(labelsize=16)
plt.legend(loc='lower right', fontsize=18)
plt.title('MNIST Important Features', fontsize=20)

# Inverse selection
ax = axarr[1]
plt.sca(ax)
for name, color in zip(names, colors):
    values = mnist_results[name]['inv_accuracy']
    plt.plot(num_features, values, color=color, label=name,
             marker='o', linestyle='--')
plt.xlabel('# Features', fontsize=18)
plt.tick_params(labelsize=16)
plt.legend(loc='lower right', fontsize=18)
plt.title('MNIST Unimportant Features', fontsize=20)

plt.tight_layout()
# plt.show()
plt.savefig('figures/feature_selection.pdf')