# Imports

In [None]:
!pip install colorama

In [None]:
import os
import copy
import math
import json
import random as rnd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import matplotlib.pyplot as plt
import pandas as  pd
import torchvision.utils as vision_utils
from PIL import Image
import torchvision
from colorama import Fore, Back, Style

from matplotlib.ticker import NullFormatter


DEVICE = torch.device('cuda')

In [None]:
!nvidia-smi

# Utils

In [None]:
@torch.no_grad()
def get_acc(model, dl):
  model.eval()
  acc = []
  for X, y in dl:
    acc.append((torch.sigmoid(model(X)) > 0.5) == y)
  acc = torch.cat(acc)
  acc = torch.sum(acc)/len(acc)
  model.train()
  return acc.item()


def plot_samples(X):
  fig = plt.figure(figsize=(8,3), dpi=130)
  grid_img = vision_utils.make_grid(X[:13].cpu(), nrow=13, normalize=True, padding=1)
  _ = plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')
  _ = plt.tick_params(axis=u'both', which=u'both',length=0)
  ax = plt.gca()
  _ = ax.xaxis.set_major_formatter(NullFormatter())
  _ = ax.yaxis.set_major_formatter(NullFormatter())
  plt.show()


def dl_to_sampler(dl):
  dl_iter = iter(dl)
  def sample():
    nonlocal dl_iter
    try:
      return next(dl_iter)
    except StopIteration:
      dl_iter = iter(dl)
      return next(dl_iter)
  return sample

In [None]:
def print_stats(stats):

  fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1,5,figsize=(16,3), dpi=110)
  ax1.grid()
  ax2.grid()
  ax3.grid()
  ax4.grid()
  ax5.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Adv Loss")
  ax3.set_title("Acc")
  ax4.set_title("Randomized 0/1 Acc")
  ax5.set_title("Randomized 7/9 Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")
  ax3.set_xlabel("iterations")
  ax4.set_xlabel("iterations")
  ax5.set_xlabel("iterations")

  for m_id, m_stats in stats.items():
    if m_id[0] != 'm':
      continue
    itrs = [x[0] for x in m_stats['loss']]
    ax1.plot(itrs, [x[1] for x in m_stats['loss']], label=m_id)
    ax2.plot(itrs, [x[1] for x in m_stats['adv-loss']], label=m_id)
    ax3.plot(itrs, [x[1] for x in m_stats['acc']], label=m_id)
    ax4.plot(itrs, [x[1] for x in m_stats['r0/1-acc']], label=m_id)
    ax5.plot(itrs, [x[1] for x in m_stats['r7/9-acc']], label=m_id)

  ax3.set_ylim(0.45, 1.05)
  ax4.set_ylim(0.45, 1.05)
  ax5.set_ylim(0.45, 1.05)

# Model

In [None]:
class LeNet(nn.Module):

  def __init__(self, num_classes=10) -> None:
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
    self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
    self.fc1 = nn.Linear(960, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, num_classes)
    self.maxPool = nn.MaxPool2d(2,2)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.maxPool(F.relu(self.conv1(x)))
    x = self.maxPool(F.relu(self.conv2(x)))
    x = torch.flatten(x, start_dim=1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

# Training utils

In [None]:
def sequential_train(num_models, train_dl, valid_dl, valid_r01_dl, valid_r79_dl, test_dl, test_r01_dl, test_r79_dl, 
                     perturb_dl, alpha=10, max_epoch=100, use_diversity_reg=True, reg_model_weights=None):
  
  models = [LeNet(num_classes=1).to(DEVICE) for _ in range(num_models)]
  
  stats = {f"m{i+1}": {"acc": [], "r0/1-acc": [], "r7/9-acc": [], "loss": [], "adv-loss": []} for i in range(len(models))}

  if reg_model_weights is None:
    reg_model_weights = [1.0 for _ in range(num_models)]

  for m_idx, m in enumerate(models):

    opt = torch.optim.Adam(m.parameters(), lr=0.0001)
    perturb_sampler = dl_to_sampler(perturb_dl)

    for epoch in range(max_epoch):
      for itr, (x, y) in enumerate(train_dl):
        (x_tilde,) = perturb_sampler()
        erm_loss = F.binary_cross_entropy_with_logits(m(x), y)
        
        if use_diversity_reg and m_idx != 0:
          adv_loss = []
          with torch.no_grad():
            ps = [torch.sigmoid(m_(x_tilde)) for m_ in models[:m_idx]]
          psm = torch.sigmoid(m(x_tilde))
          for i in range(len(ps)):
            al = - ((1.-ps[i]) * psm + ps[i] * (1.-psm) + 1e-7).log().mean()
            adv_loss.append(al*reg_model_weights[i])
        else:
          adv_loss = [torch.tensor([0]).to(DEVICE)]

        adv_loss = sum(adv_loss)/sum(reg_model_weights[:len(adv_loss)])
        loss = erm_loss + alpha * adv_loss

        opt.zero_grad()
        loss.backward()
        opt.step()

        if (itr + epoch * len(train_dl)) % 40 == 0:
          itr_ = itr + epoch * len(train_dl)
          print_str = f"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_loss.item():.2f} adv-loss: {adv_loss.item():.2f} "
          stats[f"m{m_idx+1}"]["loss"].append((itr_, erm_loss.item()))
          stats[f"m{m_idx+1}"]["adv-loss"].append((itr_, adv_loss.item()))
          acc = get_acc(m, valid_dl)
          acc_r01 = get_acc(m, valid_r01_dl)
          acc_r79 = get_acc(m, valid_r79_dl)
          stats[f"m{m_idx+1}"]["acc"].append((itr_, acc))
          stats[f"m{m_idx+1}"]["r0/1-acc"].append((itr_, acc_r01))
          stats[f"m{m_idx+1}"]["r7/9-acc"].append((itr_, acc_r79))
          #print_str += f"[m{i+1}] acc: {acc:.2f}, {Fore.RED} rand-7/9-acc: {acc_r79:.2f}, {Fore.BLUE} rand-0/1-acc: {acc_r01:.2f} {Style.RESET_ALL}"
          print_str += f" acc: {acc:.2f}, {Fore.BLUE} r0/1-acc: {acc_r01:.2f} {Style.RESET_ALL}"

          print(print_str)
        
        itr += 1

    test_acc = get_acc(m, test_dl)
    test_r01_acc = get_acc(m, test_r01_dl)
    test_r79_acc = get_acc(m, test_r79_dl)
    stats[f"m{m_idx+1}"]["test-acc"] = test_acc
    stats[f"m{m_idx+1}"]["test-r0/1-acc"] = test_r01_acc
    stats[f"m{m_idx+1}"]["test-r7/9-acc"] = test_r79_acc
    print(f"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_r01_acc:.3f}")

  return stats

# Build MM-Dominoes dataset $\hat{\mathcal{D}}$

In [None]:
def plot_samples(dataset, nrow=13):
  try:
    X, Y = dataset.tensors
  except:
    (X,) = dataset.tensors
  fig = plt.figure(figsize=(10,7), dpi=130)
  grid_img = vision_utils.make_grid(X[:nrow].cpu(), nrow=nrow, normalize=True, padding=1, pad_value=0.1)
  _ = plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')
  _ = plt.tick_params(axis=u'both', which=u'both',length=0)
  ax = plt.gca()
  _ = ax.xaxis.set_major_formatter(NullFormatter()) 
  _ = ax.yaxis.set_major_formatter(NullFormatter()) 
  plt.show()


In [None]:
def keep_only_lbls(dataset, lbls):
  lbls = {lbl: i for i, lbl in enumerate(lbls)}
  final_X, final_Y = [], []
  for x, y in dataset:
    if y in lbls:
      final_X.append(x)
      final_Y.append(lbls[y])
  X = torch.stack(final_X)
  Y = torch.tensor(final_Y).float().view(-1,1) #.long()
  return X, Y


def merge_datasets(X1, Y1, X2, Y2, randomize_1=False, randomize_2=False, device=None):

  if not randomize_1:
    X1_0 = X1[Y1 == 0].view(-1, 1, 28, 28)
    X1_1 = X1[Y1 == 1].view(-1, 1, 28, 28)
  else:
    perm = torch.randperm(len(X1))
    X1 = X1[perm]
    X1_0 = X1[:len(X1)//2].view(-1, 1, 28, 28)
    X1_1 = X1[len(X1)//2:].view(-1, 1, 28, 28)

  if not randomize_2:
    X2_0 = X2[Y2 == 0].view(-1, 1, 28, 28)
    X2_1 = X2[Y2 == 1].view(-1, 1, 28, 28)
  else:
    perm = torch.randperm(len(X2))
    X2 = X2[perm]
    X2_0 = X2[:len(X2)//2].view(-1, 1, 28, 28)
    X2_1 = X2[len(X2)//2:].view(-1, 1, 28, 28)

  final_0, final_1 = [], []

  m = min(len(X1_0), len(X2_0))
  X_0 = torch.cat((X1_0[:m], X2_0[:m]), axis=2)
  m = min(len(X1_1), len(X2_1))
  X_1 = torch.cat((X1_1[:m], X2_1[:m]), axis=2)

  Y_0 = torch.zeros(len(X_0), 1)
  Y_1 = torch.ones(len(X_1), 1)

  X = torch.cat([X_0, X_1], dim=0)
  Y = torch.cat([Y_0, Y_1], dim=0).float().view(-1,1)

  perm = torch.randperm(len(Y))
  X, Y = X[perm], Y[perm]

  if device is not None:
    X = X.to(device)
    Y = Y.to(device)

  return torch.utils.data.TensorDataset(X, Y)


transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

data_train = torchvision.datasets.MNIST('./data/mnist/', train=True, download=True, transform=transform)
data_perturb_base, data_train, data_valid = random_split(data_train, [10000, 45000, 5000], generator=torch.Generator().manual_seed(42))
X_train_7_9, Y_train_7_9 = keep_only_lbls(data_train, lbls=[7,9]) # harder to separate
X_train_0_1, Y_train_0_1 = keep_only_lbls(data_train, lbls=[0,1]) # easier to separate

X_valid_7_9, Y_valid_7_9 = keep_only_lbls(data_valid, lbls=[7,9]) # harder to separate
X_valid_0_1, Y_valid_0_1 = keep_only_lbls(data_valid, lbls=[0,1]) # easier to separate

data_test = torchvision.datasets.MNIST('./data/mnist/', train=False, download=True, transform=transform)
X_test_7_9, Y_test_7_9 = keep_only_lbls(data_test, lbls=[7,9]) # harder to separate
X_test_0_1, Y_test_0_1 = keep_only_lbls(data_test, lbls=[0,1]) # easier to separate

data_train = merge_datasets(X_train_0_1, Y_train_0_1, X_train_7_9, Y_train_7_9, randomize_1=False, randomize_2=False, device=DEVICE)

data_test = merge_datasets(X_test_0_1, Y_test_0_1, X_test_7_9, Y_test_7_9, randomize_1=False, randomize_2=False, device=DEVICE)
data_test_r01 = merge_datasets(X_test_0_1, Y_test_0_1, X_test_7_9, Y_test_7_9, randomize_1=True, randomize_2=False, device=DEVICE)
data_test_r79 = merge_datasets(X_test_0_1, Y_test_0_1, X_test_7_9, Y_test_7_9, randomize_1=False, randomize_2=True, device=DEVICE)

data_valid = merge_datasets(X_valid_0_1, Y_valid_0_1, X_valid_7_9, Y_valid_7_9, randomize_1=False, randomize_2=False, device=DEVICE)
data_valid_r01 = merge_datasets(X_valid_0_1, Y_valid_0_1, X_valid_7_9, Y_valid_7_9, randomize_1=True, randomize_2=False, device=DEVICE)
data_valid_r79 = merge_datasets(X_valid_0_1, Y_valid_0_1, X_valid_7_9, Y_valid_7_9, randomize_1=False, randomize_2=True, device=DEVICE)

train_dl = torch.utils.data.DataLoader(data_train, batch_size=256, shuffle=True)
test_dl = torch.utils.data.DataLoader(data_test, batch_size=1024, shuffle=True)
test_r79_dl = torch.utils.data.DataLoader(data_test_r79, batch_size=1024, shuffle=True)
test_r01_dl = torch.utils.data.DataLoader(data_test_r01, batch_size=1024, shuffle=True)

valid_dl = torch.utils.data.DataLoader(data_valid, batch_size=1024, shuffle=True)
valid_r79_dl = torch.utils.data.DataLoader(data_valid_r79, batch_size=1024, shuffle=True)
valid_r01_dl = torch.utils.data.DataLoader(data_valid_r01, batch_size=1024, shuffle=True)

print(f"Train length: {len(train_dl.dataset)}")
print(f"Test length: {len(test_dl.dataset)}")
print(f"Test length randomized 7/9: {len(test_r79_dl.dataset)}")
print(f"Test length randomized 0/1: {len(test_r01_dl.dataset)}")
print(f"Reserved for perturbations: {len(data_perturb_base)}")

print("Non-randomized dataset:")
plot_samples(data_train)

print("7/9-randomized dataset:")
plot_samples(data_test_r79)

print("0/1-randomized dataset:")
plot_samples(data_test_r01)

In [None]:
X, Y = data_train.tensors
fig = plt.figure(figsize=(6,6), dpi=100)
grid_img = vision_utils.make_grid((X[1:6]).cpu(), 
                                  nrow=5, 
                                  normalize=True, 
                                  padding=1, 
                                  pad_value=0.1)
_ = plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')
ax = plt.gca()
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
ax.tick_params(axis=u'both', which=u'both',length=0)
plt.savefig('MM-dominoes-train.pdf', dpi = 200, bbox_inches='tight')

# Experiments with $\mathcal{D}_\text{ood}^{(1)}$

In [None]:
data_test = data_perturb_base
X_perturb_0, _ = keep_only_lbls(data_test, lbls=[0])
X_perturb_1, _ = keep_only_lbls(data_test, lbls=[1])
X_perturb_7, _ = keep_only_lbls(data_test, lbls=[7])
X_perturb_9, _ = keep_only_lbls(data_test, lbls=[9])

min_09 = min(len(X_perturb_0), len(X_perturb_9))
X_perturb_09 = torch.cat((X_perturb_0[:min_09], X_perturb_9[:min_09]),  axis=2)
min_17 = min(len(X_perturb_1), len(X_perturb_7))
X_perturb_17 = torch.cat((X_perturb_1[:min_17], X_perturb_7[:min_17]),  axis=2)
X_perturb = torch.cat((X_perturb_09, X_perturb_17), dim=0)
X_perturb = X_perturb[torch.randperm(len(X_perturb))]

data_perturb = torch.utils.data.TensorDataset(X_perturb.to(DEVICE))

perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)

print(f"OOD dataset size: {len(perturb_dl.dataset)}")

print("OOD dataset:")
plot_samples(X_perturb)

In [None]:
all_stats = []
for _ in range(5):
  stats = sequential_train(2, train_dl, valid_dl, valid_r01_dl, valid_r79_dl, test_dl, test_r01_dl, test_r79_dl, 
                          perturb_dl, alpha=0.1, max_epoch=200)
  all_stats.append(stats)
  print_stats(stats)

# Experiments with $\mathcal{D}_\text{ood}^{(2)}$

In [None]:
data_test = data_perturb_base 
X_perturb, _ = keep_only_lbls(data_test, lbls=[0,1,2,3,4,5,6,7,8,9])
X_perturb = torch.cat((X_perturb, X_perturb[torch.randperm(len(X_perturb))]),  axis=2)
data_perturb = torch.utils.data.TensorDataset(X_perturb.to(DEVICE))

perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)

print(f"OOD dataset size: {len(perturb_dl.dataset)}")

print("OOD dataset:")
plot_samples(X_perturb)

In [None]:
all_stats = []
for _ in range(5):
  stats = sequential_train(2, train_dl, valid_dl, valid_r01_dl, valid_r79_dl, test_dl, test_r01_dl, test_r79_dl, 
                          perturb_dl, alpha=1.0, max_epoch=200)
  all_stats.append(stats)
  print_stats(stats)

# Experiments with $\mathcal{D}_\text{ood}^{(3)}$

In [None]:
data_test = data_perturb_base
X_perturb_01, _ = keep_only_lbls(data_test, lbls=[0,1])
X_perturb_ood, _ = keep_only_lbls(data_test, lbls=[2,3,4,5,6,8])

min_l = min(len(X_perturb_01), len(X_perturb_ood))
X_perturb = torch.cat((X_perturb_01[:min_l], X_perturb_ood[:min_l]),  axis=2)
X_perturb = X_perturb[torch.randperm(len(X_perturb))]

data_perturb = torch.utils.data.TensorDataset(X_perturb.to(DEVICE))

perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)

print(f"Perturb length: {len(perturb_dl.dataset)}")

print("Perturbation dataset:")
plot_samples(X_perturb)

In [None]:
all_stats = []
for _ in range(5):
  stats = sequential_train(2, train_dl, valid_dl, valid_r01_dl, valid_r79_dl, test_dl, test_r01_dl, test_r79_dl, 
                          perturb_dl, alpha=1.0, max_epoch=200)
  all_stats.append(stats)
  print_stats(stats)