In [7]:
from data.openml import get_openml_data
from auto_encoder.sklearn import AutoTransformer, ConvolutionalAutoTransformer, Transformer, IdentityTransformer
from sklearn.utils import resample
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

In [8]:
encoded_data = {}
reconstructed_data = {}
transformers = {t: Transformer(type=t) for t in ['ae', 'vae', 'dae', 'sae']}
transformers['original'] = IdentityTransformer()
transformers['pca'] = PCA(n_components=274)
sampling = True
n_samples = 10

for dataset_id in [40996]:
    x, y = get_openml_data(dataset_id)
    sample_idcs = resample(np.arange(len(y)), stratify=y, replace=False, n_samples=n_samples) if sampling else np.arange(len(y))
    for t_name, transformer in transformers.items():
        x_encoded = transformer.fit_transform(x)
        x_reconstructed = transformer.inverse_transform(x_encoded)
        encoded_data[(dataset_id, t_name)] = {'x': x_encoded, 'y': y}
        reconstructed_data[t_name] = x_reconstructed[sample_idcs]



Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100


AttributeError: 'IdentityTransformer' object has no attribute 'inverse_transform'

In [44]:
true_labels = {40996: 
               {0:'T-shirt / Top',
                1:'Trouser',
                2:'Pullover',
                3:'Dress',
                4:'Coat',
                5:'Sandal',
                6:'Shirt',
                7:'Sneaker',
                8:'Bag',
                9:'Ankle boot'},
               40668:
               {2: 'Win', 
                1: 'Loss', 
                0: 'Draw'},
               44:
               {0: 'No Spam',
                1: 'Spam'}
              }

def plot_latent_space(dataset_id, x, y, t_name, n_samples=1000, figsize=(5, 5)):
    x_samples, y_samples = resample(x, y, n_samples=n_samples, stratify=y, replace=False)
    if x_samples.shape[1] > 2:
        tsne = TSNE()
        x_samples = tsne.fit_transform(x_samples)
    fig, ax = plt.subplots(figsize=figsize)
    for clss in np.unique(y_samples):
        x_clss = x_samples[y_samples == clss]
        ax.scatter(x_clss[:, 0], x_clss[:, 1], alpha=0.8, label=true_labels[dataset_id][clss])
        
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'visualizations/{dataset_id}_{t_name}.svg', dpi=300)
    plt.close(fig)
    
def plot_reconstructions(data, figsize=(3, 3), save=True, title=None):
    figsize = (figsize[0] * len(data), figsize[1])
    fig, axs = plt.subplots(nrows=1, ncols=len(data), figsize=figsize)
    data = data.reshape((-1, 28, 28))
    for sample, ax in zip(data, axs):
        ax.imshow(sample, cmap='gray')
        ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, labelbottom=False, labelleft=False) 
        
    plt.tight_layout()
    if save:
        plt.savefig(f'visualizations/reconstructions/{title}.svg')
        plt.close(fig)

In [4]:
reconstructed_data = np.load('reconstruction_samples.npy', allow_pickle=True)[()]

In [40]:
for (dataset_id, t_name), data in encoded_data.items():
    x, y = data['x'], data['y']
    plot_latent_space(dataset_id, x, y, t_name)

In [45]:
for t_type, data in reconstructed_data.items():
    plot_reconstructions(data, save=True, title=t_type)