In [6]:
import os
import datetime
import copy
import re
import yaml
import uuid
import warnings
import time
import inspect

import numpy as np
import pandas as pd
from functools import partial, reduce
from random import shuffle
import random

import torch
from torch import nn, optim
from torch import nn
from torch.nn import functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
from torchvision.models import resnet
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
import tensorflow as tf
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from sklearn import metrics as mtx
from sklearn import model_selection as ms

In [7]:
import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore")

> # 1. Pretrain

In [8]:
def get_data_loaders(train_batch_size, val_batch_size):
    mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
    data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

    train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
                              batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

In [9]:
train_batch_size = 256
val_batch_size = 256

train_loader, valid_loader = get_data_loaders(train_batch_size, val_batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  3%|▎         | 262144/9912422 [00:00<00:03, 2470535.06it/s]

100%|██████████| 9912422/9912422 [00:00<00:00, 24631079.98it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 30925630.28it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 8082843.90it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5432143.93it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



> ### 1.1 define model

In [10]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)

> ### 1.2 helper function

In [11]:
def calculate_metric(metric_fn, true_y, pred_y):
    # multi class problems need to have averaging method
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    # just an utility printing function
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")


> ### 1.3 train

In [12]:
start_ts = time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# model:
model = MnistResNet().to(device)

# params you need to specify:
epochs = 5
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
loss_function = nn.CrossEntropyLoss() # your loss function, cross entropy works well for multi-class problems

# optimizer, I've used Adadelta, as it wokrs well without any magic numbers
optimizer = optim.Adadelta(model.parameters())


losses = []
batches = len(train_loader)
val_batches = len(val_loader)

# loop for every epoch (training + evaluation)
for epoch in range(epochs):
    total_loss = 0

    # progress bar (works in Jupyter notebook too!)
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)

    # ----------------- TRAINING  -------------------- 
    # set model to training
    model.train()
    
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        # training step for single batch
        model.zero_grad() # to make sure that all the grads are 0 
        """
        model.zero_grad() and optimizer.zero_grad() are the same 
        IF all your model parameters are in that optimizer. 
        I found it is safer to call model.zero_grad() to make sure all grads are zero, 
        e.g. if you have two or more optimizers for one model.

        """
        outputs = model(X) # forward
        loss = loss_function(outputs, y) # get loss
        loss.backward() # accumulates the gradient (by addition) for each parameter.
        optimizer.step() # performs a parameter update based on the current gradient 

        # getting training quality data
        current_loss = loss.item()
        total_loss += current_loss

        # updating progress bar
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    # releasing unceseccary memory in GPU
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # ----------------- VALIDATION  ----------------- 
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    # set model to evaluating (testing)
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)

            outputs = model(X) # this get's the prediction from the network

            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1] # get class from network's prediction
            
            # calculate P/R/F1/A metrics for batch
            # Convert the tensors to CPU and then to numpy arrays
            predicted_classes_numpy = predicted_classes.cpu().numpy()
            y_numpy = y.cpu().numpy()

            # Calculate Precision, Recall, F1 Score, and Accuracy for this batch
            precision_batch = precision_score(y_numpy, predicted_classes_numpy, average='macro')
            recall_batch = recall_score(y_numpy, predicted_classes_numpy, average='macro')
            f1_batch = f1_score(y_numpy, predicted_classes_numpy, average='macro')
            accuracy_batch = accuracy_score(y_numpy, predicted_classes_numpy)

            # Append the values to their respective lists
            precision.append(precision_batch)
            recall.append(recall_batch)
            f1.append(f1_batch)
            accuracy.append(accuracy_batch)
            # for acc, metric in zip((precision, recall, f1, accuracy), 
            #                        (precision_score, recall_score, f1_score, accuracy_score)):
            #     acc.append(
            #         calculate_metric(metric, y.cpu().data.numpy(), predicted_classes.cpu().data.numpy())
            #     )
          
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches) # for plotting learning curve
print(f"Training time: {time.time()-start_ts}s")

Loss: 1.6031: 100%|██████████| 235/235 [01:25<00:00,  2.75it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/5, training loss: 1.6030886132666404, validation loss: 1.8405001163482666
	     precision: 0.8200
	        recall: 0.6232
	            F1: 0.6092
	      accuracy: 0.6191


Loss: 1.4771: 100%|██████████| 235/235 [01:22<00:00,  2.86it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2/5, training loss: 1.4771335657606735, validation loss: 1.7645254135131836
	     precision: 0.8439
	        recall: 0.7096
	            F1: 0.6944
	      accuracy: 0.7123


Loss: 1.4728: 100%|██████████| 235/235 [01:22<00:00,  2.85it/s]


Epoch 3/5, training loss: 1.4728196742686819, validation loss: 1.4718458652496338
	     precision: 0.9913
	        recall: 0.9913
	            F1: 0.9911
	      accuracy: 0.9914


Loss: 1.4693: 100%|██████████| 235/235 [01:22<00:00,  2.86it/s]


Epoch 4/5, training loss: 1.4693300049355689, validation loss: 1.4815601110458374
	     precision: 0.9846
	        recall: 0.9832
	            F1: 0.9832
	      accuracy: 0.9833


Loss: 1.4675: 100%|██████████| 235/235 [01:21<00:00,  2.87it/s]


Epoch 5/5, training loss: 1.467486802567827, validation loss: 1.4703575372695923
	     precision: 0.9915
	        recall: 0.9915
	            F1: 0.9913
	      accuracy: 0.9916
Training time: 468.6834900379181s


> ### 1.4 save model

In [13]:
torch.save(model.state_dict(), 'mnist_state.pt')

> # 2. Data Generation

> ### 2.1 load data

In [14]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [15]:
x_train.shape

(60000, 28, 28)

In [16]:
x_train = x_train[:30001]
y_train = y_train[:30001]
x_test = x_test[:9000]
y_test = y_test[:9000]

In [17]:
# Making sure that the values are float so that we can get decimal points after division
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

# Normalizing the RGB codes by dividing it to the max RGB value.
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print('Number of images in x_train', x_train.shape[0])
print('Number of images in x_test', x_test.shape[0])

x_train shape: (30001, 28, 28)
Number of images in x_train 30001
Number of images in x_test 9000


> ### 2.2 Create tuple (index, label) for train and test

In [18]:
instance_index_label = [(i, y_train[i]) for i in range(x_train.shape[0])]
instance_index_label_test = [(i, y_test[i]) for i in range(x_test.shape[0])]

In [19]:
# find the index if label is 1
find_index = [instance_index_label[i][0] for i in range(len(instance_index_label)) if instance_index_label[i][1]==1]
# find the index if label is 1
find_index_test = [instance_index_label_test[i][0] for i in range(len(instance_index_label_test))
                   if instance_index_label_test[i][1]==1]

In [20]:
print('index:', instance_index_label[0][0]) #index
print('label:', instance_index_label[0][1]) #label

index: 0
label: 5


> ### 2.3 load pretrained model

In [21]:
import torch
from torchvision.models.resnet import ResNet, BasicBlock
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)

In [22]:
model = MnistResNet()
model.load_state_dict(torch.load('mnist_state.pt'))
body = nn.Sequential(*list(model.children()))
# extract the last layer
model = body[:9]
# the model we will use
model.eval()

Sequential(
  (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

> ### 2.4 get features

In [23]:
train_batch_size = 1
val_batch_size = 1
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
loss_function = nn.CrossEntropyLoss() # your loss function, cross entropy works well for multi-class problems

# optimizer
optimizer = optim.Adadelta(model.parameters())



In [24]:
losses = []
batches = len(train_loader)
val_batches = len(val_loader)

> #### 2.4.1 get features for train

In [25]:
# loop for every epoch (training + evaluation)
meta_table = dict()
feature_result = []

# progress bar
progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)

model.eval()

for i, data in progress:
    if i==30001:
        break
    X, y = data[0], data[1]
    # training step for single batch
    model.zero_grad()
    outputs = model(X)
    feature_result.append(outputs.reshape(-1).tolist())
    meta_table[i] = outputs.reshape(-1).tolist()
    
feature_array = np.array(feature_result)
np.save('feature_array_full',feature_array )


Loss:   0%|          | 0/60000 [00:00<?, ?it/s]

Loss:  50%|█████     | 30001/60000 [17:45<17:45, 28.16it/s]


In [27]:
# load
feature_array = np.load('./feature_array_full.npy', allow_pickle=True)

> #### 2.4.2 get features for test

In [28]:
# loop for every epoch (training + evaluation)
meta_t_table = dict()
feature_t_result = []

# progress bar
progress = tqdm(enumerate(val_loader), desc="Loss: ", total=batches)

model.eval()

for i, data in progress:
    if i==9000:
        break
    X, y = data[0], data[1]
    # training step for single batch
    model.zero_grad()
    outputs_t = model(X)
    feature_t_result.append(outputs_t.reshape(-1).tolist())
    meta_t_table[i] = outputs_t.reshape(-1).tolist()

feature_test_array = np.array(feature_t_result)
# save 
np.save('feature_test_array_full',feature_test_array )


Loss:  15%|█▌        | 9000/60000 [05:21<30:20, 28.01it/s]


In [None]:
#load
feature_test_array = np.load('feature_test_array_full.npy', allow_pickle=True)

> ### 2.5 generate data

> #### 2.5.1 generate data for train

In [None]:
from typing import List, Dict, Tuple
def data_generation(instance_index_label: List[Tuple]) -> List[Dict]:
    """
    bags: {key1: [ind1, ind2, ind3],
           key2: [ind1, ind2, ind3, ind4, ind5],
           ... }
    bag_lbls:
        {key1: 0,
         key2: 1,
         ... }
    """
    bag_size = np.random.randint(3,7,size=len(instance_index_label)//5)
    data_cp = copy.copy(instance_index_label)
    np.random.shuffle(data_cp)
    bags = {}
    bags_per_instance_labels = {}
    bags_labels = {}
    for bag_ind, size in enumerate(bag_size):
        bags[bag_ind] = []
        bags_per_instance_labels[bag_ind] = []
        try:
            for _ in range(size):
                inst_ind, lbl = data_cp.pop()
                bags[bag_ind].append(inst_ind)
                # simplfy, just use a temporary variable instead of bags_per_instance_labels
                bags_per_instance_labels[bag_ind].append(lbl)
            bags_labels[bag_ind] = bag_label_from_instance_labels(bags_per_instance_labels[bag_ind])
        except:
            break
    return bags, bags_labels

def bag_label_from_instance_labels(instance_labels):
    return int(any(((x==1) for x in instance_labels)))

In [None]:
bag_indices, bag_labels = data_generation(instance_index_label)
bag_features = {kk: torch.Tensor(feature_array[inds]) for kk, inds in bag_indices.items()}

In [None]:
# save
import pickle
pickle.dump(bag_indices, open( "bag_indices", "wb" ) )
pickle.dump(bag_labels, open( "bag_labels", "wb" ) )
pickle.dump(bag_features, open( "bag_features", "wb" ) )

In [None]:
import pickle
bag_indices = pickle.load( open( "bag_indices", "rb" ) )
bag_labels = pickle.load( open( "bag_labels", "rb" ) )
bag_features = pickle.load( open( "bag_features", "rb" ) )

> #### 2.5.2 generate data for test

In [None]:
bag_t_indices, bag_t_labels = data_generation(instance_index_label_test)

In [None]:
bag_t_features = {kk: torch.Tensor(feature_test_array[inds]) for kk, inds in bag_t_indices.items()}

In [None]:
pickle.dump(bag_t_indices, open( "bag_t_indices", "wb" ) )
pickle.dump(bag_t_labels, open( "bag_t_labels", "wb" ) )
pickle.dump(bag_t_features, open( "bag_t_features", "wb" ) )

In [None]:
bag_t_indices = pickle.load( open( "bag_t_indices", "rb" ) )
bag_t_labels = pickle.load( open( "bag_t_labels", "rb" ) )
bag_t_features = pickle.load( open( "bag_t_features", "rb" ) )

> # 3. Multiple Instance Learning

> ### 3.1 Prepare data for model

In [None]:
from torch.utils.data import Dataset
class Transform_data(Dataset):
    """
    We want to 1. pad tensor 2. transform the data to the size that fits in the input size.
    
    """

    def __init__(self, data, transform=None):
        self.transform = transform
        self.data = data
        
    def __getitem__(self, index):
        tensor = self.data[index][0]
        if self.transform is not None:
            tensor = self.transform(tensor)
        return (tensor, self.data[index][1])

    def __len__(self):
        return len(self.data)

In [None]:
train_data = [(bag_features[i],bag_labels[i]) for i in range(len(bag_features))]

In [None]:
bag_features[0]

tensor([[2.2300, 1.2076, 0.4186,  ..., 0.4665, 1.6642, 2.2301],
        [0.8595, 1.5297, 0.9046,  ..., 0.5341, 0.9438, 0.8581],
        [2.6144, 0.2394, 0.3311,  ..., 0.0778, 1.6823, 1.7735],
        [0.1552, 1.2130, 1.0170,  ..., 2.8621, 0.8361, 0.2522],
        [0.7179, 0.3636, 0.4274,  ..., 0.2156, 0.3240, 0.0769]])

In [None]:
def pad_tensor(data:list, max_number_instance) -> list:
    """
    Since our bag has different sizes, we need to pad each tensor to have the same shape (max: 7).
    We will look through each one instance and look at the shape of the tensor, and then we will pad 7-n 
    to the existing tensor where n is the number of instances in the bag.
    The function will return a padded data set."""
    new_data = []
    for bag_index in range(len(data)):
        tensor_size = len(data[bag_index][0])
        pad_size = max_number_instance - tensor_size
        p2d = (0,0, 0, pad_size)
        padded = nn.functional.pad(data[bag_index][0], p2d, 'constant', 0)
        new_data.append((padded, data[bag_index][1]))
    return new_data

In [None]:
max_number_instance = 7
padded_train = pad_tensor(train_data, max_number_instance)

In [None]:
test_data = [(bag_t_features[i],bag_t_labels[i]) for i in range(len(bag_t_features))]
padded_test = pad_tensor(test_data, max_number_instance)

In [None]:
def get_data_loaders(train_data, test_data, train_batch_size, val_batch_size):
    train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

In [None]:
train_loader,valid_loader = get_data_loaders(padded_train, padded_test, 1, 1)

In [None]:
train_batch_size = 1
val_batch_size = 1

> ### 3.2 Define model

> #### 3.2.1 Aggregation Funtion model

In [None]:
# aggregation functions

class SoftMaxMeanSimple(torch.nn.Module):
    def __init__(self, n, n_inst, dim=0):
        """
        if dim==1:
            given a tensor `x` with dimensions [N * M],
            where M -- dimensionality of the featur vector
                       (number of features per instance)
                  N -- number of instances
            initialize with `AggModule(M)`
            returns:
            - weighted result: [M]
            - gate: [N]
        if dim==0:
            ...
        """
        super(SoftMaxMeanSimple, self).__init__()
        self.dim = dim
        self.gate = torch.nn.Softmax(dim=self.dim)      
        self.mdl_instance_transform = nn.Sequential(
                            nn.Linear(n, n_inst),
                            nn.LeakyReLU(),
                            nn.Linear(n_inst, n),
                            nn.LeakyReLU(),
                            )
    def forward(self, x):
        z = self.mdl_instance_transform(x)
        if self.dim==0:
            z = z.view((z.shape[0],1)).sum(1)
        elif self.dim==1:
            z = z.view((1, z.shape[1])).sum(0)
        gate_ = self.gate(z)
        res = torch.sum(x* gate_, self.dim)
        return res, gate_

    
class AttentionSoftMax(torch.nn.Module):
    def __init__(self, in_features = 3, out_features = None):
        """
        given a tensor `x` with dimensions [N * M],
        where M -- dimensionality of the featur vector
                   (number of features per instance)
              N -- number of instances
        initialize with `AggModule(M)`
        returns:
        - weighted result: [M]
        - gate: [N]
        """
        super(AttentionSoftMax, self).__init__()
        self.otherdim = ''
        if out_features is None:
            out_features = in_features
        self.layer_linear_tr = nn.Linear(in_features, out_features)
        self.activation = nn.LeakyReLU()
        self.layer_linear_query = nn.Linear(out_features, 1)
        
    def forward(self, x):
        keys = self.layer_linear_tr(x)
        keys = self.activation(keys)
        attention_map_raw = self.layer_linear_query(keys)[...,0]
        attention_map = nn.Softmax(dim=-1)(attention_map_raw)
        result = torch.einsum(f'{self.otherdim}i,{self.otherdim}ij->{self.otherdim}j', attention_map, x)
        return result, attention_map
    

> #### 3.3.2 MIL_NN model

In [None]:
class NoisyAnd(torch.nn.Module):
    def __init__(self, a=10, dims=[1,2]):
        super(NoisyAnd, self).__init__()
#         self.output_dim = output_dim
        self.a = a
        self.b = torch.nn.Parameter(torch.tensor(0.01))
        self.dims =dims
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
#         h_relu = self.linear1(x).clamp(min=0)
        mean = torch.mean(x, self.dims, True)
        res = (self.sigmoid(self.a * (mean - self.b)) - self.sigmoid(-self.a * self.b)) / (
              self.sigmoid(self.a * (1 - self.b)) - self.sigmoid(-self.a * self.b))
        return res
    


class NN(torch.nn.Module):

    def __init__(self, n=512, n_mid = 1024,
                 n_out=1, dropout=0.2,
                 scoring = None,
                ):
        super(NN, self).__init__()
        self.linear1 = torch.nn.Linear(n, n_mid)
        self.non_linearity = torch.nn.LeakyReLU()
        self.linear2 = torch.nn.Linear(n_mid, n_out)
        self.dropout = torch.nn.Dropout(dropout)
        if scoring:
            self.scoring = scoring
        else:
            self.scoring = torch.nn.Softmax() if n_out>1 else torch.nn.Sigmoid()
        
    def forward(self, x):
        z = self.linear1(x)
        z = self.non_linearity(z)
        z = self.dropout(z)
        z = self.linear2(z)
        y_pred = self.scoring(z)
        return y_pred
    

class LogisticRegression(torch.nn.Module):
    def __init__(self, n=512, n_out=1):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(n, n_out)
        self.scoring = torch.nn.Softmax() if n_out>1 else torch.nn.Sigmoid()

    def forward(self, x):
        z = self.linear(x)
        y_pred = self.scoring(z)
        return y_pred

    
def regularization_loss(params,
                        reg_factor = 0.005,
                        reg_alpha = 0.5):
    params = [pp for pp in params if len(pp.shape)>1]
    l1_reg = nn.L1Loss()
    l2_reg = nn.MSELoss()
    loss_reg =0
    for pp in params:
        loss_reg+=reg_factor*((1-reg_alpha)*l1_reg(pp, target=torch.zeros_like(pp)) +\
                           reg_alpha*l2_reg(pp, target=torch.zeros_like(pp)))
    return loss_reg

class MIL_NN(torch.nn.Module):

    def __init__(self, n=512,  
                 n_mid=1024, 
                 n_classes=1, 
                 dropout=0.1,
                 agg = None,
                 scoring=None,
                ):
        super(MIL_NN, self).__init__()
        self.agg = agg if agg is not None else AttentionSoftMax(n)
        
        if n_mid == 0:
            self.bag_model = LogisticRegression(n, n_classes)
        else:
            self.bag_model = NN(n, n_mid, n_classes, dropout=dropout, scoring=scoring)
        
    def forward(self, bag_features, bag_lbls=None):
        """
        bag_feature is an aggregated vector of 512 features
        bag_att is a gate vector of n_inst instances
        bag_lbl is a vector a labels
        figure out batches
        """
        bag_feature, bag_att, bag_keys = list(zip(*[list(self.agg(ff.float())) + [idx]
                                                    for idx, ff in (bag_features.items())]))
        bag_att = dict(zip(bag_keys, [a.detach().cpu() for a  in bag_att]))
        bag_feature_stacked = torch.stack(bag_feature)
        y_pred = self.bag_model(bag_feature_stacked)
        return y_pred, bag_att, bag_keys

> #### 3.3.3 helper function

In [None]:
def calculate_metric(metric_fn, true_y, pred_y):
    # multi class problems need to have averaging method
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    # just an utility printing function
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

> # 4. Train & Test

In [None]:
import numpy as np
start_ts = time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

lr0 = 1e-4

# model:
model = MIL_NN().to(device)

# params you need to specify:
epochs = 10
train_loader, val_loader = get_data_loaders(padded_train, padded_test, 1, 1)
loss_function = torch.nn.BCELoss(reduction='mean') # your loss function, cross entropy works well for multi-class problems


#optimizer = optim.Adadelta(model.parameters())
optimizer = optim.SGD(model.parameters(), lr=lr0, momentum=0.9)

losses = []
batches = len(train_loader)
val_batches = len(val_loader)

# loop for every epoch (training + evaluation)
for epoch in range(epochs):
    total_loss = 0

    # progress bar (works in Jupyter notebook too!)
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)

    # ----------------- TRAINING  -------------------- 
    # set model to training
    model.train()
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        X = X.reshape([1,7*512])
        y = y.type(torch.cuda.FloatTensor)
        # training step for single batch
        model.zero_grad() # to make sure that all the grads are 0 
        """
        model.zero_grad() and optimizer.zero_grad() are the same 
        IF all your model parameters are in that optimizer. 
        I found it is safer to call model.zero_grad() to make sure all grads are zero, 
        e.g. if you have two or more optimizers for one model.

        """
        outputs = model(X) # forward
        loss = loss_function(outputs, y) # get loss
        loss.backward() # accumulates the gradient (by addition) for each parameter.
        optimizer.step() # performs a parameter update based on the current gradient 

        # getting training quality data
        current_loss = loss.item()
        total_loss += current_loss

        # updating progress bar
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    # releasing unceseccary memory in GPU
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # ----------------- VALIDATION  ----------------- 
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    # set model to evaluating (testing)
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            X = X.reshape([1,7*512])
            y = y.type(torch.cuda.FloatTensor)
            outputs = model(X) # this get's the prediction from the network
            prediced_classes =outputs.detach().round()
            #y_pred.extend(prediced_classes.tolist())
            val_losses += loss_function(outputs, y)
            
            # calculate P/R/F1/A metrics for batch
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), prediced_classes.cpu())
                )
          
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches) # for plotting learning curve
print(f"Training time: {time.time()-start_ts}s")

HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)





  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


Epoch 1/10, training loss: 0.3253995268222628, validation loss: 0.17935876548290253
	     precision: 0.9411
	        recall: 0.9411
	            F1: 0.9411
	      accuracy: 0.9411


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 2/10, training loss: 0.17075676332922501, validation loss: 0.1573299914598465
	     precision: 0.9361
	        recall: 0.9361
	            F1: 0.9361
	      accuracy: 0.9361


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 3/10, training loss: 0.1241153416950585, validation loss: 0.14403362572193146
	     precision: 0.9461
	        recall: 0.9461
	            F1: 0.9461
	      accuracy: 0.9461


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 4/10, training loss: 0.10115043728419369, validation loss: 0.11335857212543488
	     precision: 0.9628
	        recall: 0.9628
	            F1: 0.9628
	      accuracy: 0.9628


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 5/10, training loss: 0.07484548372746803, validation loss: 0.07197500765323639
	     precision: 0.9778
	        recall: 0.9778
	            F1: 0.9778
	      accuracy: 0.9778


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 6/10, training loss: 0.060945956385883286, validation loss: 0.07357295602560043
	     precision: 0.9739
	        recall: 0.9739
	            F1: 0.9739
	      accuracy: 0.9739


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 7/10, training loss: 0.05318203048250962, validation loss: 0.074022576212883
	     precision: 0.9756
	        recall: 0.9756
	            F1: 0.9756
	      accuracy: 0.9756


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 8/10, training loss: 0.04461347977752607, validation loss: 0.0600493848323822
	     precision: 0.9783
	        recall: 0.9783
	            F1: 0.9783
	      accuracy: 0.9783


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 9/10, training loss: 0.03741783676902158, validation loss: 0.0561109222471714
	     precision: 0.9856
	        recall: 0.9856
	            F1: 0.9856
	      accuracy: 0.9856


HBox(children=(IntProgress(value=0, description='Loss: ', max=6000, style=ProgressStyle(description_width='ini…


Epoch 10/10, training loss: 0.029147379207246065, validation loss: 0.05226973071694374
	     precision: 0.9850
	        recall: 0.9850
	            F1: 0.9850
	      accuracy: 0.9850
Training time: 690.9660251140594s
