In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import multiprocessing as mp
from exp_utils import *

####### Fixing Seed #######
random_seed=42
np.random.seed(random_seed)
###########################

In [2]:
task='class'
ds=[.5,.45,.4,.35,.3]
test=.1
B=250
reps = 5

left  = 0.125  # the left side of the subplots of the figure
right = 0.9    # the right side of the subplots of the figure
bottom = 0.1   # the bottom of the subplots of the figure
top = 0.9      # the top of the subplots of the figure
wspace = 0.25   # the amount of width reserved for blank space between subplots
hspace = 0.35   # the amount of height reserved for white space between subplots

# Amazon

In [3]:
X=np.load('data/X_amazon.npy')
y=np.load('data/y_amazon.npy').reshape((-1,1))-1
X.shape, y.shape

((30000, 768), (30000, 1))

In [4]:
Xs_dic, ys_dic, Xt_dic, yt_dic = get_shifted_data(X, y, ds)
kls_amazon_ours, pvals_amazon_ours = perform_ours(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)
pvals_amazon_bench = perform_bench(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)

100%|█████████████████████████████████████████████| 5/5 [00:00<00:00,  8.03it/s]
100%|████████████████████████████████████████████| 5/5 [10:13<00:00, 122.71s/it]
100%|████████████████████████████████████████████| 5/5 [17:26<00:00, 209.40s/it]


PS: The benchmark methods for conditional shifts take more time to run because they require fitting an extra classifier and predicting results for the whole test set.

# CIFAR10

In [5]:
X=np.load('data/X_cifar10.npy')
X=(X-X.mean(axis=0))/X.std(axis=0) #otherwise things will get pretty slow
y=np.load('data/y_cifar10.npy').reshape((-1,1))
X.shape, y.shape

((30000, 512), (30000, 1))

In [6]:
Xs_dic, ys_dic, Xt_dic, yt_dic = get_shifted_data(X, y, ds)
kls_cifar_ours, pvals_cifar_ours = perform_ours(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)
pvals_cifar_bench = perform_bench(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)

100%|█████████████████████████████████████████████| 5/5 [00:00<00:00,  7.83it/s]
100%|██████████████████████████████████████████| 5/5 [1:04:18<00:00, 771.65s/it]
100%|██████████████████████████████████████████| 5/5 [1:04:31<00:00, 774.30s/it]


# ImageNet

In [7]:
X=np.load('data/X_imagenet.npy')
X=(X-X.mean(axis=0))/X.std(axis=0) #otherwise things will get pretty slow
y=np.load('data/y_imagenet.npy').reshape((-1,1))
X.shape, y.shape

((30000, 512), (30000, 1))

In [None]:
Xs_dic, ys_dic, Xt_dic, yt_dic = get_shifted_data(X, y, ds)
kls_imagenet_ours, pvals_imagenet_ours = perform_ours(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)
pvals_imagenet_bench = perform_bench(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)

100%|█████████████████████████████████████████████| 5/5 [00:00<00:00,  9.69it/s]
100%|████████████████████████████████████████████| 5/5 [28:45<00:00, 345.07s/it]
 20%|████████▊                                   | 1/5 [07:47<31:08, 467.17s/it]

# Plots

In [None]:
plt.figure(figsize=(10,3))
plt.subplots_adjust(left, bottom, right, top, wspace, hspace)

##Ours
plt.subplot(2, 3, 1)
exp_plots2([str(d) for d in ds], pvals_amazon_ours, xlab="", ylab="p-values", grid='both', legend=False)
plt.ylim(0,1)
plt.text(-0.4, .7, "Ours", transform=plt.gca().transAxes, fontsize=12, va="top", rotation=90)

plt.subplot(2, 3, 2)
exp_plots2([str(d) for d in ds], pvals_imagenet_ours, xlab="", ylab="", grid='both', legend=False)
plt.ylim(0,1)

plt.subplot(2, 3, 3)
exp_plots2([str(d) for d in ds], pvals_cifar_ours, xlab="", ylab="", grid='both', legend=False)
plt.ylim(0,1)

##Bench
plt.subplot(2, 3, 4)
exp_plots2([str(d) for d in ds], pvals_amazon_bench, xlab="$\delta$", ylab="p-values", grid='both', legend=False)
plt.ylim(0,1)
plt.text(-0.4, 1.02, "Benchmark", transform=plt.gca().transAxes, fontsize=12, va="top", rotation=90)

plt.subplot(2, 3, 5)
exp_plots2([str(d) for d in ds], pvals_imagenet_bench, xlab="$\delta$", ylab="", grid='both', legend=True)
plt.ylim(0,1)

plt.subplot(2, 3, 6)
exp_plots2([str(d) for d in ds], pvals_cifar_bench, xlab="$\delta$", ylab="", grid='both', legend=False)
plt.ylim(0,1)

plt.text(0.2345, .95, "Amazon Reviews", transform=plt.gcf().transFigure, fontsize=12, ha="center")

plt.text(0.515, .95, "Tiny ImageNet", transform=plt.gcf().transFigure, fontsize=12, ha="center")

plt.text(0.7855, .95, "CIFAR-10", transform=plt.gcf().transFigure, fontsize=12, ha="center")

plt.savefig('plots/deep_pvals.png', bbox_inches='tight', dpi=300, transparent=True)

In [None]:
plt.figure(figsize=(10,1.5))
plt.subplots_adjust(left, bottom, right, top, wspace, hspace)

##Ours
plt.subplot(1, 3, 1)
exp_plots3([str(d) for d in ds], kls_amazon_ours, xlab="", ylab="$\hat{KL}$", grid='both', legend=False)
plt.ylim(0,.2)

plt.subplot(1, 3, 2)
exp_plots3([str(d) for d in ds], kls_imagenet_ours, xlab="", ylab="", grid='both', legend=True)
plt.ylim(0,.2)

plt.subplot(1, 3, 3)
exp_plots3([str(d) for d in ds], kls_cifar_ours, xlab="", ylab="", grid='both', legend=False)
plt.ylim(0,.2)

plt.text(0.2345, .95, "Amazon Reviews", transform=plt.gcf().transFigure, fontsize=12, ha="center")

plt.text(0.515, .95, "Tiny ImageNet", transform=plt.gcf().transFigure, fontsize=12, ha="center")

plt.text(0.7855, .95, "CIFAR-10", transform=plt.gcf().transFigure, fontsize=12, ha="center")

plt.savefig('plots/deep_kls.png', bbox_inches='tight', dpi=300, transparent=True)

# Extra exps

In [None]:
X=np.load('data/X_stack.npy')
y=np.load('data/y_stack.npy').reshape((-1,1))
X.shape, y.shape

In [None]:
Xs_dic, ys_dic, Xt_dic, yt_dic = get_shifted_data(X, y, ds)
kls_stack_ours, pvals_stack_ours = perform_ours(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)
pvals_stack_bench = perform_bench(Xs_dic, ys_dic, Xt_dic, yt_dic, ds, reps, task, B, test)

In [None]:
plt.figure(figsize=(5,4))
plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
plt.subplot(2, 1, 1)
exp_plots4([str(d) for d in ds], kls_stack_ours, xlab="$\delta$", ylab="$\hat{KL}$", grid='both')
plt.subplot(2, 1, 2)
exp_plots4([str(d) for d in ds], pvals_stack_ours, xlab="$\delta$", ylab="p-values", grid='both')
plt.ylim(0,1.2)
#plt.savefig('plots/stack.png', bbox_inches='tight', dpi=300, transparent=True)

In [None]:
plt.figure(figsize=(5,1.75))
exp_plots4([str(d) for d in ds], pvals_stack_bench, xlab="$\delta$", ylab="p-values", grid='both')
plt.ylim(0,1.2)
#plt.savefig('plots/stack2.png', bbox_inches='tight', dpi=300, transparent=True)