## Experiments around the disentanglement of the models' embeddings 

### VQ-VAE: TSNE

In [None]:
# TODO: TSNE
from ccbir.configuration import config
import torch 

config.pythonpath_fix()

def _plot_tsne_static(embedding, labels, beta, perplexity, num_points=1000, n_iter='default'):
    #plt.figure(dpi=100)
    plt.gca().set_aspect('auto', 'box')
    
    embedding = torch.as_tensor(embedding)
    labels = torch.as_tensor(labels)

    # sample without replacement
    sample_idxs = torch.randperm(len(embedding))[:num_points]
    embedding_sample = torch.index_select(embedding, dim=0, index=sample_idxs) 
    labels_sample = torch.index_select(labels, dim=0, index=sample_idxs)

    for digit in range(0, 9 + 1):
        embedding_for_digit = embedding_sample[labels_sample == digit, :]
        x = embedding_for_digit[:, 0]
        y = embedding_for_digit[:, 1]
        plt.scatter(x, y, label=str(digit))
    
    if n_iter == 'default':
        n_iter = 1000 # as per sklearn's default

    plt.title(f"{beta=}, {perplexity=} {n_iter=}")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.figure()
    plt.show()
 

def plot_tsne_static(beta, perplexity=30, recompute_embedding=False, n_iter='default'):
    assert isinstance(perplexity, int) 
    if n_iter == 'default':
        pickle_file = vae_path(beta) / f"z_tsne_embedded_perplexity_{perplexity}.obj"
    else:
        pickle_file = vae_path(beta) / f"z_tsne_embedded_perplexity_{perplexity}_n_iter_{n_iter}.obj"
    
    if recompute_embedding or not pickle_file.exists():
        model = load_vae_model(beta)
        z = model.reparametrize(*model.encode(test_images.to(device)))
        z_cpu = z.detach().cpu().numpy()
        tsne = (
            TSNE(perplexity=perplexity, n_jobs=-1) if n_iter == 'default' else
            TSNE(perplexity=perplexity, n_iter=n_iter, n_jobs=-1)
        )
        z_embedded = tsne.fit_transform(z_cpu)
        pickle.dump(z_embedded, open(pickle_file, 'wb+'))
    else:
        z_embedded = pickle.load(open(pickle_file, 'rb'))
    
    _plot_tsne_static(z_embedded, test_labels, beta, perplexity, n_iter=n_iter)