In [1]:
!git clone https://github.com/chrishendra93/MI_Workshop


Cloning into 'MI_Workshop'...
remote: Enumerating objects: 86, done.[K
remote: Counting objects: 100% (86/86), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 86 (delta 38), reused 60 (delta 22), pack-reused 0[K
Unpacking objects: 100% (86/86), done.


In [2]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

In [9]:
# Setting random seed so that everything is replicable

np.random.seed(0)

In [12]:
root_dir = "/content/MI_Workshop/mnist_clr"
data_dir = "/content/MI_Workshop/lab_3"

mnist_columns = ["label"] + ["features_{}".format(i) for i in range(28 ** 2)]
mnist_train = pd.read_csv("./sample_data/mnist_train_small.csv", names=mnist_columns)
mnist_test = pd.read_csv("./sample_data/mnist_test.csv", names=mnist_columns)
mnist_arr_train = mnist_train[["features_{}".format(i) for i in range(28 ** 2)]].values
mnist_arr_test = mnist_test[["features_{}".format(i) for i in range(28 ** 2)]].values

X_train = np.load(os.path.join(root_dir, "train_features.npy"))
X_test = np.load(os.path.join(root_dir, "test_features.npy"))
y_train = np.load(os.path.join(root_dir, "train_labels.npy"))
y_test = np.load(os.path.join(root_dir, "test_labels.npy"))

print(np.all(y_train == mnist_train["label"].values))
print(np.all(y_test == mnist_test["label"].values))

train_bags, train_labels = np.load(os.path.join(data_dir, "train_bags.npy"), allow_pickle=True), np.load(os.path.join(data_dir, "train_labels.npy"),allow_pickle=True)
test_bags, test_labels = np.load(os.path.join(data_dir, "test_bags.npy"), allow_pickle=True), np.load(os.path.join(data_dir, "test_labels.npy"), allow_pickle=True)


True
True


In [None]:
# Visualize distribution of labels in train and test sets


In [None]:
# Visualize distribution of instances in train and test sets


In [23]:
def visualize_bags(mnist_arr, bags, bag_labels, n_bags):
  pos_bags = np.argwhere(bag_labels == 1).flatten()
  neg_bags = np.argwhere(bag_labels == 0).flatten()
  n_pos = (n_bags // 2 )
  n_neg = n_bags - n_pos
  sampled_indices = np.concatenate([np.random.choice(pos_bags, n_pos, replace=False), np.random.choice(neg_bags, n_neg, replace=False)])
  sampled_bags = bags[sampled_indices]
  sampled_bag_labels = bag_labels[sampled_indices] 
  max_instances_num = np.max([len(bag) for bag in sampled_bags])
  _, axes = plt.subplots(len(sampled_bags), max_instances_num,
                         figsize=(10 * len(sampled_bags), 
                                  10 * max_instances_num))
  for idx, bag, bag_label in zip(np.arange(n_bags), sampled_bags, sampled_bag_labels):
    ax = axes[idx, :]
    instance_size = None
    for i in range(max_instances_num):
      if i >= len(bag):
        ax[i].imshow(np.zeros(instance_size)) # Pad with empty images if bag has fewer instances than max
      else:
        img = mnist_arr[bag[i]]
        if instance_size is None:
          w = int(np.sqrt(len(img)))
          instance_size = (w, w)
        ax[i].imshow(img.reshape(instance_size))
      if i == max_instances_num // 2:
        ax[i].set_title('Bag Label: {}'.format(bag_label), fontsize=50)
  plt.subplots_adjust(bottom=0.1, top=0.3, hspace=0.2)


In [None]:
# Visualize instances in positive bags


In [None]:
import torch.nn.functional as F
import torch
from torch import nn
from torch.nn import BCELoss
from torch.optim import LBFGS, Adam
from torch.utils.data import Dataset, DataLoader

In [None]:
class LogisticRegressionMI(nn.Module):

  def __init__(self, n_dim, mode='max'):
    super(LogisticRegressionMI, self).__init__()
    if mode not in ('max', 'mean'):
      raise ValueError("Invalid mode {}, must be one of max or mean".format(mode))
    self.mode = mode
    self.encoder = nn.Linear(n_dim, 1)
    
  def forward(self, x, indices):
    x = self.encoder(x)
    x = torch.sigmoid(x)
    if self.mode == 'max':
      x = torch.stack([torch.max(x[idx]) for idx in indices])
    else:
      x = torch.stack([torch.mean(x[idx]) for idx in indices])
    return x
  
  def get_max_indices(self, x, indices):
    x = self.encoder(x)
    x = torch.sigmoid(x)
    pred, max_indices = [], []
    for idx in indices:
      max_idx = torch.argmax(x[idx])
      max_indices.append(idx[max_idx])
      pred.append(x[idx][max_idx])
    return torch.cat(pred), np.array(max_indices)
    

In [4]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_curve, precision_recall_curve, auc
from scipy.stats import mode

def get_roc_auc(y_true, y_pred):
    fpr, tpr, _  = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    return roc_auc


def get_pr_auc(y_true, y_pred):
    precision, recall, _ = precision_recall_curve(y_true, y_pred, pos_label=1)
    pr_auc = auc(recall, precision)
    return pr_auc


def get_accuracy(y_true, y_pred):
    return balanced_accuracy_score(y_true, y_pred)

In [None]:
# logit_mi = LogisticRegressionMI(64, mode='mean')
# criterion = BCELoss()
# optimizer = LBFGS(logit_mi.parameters(), lr=0.001, max_iter=1000)
# logit_mi.train()


# def closure():
#     optimizer.zero_grad()
#     output = logit_mi(torch.Tensor(X_train), train_bags)
#     loss = criterion(output, torch.Tensor(train_labels))
#     loss.backward()
#     return loss

# optimizer.step(closure) 

# logit_mi.eval()
# with torch.no_grad():
#   y_pred_logit_mi_test = logit_mi(torch.Tensor(X_test), test_bags).detach().numpy()
#   print("ROC AUC: {}".format(get_roc_auc(test_labels, y_pred_logit_mi_test)))
#   print("PR AUC: {}".format(get_pr_auc(test_labels, y_pred_logit_mi_test)))
#   print("Accuracy Score: {}".format(accuracy_score(test_labels, y_pred_logit_mi_test >= 0.5)))
#   print("Balanced Accuracy Score: {}".format(balanced_accuracy_score(test_labels, y_pred_logit_mi_test >= 0.5)))

ROC AUC: 0.8444277901716553
PR AUC: 0.3773344528695057
Accuracy Score: 0.8806306306306306
Balanced Accuracy Score: 0.49974437627811863


In [14]:
import torch
from torch import nn
from torch.optim import Adam

torch.random.manual_seed(0)

class nnMI(nn.Module):

  def __init__(self, encoder, decoder, mode='max'):
    super(nnMI, self).__init__()
    if mode not in ('max', 'mean'):
      raise ValueError("Invalid mode {}, must be one of max or mean".format(mode))
    self.mode = mode
    self.encoder = encoder
    self.decoder = decoder
    
  def forward(self, x, indices):
    x = self.encoder(x)
    if self.mode == 'max':
      x = torch.stack([torch.max(x[idx], axis=0).values for idx in indices])
    else:
      x = torch.stack([torch.mean(x[idx], axis=0).values for idx in indices])
    x = self.decoder(x)
    return x

In [21]:
device = 'cpu'

encoder = nn.Sequential(*[nn.Linear(64, 128), nn.ReLU(),
                          nn.Linear(128, 64), nn.ReLU()])

decoder = nn.Sequential(*[nn.Linear(64, 128), nn.ReLU(),
                          nn.Linear(128, 64), nn.ReLU(),
                          nn.Linear(64, 1)])

nn_mi = nnMI(encoder, decoder, mode='max').to(device)


# Computing class weight for BCEWithLogitsLoss
# The loss requires weight for each individual data point so we have to create a tensor of size (N, )
# Here N is the number of bags in the training set and each entry corresponds to the weight that we want to assign to that particular entr

classes, counts = np.unique(train_labels, return_counts=True)
class_weights = 1./torch.tensor(counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum()

weights = torch.zeros(len(train_labels))
weights[np.argwhere(train_labels == 0)] = class_weights[0]
weights[np.argwhere(train_labels == 1)] = class_weights[1]

# Initialize loss function, optimizer and number of epochs


# Implement the torch training loop in here

nn_mi.eval()
with torch.no_grad():
  y_pred_nn_test = torch.sigmoid(nn_mi(torch.Tensor(X_test).to(device), test_bags)).detach().cpu().numpy().flatten()
  print("ROC AUC: {}".format(get_roc_auc(test_labels, y_pred_nn_test)))
  print("PR AUC: {}".format(get_pr_auc(test_labels, y_pred_nn_test)))
  print("Accuracy Score: {}".format(accuracy_score(test_labels, y_pred_nn_test >= 0.5)))
  print("Balanced Accuracy Score: {}".format(balanced_accuracy_score(test_labels, y_pred_nn_test >= 0.5)))


ROC AUC: 0.9306891770465391
PR AUC: 0.6986790905666337
Accuracy Score: 0.918018018018018
Balanced Accuracy Score: 0.8420710169176426


In [None]:
## Implement a dataset class here

# class MIDataSet(Dataset):

#   def __init__(self, X, bags, labels):


#   def __getitem__(self, idx):

  
#   def __len__(self):


# def collate_fn(batch):
#   X = torch.cat([sample[0] for sample in batch])
#   y = torch.cat([sample[1] for sample in batch])
#   indices = []
#   i = 0
#   for sample in batch:
#     n_instances = sample[2]
#     indices.append(np.arange(i, i + n_instances))
#     i += n_instances
#   return X, y, indices

# train_ds = MIDataSet(X_train, train_bags, train_labels)
# train_dl = DataLoader(train_ds, num_workers=10, batch_size=64, shuffle=True, collate_fn=collate_fn)

# test_ds = MIDataSet(X_test, test_bags, test_labels)
# test_dl = DataLoader(test_ds, num_workers=10, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [22]:
# Implement a training loop just as before but with torch 