In [None]:
import pickle
import numpy as np
import umap

import torch
from torch import nn

import random

from bokeh.plotting import figure, output_notebook, show, ColumnDataSource
from bokeh.transform import factor_cmap, linear_cmap
from bokeh.palettes import inferno, viridis

In [None]:
EMBEDDINGS="/media/eduseiti/bigdata02/unicamp/doutorado/bootstrap.pytorch/logs/MSEmbedding/mixedSpectraCrux_all_lstm_40_3_layers_double_n_pair_LR1e-2_q0.01_pvalue_0.3_epsilon_1e-8_plot_validation/sample_embeddings_q0.01_all_lstm40_3layer_pvalue_0.3_double_n_pair.pkl"

In [None]:
with open(EMBEDDINGS, "rb") as inputFile:
    data = pickle.load(inputFile)

In [None]:
len(data)

### Some initialization

In [None]:
only_embeddings = np.array([elem[2].numpy() for elem in data])
only_sequences = np.array([elem[0] for elem in data])
sequences_len = np.array([len(sequence) for sequence in only_sequences])

In [None]:
only_embeddings.shape

In [None]:
embeddings = torch.from_numpy(only_embeddings)

In [None]:
embeddingsNorm = nn.functional.normalize(embeddings)
allCosineDistances = 1 - torch.mm(embeddingsNorm, embeddingsNorm.t())

In [None]:
ranks = []
near_counts = []

for i in range(len(embeddings) // 2):

    allCosineDistances[i * 2, i * 2] = -1 # Make sure the same embedding distance is always the first after sorting

    orderedDistancesFast = torch.argsort(allCosineDistances[i * 2])
    orderedListFast = orderedDistancesFast.tolist()

    sameRankFast = orderedListFast.index(i * 2)
    positiveExampleRankFast = orderedListFast.index(i * 2 + 1) - 1

    ranks.append(positiveExampleRankFast)
    
    near_count = 0
    
    for j in range(1, len(embeddings) // 2 - 1):
        if allCosineDistances[i * 2, orderedListFast[j]] < 0.05:
            near_count += 1
        else:
            break
            
    near_counts.append(near_count)

In [None]:
near_counts

In [None]:
output_notebook()

In [None]:
NUMBER_OF_POINTS = 1000

In [None]:
random.seed(4589)
sampled_points = random.sample(range(len(data)), NUMBER_OF_POINTS // 2)

which_points = []

for i in sampled_points:
    if i % 2 == 0:
        which_points += [i, i + 1]
    else:
        which_points += [i - 1, i]

In [None]:
unique_sequences = sorted(np.unique(only_sequences[which_points]))
sequence_color_map = factor_cmap("sequence", palette=inferno(256) + viridis(256), factors=unique_sequences)

In [None]:
unique_ordered_len = sorted(np.unique(sequences_len))
sequence_color_map_len = linear_cmap("sequence_length", palette=inferno(len(unique_ordered_len)), low=min(unique_ordered_len), high=max(unique_ordered_len))

In [None]:
SEQUENCE_TOOLTIP = [
    ("index", "$index"),
    ("sequence", "@sequence"),
    ("len", "@sequence_length")
]

In [None]:
def plot_umap(n_neighbors, min_dist, color_map=sequence_color_map):
    
#     print("which_points={}".format(which_points))
    
    fit = umap.UMAP(metric="cosine", n_neighbors=n_neighbors, min_dist=min_dist)
    
    result = fit.fit_transform(only_embeddings[which_points])
    
    data_source = ColumnDataSource(data = dict(x = result[:, 0], y = result[:, 1], sequence = only_sequences[which_points], sequence_length = sequences_len[which_points]))
    
    chart = figure(plot_width=750, plot_height=750, tooltips=SEQUENCE_TOOLTIP)
    chart.circle('x', 'y', size=10, source=data_source, alpha=0.5, color=color_map)
    show(chart)

### Try some different UMAP parameters

In [None]:
plot_umap(2, 1)

In [None]:
plot_umap(5, 0.01, color_map=sequence_color_map_len)

In [None]:
plot_umap(5, 0.01)

In [None]:
plot_umap(5, 0.01)

In [None]:
plot_umap(5, 0.001, color_map=sequence_color_map_len)

In [None]:
plot_umap(5, 0.001)

In [None]:
plot_umap(60, 0.0000001, color_map=sequence_color_map_len)

In [None]:
plot_umap(7, 1)

In [None]:
plot_umap(10, 0.1)

In [None]:
plot_umap(50, 1)