# 5.0 MNIST
We are getting ~80 correct predictions. This might be because the ten digit signatures are insufficient to capture the different variation of each digit. We might try to define the different flavors of each digit using a round of k-means clustering and labeling in the training dataset then a "narrow" signature definition, then prediction on the test dataset using this narrow set of signatures. Also, visualization of these digits as images will be useful. We can also carry over the similarity score as a value based category or use it to generate a ROC curve. 

In [1]:
from clustergrammer2 import net
df = {}

clustergrammer2 backend version 0.2.9


In [2]:
import clustergrammer_groupby as cby
from copy import deepcopy
import random
random.seed(99)

In [3]:
net.load_file('../data/big_data/MNIST_row_labels.txt')
df['mnist'] = net.export_df()
df['mnist'].shape

(784, 70000)

In [4]:
cols = df['mnist'].columns.tolist()
new_cols = [(x, 'Digit: ' + x.split('-')[0]) for x in cols]
df['mnist-cat'] = deepcopy(df['mnist'])
df['mnist-cat'].columns = new_cols
print(new_cols[0])

('Zero-0', 'Digit: Zero')


### Make Train and Predict

In [5]:
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)

(784, 35000) (784, 35000)


### Make Z-scored Data

In [6]:
net.load_df(df['mnist-train'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-train-z'] = net.export_df()

In [7]:
net.load_df(df['mnist-pred'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-pred-z'] = net.export_df()

### Make Signatures

In [8]:
pval_cutoff = 0.00001
num_top_dims = 50

In [9]:
df['sig'], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(df['mnist-train'],
                                                                     'Digit', num_top_dims=num_top_dims)
df['sig'].shape

(276, 10)

In [10]:
df['sig-z'], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(df['mnist-train-z'],
                                                                     'Digit', num_top_dims=num_top_dims)
df['sig-z'].shape

  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


(276, 10)

In [11]:
net.load_df(df['sig'])
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_10-10", "ini": 276, "clust": 234, "rank": 218, "rankvar": …

In [12]:
net.load_df(df['sig-z'])
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_10-10", "ini": 276, "clust": 60, "rank": 264, "rankvar": 1…

# Predict Digit Type Using Signatures

### Raw Data

In [13]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

Predict:  0.813285714286

broad cell type:  0.813285714286 

One      0.929062
Zero     0.901891
Six      0.884470
Seven    0.819078
Four     0.806868
Three    0.795187
Two      0.789069
Nine     0.784588
Eight    0.749117
Five     0.644683
dtype: float64


In [14]:
df_pred_cat.shape

(276, 35000)

In [15]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=100)
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_3-15", "ini": 276, "clust": 18, "rank": 9, "rankvar": 9, "…

### Z-scored Data

In [18]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-z'], df['sig-z'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

Predict:  0.790514285714

broad cell type:  0.790514285714 

One      0.964404
Zero     0.904255
Six      0.837446
Seven    0.835332
Four     0.793367
Three    0.773923
Two      0.750071
Nine     0.739054
Eight    0.662544
Five     0.602398
dtype: float64


In [19]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=100)
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_3-15", "ini": 276, "clust": 49, "rank": 1, "rankvar": 0, "…