In [9]:
# Here we explore feature importance modules provided by captum using toy datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn.functional as F
from sklearn.datasets import make_classification
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", font_scale=0.8, rc=custom_params)
%config InlineBackend.figure_format='retina'

x, y = make_classification(
    n_samples=1000,
    n_features=10,
    n_informative=3,
    n_redundant=0,
    n_repeated=0,
    n_classes=4,
    n_clusters_per_class=1,
    weights=None,
    flip_y=0.01,
    class_sep=1.0,
    hypercube=True,
    shift=0.0,
    scale=1.0,
    shuffle=False,
    random_state=0,
)

In [10]:
# x with only informative features
# x = x[:, :3]

In [11]:
feature_names = [f"x{i}" for i in range(1, x.shape[1] + 1)]
df = pd.DataFrame(x, columns=feature_names)
df["y"] = y

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim

In [67]:
batch_size = 50
num_epochs = 2000
learning_rate = 0.01
size_hidden1 = 100
size_hidden2 = 50
size_hidden3 = 10
size_hidden4 = 1

In [68]:
from sklearn.model_selection import StratifiedKFold
from torch_geometric.data import Data

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Keep cells who are participants in the multilayer graph (more than 1 connection)
train_idx, test_idx = next(skf.split(np.arange(x.shape[0]), y))

train_mask = np.zeros(x.shape[0], dtype=bool)
train_mask[train_idx] = True
train_mask = torch.tensor(train_mask, dtype=torch.bool)

test_mask = np.zeros(x.shape[0], dtype=bool)
test_mask[test_idx] = True
test_mask = torch.tensor(test_mask, dtype=torch.bool)

In [69]:
edgelist_self = torch.tensor([[i, i] for i in range(x.shape[0])])
edgelist_self = edgelist_self.T

one_sec_x = torch.tensor(x, dtype=torch.float)
labels = torch.tensor(y, dtype=torch.long)
data_self = Data(x=one_sec_x, edge_index=edgelist_self, y=labels, train_mask=train_mask, test_mask=test_mask)

In [70]:
criterion = nn.CrossEntropyLoss()

In [71]:
def exp_decay_temp_schedule(epoch, total_epoch):
    start_temp = 10
    end_temp = 0.01
    temp = start_temp * (end_temp / start_temp) ** (epoch / total_epoch)
    return temp


def eval_stats_concrete(model_inp, data):
    model_inp.eval()
    with torch.no_grad():
        outputs = model_inp(data.x[data.test_mask], data.edge_index, 0.01, True)
        _, predicted = torch.max(outputs.data, 1)

        total = data.y[data.test_mask].size(0)
        correct = (predicted == data.y[data.test_mask].view(-1)).sum().item()
        print("Accuracy of the network on the test set: %d %%" % (100 * correct / total))


def trainConcrete(model_inp, data, num_epochs=num_epochs):
    optimizer = torch.optim.Adam(model_inp.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        outputs = model_inp(data.x, data.edge_index, exp_decay_temp_schedule(epoch, num_epochs), False)
        loss = criterion(outputs[data.train_mask], data.y[data.train_mask])
        optimizer.zero_grad()
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
        if epoch % 20 == 0:
            eval_stats_concrete(model_inp, data)
            print(
                "Epoch [%d]/[%d] running accumulative loss across all batches: %.3f"
                % (epoch + 1, num_epochs, running_loss)
            )
        running_loss = 0.0

In [72]:
from torch_geometric.nn import GATv2Conv


class ModelConcrete(torch.nn.Module):
    def __init__(self, n_mask, hidden_channels, num_features=10, n_classes=4):
        super().__init__()
        self.n_mask = n_mask
        self.num_classes = n_classes
        self.num_features = num_features
        self.hidden_channels = hidden_channels
        self.concrete = nn.Parameter(torch.randn(self.n_mask, self.num_features))

        self.conv1 = GATv2Conv(self.num_features, self.hidden_channels, heads=8, concat=False)
        self.conv2 = GATv2Conv(self.hidden_channels, self.num_classes, heads=8, concat=False)
        self.lin1 = nn.Linear(num_features, self.num_classes)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x, edge_index, temp, hard_):
        mask = F.gumbel_softmax(self.concrete, tau=temp, hard=hard_)
        mask = torch.sum(mask, axis=0)
        mask = torch.clamp(mask, min=0, max=1)
        x = mask * x

        residual1 = self.lin1(x)

        out = self.conv1(x, edge_index)
        out = out.relu()
        out = self.dropout(out)

        out = self.conv2(out, edge_index)
        out = out + residual1

        return out

    def softmax(self):
        return F.softmax(self.concrete, dim=1)


modelConcrete = ModelConcrete(n_mask=3, hidden_channels=6, num_features=x.shape[1], n_classes=np.unique(y).size)
modelConcrete.train()

ModelConcrete(
  (conv1): GATv2Conv(10, 6, heads=8)
  (conv2): GATv2Conv(6, 4, heads=8)
  (lin1): Linear(in_features=10, out_features=4, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
)

Consider adding LR scheduler

In [73]:
trainConcrete(modelConcrete, data_self)

Accuracy of the network on the test set: 18 %
Epoch [1]/[2000] running accumulative loss across all batches: 1.409
Accuracy of the network on the test set: 59 %
Epoch [21]/[2000] running accumulative loss across all batches: 1.101
Accuracy of the network on the test set: 47 %
Epoch [41]/[2000] running accumulative loss across all batches: 0.794
Accuracy of the network on the test set: 36 %
Epoch [61]/[2000] running accumulative loss across all batches: 0.586
Accuracy of the network on the test set: 63 %
Epoch [81]/[2000] running accumulative loss across all batches: 0.481
Accuracy of the network on the test set: 40 %
Epoch [101]/[2000] running accumulative loss across all batches: 0.419
Accuracy of the network on the test set: 84 %
Epoch [121]/[2000] running accumulative loss across all batches: 0.364
Accuracy of the network on the test set: 66 %
Epoch [141]/[2000] running accumulative loss across all batches: 0.336
Accuracy of the network on the test set: 83 %
Epoch [161]/[2000] runni

In [74]:
mask = F.gumbel_softmax(modelConcrete.concrete, tau=0.01, hard=True)
mask = torch.sum(mask, axis=0)
mask = torch.clamp(mask, min=0, max=1)
mask

tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<ClampBackward1>)