In [1]:
import numpy as np
import matplotlib.pyplot as plt
import io
import torch.nn as nn
import torch
import umap.plot
from umap.umap_ import UMAP
import pandas as pd

## Generate UMAP embedding

In [None]:
val_data = np.load("val_data.npy", allow_pickle=True)

In [None]:
val_data[34983][0].shape

In [None]:
all_x = np.array([x[0].reshape(28,28).flatten() for x in val_data])

In [None]:
embedding = UMAP().fit(all_x);

In [None]:
df = pd.DataFrame(embedding.embedding_, columns = ['x', 'y'])

In [None]:
df.plot.scatter(x='x', y='y')

In [None]:
all_y = np.array([x[1] for x in val_data])

In [None]:
classes = ['airplane',
           'apple',
           'bee',
           'car',
           'dragon',
           'mosquito',
           'moustache',
           'mouth',
           'pear',
           'piano',
           'pineapple',
           'smiley face',
           'train',
           'umbrella',
           'wine bottle']

In [None]:
df['class'] = all_y

In [None]:
df['class_name'] = df['class'].apply(lambda x: classes[x])

In [None]:
model = nn.Sequential(
    nn.Conv2d(1, 16, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(16, 32, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 32, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(288, 128),
    nn.ReLU(),
    nn.Linear(128, len(classes)),
)

checkpoint = torch.load('./model_lessCapacity.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

model.eval()

In [None]:
y_hats = []
for sketch_index in range(300000):
    x = val_data[sketch_index][0]
    logits = model(torch.tensor(x).unsqueeze(1))
    y_hat = np.argmax(logits.detach().numpy())
    y_hats.append(y_hat)

In [None]:
df['predicted'] = y_hats

In [None]:
df['predicted_name'] = df['predicted'].apply(lambda x: classes[x])

In [None]:
df.head()

In [None]:
#df.to_csv('umap.csv')

## UMAP Last Conv Layer

In [None]:
last_cl_representations = []
for sketch_index in range(300000):
    x = val_data[sketch_index][0]
    last_cl = model[:7](torch.tensor(x).unsqueeze(1))
    last_cl_representations.append(last_cl)

In [None]:
last_cl_np = np.array(list(map(lambda tensor: tensor.squeeze(0).detach().numpy().flatten(), last_cl_representations)))



In [2]:
last_cl_np = np.load("last_convolutional_layer.npy", allow_pickle=True)

In [8]:
embedding_cl = UMAP().fit(last_cl_np);

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [9]:
df_cl = pd.DataFrame(embedding_cl.embedding_, columns = ['x', 'y'])    

In [11]:
df_cl.shape

(300000, 2)

In [27]:
df_cl.to_csv('umap_cl.csv')

In [13]:
umap_df = pd.read_csv("./umap.csv")

In [19]:
umap_df.

AttributeError: 'DataFrame' object has no attribute 'drop_index'

In [32]:
extended_df = pd.concat([umap_df, df_cl], axis=1)[['x','y','class','class_name','predicted','predicted_name','x_cl','y_cl']]

In [33]:
extended_df.to_csv('umap_extended.csv')

In [34]:
extended_df.head()

Unnamed: 0,x,y,class,class_name,predicted,predicted_name,x_cl,y_cl
0,10.439865,11.399223,14,wine bottle,14,wine bottle,10.26177,12.589383
1,-1.522812,9.091824,9,piano,9,piano,-1.679909,6.546764
2,-0.078322,6.752753,9,piano,9,piano,-2.918781,5.499226
3,-1.777769,8.416865,12,train,12,train,-1.568735,6.539262
4,3.49268,4.150504,11,smiley face,11,smiley face,4.227337,5.325876


In [36]:
extended_df[(extended_df.x > 10) & (extended_df.x < 10.1) & (extended_df.y > 10) & (extended_df.y < 10.1) ]

Unnamed: 0,x,y,class,class_name,predicted,predicted_name,x_cl,y_cl
459,10.058564,10.078728,10,pineapple,10,pineapple,8.655949,10.81834
7635,10.029596,10.046811,10,pineapple,10,pineapple,7.941402,10.687274
11596,10.047059,10.005007,10,pineapple,10,pineapple,8.62855,10.716096
13050,10.058895,10.09474,10,pineapple,10,pineapple,8.575015,10.72392
14394,10.071191,10.006775,10,pineapple,10,pineapple,8.515562,10.763453
30509,10.075893,10.012474,10,pineapple,10,pineapple,8.249722,10.346437
42004,10.096755,10.003428,10,pineapple,10,pineapple,8.731153,10.892737
45274,10.08696,10.062394,10,pineapple,10,pineapple,8.645429,10.264174
49545,10.000556,10.03733,10,pineapple,10,pineapple,8.363608,10.825409
52399,10.057635,10.088057,10,pineapple,10,pineapple,8.495973,10.942862
