In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import pickle as pkl
import sys

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import torch
import torchvision
from pytorch_lightning import LightningModule, Trainer
from sklearn.ensemble import RandomForestClassifier
from sklearn.manifold import TSNE
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report, confusion_matrix
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchmetrics import MeanSquaredError
from torchvision.transforms import Compose, ToTensor, Normalize
from tqdm.notebook import tqdm

In [None]:
sys.path.insert(0, '../' * 2)

from continual_learning.models.autoencoder.omniglot import OmniglotAutoencoderModel, OmniglotAutoencoder, ThresholdStopping
from settings import DATASETS_DIR, MODELS_DIR

## Setup

In [None]:
ENCODER_OUTPUT_PATH = MODELS_DIR / 'ensemble_omniglot_autoencoder' / 'encoder.ckpt'
OMNIGLOT_DATASET_DIR = DATASETS_DIR / 'omniglot'
MNIST_DATASET_DIR = DATASETS_DIR / 'mnist'

In [None]:
# Autoencoder
ENCODER_SIZE = 512
INPUT_SIZE = 28
LEARNING_RATE = 0.001
BATCH_SIZE = 48
BATCH_SIZE_TEST = 256
# Training
MAX_EPOCHS = -1
LOSS_THRESHOLD = 0.020
THRESHOLD_METRIC = 'val/reconstruction_loss'

## Data loading

In [None]:
OMNIGLOT_DATASET_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
train_dataset = torchvision.datasets.Omniglot(
    root=OMNIGLOT_DATASET_DIR,
    download=True,
    background=True,
    transform=Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((28, 28)),
        torchvision.transforms.Lambda(lambda x: 1.0 - x)
    ])
)
val_dataset = torchvision.datasets.Omniglot(
    root=OMNIGLOT_DATASET_DIR,
    download=True,
    background=False,
    transform=Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((28, 28)),
        torchvision.transforms.Lambda(lambda x: 1.0 - x)
    ])
)

train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_data_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE_TEST)

len(train_dataset), len(val_dataset)

## Model preparation

In [None]:
# model = OmniglotAutoencoderModel(encoder_size=28)
autoencoder = OmniglotAutoencoder(input_size=INPUT_SIZE, encoder_size=ENCODER_SIZE, learning_rate=LEARNING_RATE)

## Training

In [None]:
trainer = Trainer(
    max_epochs=MAX_EPOCHS,
    progress_bar_refresh_rate=10,
    enable_progress_bar=True,
    enable_checkpointing=False,
    checkpoint_callback=False,
    logger=True,
    weights_summary=None,
    callbacks=[ThresholdStopping(metric=THRESHOLD_METRIC, threshold=LOSS_THRESHOLD)],
)

trainer.fit(autoencoder, train_data_loader, val_data_loader)

In [None]:
trainer.callback_metrics

Save the model

In [None]:
ENCODER_OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
trainer.save_checkpoint(ENCODER_OUTPUT_PATH)

Load the model (for testing)

In [None]:
autoencoder = OmniglotAutoencoder.load_from_checkpoint(checkpoint_path=ENCODER_OUTPUT_PATH, input_size=INPUT_SIZE, encoder_size=ENCODER_SIZE, learning_rate=LEARNING_RATE)

## Verifying reconstruction quality

### Comparison with original examples

In [None]:
originals = []
encodings = []
reconstructions = []
labels = []

with torch.no_grad():
    for items, batch_labels in train_data_loader:
        reconstruction, joint_encoding, tanh_encoding, relu_encoding = autoencoder.model.forward(items)
        reconstructions.append(reconstruction.squeeze())
        labels.extend(batch_labels)
        originals.append(items.squeeze())
        encodings.append(joint_encoding)

    reconstructions = torch.cat(reconstructions, dim=0)
    labels = [label.item() for label in labels]
    originals = torch.cat(originals, dim=0)
    encodings = torch.cat(encodings, dim=0)
    
print(f"Original images: {np.shape(originals)}")
print(f"Reconstructions: {np.shape(reconstructions)}")

In [None]:
fig, axs = plt.subplots(figsize=(25, 6), nrows=2, ncols=5)

indexes_to_show = (0, 5)
for index, img_index in enumerate(range(*indexes_to_show)):
    axs[0][index].imshow(originals[img_index])
    axs[1][index].imshow(reconstructions[img_index])

### Encoding MNIST dataset

In [None]:
# mnist_train_dataset = torchvision.datasets.MNIST(
#     root=MNIST_DATASET_DIR,
#     download=True,
#     train=True,
#     transform=Compose([
#         torchvision.transforms.ToTensor(),
#         torchvision.transforms.Resize((28, 28)),
#     ])
# )

# mnist_test_dataset = torchvision.datasets.MNIST(
#     root=MNIST_DATASET_DIR,
#     download=True,
#     train=False,
#     transform=Compose([
#         torchvision.transforms.ToTensor(),
#         torchvision.transforms.Resize((28, 28)),
#     ])
# )

# mnist_data_loader_train = DataLoader(mnist_train_dataset, BATCH_SIZE)
# mnist_data_loader_test = DataLoader(mnist_test_dataset, BATCH_SIZE)

In [None]:
mnist_dataset_train = torchvision.datasets.MNIST(
    root=OMNIGLOT_DATASET_DIR,
    download=True,
    train=True,
    transform=Compose([
        ToTensor(),
#         Normalize((0.1307,), (0.3081,)),
    ])
)
mnist_dataset_test = torchvision.datasets.MNIST(
    root=OMNIGLOT_DATASET_DIR,
    download=True,
    train=False,
    transform=Compose([
        ToTensor(),
#         Normalize((0.1307,), (0.3081,)),
    ])
)
mnist_data_loader_train = DataLoader(mnist_dataset_train, batch_size=48)
mnist_data_loader_test = DataLoader(mnist_dataset_test, batch_size=48)

In [None]:
# mnist_encodings = []
# mnist_labels = []

# for items, batch_labels in mnist_data_loader_train:
#     tanh_encoding, relu_encoding = autoencoder.model.encode(items)
#     mnist_encodings.append(tanh_encoding)
#     mnist_labels.extend(batch_labels)
    
# mnist_labels = [label.item() for label in mnist_labels]
# mnist_encodings = torch.cat(mnist_encodings, dim=0).detach().numpy()

# print(mnist_encodings.shape)

In [None]:
mnist_encodings_test = []
mnist_labels_test = []

for items, batch_labels in tqdm(mnist_data_loader_test, total=len(mnist_data_loader_test)):
    tanh_encoding, relu_encoding = autoencoder.model.encode(items)
    mnist_encodings_test.append(tanh_encoding)
    mnist_labels_test.extend(batch_labels)
    
mnist_labels_test = [str(label.item()) for label in mnist_labels_test]
mnist_encodings_test = torch.cat(mnist_encodings_test, dim=0).detach().numpy()

print(mnist_encodings_test.shape)

T-SNE embeddings of MNIST

In [None]:
# X_embedded_train = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(mnist_encodings[:10000])
X_embedded_test = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(mnist_encodings_test)

# X_embedded_train.shape
X_embedded_test.shape

In [None]:
arg_sorted = np.argsort(mnist_labels_test)
mnist_labels_test = np.array(mnist_labels_test)[arg_sorted]
X_embedded_test = X_embedded_test[arg_sorted]

In [None]:
plt.figure(figsize=(10, 7))
# ax_train = sns.scatterplot(
#     x=X_embedded_train[:, 0],
#     y=X_embedded_train[:, 1],
#     hue=mnist_labels[:10000],
#     palette="gist_rainbow",
#     ax=axs[0],
# )
ax_test = sns.scatterplot(
    x=X_embedded_test[:, 0],
    y=X_embedded_test[:, 1],
    hue=mnist_labels_test,
    palette="gist_rainbow",
)
handles, labels = ax_test.get_legend_handles_labels()
ax_test.legend(handles, labels, title='Class')

## Offline baselines on encodings

MLP

In [None]:
clf = MLPClassifier(hidden_layer_sizes=(), learning_rate_init=0.001, verbose=False, early_stopping=True, max_iter=500)
clf.fit(mnist_encodings, mnist_labels)

In [None]:
y_pred = clf.predict(mnist_encodings_test)

print(classification_report(mnist_labels_test, y_pred, output_dict=False))

Random Forest

In [None]:
clf = RandomForestClassifier(n_estimators=300)
clf.fit(mnist_encodings, mnist_labels)

In [None]:
y_pred = clf.predict(mnist_encodings_test)

print(classification_report(mnist_labels_test, y_pred, output_dict=False))

GaussianNB

In [None]:
clf = GaussianNB()

all_classes = list(range(len(np.unique(mnist_labels_test))))
                   
for item, label in tqdm(zip(mnist_encodings_test, mnist_labels_test), total=len(mnist_labels_test)):
    clf.partial_fit(item.reshape(1, -1), [label], classes=all_classes)

In [None]:
y_pred = clf.predict(mnist_encodings_test)

is_correct = y_pred == mnist_labels_test

print(classification_report(mnist_labels_test, y_pred, output_dict=False))

In [None]:
fig, axs = plt.subplots(figsize=(20, 8), nrows=1, ncols=2)

ax_test = sns.scatterplot(
    x=X_embedded_test[:, 0],
    y=X_embedded_test[:, 1],
    hue=mnist_labels_test,
    style=is_correct,
    size=[50 if item else 100 for item in is_correct],
    palette="gist_rainbow",
    ax=axs[1],
)

In [None]:
def generate_example_from_nb(model, class_to_sample: int, examples_count: int = 1):
    class_mean = model.theta_[class_to_sample]
    class_std = np.sqrt(model.var_[class_to_sample])
    sampled = np.random.normal(loc=class_mean, scale=class_std, size=(examples_count, len(class_mean))).reshape(examples_count, -1)
    
    proba = model.predict_proba(sampled)[0][class_to_sample]
    
    sampled = np.clip(sampled, -1, 1)
    return sampled


sampled = generate_example_from_nb(clf, 9, examples_count=1)
print(sampled.shape)
reconstruction = autoencoder.model.decode(torch.tensor(sampled).float()).detach().reshape(28, 28).numpy()

plt.imshow(reconstruction)

In [None]:
def generate_batch(nb_model, classes_to_sample, shuffle: bool = True) -> torch.tensor:
    generated_examples = []
    generated_labels = []
    for class_index, examples_count in classes_to_sample.items():
        generated_example = generate_example_from_nb(nb_model, class_index, examples_count)
        generated_tensor = torch.from_numpy(generated_example)
        generated_examples.append(generated_tensor)
        generated_labels.extend([class_index] * examples_count)
    
    batch = torch.cat(generated_examples, dim=0).float()
    labels = torch.tensor(generated_labels)
    
    if shuffle:
        indexes_shuffled = torch.randperm(len(labels))
        batch = batch[indexes_shuffled].view(batch.size())
        labels = labels[indexes_shuffled]
    
    return batch, labels


generate_batch(clf, {1: 2, 2: 5})