In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import itertools
import pandas as pd
import torch_geometric as pyg
import numpy as np
import pandas as pd
import torch_geometric as pyg
import hydra
from hydra.core.global_hydra import GlobalHydra
import os
os.environ["HYDRA_FULL_ERROR"] = "1"


from mil.data.mnist import MNISTBags, OneHotMNISTBags, MNISTCollage, OneHotMNISTCollage
from mil.utils import device, human_format, set_seed
from mil.utils.visualize import print_one_hot_bag_with_attention, print_one_hot_bag, plot_attention_head, plot_bag, plot_one_hot_collage
from mil.utils.stats import print_prediction_stats
from mil.models.abmil import WeightedAverageAttention
from mil.models.self_attention import MultiHeadSelfAttention
from mil.models.distance_aware_self_attention import DistanceAwareSelfAttentionHead

RESULTS_FILE = "train.csv"

GlobalHydra().clear()
hydra.initialize(config_path="conf")
cfg = hydra.compose("config.yaml", overrides=["+experiment=mnist_bags", "+model=self_attention"])

set_seed(cfg.seed)

  from .autonotebook import tqdm as notebook_tqdm
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path="conf")


# MNIST bags / MNIST collage

This notebook trains models on variations of the *mnist-bags*, *multi-mnist-bags*, and *mnist-collage* datasets. 
The goal of this notebook is to see which models are able to overfit on these datasets.


The following cell defines three variables, `DATASET`, `TARGET_NUMBERS` and `MODEL` which can be used to run different dataset/model configurations.

`DATASET`:
- `OneHotMNISTBags`: one-hot version of *mnist-bags*, where the dataset yields 10-dimensional one-hot encoded feature vectors directly (i.e. we are not yet working with MNIST digits)
- `MNISTBags`: the *mnist-bags* dataset
- `OneHotMNISTCollage`: one-hot version of *mnist-collage*
- `MNISTCollage`: *mnist-collage* dataset

`TARGET_NUMBERS`:
- `0` corresponds to the *mnist-bags* dataset
- `(0, 1)` corresponds to the *multi-mnist-bags* dataset

`MODELS`:
- `"mean_pool"`: simple baseline that uses mean pooling. Works neither dataset.
- `"max_pool"`: simple baseline that uses max pooling. Works for *mnist-bags*, but not *multi-mnist-bags*.
- `"weighted_average_attention"`: uses attention mechanism from "Attention Based Deep Multiple Instance Learning" paper which can only "focus" on one target number at a time. Works for *mnist-bags*, but not *multi-mnist-bags*.
- `"self_attention_mean_pooling"`: uses a single transformer layer (self attention) followed by mean pooling. Works for both datasets.
- `"self_attention_max_pooling"`: uses a single transformer layer (self attention) followed by max pooling. Works for both datasets (but better than mean pooling).


Try changing `DATASET`, `TARGET_NUMBERS` and `MODELS` below and rerunning the notebook.

In [2]:
# Instantiate datasets and loaders
train_dataset = hydra.utils.instantiate(cfg.dataset.train)
test_dataset = hydra.utils.instantiate(cfg.dataset.test)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x[0], num_workers=0, pin_memory=False)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=0, pin_memory=False)

## Define the model

The three parts of the MIL model are:
1. **feature extractor**: extract a feature vector $z \in \mathbb{R}^D$ from each instance. In the case of the one-hot dataset, this is just the identity function. For the actual *mnist-bags* dataset, this is a CNN.
2. **pooling**: a function $f : \mathbb{R}^{N \times D} \to \mathbb{R}^D$ that aggregates the $N$ feature vectors in the bag to a single feature vector.
3. **classifier**: a function $g : \mathbb{R}^D \to \mathbb{R}$ that transforms the aggregated feature vector into a binary classification prediction (we parameterise $g$ using a linear layer followed by a sigmoid)


In [10]:
import typing
from mil.utils import identity
from torch import nn
import torch
from torch_geometric.data import Data

class ABMIL(nn.Module):
    """
    Implementation of the attention layer from the paper: "Attention-Based Deep Multiple Instance Learning", https://arxiv.org/pdf/1802.04712.pdf.

    The attention layer is a weighted average of the features, where the weights are calculated by a neural network.
    """

    def __init__(self, feature_size: int, hidden_dim: int):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_size, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Softmax(dim=-2)
        )

    def forward(self, features):
        H = features  # BxNxL

        # Attention weights
        A = self.attention(H)  # BxNx1

        # Context vector (weighted average of the features)
        M = torch.sum(A * H, dim=-2)  # BxL

        self.A = A
        return M

class MILModel(nn.Module):
    """Structure of a multiple instance learning model.

    The model consists of three parts:
    1. A feature extractor that takes a bag and returns a Z-dimensional feature vector for each instance. A bag with N instances will thus be represented by a N x Z matrix.
    2. A pooling function that takes the feature matrix and returns a single feature vector for the bag. (N x Z) -> (1 x Z)
    3. A classifier that takes a bag and returns a single scalar value. (1 x Z) -> (1 x 1)
    """

    def __init__(self,
                 feature_extractor: nn.Module,
                 pooler: nn.Module,
                 classifier: nn.Module,
                 logit_to_prob: typing.Callable[[torch.Tensor], torch.Tensor] = identity):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.pooler = pooler
        self.classifier = classifier
        self.logit_to_prob = logit_to_prob

    def forward(self, bag: Data):
        features = self.feature_extractor(bag.x)
        pooled = self.pooler(features, bag.edge_index, bag.edge_attr, bag.pos)
        logit = self.classifier(pooled).squeeze(-1)
        prob = self.logit_to_prob(logit)
        return prob, logit

class DefaultClassifier(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        print(x.shape)

class AdditiveClassifier(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        return self.linear(x).squeeze(-1).sum(-1)

model = MILModel(
    feature_extractor=hydra.utils.instantiate(cfg.model.feature_extractor),
    pooler=pyg.nn.Sequential("x, edge_index, edge_attr", [
        (
            ABMIL(feature_size=cfg.settings.feature_size, hidden_dim=cfg.settings.hidden_dim),
            "x -> x"
        )
    ]),
    classifier=DefaultClassifier(input_dim=cfg.settings.hidden_dim),
)

# model = hydra.utils.instantiate(cfg.model, _convert_="partial")


## Define loss function and optimizer

We use binary cross-entropy loss.

In [11]:
loss_function = nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

def error_score(y_pred, y):
    return 1. - ((y_pred > .5).float() == y).cpu().detach().float()

Helper code to evaluate on test set:

In [12]:
def test_loss_and_error(model, loader):
    model.eval()

    total_loss = 0.
    total_error = 0.
    predictions = []

    with torch.no_grad():
        for i, bag in enumerate(loader):
            bag = device(bag)
            y = bag.y.float()

            # Calculate loss and metrics
            y_pred = model(bag).squeeze()
            loss = loss_function(y_pred, y)

            predictions.append((bag.cpu().detach(), y_pred.detach().cpu()))

            error = error_score(y_pred, y)
            total_error += error
            total_loss += loss.detach().cpu()
    return total_loss / len(loader), total_error / len(loader), predictions

## Train

In [13]:
stats = []

model.train()
print(f"Training model with {human_format(sum(p.numel() for p in model.parameters() if p.requires_grad))} parameters")

for epoch in range(50):
    model.train()

    total_loss = 0.
    total_error = 0.
    for bag in train_loader:
        bag = device(bag)
        y = bag.y

        optimizer.zero_grad()

        # Calculate loss and metrics
        y_pred = model(bag.x, bag.edge_index, bag.edge_attr).squeeze()
        loss = loss_function(y_pred, y)

        error = error_score(y_pred, y)
        total_error += error
        
        # Backward pass
        loss.backward()

        total_loss += loss.detach().cpu()
        # Step
        optimizer.step()
    
    test_loss, test_error, _ = test_loss_and_error(model, test_loader)

    stats.append({
        "epoch": epoch,
        "loss": total_loss / len(train_loader),
        "error": total_error / len(train_loader),
        "test_loss": test_loss,
        "test_error": test_error
    })
    print(
        f"Epoch: {epoch:3d}, loss: {total_loss/len(train_loader):.4f}, error: {total_error/len(train_loader):.4f}, test_loss: {test_loss:.4f}, test_error: {test_error:.4f}")

# Plot training and test loss/error
stats = pd.DataFrame(stats)
float_cols = [col for col in stats.columns if col != "epoch"]
stats[float_cols] = stats[float_cols].astype(float)
stats.to_csv(RESULTS_FILE, index=False)

plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.title("Loss")
plt.plot(stats["epoch"], stats["loss"], label="train")
plt.plot(stats["epoch"], stats["test_loss"], label="test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(122)
plt.title("Error")
plt.plot(stats["epoch"], stats["error"], label="train")
plt.plot(stats["epoch"], stats["test_error"], label="test")
plt.xlabel("Epoch")
plt.ylabel("Error")
plt.legend()

Training model with 15.9K parameters


TypeError: forward() takes 2 positional arguments but 4 were given

## Test

In [None]:
test_loss, test_error, predictions = test_loss_and_error(model, test_loader)
print(f"Test loss: {test_loss:.4f}, test error: {test_error:.4f}")

print_prediction_stats(predictions, target_numbers=cfg.settings.mnist.target_numbers)

### First 10 bags in test dataset

In [None]:
def plot_dist_aware_attention(bag):
    plt.figure(figsize=(12, 4))
    plt.suptitle(f"Bag label: {y.item():.0f}, pred: {y_pred.item():.2f}")
    plt.subplot(141)
    plt.title("dist")
    data = attention_layer.data
    dist = pyg.utils.to_dense_adj(data.edge_index, edge_attr=data.edge_attr.squeeze(-1), max_num_nodes=data.num_nodes).squeeze(0)  # NxN
    plot_attention_head(bag, dist, limit_range=False)
    plt.subplot(142)
    plt.title("A0")
    plot_attention_head(bag, attention_layer.A0)
    plt.subplot(143)
    plt.title("A")
    plot_attention_head(bag, attention_layer.A)

def visualize_prediction(bag, y_pred):
    y = bag.y
    if DATASET == OneHotMNISTBags:
        if isinstance(attention_layer, WeightedAverageAttention):
            print_one_hot_bag_with_attention(bag, attention_layer.A, y_pred>.5)
            print()
        elif isinstance(attention_layer, MultiHeadAttention):
            plt.figure()
            plot_attention_head(bag, attention_layer.A[0])
            plt.title(f"Bag label: {y.item():.0f}, pred: {y_pred.item():.2f}")
        else:
            print_one_hot_bag(bag, y_pred>.5)
    elif DATASET == MNISTBags:
        if isinstance(attention_layer, WeightedAverageAttention):
            plot_bag(bag, y_pred=y_pred, attention=attention_layer.A.squeeze(-1))
        elif isinstance(attention_layer, MultiHeadAttention):
            plot_bag(bag, y_pred=y_pred)
            plt.figure()
            plot_attention_head(bag, attention_layer.A[0])
            plt.title(f"Bag label: {y.item():.0f}, pred: {y_pred.item():.2f}")
        else:
            plot_bag(bag, y_pred=y_pred)
    elif DATASET == OneHotMNISTCollage:
        plt.figure()
        plot_one_hot_collage(bag, y_pred=y_pred)
        plt.title(f"Bag label: {y.item():.0f}, pred: {y_pred.item():.2f}")
        if isinstance(attention_layer, DistanceAwareSelfAttentionHead):
            plot_dist_aware_attention(bag)
    elif DATASET == MNISTCollage:
        plot_bag(bag, y_pred=y_pred, collage_size=COLLAGE_SIZE)
        if isinstance(attention_layer, DistanceAwareSelfAttentionHead):
            plot_dist_aware_attention(bag)

# Visualize first 10 bags
model.eval()
with torch.no_grad():
    for bag in itertools.islice(test_loader, 10):
        bag = device(bag)
        y = bag.y.float()
        y_pred = model(bag).squeeze(0)
        visualize_prediction(bag, y_pred)

### First 10 mistakes in test dataset

In [None]:
# Visualize first 10 mistakes
model.eval()
with torch.no_grad():
    i = 0
    for bag in test_loader:
        if i == 10:
            break
        bag = device(bag)
        y = bag.y.float()
        y_pred = model(bag).squeeze(0)
        if ((y_pred > .5).float() != y).cpu().detach():
            visualize_prediction(bag, y_pred)
            i += 1