<a href="https://colab.research.google.com/github/mgwoozdz/Neural-Trees-for-Multiple-Instance-Learning/blob/main/Neural_Trees_for_Multiple_Instance_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
# dataloader.py

"""Pytorch dataset object that loads MNIST dataset as bags."""

import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms


class MnistBags(data_utils.Dataset):
    def __init__(self, target_number=9, mean_bag_length=10, var_bag_length=2, num_bag=250, seed=1, train=True):
        self.target_number = target_number
        self.mean_bag_length = mean_bag_length
        self.var_bag_length = var_bag_length
        self.num_bag = num_bag
        self.train = train

        self.r = np.random.RandomState(seed)

        self.num_in_train = 60000
        self.num_in_test = 10000

        if self.train:
            self.train_bags_list, self.train_labels_list = self._create_bags()
        else:
            self.test_bags_list, self.test_labels_list = self._create_bags()

    def _create_bags(self):
        if self.train:
            loader = data_utils.DataLoader(datasets.MNIST('../datasets',
                                                          train=True,
                                                          download=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(),
                                                              transforms.Normalize((0.1307,), (0.3081,))])),
                                           batch_size=self.num_in_train,
                                           shuffle=False)
        else:
            loader = data_utils.DataLoader(datasets.MNIST('../datasets',
                                                          train=False,
                                                          download=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(),
                                                              transforms.Normalize((0.1307,), (0.3081,))])),
                                           batch_size=self.num_in_test,
                                           shuffle=False)

        for (batch_data, batch_labels) in loader:
            all_imgs = batch_data
            all_labels = batch_labels

        bags_list = []
        labels_list = []

        for i in range(self.num_bag):
            bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
            if bag_length < 1:
                bag_length = 1

            if self.train:
                indices = torch.LongTensor(self.r.randint(0, self.num_in_train, bag_length))
            else:
                indices = torch.LongTensor(self.r.randint(0, self.num_in_test, bag_length))

            labels_in_bag = all_labels[indices]
            labels_in_bag = labels_in_bag == self.target_number

            bags_list.append(all_imgs[indices])
            labels_list.append(labels_in_bag)

        return bags_list, labels_list

    def __len__(self):
        if self.train:
            return len(self.train_labels_list)
        else:
            return len(self.test_labels_list)

    def __getitem__(self, index):
        if self.train:
            bag = self.train_bags_list[index]
            label = [max(self.train_labels_list[index]), self.train_labels_list[index]]
        else:
            bag = self.test_bags_list[index]
            label = [max(self.test_labels_list[index]), self.test_labels_list[index]]

        return bag, label


# if __name__ == "__main__":

#     train_loader = data_utils.DataLoader(MnistBags(target_number=9,
#                                                    mean_bag_length=10,
#                                                    var_bag_length=2,
#                                                    num_bag=100,
#                                                    seed=1,
#                                                    train=True),
#                                          batch_size=1,
#                                          shuffle=True)

#     test_loader = data_utils.DataLoader(MnistBags(target_number=9,
#                                                   mean_bag_length=10,
#                                                   var_bag_length=2,
#                                                   num_bag=100,
#                                                   seed=1,
#                                                   train=False),
#                                         batch_size=1,
#                                         shuffle=False)

#     len_bag_list_train = []
#     mnist_bags_train = 0
#     for batch_idx, (bag, label) in enumerate(train_loader):
#         len_bag_list_train.append(int(bag.squeeze(0).size()[0]))
#         mnist_bags_train += label[0].numpy()[0]
#     print('Number positive train bags: {}/{}\n'
#           'Number of instances per bag, mean: {}, max: {}, min {}\n'.format(
#         mnist_bags_train, len(train_loader),
#         np.mean(len_bag_list_train), np.max(len_bag_list_train), np.min(len_bag_list_train)))

#     len_bag_list_test = []
#     mnist_bags_test = 0
#     for batch_idx, (bag, label) in enumerate(test_loader):
#         len_bag_list_test.append(int(bag.squeeze(0).size()[0]))
#         mnist_bags_test += label[0].numpy()[0]
#     print('Number positive test bags: {}/{}\n'
#           'Number of instances per bag, mean: {}, max: {}, min {}\n'.format(
#         mnist_bags_test, len(test_loader),
#         np.mean(len_bag_list_test), np.max(len_bag_list_test), np.min(len_bag_list_test)))


In [2]:
# mnist_bags_loader.py

"""Pytorch Dataset object that loads perfectly balanced MNIST dataset in bag form."""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms


class MnistBags(data_utils.Dataset):
    def __init__(self, target_number=9, mean_bag_length=10, var_bag_length=1, num_bag=1000, seed=7, train=True):
        self.target_number = target_number
        self.mean_bag_length = mean_bag_length
        self.var_bag_length = var_bag_length
        self.num_bag = num_bag
        self.seed = seed
        self.train = train

        self.r = np.random.RandomState(seed)

        self.num_in_train = 60000
        self.num_in_test = 10000

        if self.train:
            self.train_bags_list, self.train_labels_list = self._form_bags()
        else:
            self.test_bags_list, self.test_labels_list = self._form_bags()

    def _form_bags(self):
        if self.train:
            train_loader = data_utils.DataLoader(datasets.MNIST('../datasets',
                                                                train=True,
                                                                download=True,
                                                                transform=transforms.Compose([
                                                                         transforms.ToTensor(),
                                                                         transforms.Normalize((0.1307,), (0.3081,))])),
                                                 batch_size=self.num_in_train,
                                                 shuffle=False)

            bags_list = []
            labels_list = []
            valid_bags_counter = 0
            label_of_last_bag = 0

            for batch_data in train_loader:
                numbers = batch_data[0]
                labels = batch_data[1]

            while valid_bags_counter < self.num_bag:
                bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
                if bag_length < 1:
                    bag_length = 1
                indices = torch.LongTensor(self.r.randint(0, self.num_in_train, bag_length))
                labels_in_bag = labels[indices]

                if (self.target_number in labels_in_bag) and (label_of_last_bag == 0):
                    labels_in_bag = labels_in_bag >= self.target_number
                    labels_list.append(labels_in_bag)
                    bags_list.append(numbers[indices])
                    label_of_last_bag = 1
                    valid_bags_counter += 1
                elif label_of_last_bag == 1:
                    index_list = []
                    bag_length_counter = 0
                    while bag_length_counter < bag_length:
                        index = torch.LongTensor(self.r.randint(0, self.num_in_train, 1))
                        label_temp = labels[index]
                        if label_temp.numpy()[0] != self.target_number:
                            index_list.append(index)
                            bag_length_counter += 1

                    index_list = np.array(index_list)
                    labels_in_bag = labels[index_list]
                    labels_in_bag = labels_in_bag >= self.target_number
                    labels_list.append(labels_in_bag)
                    bags_list.append(numbers[index_list])
                    label_of_last_bag = 0
                    valid_bags_counter += 1
                else:
                    pass

        else:
            test_loader = data_utils.DataLoader(datasets.MNIST('../datasets',
                                                               train=False,
                                                               download=True,
                                                               transform=transforms.Compose([
                                                                    transforms.ToTensor(),
                                                                    transforms.Normalize((0.1307,), (0.3081,))])),
                                                batch_size=self.num_in_test,
                                                shuffle=False)

            bags_list = []
            labels_list = []
            valid_bags_counter = 0
            label_of_last_bag = 0

            for batch_data in test_loader:
                numbers = batch_data[0]
                labels = batch_data[1]

            while valid_bags_counter < self.num_bag:
                bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
                if bag_length < 1:
                    bag_length = 1
                indices = torch.LongTensor(self.r.randint(0, self.num_in_test, bag_length))
                labels_in_bag = labels[indices]

                if (self.target_number in labels_in_bag) and (label_of_last_bag == 0):
                    labels_in_bag = labels_in_bag >= self.target_number
                    labels_list.append(labels_in_bag)
                    bags_list.append(numbers[indices])
                    label_of_last_bag = 1
                    valid_bags_counter += 1
                elif label_of_last_bag == 1:
                    index_list = []
                    bag_length_counter = 0
                    while bag_length_counter < bag_length:
                        index = torch.LongTensor(self.r.randint(0, self.num_in_test, 1))
                        label_temp = labels[index]
                        if label_temp.numpy()[0] != self.target_number:
                            index_list.append(index)
                            bag_length_counter += 1

                    index_list = np.array(index_list)
                    labels_in_bag = labels[index_list]
                    labels_in_bag = labels_in_bag >= self.target_number
                    labels_list.append(labels_in_bag)
                    bags_list.append(numbers[index_list])
                    label_of_last_bag = 0
                    valid_bags_counter += 1
                else:
                    pass

        return bags_list, labels_list

    def __len__(self):
        if self.train:
            return len(self.train_labels_list)
        else:
            return len(self.test_labels_list)

    def __getitem__(self, index):
        if self.train:
            bag = self.train_bags_list[index]
            label = [max(self.train_labels_list[index]), self.train_labels_list[index]]
        else:
            bag = self.test_bags_list[index]
            label = [max(self.test_labels_list[index]), self.test_labels_list[index]]

        return bag, label


# if __name__ == "__main__":
#     to_pil = transforms.Compose([transforms.ToPILImage()])

#     kwargs = {}
#     batch_size = 1

#     train_loader = data_utils.DataLoader(MnistBags(target_number=9,
#                                                    mean_bag_length=10,
#                                                    var_bag_length=2,
#                                                    num_bag=100,
#                                                    seed=98,
#                                                    train=True),
#                                          batch_size=batch_size,
#                                          shuffle=False, **kwargs)

#     test_loader = data_utils.DataLoader(MnistBags(target_number=9,
#                                                   mean_bag_length=10,
#                                                   var_bag_length=2,
#                                                   num_bag=10,
#                                                   seed=98,
#                                                   train=False),
#                                         batch_size=batch_size,
#                                         shuffle=False, **kwargs)

#     len_bag_list = []
#     mnist_bags_train = 0
#     for batch_idx, data in enumerate(train_loader):
#         plot_data = data[0].squeeze(0)
#         len_bag_list.append(int(plot_data.size()[0]))
#         # plot_data = data[0].squeeze(0)
#         # num_instances = int(plot_data.size()[0])
#         # print(data[1][0])
#         # for i in range(num_instances):
#         #     plt.subplot(num_instances, 1, i + 1)
#         #     to_pil(plot_data[i, :, :, :]).show()
#         # plt.show()
#         if data[1][0][0] == 1:
#             mnist_bags_train += 1
#     print('number of bags with 9(s): ', mnist_bags_train)
#     print('total number of bags', len(train_loader))
#     print(np.mean(len_bag_list), np.min(len_bag_list), np.max(len_bag_list))

#     len_bag_list = []
#     mnist_bags_test = 0
#     for batch_idx, data in enumerate(test_loader):
#         plot_data = data[0].squeeze(0)
#         len_bag_list.append(int(plot_data.size()[0]))
#         if data[1][0][0] == 1:
#             mnist_bags_test += 1
#     print('number of bags with 9(s): ', mnist_bags_test)
#     print('total number of bags', len(test_loader))
#     print(np.mean(len_bag_list), np.min(len_bag_list), np.max(len_bag_list))

In [3]:
# model.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 4 * 4, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, 50 * 4 * 4)
        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat, _ = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _, A = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli

        return neg_log_likelihood, A

class GatedAttention(nn.Module):
    def __init__(self):
        super(GatedAttention, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 4 * 4, self.L),
            nn.ReLU(),
        )

        self.attention_V = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh()
        )

        self.attention_U = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Sigmoid()
        )

        self.attention_weights = nn.Linear(self.D, self.K)

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, 50 * 4 * 4)
        H = self.feature_extractor_part2(H)  # NxL

        A_V = self.attention_V(H)  # NxD
        A_U = self.attention_U(H)  # NxD
        A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat, _ = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().item()

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _, A = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli

        return neg_log_likelihood, A


In [12]:
# main.py

from __future__ import print_function

import numpy as np

import argparse
import torch
import torch.utils.data as data_utils
import torch.optim as optim
from torch.autograd import Variable

# from dataloader import MnistBags
# from model import Attention, GatedAttention

# Training settings
# parser = argparse.ArgumentParser(description='PyTorch MNIST bags Example')
# parser.add_argument('--epochs', type=int, default=20, metavar='N',
#                     help='number of epochs to train (default: 20)')
# parser.add_argument('--lr', type=float, default=0.0005, metavar='LR',
#                     help='learning rate (default: 0.0005)')
# parser.add_argument('--reg', type=float, default=10e-5, metavar='R',
#                     help='weight decay')
# parser.add_argument('--target_number', type=int, default=9, metavar='T',
#                     help='bags have a positive labels if they contain at least one 9')
# parser.add_argument('--mean_bag_length', type=int, default=10, metavar='ML',
#                     help='average bag length')
# parser.add_argument('--var_bag_length', type=int, default=2, metavar='VL',
#                     help='variance of bag length')
# parser.add_argument('--num_bags_train', type=int, default=200, metavar='NTrain',
#                     help='number of bags in training set')
# parser.add_argument('--num_bags_test', type=int, default=50, metavar='NTest',
#                     help='number of bags in test set')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                     help='random seed (default: 1)')
# parser.add_argument('--no-cuda', action='store_true', default=False,
#                     help='disables CUDA training')
# parser.add_argument('--model', type=str, default='attention', help='Choose b/w attention and gated_attention')

# args = parser.parse_args()

class Arguments:
    
    def __init__(self,
                 epochs=20,
                 lr=10e-4,
                 reg=10e-5,
                 target_number=9,
                 mean_bag_length=10,
                 var_bag_length=2,
                 num_bags_train=20,
                 num_bags_test=50,
                 seed=1,
                 no_cuda=True,
                 model="attention"):
        self.epochs=epochs
        self.lr=lr
        self.reg=reg
        self.target_number=target_number
        self.mean_bag_length=mean_bag_length
        self.var_bag_length=var_bag_length
        self.num_bags_train=num_bags_train
        self.num_bags_test=num_bags_test
        self.seed=seed
        self.no_cuda=no_cuda
        self.model=model


args = Arguments()

args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    print('\nGPU is ON!')

print('Load Train and Test Set')
loader_kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

train_loader = data_utils.DataLoader(MnistBags(target_number=args.target_number,
                                               mean_bag_length=args.mean_bag_length,
                                               var_bag_length=args.var_bag_length,
                                               num_bag=args.num_bags_train,
                                               seed=args.seed,
                                               train=True),
                                     batch_size=1,
                                     shuffle=True,
                                     **loader_kwargs)

test_loader = data_utils.DataLoader(MnistBags(target_number=args.target_number,
                                              mean_bag_length=args.mean_bag_length,
                                              var_bag_length=args.var_bag_length,
                                              num_bag=args.num_bags_test,
                                              seed=args.seed,
                                              train=False),
                                    batch_size=1,
                                    shuffle=False,
                                    **loader_kwargs)

print('Init Model')
if args.model=='attention':
    model = Attention()
elif args.model=='gated_attention':
    model = GatedAttention()
if args.cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.reg)


def train(epoch):
    model.train()
    train_loss = 0.
    train_error = 0.
    for batch_idx, (data, label) in enumerate(train_loader):
        bag_label = label[0]
        if args.cuda:
            data, bag_label = data.cuda(), bag_label.cuda()
        data, bag_label = Variable(data), Variable(bag_label)

        # reset gradients
        optimizer.zero_grad()
        # calculate loss and metrics
        loss, _ = model.calculate_objective(data, bag_label)
        train_loss += loss.data[0]
        error, _ = model.calculate_classification_error(data, bag_label)
        train_error += error
        # backward pass
        loss.backward()
        # step
        optimizer.step()

    # calculate loss and error for epoch
    train_loss /= len(train_loader)
    train_error /= len(train_loader)

    print('Epoch: {}, Loss: {:.4f}, Train error: {:.4f}'.format(epoch, train_loss.cpu().numpy()[0], train_error))


def test():
    model.eval()
    test_loss = 0.
    test_error = 0.
    for batch_idx, (data, label) in enumerate(test_loader):
        bag_label = label[0]
        instance_labels = label[1]
        if args.cuda:
            data, bag_label = data.cuda(), bag_label.cuda()
        data, bag_label = Variable(data), Variable(bag_label)
        loss, attention_weights = model.calculate_objective(data, bag_label)
        test_loss += loss.data[0]
        error, predicted_label = model.calculate_classification_error(data, bag_label)
        test_error += error

        if batch_idx < 5:  # plot bag labels and instance labels for first 5 bags
            bag_level = (bag_label.cpu().data.numpy()[0], int(predicted_label.cpu().data.numpy()[0][0]))
            instance_level = list(zip(instance_labels.numpy()[0].tolist(),
                                 np.round(attention_weights.cpu().data.numpy()[0], decimals=3).tolist()))

            print('\nTrue Bag Label, Predicted Bag Label: {}\n'
                  'True Instance Labels, Attention Weights: {}'.format(bag_level, instance_level))

    test_error /= len(test_loader)
    test_loss /= len(test_loader)

    print('\nTest Set, Loss: {:.4f}, Test error: {:.4f}'.format(test_loss.cpu().numpy()[0], test_error))


# if __name__ == "__main__":
#     print('Start Training')
#     for epoch in range(1, args.epochs + 1):
#         train(epoch)
#     print('Start Testing')
#     test()

Load Train and Test Set
Init Model


In [13]:
print('Start Training')
for epoch in range(1, args.epochs + 1):
    train(epoch)
print('Start Testing')
test()

Start Training
Epoch: 1, Loss: 0.7381, Train error: 0.5500
Epoch: 2, Loss: 0.6817, Train error: 0.4000
Epoch: 3, Loss: 0.6605, Train error: 0.4000
Epoch: 4, Loss: 0.5994, Train error: 0.4000
Epoch: 5, Loss: 0.7965, Train error: 0.3000
Epoch: 6, Loss: 0.5494, Train error: 0.1000
Epoch: 7, Loss: 0.3999, Train error: 0.1000
Epoch: 8, Loss: 0.2606, Train error: 0.1000
Epoch: 9, Loss: 0.1657, Train error: 0.0500
Epoch: 10, Loss: 0.0883, Train error: 0.0500
Epoch: 11, Loss: 0.0214, Train error: 0.0000
Epoch: 12, Loss: 0.1629, Train error: 0.1000
Epoch: 13, Loss: 0.4894, Train error: 0.3500
Epoch: 14, Loss: 0.3745, Train error: 0.1500
Epoch: 15, Loss: 0.1238, Train error: 0.0000
Epoch: 16, Loss: 0.0393, Train error: 0.0000
Epoch: 17, Loss: 0.0146, Train error: 0.0000
Epoch: 18, Loss: 0.0572, Train error: 0.0500
Epoch: 19, Loss: 0.0235, Train error: 0.0000
Epoch: 20, Loss: 0.0033, Train error: 0.0000
Start Testing

True Bag Label, Predicted Bag Label: (True, 0)
True Instance Labels, Attention 