# Imports and Utils

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob
import os
import pickle
from time import time

import numpy as np

from dslr.distance_utils import energy_distance, local_energy_distance
from dslr.distribution_shift_utils import get_pairs, run, analyze_local_energy_metric

___

# 1. Load Data

In [None]:
BASE = '../embeddings'
TAG = 'mnist'

if TAG == 'mnist':
    # SETUP MNIST
    TEMPLATE = BASE + '/mnist/{}_embeddings.pkl'
    #DATASET_NAMES = ['mnist', 'mnist_ds', 'mnist_sp']
    DATASET_NAMES = ['mnist_small', 'mnist_small_ds', 'mnist_small_sp']
elif TAG == 'wilds':
    # SETUP WILDS
    TEMPLATE = BASE + '/wilds/{}_featurizer_embeddings.pkl'
    DATASET_NAMES = ['poverty', 'ogb-molpcba', 'camelyon17', 'civilcomments']

In [None]:
# Load embedding data from blobstorage.
def load_data():
    data = {}
    for dataset_name in DATASET_NAMES:
        with open(TEMPLATE.format(dataset_name), 'rb') as f:
            data[dataset_name] = pickle.load(f)
        
        # Print shapes.
        for split_name in data[dataset_name].keys():
            print(
                dataset_name,
                split_name,
                data[dataset_name][split_name]['embeddings'].shape
            )

    return data

In [None]:
DATA = load_data()

# 2. Define Split Pairs

In [None]:
PAIRS = get_pairs(DATA)

In [None]:
print(PAIRS)

# 3. Analyze Local Energy Distance

In [None]:
import matplotlib as mpl
mpl.rcParams['text.usetex'] = False

In [None]:
fig = analyze_local_energy_metric(DATA, PAIRS, tag=TAG, num_dist_samples=50, subsample_size=1000)