# Various Tests of opt-SNE on flow data

We will investigate: 
- 1NN accuracy and KLD values of embeddings 
- across "bh" vs "fft" 
- random vs pca init 
- embeddings with or without opt-SNE automated stopping 

## Utils

### Data Preprocessing

In [2]:
%matplotlib inline

# imports 
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import flowkit as fk 
from openTSNE import TSNE

# Pre-processing data 
sample = fk.Sample('data/flow18_annotated.fcs', sample_id='flow18', channel_labels=('Parameter_1', 'Parameter_10', 'Parameter_11', 'Parameter_12', 'Parameter_13', 'Parameter_14', 'Parameter_15', 'Parameter_16', 'Parameter_17', 'Parameter_18', 'Parameter_19', 'Parameter_2', 'Parameter_20', 'Parameter_21', 'Parameter_22', 'Parameter_23', 'Parameter_24', 'Parameter_3', 'Parameter_4', 'Parameter_5', 'Parameter_6', 'Parameter_7', 'Parameter_8', 'Parameter_9', 'SampleID', 'class'))
df_events = sample.as_dataframe(source="raw")

# only use selected columns 
selected_columns = [
    'Parameter_10', 'Parameter_11', 'Parameter_12', 
    'Parameter_13', 'Parameter_15', 'Parameter_18', 'Parameter_20', 
    'Parameter_21', 'Parameter_23', 'Parameter_8', 'Parameter_9', 'class'
]

df_filtered = df_events[selected_columns]

### KLD Monitoring

In [1]:
from openTSNE import callbacks
buffer_ee = 15
switch_buffer = 2 

class KLDRCMonitorEE(callbacks.Callback):
    def __init__(self, record_every=3):
        """
        Parameters:
            record_every (int): Check KL divergence every this many iterations.
            buffer_ee (int): Minimum iterations before monitoring KL divergence.
            switch_buffer (int): Extra iterations to confirm EE phase exit.
        """
        self.record_every = record_every  # Equivalent to `auto_iter_pollrate_ee = 3`
        
        self.kl_divergences = []
        self.last_error = None
        self.last_rel_change = None
        self.switch_buffer_count = switch_buffer  # Tracks remaining iterations before exiting EE

    def __call__(self, iteration, error, embedding):
        """
        Monitors KL divergence and determines when to stop Early Exaggeration.
        Returns True if EE should stop.
        """
        # Only check every `record_every` iterations
        if iteration % self.record_every == 0:
            self.kl_divergences.append((iteration, error))

            if self.last_error is not None:
                # Compute relative change: (prev_error - current_error) / prev_error
                rel_change = 100 * (self.last_error - error) / self.last_error  

                print(f"Iteration {iteration}: KL Divergence = {error:.4f}, Relative Change = {rel_change:.4f}%")

                # Start checking only after `buffer_ee` iterations
                if iteration > buffer_ee:
                    if self.last_rel_change is not None and rel_change < self.last_rel_change:
                        # If relative change decreases, start the switch buffer countdown
                        if self.switch_buffer_count < 1:
                            print("Relative change has consistently decreased. Stopping Early Exaggeration.")
                            print(f"EE Iteration stopped at {iteration}")
                            return True  # Signal to stop EE phase
                        self.switch_buffer_count -= 1
                    else:
                        # Reset switch buffer if relative change increases again
                        self.switch_buffer_count = switch_buffer

                self.last_rel_change = rel_change

            # Update last error for the next iteration
            self.last_error = error

        EE_iteration_stopped = iteration 
        return False  # Continue EE phase if conditions are not met


In [3]:
class KLDRCMonitorNoOpt(callbacks.Callback):
    def __init__(self, record_every=5):
     
        self.record_every = record_every  # Equivalent to `auto_iter_pollrate_ee = 3`
        self.kl_divergences = []

    def __call__(self, iteration, error, embedding):
        """
        Monitors KL divergence. 
        """
        # Only check every `record_every` iterations
        if iteration % self.record_every == 0:
            self.kl_divergences.append((iteration, error))

In [4]:
buffer_run = 150 
auto_iter_end = 100 

class KLDRCMonitorRun(callbacks.Callback):
    def __init__(self, record_every=5):
        """
        Parameters:
            record_every (int): Check KL divergence every this many iterations.
            buffer_run (int): Minimum iterations after EE before monitoring for stopping.
            auto_iter_end (float): Threshold for stopping, lower values stop earlier.
        """
        self.record_every = record_every  # Equivalent to `auto_iter_pollrate_run = 5`
        # self.buffer_run = buffer_run  # Equivalent to `auto_iter_buffer_run = 15`
        # self.auto_iter_end = auto_iter_end  # Used for stopping condition

        self.kl_divergences = []
        self.last_error = None

    def __call__(self, iteration, error, embedding):
        """
        Monitors KL divergence and determines when to stop the full t-SNE run.
        Returns True if the run should stop.
        """
        # Only check KL divergence every `record_every` iterations
        if iteration % self.record_every == 0:
            self.kl_divergences.append((iteration, error))

            if self.last_error is not None:
                # Compute absolute error difference
                error_diff = abs(self.last_error - error)

                print(f"Iteration {iteration}: KL Divergence = {error:.4f}, Error Diff = {error_diff:.6f}")

                # Start monitoring only after `buffer_run` iterations have passed
                if iteration > buffer_run:
                    # Stopping condition from C++: abs(error_diff)/pollrate < error/auto_iter_end
                    if (error_diff / self.record_every) < (error / auto_iter_end):
                        print("KL divergence change is below threshold. Stopping optimization.")
                        print(f"Run iteration stopped at {iteration}")
                        return True  # Signal to stop t-SNE run

            # Update last error
            self.last_error = error

        return False  # Continue t-SNE run


### 1NN Accuracy

In [5]:
from sklearn.neighbors import NearestNeighbors

def compute_1nn_accuracy(Y, labels):
    """
    Computes the 1-Nearest Neighbor (1NN) accuracy of the t-SNE embedding.

    Parameters:
    - Y (numpy array): t-SNE embedding of shape (N, no_dims)
    - labels (numpy array): Ground truth labels of shape (N,)

    Returns:
    - accuracy (float): 1NN classification accuracy
    """
    N = Y.shape[0]  # Number of data points

    # Use Nearest Neighbors to find the closest point
    nn = NearestNeighbors(n_neighbors=2, metric='euclidean')  # Find 2 nearest (1st is itself)
    nn.fit(Y)
    distances, indices = nn.kneighbors(Y)  # Get nearest neighbor indices
    labels = np.array(labels)

    # The 1NN prediction is the label of the nearest neighbor (not itself)
    nearest_neighbor_indices = indices[:, 1]  # Take the second closest (first is itself)
    predicted_labels = labels[nearest_neighbor_indices]

    # Compute accuracy
    accuracy = np.mean(predicted_labels == labels)  # Check how many match

    return accuracy

# Example Usage
# Y: Your t-SNE embedding of shape (N, 2) or (N, 3)
# labels: Ground truth labels of shape (N,)
# accuracy = compute_1nn_accuracy(Y, labels)
# print(f"1NN Accuracy: {accuracy * 100:.2f}%")

### Plotting

In [None]:
def plot_embedding(embedding, labels, data_percentage, neg_grad_method, init, learning_rate):     
    plt.figure(figsize=(6, 6))
    plt.scatter(
        embedding[:, 0], 
        embedding[:, 1], 
        c=labels, 
        cmap=plt.colormaps.get_cmap('Paired'), 
        s=10, 
        alpha=0.4
    )
    plt.title(f"opt-SNE on flow18, {100*data_percentage}% of dataset, {neg_grad_method}, {init} initialization, {learning_rate} learning rate")
    plt.show()