# Imports

In [None]:
!pip install colorama

In [None]:
import os
import math
import json
import random as rnd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import matplotlib.pyplot as plt
import torchvision.utils as vision_utils
from PIL import Image
import torchvision
from colorama import Fore, Back, Style

from matplotlib.ticker import NullFormatter


DEVICE = torch.device('cuda')

In [None]:
!nvidia-smi

# Build MF-Dominoes dataset $\hat{\mathcal{D}}$

In [None]:
def plot_samples(dataset, nrow=13, figsize=(10,7)):
  try:
    X, Y = dataset.tensors
  except:
    try:
      (X,) = dataset.tensors
    except:
      X = dataset
  fig = plt.figure(figsize=figsize, dpi=130)
  grid_img = vision_utils.make_grid(X[:nrow].cpu(), nrow=nrow, normalize=True, padding=1)
  _ = plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')
  _ = plt.tick_params(axis=u'both', which=u'both',length=0)
  ax = plt.gca()
  _ = ax.xaxis.set_major_formatter(NullFormatter()) 
  _ = ax.yaxis.set_major_formatter(NullFormatter()) 
  plt.show()


def keep_only_lbls(dataset, lbls):
  lbls = {lbl: i for i, lbl in enumerate(lbls)}
  final_X, final_Y = [], []
  for x, y in dataset:
    if y in lbls:
      final_X.append(x)
      final_Y.append(lbls[y])
  X = torch.stack(final_X)
  Y = torch.tensor(final_Y).float().view(-1,1)
  return X, Y


transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

mnist_train = torchvision.datasets.MNIST('./data/mnist/', train=True, download=True, transform=transform)
fashm_train = torchvision.datasets.FashionMNIST('./data/FashionMNIST/', train=True, download=True, transform=transform)
mnist_perturb_base, mnist_train, mnist_valid = random_split(mnist_train, [10000, 45000, 5000], generator=torch.Generator().manual_seed(42))
fashm_perturb_base, fashm_train, fashm_valid = random_split(fashm_train, [10000, 45000, 5000], generator=torch.Generator().manual_seed(42))

mnist_test = torchvision.datasets.MNIST('./data/mnist/', train=False, download=True, transform=transform)
fashm_test = torchvision.datasets.FashionMNIST('./data/FashionMNIST/', train=False, download=True, transform=transform)

def build_mf_dataset(mnist_data, fashm_data, randomize_m=False, randomize_f=False):
  X_m_train_0, _ = keep_only_lbls(mnist_data, lbls=[0])
  X_m_train_1, _ = keep_only_lbls(mnist_data, lbls=[1])
  X_m_train_0 = X_m_train_0[torch.randperm(len(X_m_train_0))]
  X_m_train_1 = X_m_train_1[torch.randperm(len(X_m_train_1))]

  X_f_train_4, _ = keep_only_lbls(fashm_data, lbls=[4])
  X_f_train_3, _ = keep_only_lbls(fashm_data, lbls=[3])
  X_f_train_4 = X_f_train_4[torch.randperm(len(X_f_train_4))]
  X_f_train_3 = X_f_train_3[torch.randperm(len(X_f_train_3))]

  min_04 = min(len(X_m_train_0), len(X_f_train_4))
  min_13 = min(len(X_m_train_1), len(X_f_train_3))
  X_top = torch.cat((X_m_train_0[:min_04], X_m_train_1[:min_13]),dim=0) 
  X_bottom = torch.cat((X_f_train_4[:min_04], X_f_train_3[:min_13]),dim=0) 
  if randomize_m:
    shuffle = torch.randperm(len(X_top))
    X_top = X_top[shuffle]
  if randomize_f:
    shuffle = torch.randperm(len(X_bottom))
    X_bottom = X_bottom[shuffle]
  X_train = torch.cat((X_top, X_bottom), dim=2)
  Y_train = torch.cat((torch.zeros((min_04,)), torch.ones((min_13,))))
  shuffle = torch.randperm(len(X_train))
  X_train = X_train[shuffle]
  Y_train = Y_train[shuffle].float().view(-1,1)
  data_train = torch.utils.data.TensorDataset(X_train.to(DEVICE), Y_train.to(DEVICE))
  return data_train


# Training / valid / test datasets
data_train = build_mf_dataset(mnist_train, fashm_train)
data_valid = build_mf_dataset(mnist_valid, fashm_valid)
data_test = build_mf_dataset(mnist_test, fashm_test)

train_dl = torch.utils.data.DataLoader(data_train, batch_size=256, shuffle=True)
valid_dl = torch.utils.data.DataLoader(data_valid, batch_size=1024, shuffle=False)
test_dl = torch.utils.data.DataLoader(data_test, batch_size=1024, shuffle=False)


# MNIST randomized test / valid datasets
data_test_rm = build_mf_dataset(mnist_test, fashm_test, randomize_m=True, randomize_f=False)
data_valid_rm = build_mf_dataset(mnist_valid, fashm_valid, randomize_m=True, randomize_f=False)

test_rm_dl = torch.utils.data.DataLoader(data_test_rm, batch_size=1024, shuffle=False)
valid_rm_dl = torch.utils.data.DataLoader(data_valid_rm, batch_size=1024, shuffle=False)

# F-MNIST randomized test / valid datasets
data_test_rf = build_mf_dataset(mnist_test, fashm_test, randomize_m=False, randomize_f=True)
data_valid_rf = build_mf_dataset(mnist_valid, fashm_valid, randomize_m=False, randomize_f=True)

test_rf_dl = torch.utils.data.DataLoader(data_test_rf, batch_size=1024, shuffle=False)
valid_rf_dl = torch.utils.data.DataLoader(data_valid_rf, batch_size=1024, shuffle=False)


print(f"Train length: {len(train_dl.dataset)}")
print(f"Test length: {len(test_dl.dataset)}")
print(f"Test length randomized mnist: {len(test_rm_dl.dataset)}")
print(f"Test length randomized cifar10: {len(test_rf_dl.dataset)}")

print("Non-randomized train dataset:")
plot_samples(data_train)

print("Non-randomized test dataset:")
plot_samples(data_test)

print("MNIST-randomized test dataset:")
plot_samples(data_test_rm)

print("CIFAR10-randomized test dataset:")
plot_samples(data_test_rf)

# Utils

In [None]:
@torch.no_grad()
def get_acc(model, dl):
  model.eval()
  acc = []
  for X, y in dl:
    acc.append((torch.sigmoid(model(X)) > 0.5) == y)
  acc = torch.cat(acc)
  acc = torch.sum(acc)/len(acc)
  model.train()
  return acc.item()


def dl_to_sampler(dl):
  dl_iter = iter(dl)
  def sample():
    nonlocal dl_iter
    try:
      return next(dl_iter)
    except StopIteration:
      dl_iter = iter(dl)
      return next(dl_iter)
  return sample


def print_stats(stats):

  fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1,5,figsize=(16,3), dpi=110)
  ax1.grid()
  ax2.grid()
  ax3.grid()
  ax4.grid()
  ax5.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Adv Loss")
  ax3.set_title("Acc")
  ax4.set_title("Randomized MNIST Acc")
  ax5.set_title("Randomized F-MNIST Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")
  ax3.set_xlabel("iterations")
  ax4.set_xlabel("iterations")
  ax5.set_xlabel("iterations")

  for m_id, m_stats in stats.items():
    if m_id[0] != 'm':
      continue
    itrs = [x[0] for x in m_stats['loss']]
    ax1.plot(itrs, [x[1] for x in m_stats['loss']], label=m_id)
    ax2.plot(itrs, [x[1] for x in m_stats['adv-loss']], label=m_id)
    ax3.plot(itrs, [x[1] for x in m_stats['acc']], label=m_id)
    ax4.plot(itrs, [x[1] for x in m_stats['rm-acc']], label=m_id)
    ax5.plot(itrs, [x[1] for x in m_stats['rf-acc']], label=m_id)

  ax3.set_ylim(0.45, 1.05)
  ax4.set_ylim(0.45, 1.05)
  ax5.set_ylim(0.45, 1.05)

# Model

In [None]:
class LeNet(nn.Module):

  def __init__(self, num_classes=10) -> None:
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
    self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
    self.fc1 = nn.Linear(960, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, num_classes)
    self.maxPool = nn.MaxPool2d(2,2)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.maxPool(F.relu(self.conv1(x)))
    x = self.maxPool(F.relu(self.conv2(x)))
    x = torch.flatten(x, start_dim=1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

    
def set_train_mode(models):
  for m in models:
    m.train()


def set_eval_mode(models):
  for m in models:
    m.eval()

# Training code

In [None]:
def sequential_train(num_models, train_dl, valid_dl, valid_rm_dl, valid_rf_dl, test_dl, test_rm_dl, test_rf_dl, 
                     perturb_dl, alpha=10, max_epoch=100, use_diversity_reg=True, reg_model_weights=None, lr=0.01, weight_decay=1e-6):
  
  models = [LeNet(num_classes=1).to(DEVICE) for _ in range(num_models)]
  
  stats = {f"m{i+1}": {"acc": [], "rm-acc": [], "rf-acc": [], "loss": [], "adv-loss": []} for i in range(len(models))}

  if reg_model_weights is None:
    reg_model_weights = [1.0 for _ in range(num_models)]

  for m_idx, m in enumerate(models):

    opt = torch.optim.AdamW(m.parameters(), lr=lr, weight_decay=weight_decay)
    perturb_sampler = dl_to_sampler(perturb_dl)

    for epoch in range(max_epoch):
      for itr, (x, y) in enumerate(train_dl):
        (x_tilde,) = perturb_sampler()
        erm_loss = F.binary_cross_entropy_with_logits(m(x), y)
        
        if use_diversity_reg and m_idx != 0:
          adv_loss = []
          with torch.no_grad():
            ps = [torch.sigmoid(m_(x_tilde)) for m_ in models[:m_idx]]
          psm = torch.sigmoid(m(x_tilde))
          for i in range(len(ps)):
            al = - ((1.-ps[i]) * psm + ps[i] * (1.-psm) + 1e-7).log().mean()
            adv_loss.append(al*reg_model_weights[i])
        else:
          adv_loss = [torch.tensor([0]).to(DEVICE)]

        adv_loss = sum(adv_loss)/sum(reg_model_weights[:len(adv_loss)])
        loss = erm_loss + alpha * adv_loss

        opt.zero_grad()
        loss.backward()
        opt.step()

        if (itr + epoch * len(train_dl)) % 200 == 0:
          itr_ = itr + epoch * len(train_dl)
          print_str = f"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_loss.item():.2f} adv-loss: {adv_loss.item():.2f} "
          stats[f"m{m_idx+1}"]["loss"].append((itr_, erm_loss.item()))
          stats[f"m{m_idx+1}"]["adv-loss"].append((itr_, adv_loss.item()))
          acc = get_acc(m, valid_dl)
          acc_rm = get_acc(m, valid_rm_dl)
          acc_rf = get_acc(m, valid_rf_dl)
          stats[f"m{m_idx+1}"]["acc"].append((itr_, acc))
          stats[f"m{m_idx+1}"]["rm-acc"].append((itr_, acc_rm))
          stats[f"m{m_idx+1}"]["rf-acc"].append((itr_, acc_rf))
          print_str += f" acc: {acc:.2f}, {Fore.BLUE} rm-acc: {acc_rm:.2f} {Style.RESET_ALL}"

          print(print_str)
        
        itr += 1

    test_acc = get_acc(m, test_dl)
    test_rm_acc = get_acc(m, test_rm_dl)
    test_rf_acc = get_acc(m, test_rf_dl)
    stats[f"m{m_idx+1}"]["test-acc"] = test_acc
    stats[f"m{m_idx+1}"]["test-rm-acc"] = test_rm_acc
    stats[f"m{m_idx+1}"]["test-rf-acc"] = test_rf_acc
    print(f"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_rm_acc:.3f}")

  return stats

# Experiments with $\mathcal{D}_\text{ood}^{(1)}$

In [None]:
mnist_test = mnist_perturb_base
fashm_test = fashm_perturb_base

X_m_test_0, _ = keep_only_lbls(mnist_test, lbls=[0])
X_m_test_1, _ = keep_only_lbls(mnist_test, lbls=[1])

X_c_test_1, _ = keep_only_lbls(fashm_test, lbls=[4])
X_c_test_9, _ = keep_only_lbls(fashm_test, lbls=[3])
X_c_test_1 = X_c_test_1[torch.randperm(len(X_c_test_1))]
X_c_test_9 = X_c_test_9[torch.randperm(len(X_c_test_9))]

min_09 = min(len(X_m_test_0), len(X_c_test_9))
X_perturb_09 = torch.cat((X_m_test_0[:min_09], X_c_test_9[:min_09]),  axis=2)
min_11 = min(len(X_m_test_1), len(X_c_test_1))
X_perturb_11 = torch.cat((X_m_test_1[:min_11], X_c_test_1[:min_11]),  axis=2)
X_perturb = torch.cat((X_perturb_09, X_perturb_11), dim=0)
X_perturb = X_perturb[torch.randperm(len(X_perturb))]

data_perturb = torch.utils.data.TensorDataset(X_perturb.to(DEVICE))

perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)

print(f"OOD dataset size: {len(perturb_dl.dataset)}")

print("OOD dataset:")
plot_samples(data_perturb)

In [None]:
all_stats = []
for _ in range(5):
  stats = sequential_train(2, train_dl, valid_dl, valid_rm_dl, valid_rf_dl, test_dl, test_rm_dl, test_rf_dl, 
                          perturb_dl, alpha=0.2, max_epoch=200, lr=0.0001)
  all_stats.append(stats)
  print_stats(stats)

# Experiments with $\mathcal{D}_\text{ood}^{(2)}$

In [None]:
mnist_test = mnist_perturb_base
fashm_test = fashm_perturb_base

X_m_test, _ = keep_only_lbls(mnist_test, lbls=[0,1,2,3,4,5,6,7,8,9])
X_m_test = X_m_test[torch.randperm(len(X_m_test))]

X_c_test, _ = keep_only_lbls(fashm_test, lbls=[0,1,2,3,4,5,6,7,8,9])
X_c_test = X_c_test[torch.randperm(len(X_c_test))]

min_l = min(len(X_m_test), len(X_c_test))
X_perturb = torch.cat((X_m_test[:min_l], X_c_test[:min_l]),  axis=2)

data_perturb = torch.utils.data.TensorDataset(X_perturb.to(DEVICE))

perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)

print(f"OOD dataset size: {len(perturb_dl.dataset)}")

print("OOD dataset:")
plot_samples(data_perturb)

In [None]:
all_stats = []
for _ in range(5):
  stats = sequential_train(2, train_dl, valid_dl, valid_rm_dl, valid_rf_dl, test_dl, test_rm_dl, test_rf_dl, 
                          perturb_dl, alpha=1, max_epoch=200, lr=0.0001)
  all_stats.append(stats)
  print_stats(stats)

# Experiments with $\mathcal{D}_\text{ood}^{(3)}$

In [None]:
mnist_test = mnist_perturb_base
fashm_test = fashm_perturb_base

X_m_test_0, _ = keep_only_lbls(mnist_test, lbls=[0])
X_m_test_1, _ = keep_only_lbls(mnist_test, lbls=[1])
X_m_test_0 = X_m_test_0[torch.randperm(len(X_m_test_0))]
X_m_test_1 = X_m_test_1[torch.randperm(len(X_m_test_1))]

X_c_test_ood, _ = keep_only_lbls(fashm_test, lbls=[0,1,2,5,6,7,8,9])
X_c_test_ood = X_c_test_ood[torch.randperm(len(X_c_test_ood))]

K = int((len(X_c_test_ood) / len(X_m_test_0)) // 2)
X_m_test_0 = torch.cat([X_m_test_0 for _ in range(K)], dim=0)
X_m_test_1 = torch.cat([X_m_test_1 for _ in range(K)], dim=0)
X_perturb_0ood = torch.cat((X_m_test_0, X_c_test_ood[:len(X_m_test_0)]),  axis=2)
X_perturb_1ood = torch.cat((X_m_test_1, X_c_test_ood[-len(X_m_test_1):]),  axis=2)
X_perturb = torch.cat((X_perturb_0ood, X_perturb_1ood), dim=0)
X_perturb = X_perturb[torch.randperm(len(X_perturb))]

data_perturb = torch.utils.data.TensorDataset(X_perturb.to(DEVICE))

perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)

print(f"OOD dataset size: {len(perturb_dl.dataset)}")

print("OOD dataset:")
plot_samples(data_perturb)

In [None]:
all_stats = []
for _ in range(5):
  stats = sequential_train(2, train_dl, valid_dl, valid_rm_dl, valid_rf_dl, test_dl, test_rm_dl, test_rf_dl, 
                          perturb_dl, alpha=1.0, max_epoch=200, lr=0.0001)
  all_stats.append(stats)
  print_stats(stats)