# Imports and Utils

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pickle
from collections import defaultdict
from time import time
from typing import Callable, List

import matplotlib.pyplot as plt
import numpy as np

from dslr.distance_utils import (
    energy_distance, local_energy_distance, sliced_wasserstein_distance_persistence_features
)
from dslr.distribution_shift_utils import (
    run_approximate_shift_test, wrap_subsample_test
)

___

# 1. Load MNIST Data

In [None]:
BASE = '../embeddings'

In [None]:
TEMPLATE = BASE + '/mnist/{}_embeddings.pkl'
DATASET_NAMES = ['mnist_small', 'mnist_small_ds', 'mnist_small_sp']

In [None]:
# Load embedding data from blobstorage.
def load_data():
    data = {}
    for dataset_name in DATASET_NAMES:
        with open(TEMPLATE.format(dataset_name), 'rb') as f:
            data[dataset_name] = pickle.load(f)
        
        # Print shapes.
        for split_name in data[dataset_name].keys():
            print(
                dataset_name,
                split_name,
                data[dataset_name][split_name]['embeddings'].shape
            )

    return data

In [None]:
DATA = load_data()

# 2. Run Dirichlet Ablation

In [None]:
def run_ablation(
    data_name: str,
    shift_test: Callable,
    distance_fn: Callable,
    sample_sizes: List = [100],
    repetitions: int = 1,
    logfile: str = 'ablation_subsample.log'
) -> dict:
    """Runs ablation on shift magnitude.
    
    Args:
        data_name: Name of dataset to use.
        shift_test: Function that runs shift test.
        distance_fn: Distance function between two sets of embeddings.
        sample_sizes: Sizes of subsamples to take in shift tests.
        repetitions: Number of Dirichlet runs for a particular setting.
        logfile: Log file for shift test.
    
    Returns:
        results: Dict with sample size as key, and tuple of distances
            and sensitivity proportions.
    """
    data_train = np.array(DATA[data_name]['train']['embeddings'])
    data_test = np.array(DATA[data_name]['test']['embeddings'])
    labels_test = np.array(DATA[data_name]['test']['ids'])
    
    fig = plt.Figure()
    colors = ['blue', 'green', 'red', 'gray', 'orange']
    assert len(sample_sizes) <= len(colors)
    
    results = {}
    
    # Collect separate ablation results for each sample size.
    for i, sample_size in enumerate(sample_sizes):
        
        distances = []
        proportions = []

        for k in range(repetitions):
            # Sample new class weights.
            # TODO: Choose alpha parameter to regulate variance in shifts.
            alpha = np.logspace(-1.5, 4, num=repetitions)[k]
            #alpha = 10
            weights = np.random.dirichlet([alpha] * 10)

            # Scale max weight to 1.0, while keep pairwise ratios.
            weights_one_scaled = weights / max(weights)

            # Aggregate indices for each class.
            class_map = defaultdict(list)
            for j, label in enumerate(labels_test):
                class_map[label].append(j)

            # Get indices per class based on weights.
            new_indices = []
            for target_class in class_map.keys():
                target_indices = class_map[target_class]
                np.random.shuffle(target_indices)
                target_count = int(weights_one_scaled[target_class] * len(target_indices))
                target_indices = target_indices[:target_count]
                new_indices.extend(target_indices)

            # Assign newly sampled data.
            np.random.shuffle(new_indices)
            data_test_reweighted = data_test[new_indices]
            labels_test_reweighted = labels_test[new_indices]

            # Compute distribution distance between datasets.
            dist_xy = distance_fn(np.array(data_train), np.array(data_test_reweighted))

            # Compute proportion of runs that detected a shift.
            config = {
                'dataset_name': data_name,
                'data': {data_name: {'train': {'embeddings': data_train}, 'test': {'embeddings': data_test_reweighted}}},
                'pair': ('train', 'test'),
                'test_name': 'subsample',
                'shift_test': shift_test,
                'distance_measure': distance_fn,
                'sample_size': sample_size,
                'logfile': logfile,
                'num_runs': 20
            }
            decision_counts, runs_dxx, runs_dxy, runs_pvals = run_approximate_shift_test(config)
            proportion_detection = decision_counts[True] / sum(decision_counts.values())

            # Add to collection of results.
            distances.append(dist_xy)
            proportions.append(proportion_detection)
    
        # Collect and plot results for this sample size.
        plt.scatter(distances, proportions, c=colors[i], label=sample_size, alpha=0.35)
        results[sample_size] = (distances, proportions)

    
    if distance_fn == local_energy_distance:
        distance_str = 'local_energy_distance'
    elif distance_fn == energy_distance:
        distance_str = 'energy_distance'
    elif distance_fn == sliced_wasserstein_distance_persistence_features:
        distance_str = 'sliced_wasserstein_persistence'
    
    plt.title('Detection Sensitivity by Sample Size')
    plt.xlabel(f'{distance_str}(X, Y)')
    plt.ylabel('Positive Rate')
    plt.legend()
    filename = (
        f'ablation_plot_{distance_str}_'
        f'rep{str(repetitions)}_'
        f'n{str(config["num_runs"])}_'
        f'ss{"-".join([str(n) for n in sample_sizes])}.alpharange.png'
    )
    plt.savefig(filename)
    print(f'Saved output to {filename}')
    plt.close()
    
    return results

In [None]:
data_name = 'mnist_small'
shift_test = wrap_subsample_test
sample_sizes = [25, 50, 100]
repetitions = 100

In [None]:
import matplotlib as mpl
mpl.rcParams['text.usetex'] = False

In [None]:
distance_fns = [local_energy_distance]
for distance_fn in distance_fns:
    run_ablation(data_name, shift_test, distance_fn, sample_sizes, repetitions)