In [2]:
import torch
from torch import nn

from mil.data.mnist import Bag, OneHotMNISTBags
from mil.utils import device
from mil.models import MILModel


  from .autonotebook import tqdm as notebook_tqdm


# One hot MNIST bags baseline

We can train a model to overfit `OneHotMNISTBags` using max pooling. 
However, this only works for the *mnist-bags* variant, i.e. when there is one target number. 
In the *multi-mnist-bags* variant, this baseline fails (because it cannot "focus" on two target numbers simultaneously).

Try changing `TARGET_NUMBERS` below and rerunning the notebook.

In [3]:
TARGET_NUMBERS = 9
# TARGET_NUMBERS = (9, 7)

def make_data_loader(train: bool = True):
    """Utility function to create a data loader for the dataset."""
    ds = OneHotMNISTBags(target_numbers=TARGET_NUMBERS, # target number
                        min_instances_per_target=1, # 1 instance of a "9" suffices to label a bag as positive
                        num_digits=10, # sample from all 10 MNIST digits
                        mean_bag_size=10, # mean bag length
                        var_bag_size=2, # variance of bag length
                        num_bags=250, # number of bags
                        seed=1,
                        train=train)
    loader = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=train, collate_fn=lambda x: x[0])
    return loader

## Define the model

This is a super simple baseline model for `OneHotBags`. The three parts of the MIL model are:
1. **feature extractor**: no feature extractor, as the features are already the one-hot encoded digits (feature vectors are 10-dimensional)
2. **max pooling**: simply retrieve maximum for each of the 10 dimensions in the feature vector
3. **classifier**: linear layer from 10 to one unit followed by sigmoid


In [4]:
class OneHotFeatureExtractor(nn.Module):
    def forward(self, bag: Bag):
        # In the case of OneHotBags, the instances are already the features.
        return bag.instances

class SimplePooler(nn.Module):
    """Simple pooling layer for mean/max pooling."""
    def __init__(self, pool: str = "mean", dim: int = 0):
        super().__init__()
        self.pool = pool
        self.dim = dim

    def forward(self, bag: Bag, features):
        pool = getattr(torch, self.pool)
        result = pool(features, dim=self.dim)
        if self.pool == "max":
            result = result.values
        return result

class Classifier(nn.Sequential):
    def __init__(self, hidden_dim: int):
        super().__init__(
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

# Define model
model = MILModel(feature_extractor=OneHotFeatureExtractor(),
                 pooler=SimplePooler("max"),
                 classifier=Classifier(hidden_dim=10))

## Define loss function and optimizer

We use binary cross-entropy loss.

In [5]:
loss_function = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

## Train

In [6]:
loader = make_data_loader(train=True)

model.train()
for epoch in range(20):
    total_loss = 0.
    total_error = 0.
    for bag in loader:
        bag = device(bag)
        y = bag.bag_label

        optimizer.zero_grad()

        # Calculate loss and metrics
        y_pred = model(bag).squeeze(0)
        loss = loss_function(y_pred, y)

        error = 1. - ((y_pred > .5).float() == y).cpu().detach().float()
        total_error += error
        
        # Backward pass
        loss.backward()

        total_loss += loss.detach().cpu()
        # Step
        optimizer.step()

    print(
        f"Epoch: {epoch}, loss: {total_loss/len(loader):.4f}, error: {total_error/len(loader):.4f}")

Epoch: 0, loss: 0.6970, error: 0.4880
Epoch: 1, loss: 0.6552, error: 0.3960
Epoch: 2, loss: 0.6254, error: 0.3000
Epoch: 3, loss: 0.6000, error: 0.2040
Epoch: 4, loss: 0.5755, error: 0.1520
Epoch: 5, loss: 0.5548, error: 0.1160
Epoch: 6, loss: 0.5343, error: 0.0880
Epoch: 7, loss: 0.5154, error: 0.0840
Epoch: 8, loss: 0.4979, error: 0.0720
Epoch: 9, loss: 0.4812, error: 0.0640
Epoch: 10, loss: 0.4655, error: 0.0520
Epoch: 11, loss: 0.4508, error: 0.0400
Epoch: 12, loss: 0.4372, error: 0.0360
Epoch: 13, loss: 0.4239, error: 0.0320
Epoch: 14, loss: 0.4116, error: 0.0320
Epoch: 15, loss: 0.3993, error: 0.0280
Epoch: 16, loss: 0.3886, error: 0.0280
Epoch: 17, loss: 0.3778, error: 0.0280
Epoch: 18, loss: 0.3676, error: 0.0240
Epoch: 19, loss: 0.3581, error: 0.0240


## Test

In [7]:
loader = make_data_loader(train=False)
model.eval()

total_loss = 0.
total_error = 0.
with torch.no_grad():
    for i, bag in enumerate(loader):
        bag = device(bag)
        y = bag.bag_label.float()

        # Calculate loss and metrics
        y_pred = model(bag).squeeze(0)
        loss = loss_function(y_pred, y)

        error = 1. - ((y_pred > .5).float() == y).cpu().detach().float()
        total_error += error
        total_loss += loss.detach().cpu()

        if i < 10:  # Print bag labels and instance labels for first 5 bags
            print(f"#{i}: true, predicted bag label: {bag.bag_label.float().item():.0f}, {(y_pred > .5).float().item():.0f}")

print(
    f"Test loss: {total_loss/len(loader):.4f}, error: {total_error/len(loader):.4f}")


#0: true, predicted bag label: 0, 0
#1: true, predicted bag label: 0, 0
#2: true, predicted bag label: 0, 0
#3: true, predicted bag label: 1, 1
#4: true, predicted bag label: 0, 0
#5: true, predicted bag label: 1, 1
#6: true, predicted bag label: 0, 0
#7: true, predicted bag label: 0, 0
#8: true, predicted bag label: 0, 0
#9: true, predicted bag label: 1, 1
Test loss: 0.3524, error: 0.0240
