In [25]:
%load_ext autoreload 

from tqdm import tqdm
import matplotlib.pyplot as plt
import plotly.express as px
import numpy as np
from easydict import EasyDict as edict
import yaml
import pandas as pd
import os

from scipy.signal import convolve2d, convolve
from scipy.signal.windows import blackman, gaussian
import copy

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, KernelPCA, FastICA
from sklearn.metrics import r2_score, make_scorer
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.model_selection import cross_val_score
from sklearn.datasets import make_swiss_roll,\
                             make_s_curve,\
                             make_moons

from collections import defaultdict
from joblib import Parallel, delayed
from umap import UMAP

from IPython.core.debugger import set_trace
from IPython.display import clear_output

import torch
from torch import nn
from torch import optim
from torch import autograd
import torch.nn.functional as F


from train_utils import get_capacity, plot_weights_hist, train
from metric_utils import calculate_Q_metrics, \
                         strain, \
                         l2_loss, \
                         to_numpy, \
                         get_pred_index, \
                         numpy_metric, \
                         cosine_sim

from input_utils import DataGenerator, make_random_affine
from models_utils import init_weights, \
                         universal_approximator, \
                         dJ_criterion, \
                         gained_function, \
                         adjust_learning_rate, \
                         compute_joint_probabilities, \
                         tsne_loss,\
                         tsne_criterion

import warnings
warnings.filterwarnings("ignore")

plt.rcParams['font.size'] = 20
device = torch.device('cuda:0')

from embedding_utils import ConstructUMAPGraph, UMAPLoss, UMAPDataset

from pynndescent import NNDescent
from umap.umap_ import fuzzy_simplicial_set, make_epochs_per_sample
from sklearn.utils import check_random_state

%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:

# SCALER = MinMaxScaler((-1,1))
SCALER = StandardScaler()

input_parameters = {'generator': make_swiss_roll, #make_s_curve, 
                    'generator_kwargs': {'n_samples':10000, 'noise':1e-2}, # 1e-1
                    'unsupervised':True,
                    'whiten':True,
                    'scaler':SCALER,
                    'use_outpt_color':True} 

create_data = DataGenerator(**input_parameters)

inpt, outpt, color = create_data()
inpt_test, outpt_test, color_test = create_data()

In [52]:
inpt_torch = torch.tensor(inpt, dtype=torch.float).to(device)
inpt_torch_test = torch.tensor(inpt_test, dtype=torch.float).to(device)

In [3]:
# plt.ioff()
# plt.figure()
# df = pd.DataFrame(inpt.T, columns=['x','y', 'z'])
# if color is not None:
#     df['target'] = color
# fig = px.scatter_3d(df, x='x', y='y', z='z', color='target' if 'target' in df else None)

# fig.show()

In [4]:
# X = inpt.T.copy()

# # number of trees in random projection forest
# n_trees = 5 + int(round((X.shape[0]) ** 0.5 / 20.0))
# # max number of nearest neighbor iters to perform
# n_iters = max(5, int(round(np.log2(X.shape[0]))))
# # distance metric
# metric="euclidean"
# # number of neighbors for computing k-neighbor graph
# n_neighbors = 10

# nnd = NNDescent(
#     X,
#     n_neighbors=n_neighbors,
#     metric=metric,
#     n_trees=n_trees,
#     n_iters=n_iters,
#     max_candidates=60,
#     verbose=True,
#     n_jobs=-1
# )
# # get indices and distances
# knn_indices, knn_dists = nnd.neighbor_graph

In [5]:
# # get indices and distances
# knn_indices, knn_dists = nnd.neighbor_graph
# random_state = check_random_state(None)
# # build fuzzy_simplicial_set
# out = fuzzy_simplicial_set(
#     X = X,
#     n_neighbors = n_neighbors,
#     metric = metric,
#     random_state = random_state,
#     knn_indices= knn_indices,
#     knn_dists = knn_dists,
#     return_dists=True
# )

In [6]:
# (out[-1].toarray() == out[-1].toarray().T).all()

# Train

In [67]:
class Encoder(nn.Module):

    def __init__ (self, input_dim, output_dim=2):
        super().__init__()
        self.linear1 = nn.Linear(3, 20)
        self.linear2 = nn.Linear(20, 20)
        self.linear3 = nn.Linear(20, 10)
        self.linear4 = nn.Linear(10, output_dim)

    def forward(self, x):
        
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))
        x = F.tanh(self.linear3(x))
        x = self.linear4(x)

        return x
    
model = Encoder(input_dim=3, output_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [68]:
# construct graph of nearest neighbors

X = inpt.T
graph_constructor =  ConstructUMAPGraph(metric='euclidean', n_neighbors=30, batch_size=1000, random_state=42)
epochs_per_sample, head, tail, weight = graph_constructor(X)

dataset = UMAPDataset(X, epochs_per_sample, head, tail, weight, device=device, batch_size=1000)


criterion = UMAPLoss(device=device, 
                     min_dist=0.1,
                     batch_size=1000,
                     negative_sample_rate=5,
                     edge_weight=None,
                     repulsion_strength=1.0)

Thu Jul  7 17:43:55 2022 Building RP forest with 10 trees
Thu Jul  7 17:43:55 2022 NN descent for 13 iterations
	 1  /  13
	 2  /  13
	Stopping threshold met -- exiting after 2 iterations


In [None]:
train_losses = []
for epoch in tqdm(range(500)):
    train_loss = 0.
    for batch_to, batch_from in dataset.get_batches():
        optimizer.zero_grad()
        
        batch_to = batch_to.to(device)
        batch_from = batch_from.to(device)
        
        embedding_to = model(batch_to)
        embedding_from = model(batch_from)
        
        loss = criterion(embedding_to, embedding_from)
        
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_losses.append(train_loss)
    print('epoch: {}, loss: {}'.format(epoch, train_loss))
    

  0%|▎                                                                                                                               | 1/500 [00:03<26:07,  3.14s/it]

epoch: 0, loss: 338.484777957201


  0%|▌                                                                                                                               | 2/500 [00:06<26:02,  3.14s/it]

epoch: 1, loss: 158.03665070980787


  1%|▊                                                                                                                               | 3/500 [00:09<26:00,  3.14s/it]

epoch: 2, loss: 145.99532625824213


  1%|█                                                                                                                               | 4/500 [00:12<25:53,  3.13s/it]

epoch: 3, loss: 140.99359356611967


  1%|█▎                                                                                                                              | 5/500 [00:15<25:48,  3.13s/it]

epoch: 4, loss: 134.8541722819209


  1%|█▌                                                                                                                              | 6/500 [00:18<25:44,  3.13s/it]

epoch: 5, loss: 127.24528189748526


  1%|█▊                                                                                                                              | 7/500 [00:21<25:41,  3.13s/it]

epoch: 6, loss: 120.86740320920944


  2%|██                                                                                                                              | 8/500 [00:25<25:36,  3.12s/it]

epoch: 7, loss: 116.79431929439306


  2%|██▎                                                                                                                             | 9/500 [00:28<25:29,  3.12s/it]

epoch: 8, loss: 113.51284345239401


  2%|██▌                                                                                                                            | 10/500 [00:31<25:24,  3.11s/it]

epoch: 9, loss: 110.70672848820686


  2%|██▊                                                                                                                            | 11/500 [00:34<25:20,  3.11s/it]

epoch: 10, loss: 108.74866299331188


  2%|███                                                                                                                            | 12/500 [00:37<25:18,  3.11s/it]

epoch: 11, loss: 107.77163507044315


  3%|███▎                                                                                                                           | 13/500 [00:40<25:16,  3.11s/it]

epoch: 12, loss: 107.13683497160673


  3%|███▌                                                                                                                           | 14/500 [00:43<25:12,  3.11s/it]

epoch: 13, loss: 106.67290645092726


  3%|███▊                                                                                                                           | 15/500 [00:46<25:08,  3.11s/it]

epoch: 14, loss: 106.37697187066078


  3%|████                                                                                                                           | 16/500 [00:49<25:07,  3.11s/it]

epoch: 15, loss: 106.11909382045269


  3%|████▎                                                                                                                          | 17/500 [00:53<25:04,  3.11s/it]

epoch: 16, loss: 105.62943861633539


  4%|████▌                                                                                                                          | 18/500 [00:56<25:01,  3.11s/it]

epoch: 17, loss: 105.43233702331781


  4%|████▊                                                                                                                          | 19/500 [00:59<24:57,  3.11s/it]

epoch: 18, loss: 104.97323666512966


  4%|█████                                                                                                                          | 20/500 [01:02<24:53,  3.11s/it]

epoch: 19, loss: 104.67148205637932


  4%|█████▎                                                                                                                         | 21/500 [01:05<24:49,  3.11s/it]

epoch: 20, loss: 104.76458992063999


  4%|█████▌                                                                                                                         | 22/500 [01:08<24:47,  3.11s/it]

epoch: 21, loss: 104.42895598709583


  5%|█████▊                                                                                                                         | 23/500 [01:11<24:45,  3.11s/it]

epoch: 22, loss: 103.83365332335234


  5%|██████                                                                                                                         | 24/500 [01:14<24:46,  3.12s/it]

epoch: 23, loss: 103.85624372214079


  5%|██████▎                                                                                                                        | 25/500 [01:17<24:45,  3.13s/it]

epoch: 24, loss: 103.94737497717142


  5%|██████▌                                                                                                                        | 26/500 [01:21<24:38,  3.12s/it]

epoch: 25, loss: 103.36672351509333


  5%|██████▊                                                                                                                        | 27/500 [01:24<24:31,  3.11s/it]

epoch: 26, loss: 103.43015360832214


  6%|███████                                                                                                                        | 28/500 [01:27<24:27,  3.11s/it]

epoch: 27, loss: 103.39892856776714


  6%|███████▎                                                                                                                       | 29/500 [01:30<24:22,  3.10s/it]

epoch: 28, loss: 103.29506955295801


  6%|███████▌                                                                                                                       | 30/500 [01:33<24:17,  3.10s/it]

epoch: 29, loss: 102.99222132563591


  6%|███████▊                                                                                                                       | 31/500 [01:36<24:14,  3.10s/it]

epoch: 30, loss: 102.66658975183964


  6%|████████▏                                                                                                                      | 32/500 [01:39<24:12,  3.10s/it]

epoch: 31, loss: 102.6829483807087


  7%|████████▍                                                                                                                      | 33/500 [01:42<24:09,  3.10s/it]

epoch: 32, loss: 102.63694394379854


  7%|████████▋                                                                                                                      | 34/500 [01:45<24:07,  3.11s/it]

epoch: 33, loss: 102.53990598767996


  7%|████████▉                                                                                                                      | 35/500 [01:48<23:51,  3.08s/it]

epoch: 34, loss: 102.41287647932768


  7%|█████████▏                                                                                                                     | 36/500 [01:51<23:34,  3.05s/it]

epoch: 35, loss: 102.36207477748394


  7%|█████████▍                                                                                                                     | 37/500 [01:54<23:24,  3.03s/it]

epoch: 36, loss: 102.30641531944275


  8%|█████████▋                                                                                                                     | 38/500 [01:57<23:14,  3.02s/it]

epoch: 37, loss: 102.14077414572239


  8%|█████████▉                                                                                                                     | 39/500 [02:00<23:05,  3.01s/it]

epoch: 38, loss: 102.02860028296709


  8%|██████████▏                                                                                                                    | 40/500 [02:03<23:00,  3.00s/it]

epoch: 39, loss: 102.04716634005308


  8%|██████████▍                                                                                                                    | 41/500 [02:06<22:55,  3.00s/it]

epoch: 40, loss: 101.88248163461685


  8%|██████████▋                                                                                                                    | 42/500 [02:09<22:50,  2.99s/it]

epoch: 41, loss: 101.84367223829031


  9%|██████████▉                                                                                                                    | 43/500 [02:12<22:46,  2.99s/it]

epoch: 42, loss: 101.91599417477846


  9%|███████████▏                                                                                                                   | 44/500 [02:15<22:42,  2.99s/it]

epoch: 43, loss: 101.91100866347551


  9%|███████████▍                                                                                                                   | 45/500 [02:18<22:38,  2.99s/it]

epoch: 44, loss: 101.83744624257088


  9%|███████████▋                                                                                                                   | 46/500 [02:21<22:35,  2.98s/it]

epoch: 45, loss: 102.03180320560932


  9%|███████████▉                                                                                                                   | 47/500 [02:24<22:31,  2.98s/it]

epoch: 46, loss: 101.84419547766447


 10%|████████████▏                                                                                                                  | 48/500 [02:27<22:27,  2.98s/it]

epoch: 47, loss: 101.64842235296965


 10%|████████████▍                                                                                                                  | 49/500 [02:30<22:38,  3.01s/it]

epoch: 48, loss: 101.750058747828


 10%|████████████▋                                                                                                                  | 50/500 [02:33<22:52,  3.05s/it]

epoch: 49, loss: 101.8117280676961


 10%|████████████▉                                                                                                                  | 51/500 [02:37<22:58,  3.07s/it]

epoch: 50, loss: 101.45809541642666


 10%|█████████████▏                                                                                                                 | 52/500 [02:40<23:01,  3.08s/it]

epoch: 51, loss: 101.64349554479122


 11%|█████████████▍                                                                                                                 | 53/500 [02:43<23:01,  3.09s/it]

epoch: 52, loss: 101.78141621500254


 11%|█████████████▋                                                                                                                 | 54/500 [02:46<23:01,  3.10s/it]

epoch: 53, loss: 101.67086922377348


 11%|█████████████▉                                                                                                                 | 55/500 [02:49<22:59,  3.10s/it]

epoch: 54, loss: 101.30917491018772


 11%|██████████████▏                                                                                                                | 56/500 [02:52<22:58,  3.10s/it]

epoch: 55, loss: 101.53082120418549


 11%|██████████████▍                                                                                                                | 57/500 [02:55<22:55,  3.11s/it]

epoch: 56, loss: 101.31948654353619


 12%|██████████████▋                                                                                                                | 58/500 [02:58<22:54,  3.11s/it]

epoch: 57, loss: 101.65541085600853


 12%|██████████████▉                                                                                                                | 59/500 [03:01<22:53,  3.11s/it]

epoch: 58, loss: 101.42713379859924


 12%|███████████████▏                                                                                                               | 60/500 [03:05<22:50,  3.11s/it]

epoch: 59, loss: 101.3343391790986


 12%|███████████████▍                                                                                                               | 61/500 [03:08<22:45,  3.11s/it]

epoch: 60, loss: 101.41360133886337


 12%|███████████████▋                                                                                                               | 62/500 [03:11<22:41,  3.11s/it]

epoch: 61, loss: 101.22543373703957


 13%|████████████████                                                                                                               | 63/500 [03:14<22:37,  3.11s/it]

epoch: 62, loss: 101.28459779173136


 13%|████████████████▎                                                                                                              | 64/500 [03:17<22:34,  3.11s/it]

epoch: 63, loss: 101.11793719232082


 13%|████████████████▌                                                                                                              | 65/500 [03:20<22:30,  3.11s/it]

epoch: 64, loss: 101.14747996628284


 13%|████████████████▊                                                                                                              | 66/500 [03:23<22:27,  3.10s/it]

epoch: 65, loss: 101.29962220788002


 13%|█████████████████                                                                                                              | 67/500 [03:26<22:14,  3.08s/it]

epoch: 66, loss: 101.52376524358988


 14%|█████████████████▎                                                                                                             | 68/500 [03:29<21:54,  3.04s/it]

epoch: 67, loss: 101.17763647437096


 14%|█████████████████▌                                                                                                             | 69/500 [03:32<21:39,  3.01s/it]

epoch: 68, loss: 101.12762182205915


 14%|█████████████████▊                                                                                                             | 70/500 [03:35<21:44,  3.03s/it]

epoch: 69, loss: 101.2257237881422


 14%|██████████████████                                                                                                             | 71/500 [03:38<21:48,  3.05s/it]

epoch: 70, loss: 101.2172804698348


 14%|██████████████████▎                                                                                                            | 72/500 [03:41<21:51,  3.07s/it]

epoch: 71, loss: 101.17280331254005


 15%|██████████████████▌                                                                                                            | 73/500 [03:44<21:53,  3.08s/it]

epoch: 72, loss: 101.30094039440155


 15%|██████████████████▊                                                                                                            | 74/500 [03:48<21:53,  3.08s/it]

epoch: 73, loss: 101.18896967172623


 15%|███████████████████                                                                                                            | 75/500 [03:51<21:53,  3.09s/it]

epoch: 74, loss: 101.18668080121279


 15%|███████████████████▎                                                                                                           | 76/500 [03:54<21:53,  3.10s/it]

epoch: 75, loss: 101.10811383277178


 15%|███████████████████▌                                                                                                           | 77/500 [03:57<21:53,  3.11s/it]

epoch: 76, loss: 101.13014883548021


 16%|███████████████████▊                                                                                                           | 78/500 [04:00<21:50,  3.11s/it]

epoch: 77, loss: 101.03514854609966


 16%|████████████████████                                                                                                           | 79/500 [04:03<21:48,  3.11s/it]

epoch: 78, loss: 101.27054848521948


 16%|████████████████████▎                                                                                                          | 80/500 [04:06<21:45,  3.11s/it]

epoch: 79, loss: 101.14956094324589


 16%|████████████████████▌                                                                                                          | 81/500 [04:09<21:41,  3.11s/it]

epoch: 80, loss: 101.14034391194582


 16%|████████████████████▊                                                                                                          | 82/500 [04:12<21:35,  3.10s/it]

epoch: 81, loss: 100.93220705538988


 17%|█████████████████████                                                                                                          | 83/500 [04:16<21:30,  3.10s/it]

epoch: 82, loss: 100.86279848963022


 17%|█████████████████████▎                                                                                                         | 84/500 [04:19<21:29,  3.10s/it]

epoch: 83, loss: 101.01277838647366


 17%|█████████████████████▌                                                                                                         | 85/500 [04:22<21:29,  3.11s/it]

epoch: 84, loss: 100.97662508487701


 17%|█████████████████████▊                                                                                                         | 86/500 [04:25<21:31,  3.12s/it]

epoch: 85, loss: 100.98680716007948


 17%|██████████████████████                                                                                                         | 87/500 [04:28<21:27,  3.12s/it]

epoch: 86, loss: 100.95861967653036


 18%|██████████████████████▎                                                                                                        | 88/500 [04:31<21:22,  3.11s/it]

epoch: 87, loss: 100.98464654386044


 18%|██████████████████████▌                                                                                                        | 89/500 [04:34<21:20,  3.12s/it]

epoch: 88, loss: 100.9549647346139


 18%|██████████████████████▊                                                                                                        | 90/500 [04:37<21:16,  3.11s/it]

epoch: 89, loss: 100.8848369717598


 18%|███████████████████████                                                                                                        | 91/500 [04:40<21:11,  3.11s/it]

epoch: 90, loss: 100.79675459861755


 18%|███████████████████████▎                                                                                                       | 92/500 [04:44<21:06,  3.10s/it]

epoch: 91, loss: 101.11502814292908


 19%|███████████████████████▌                                                                                                       | 93/500 [04:47<21:01,  3.10s/it]

epoch: 92, loss: 100.83308932185173


 19%|███████████████████████▉                                                                                                       | 94/500 [04:50<20:58,  3.10s/it]

epoch: 93, loss: 100.86219940334558


 19%|████████████████████████▏                                                                                                      | 95/500 [04:53<20:51,  3.09s/it]

epoch: 94, loss: 100.8403265401721


 19%|████████████████████████▍                                                                                                      | 96/500 [04:56<20:34,  3.05s/it]

epoch: 95, loss: 100.95748060196638


 19%|████████████████████████▋                                                                                                      | 97/500 [04:59<20:21,  3.03s/it]

epoch: 96, loss: 100.73827377706766


 20%|████████████████████████▉                                                                                                      | 98/500 [05:02<20:12,  3.02s/it]

epoch: 97, loss: 100.69939376413822


In [None]:
plt.plot(train_losses)

In [None]:
output_pred = to_numpy(model(inpt_torch_test.T))

In [None]:
plt.scatter(output_pred[:,0], output_pred[:,1], c=color_test)