## Notebook hypotheses
#### 1. Contrastative learning embedding has significantly better separability than PCA or tsne even when overfitted

### Motivation
Compare contrastative learning model against benchmarks that should perform worse.
* PCA is a simple and fast linear method, shouldnt beat a contrastative model non-linear mapping in a complex problems.
* t-sne does not learn a mapping from feature-space so cant be used as an embedding. If we cant beat t-sne then there is a better non-linear mapping we have not achived. 

### Actions:
* Hypotheses N. 1 is `True`.
    * Make a dataset of embeddings useful for training a classifier
* Hypotheses N. 1 is `False`.
    * Review contrastative learning model.

### Results:
* 1. `True`, contrastative embedding has orders of magnitude more separability. [Link](#techniques-separability).

In [None]:
# nb metadata
NB_NAME = "dimentionality reduction.ipynb"
NB_PATH = "notebooks"

## Contrastative embeddings dimentionality reduction

In [None]:
%env WANDB_API_KEY=4f8699d18b665419da19c00aeb7291bcafb88ac5
# import os
# os.environ['WANDB_API_KEY'] = '4f8699d18b665419da19c00aeb7291bcafb88ac5'
from omegaconf import DictConfig
from test_nn_template.run import WandbHandler

print(f"Runs: {len(WandbHandler.get_runs(entity='fernandoezequiel512', project='test_nn_template'))}")
run_id = "fernandoezequiel512/test_nn_template/runs/28vaq6y9"
run = WandbHandler.get_run(run_id)
print(f"Got {run.name}")
model = WandbHandler.load_run_model_checkpoint(run_id)
print(model)


def replace_dataset(cfg: DictConfig):
    _target_train = cfg.data.datamodule.datasets.train._target_.replace("MyContrastativeDataset", "MyDataset")
    _target_test = cfg.data.datamodule.datasets.test._target_.replace("MyContrastativeDataset", "MyDataset")
    cfg.data.datamodule.datasets.train._target_ = _target_train
    cfg.data.datamodule.datasets.test._target_ = _target_test
    return cfg


datamodule = WandbHandler.load_run_datamodule(run_id, cfg_func=replace_dataset)
# print(datamodule)

In [None]:
datamodule.setup(stage="test")
test_dataloader = datamodule.test_dataloader()[0]
train_dataloader = datamodule.train_dataloader()

In [None]:
import random
import matplotlib.pyplot as plt
import numpy as np
from test_nn_template.data.datamodule import MyDataModule

random.seed(42)

x_contrast = []
labels = []
for x, y in test_dataloader:
    # print(x.shape, y.detach().numpy())
    emb = model.model.forward_once(x).detach().numpy()
    x_contrast.append(emb)
    labels.append(y)
    # break
x_contrast = np.concatenate(x_contrast)
labels = np.concatenate(labels)

ixs = random.choices(range(x_contrast.shape[0]), k=500)

plt.scatter(x_contrast[ixs, 0], x_contrast[ixs, 1], c=labels[ixs])

## PCA dimentionality reduction

In [None]:
import numpy as np
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
X = []
labels = []
for x, y in test_dataloader:
    X.append(x.detach().numpy())
    labels.append(y)
X = np.concatenate(X).squeeze()
X = X.reshape((X.shape[0], -1))
print(X.shape)
pca.fit(X)
X_pca = pca.transform(X)

print(f"explained variance: {pca.explained_variance_ratio_}")
print(f"singular values: {pca.singular_values_}")

labels = np.concatenate(labels)

plt.scatter(X_pca[ixs, 0], X_pca[ixs, 1], c=labels[ixs])

## t-sne dimentionality reduction

In [None]:
from sklearn.manifold import TSNE

X = []
labels = []
for x, y in test_dataloader:
    X.append(x.detach().numpy())
    labels.append(y)
X = np.concatenate(X).squeeze()
X = X.reshape((X.shape[0], -1))
# print(X.shape)

X_tsne = TSNE(n_components=2, learning_rate="auto", init="random", perplexity=3).fit_transform(X)
X_tsne.shape

labels = np.concatenate(labels)
plt.scatter(X_tsne[ixs, 0], X_tsne[ixs, 1], c=labels[ixs])

## Techniques separability

In [None]:
import time
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.model_selection import train_test_split
import pandas as pd

vars = ["contrast_1", "contrast_2", "pca_1", "pca_2", "tsne_1", "tsne_2"]
X = np.concatenate([x_contrast, X_pca, X_tsne], axis=1)
X_train, X_test, y_train, y_test = train_test_split(X, labels, stratify=labels, random_state=42)

forest = RandomForestClassifier(random_state=0)
forest.fit(X_train, y_train)

start_time = time.time()
result = permutation_importance(forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2)
elapsed_time = time.time() - start_time
print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")
forest_importances = pd.Series(result.importances_mean, index=vars)

fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()

## Save embeddings dataset

In [None]:
import os
from test_nn_template.data.export import EmbeddingsSaver


source = os.path.join(NB_PATH, NB_NAME)
transform = lambda x: model.model.forward_once(x)
class_to_index = {
    "T-shirt/top": 0,
    "Trouser": 1,
    "Pullover": 2,
    "Dress": 3,
    "Coat": 4,
    "Sandal": 5,
    "Shirt": 6,
    "Sneaker": 7,
    "Bag": 8,
    "Ankle boot": 9,
}

embeddings_saver = EmbeddingsSaver(run, train_dataloader, test_dataloader, source, class_to_index, transform)
embeddings_saver.save("data")

In [None]:
import torch

torch.load("/home/fernando/conversenow/test_nn_template/data/sage-field-17/test/data/embeddings/2.pt")

In [None]:
torch.load("/home/fernando/conversenow/test_nn_template/data/sage-field-17/test/data/labels/2.pt")

## Feed-forward classifier trained on embedding space

In [None]:
run_id = "fernandoezequiel512/test_nn_template/runs/2841ki1m"
run = WandbHandler.get_run(run_id)
print(f"Got {run.name}")
model = WandbHandler.load_run_model_checkpoint(run_id)
print(model)
print(f"Accuract test: {round(run.summary['acc/test'], 2)}")

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

x = []
y = []
labels = []
for i in np.arange(-1.5, 2, 0.1):
    for j in np.arange(-1.5, 2, 0.1):
        x.append(i)
        y.append(j)
        pos = torch.tensor([i, j]).float()
        # print(pos.float())
        probs = torch.softmax(model(pos), dim=-1)
        label = int(probs.argmax().detach())
        labels.append(label)

plt.scatter(x, y, c=labels)