In [1]:
import numpy as np
import matplotlib.pyplot as plt 

from adaptive import * 
from energies import *
from datasets import load_dataset

In [2]:
X, labels = load_dataset("test", n_test=500)

# Usage 

With a dataset $X \in \mathbb{R}^{d \times n}$ (i.e., columns are the datapoints of interest), then we do the following to identify a prototype set:
* __Define an Energy object__: ``energy = LowRankEnergy(X, p=p)``
    * This handles all the updating of distances based on the choice of (i) the distance $d(x_i, \mathcal{S})$ and (ii) the value of power, $p$
* __Define a AdaptiveAlgorithm object__: ``algorithm = AdaptiveAlgorithm(energy)``
    * This object handles the interactivity of selecting which points to add to the prototype set. This object references the ``energy`` object previously defined. 

Prototype set selection then proceeds as follows
```python
# Build Phase
algorithm.build_phase(k, method="sampling")  # method is "sampling" or "search", selecting k prototypes

# Swap Phase
algorithm.swap_phase(method="search")     # method is "sampling" or "search"
```

## Adaptive Search Build

In [5]:
energy = LowRankEnergy(X, p=2)
adapsearch = AdaptiveAlgorithm(energy, seed=10)
adapsearch.build_phase(5, "search")  # build prototype set of k = 5 points
print(energy.energy)
search_inds = energy.indices 

10.28591770473395


## Adaptive Sampling Build

In [6]:
energy = LowRankEnergy(X, p=2)  
adapsampling = AdaptiveAlgorithm(energy, seed=10)
adapsampling.build_phase(5, "sampling")
print(energy.energy)
sampling_inds = energy.indices

12.021834417833693


## Adaptive Sampling Build + Adaptive Search Swap

In [9]:
energy = LowRankEnergy(X, p=2)  
sampling_search = AdaptiveAlgorithm(energy, seed=10)
sampling_search.build_phase(5, "sampling")
sampling_search.swap_phase("search")
print(energy.energy)

9.385902447438058
