In [1]:
!git clone https://github.com/chrishendra93/MI_Workshop


Cloning into 'MI_Workshop'...
remote: Enumerating objects: 74, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (57/57), done.[K
remote: Total 74 (delta 30), reused 46 (delta 16), pack-reused 0[K
Unpacking objects: 100% (74/74), done.


In [32]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

In [33]:
# Setting random seed so that everything is replicable

np.random.seed(0)
torch.random.manual_seed(0)

<torch._C.Generator at 0x7f890acbba10>

In [34]:
root_dir = "/content/MI_Workshop/mnist_clr"
data_dir = "/content/MI_Workshop/lab_3"

mnist_columns = ["label"] + ["features_{}".format(i) for i in range(28 ** 2)]
mnist_train = pd.read_csv("./sample_data/mnist_train_small.csv", names=mnist_columns)
mnist_test = pd.read_csv("./sample_data/mnist_test.csv", names=mnist_columns)
mnist_arr_train = mnist_train[["features_{}".format(i) for i in range(28 ** 2)]].values
mnist_arr_test = mnist_test[["features_{}".format(i) for i in range(28 ** 2)]].values

X_train = np.load(os.path.join(root_dir, "train_features.npy"))
X_test = np.load(os.path.join(root_dir, "test_features.npy"))
y_train = np.load(os.path.join(root_dir, "train_labels.npy"))
y_test = np.load(os.path.join(root_dir, "test_labels.npy"))

print(np.all(y_train == mnist_train["label"].values))
print(np.all(y_test == mnist_test["label"].values))

train_bags, train_labels = np.load(os.path.join(data_dir, "train_bags.npy"), allow_pickle=True), np.load(os.path.join(data_dir, "train_labels.npy"),allow_pickle=True)
test_bags, test_labels = np.load(os.path.join(data_dir, "test_bags.npy"), allow_pickle=True), np.load(os.path.join(data_dir, "test_labels.npy"), allow_pickle=True)


True
True


In [35]:
# Visualize distribution of labels in train and test sets


In [36]:
# Visualize distribution of instances in train and test sets


In [37]:
import torch.nn.functional as F
import torch
from torch import nn
from torch.nn import BCELoss
from torch.optim import LBFGS, Adam
from torch.utils.data import Dataset, DataLoader

In [38]:
class LogisticRegressionMI(nn.Module):

  def __init__(self, n_dim, mode='max'):
    super(LogisticRegressionMI, self).__init__()
    if mode not in ('max', 'mean'):
      raise ValueError("Invalid mode {}, must be one of max or mean".format(mode))
    self.mode = mode
    self.encoder = nn.Linear(n_dim, 1)
    
  def forward(self, x, indices):
    x = self.encoder(x)
    x = torch.sigmoid(x)
    if self.mode == 'max':
      x = torch.stack([torch.max(x[idx]) for idx in indices])
    else:
      x = torch.stack([torch.mean(x[idx]) for idx in indices])
    return x
  
  def get_max_indices(self, x, indices):
    x = self.encoder(x)
    x = torch.sigmoid(x)
    pred, max_indices = [], []
    for idx in indices:
      max_idx = torch.argmax(x[idx])
      max_indices.append(idx[max_idx])
      pred.append(x[idx][max_idx])
    return torch.cat(pred), np.array(max_indices)
    

In [39]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_curve, precision_recall_curve, auc
from scipy.stats import mode

def get_roc_auc(y_true, y_pred):
    fpr, tpr, _  = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    return roc_auc


def get_pr_auc(y_true, y_pred):
    precision, recall, _ = precision_recall_curve(y_true, y_pred, pos_label=1)
    pr_auc = auc(recall, precision)
    return pr_auc


def get_accuracy(y_true, y_pred):
    return balanced_accuracy_score(y_true, y_pred)

In [41]:
logit_mi = LogisticRegressionMI(64, mode='mean')
criterion = BCELoss()
optimizer = LBFGS(logit_mi.parameters(), lr=0.001, max_iter=1000)
logit_mi.train()


def closure():
    optimizer.zero_grad()
    output = logit_mi(torch.Tensor(X_train), train_bags)
    loss = criterion(output, torch.Tensor(train_labels))
    loss.backward()
    return loss

optimizer.step(closure) 

logit_mi.eval()
with torch.no_grad():
  y_pred_logit_mi_test = logit_mi(torch.Tensor(X_test), test_bags).detach().numpy()
  print("ROC AUC: {}".format(get_roc_auc(test_labels, y_pred_logit_mi_test)))
  print("PR AUC: {}".format(get_pr_auc(test_labels, y_pred_logit_mi_test)))
  print("Accuracy Score: {}".format(accuracy_score(test_labels, y_pred_logit_mi_test >= 0.5)))
  print("Balanced Accuracy Score: {}".format(balanced_accuracy_score(test_labels, y_pred_logit_mi_test >= 0.5)))

ROC AUC: 0.8444277901716553
PR AUC: 0.3773344528695057
Accuracy Score: 0.8806306306306306
Balanced Accuracy Score: 0.49974437627811863


In [42]:
logit_mi = LogisticRegressionMI(64, mode='max')

classes, counts = np.unique(train_labels, return_counts=True)
class_weights = 1./torch.tensor(counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum()
weights = torch.zeros(len(train_labels))
weights[np.argwhere(train_labels == 0)] = class_weights[0] / counts[0]
weights[np.argwhere(train_labels == 1)] = class_weights[1] / counts[1]

criterion = BCELoss(weight=weights)

logit_mi.train()


optimizer = Adam(logit_mi.parameters(), lr=4e-3)
n_epochs = 100

for i in range(n_epochs):
  optimizer.zero_grad()
  output = logit_mi(torch.Tensor(X_train), train_bags)
  loss = criterion(output, torch.Tensor(train_labels))
  loss.backward()
  optimizer.step() 

logit_mi.eval()
with torch.no_grad():
  y_pred_logit_mi_test = logit_mi(torch.Tensor(X_test), test_bags).detach().numpy()
  print("ROC AUC: {}".format(get_roc_auc(test_labels, y_pred_logit_mi_test)))
  print("PR AUC: {}".format(get_pr_auc(test_labels, y_pred_logit_mi_test)))
  print("Accuracy Score: {}".format(accuracy_score(test_labels, y_pred_logit_mi_test >= 0.5)))
  print("Balanced Accuracy Score: {}".format(balanced_accuracy_score(test_labels, y_pred_logit_mi_test >= 0.5)))

ROC AUC: 0.7338860228047344
PR AUC: 0.23332103380168645
Accuracy Score: 0.11891891891891893
Balanced Accuracy Score: 0.5


In [None]:
class nnMI(nn.Module):

  def __init__(self, encoder, mode='max'):
    super(nnMI, self).__init__()
    if mode not in ('max', 'mean'):
      raise ValueError("Invalid mode {}, must be one of max or mean".format(mode))
    self.mode = mode
    self.encoder = encoder
    
  def forward(self, x, indices):
    x = self.encoder(x)
    if self.mode == 'max':
      x = torch.stack([torch.max(x[idx]) for idx in indices])
    else:
      x = torch.stack([torch.mean(x[idx]) for idx in indices])
    return x


In [43]:
device = 'cuda'

nn_mi = nnMI(nn.Sequential(*[nn.Linear(64, 128), nn.ReLU(),
                             nn.Linear(128, 64), nn.ReLU(),
                             nn.Linear(64, 1)]), mode='max').to(device)

classes, counts = np.unique(train_labels, return_counts=True)
class_weights = 1 /torch.tensor(counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum()

weights = torch.zeros(len(train_labels))
weights[np.argwhere(train_labels == 0)] = class_weights[0]
weights[np.argwhere(train_labels == 1)] = class_weights[1]

criterion = nn.BCEWithLogitsLoss(weight=weights).to(device)

optimizer = Adam(nn_mi.parameters(), lr=4e-5)
n_epochs = 100

for i in range(n_epochs):
  nn_mi.train()
  optimizer.zero_grad()
  output = nn_mi(torch.Tensor(X_train).to(device), train_bags)
  loss = criterion(output, torch.Tensor(train_labels).to(device))
  loss.backward()
  optimizer.step() 

nn_mi.eval()
with torch.no_grad():
  y_pred_nn_test = torch.sigmoid(nn_mi(torch.Tensor(X_test).to(device), test_bags)).detach().cpu().numpy().flatten()
  print("ROC AUC: {}".format(get_roc_auc(test_labels, y_pred_nn_test)))
  print("PR AUC: {}".format(get_pr_auc(test_labels, y_pred_nn_test)))
  print("Accuracy Score: {}".format(accuracy_score(test_labels, y_pred_nn_test >= 0.5)))
  print("Balanced Accuracy Score: {}".format(balanced_accuracy_score(test_labels, y_pred_nn_test >= 0.5)))


ROC AUC: 0.7740915675156472
PR AUC: 0.2557849172870432
Accuracy Score: 0.11891891891891893
Balanced Accuracy Score: 0.5


In [None]:
class nnMIMax(nn.Module):

  def __init__(self, encoder, decoder, mode='max'):
    super(nnMIMax, self).__init__()
    if mode not in ('max', 'mean'):
      raise ValueError("Invalid mode {}, must be one of max or mean".format(mode))
    self.mode = mode
    self.encoder = encoder
    self.decoder = decoder
    
  def forward(self, x, indices):
    x = self.encoder(x)
    if self.mode == 'max':
      x = torch.stack([torch.max(x[idx], axis=0).values for idx in indices])
    else:
      x = torch.stack([torch.mean(x[idx], axis=0).values for idx in indices])
    x = self.decoder(x)
    return x

In [44]:
device = 'cuda'

nn_mi = nnMIMax(nn.Sequential(*[nn.Linear(64, 128), nn.ReLU(),
                               nn.Linear(128, 64), nn.ReLU()]),
                nn.Sequential(*[nn.Linear(64, 128), nn.ReLU(), 
                                nn.Linear(128, 64), nn.ReLU(),
                                nn.Linear(64, 1)]), mode='max').to(device)

classes, counts = np.unique(train_labels, return_counts=True)
class_weights = 1./torch.tensor(counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum()

weights = torch.zeros(len(train_labels))
weights[np.argwhere(train_labels == 0)] = class_weights[0]
weights[np.argwhere(train_labels == 1)] = class_weights[1]

criterion = nn.BCEWithLogitsLoss(weight=weights).to(device)

optimizer = Adam(nn_mi.parameters(), lr=4e-3)
n_epochs = 100

for i in range(n_epochs):
  nn_mi.train()
  optimizer.zero_grad()
  output = nn_mi(torch.Tensor(X_train).to(device), train_bags)
  loss = criterion(output.flatten(), torch.Tensor(train_labels).to(device))
  loss.backward()
  optimizer.step() 

nn_mi.eval()
with torch.no_grad():
  y_pred_nn_test = torch.sigmoid(nn_mi(torch.Tensor(X_test).to(device), test_bags)).detach().cpu().numpy().flatten()
  print("ROC AUC: {}".format(get_roc_auc(test_labels, y_pred_nn_test)))
  print("PR AUC: {}".format(get_pr_auc(test_labels, y_pred_nn_test)))
  print("Accuracy Score: {}".format(accuracy_score(test_labels, y_pred_nn_test >= 0.5)))
  print("Balanced Accuracy Score: {}".format(balanced_accuracy_score(test_labels, y_pred_nn_test >= 0.5)))


ROC AUC: 0.9342233686558841
PR AUC: 0.7028383752189724
Accuracy Score: 0.9198198198198199
Balanced Accuracy Score: 0.8283486707566463


In [None]:
class nnMIAttention(nn.Module):

  def __init__(self, encoder, decoder, mode='max'):
    super(nnMIMax, self).__init__()
    if mode not in ('max', 'mean'):
      raise ValueError("Invalid mode {}, must be one of max or mean".format(mode))
    self.mode = mode
    self.encoder = encoder
    self.decoder = decoder
    
  def forward(self, x, indices):
    x = self.encoder(x)
    if self.mode == 'max':
      x = torch.stack([torch.max(x[idx], axis=0).values for idx in indices])
    else:
      x = torch.stack([torch.mean(x[idx], axis=0).values for idx in indices])
    x = self.decoder(x)
    return x