# 5.0 MNIST
We are getting ~80 correct predictions. This might be because the ten digit signatures are insufficient to capture the different variation of each digit. We might try to define the different flavors of each digit using a round of k-means clustering and labeling in the training dataset then a "narrow" signature definition, then prediction on the test dataset using this narrow set of signatures. Also, visualization of these digits as images will be useful. We can also carry over the similarity score as a value based category or use it to generate a ROC curve. 

In [1]:
from clustergrammer2 import net
df = {}

clustergrammer2 backend version 0.2.9


In [2]:
import clustergrammer_groupby as cby
from copy import deepcopy
import random
random.seed(99)

In [3]:
import pandas as pd

In [4]:
net.load_file('../data/big_data/MNIST_row_labels.txt')
df['mnist'] = net.export_df()
df['mnist'].shape

(784, 70000)

In [5]:
cols = df['mnist'].columns.tolist()
new_cols = [(x, 'Digit: ' + x.split('-')[0]) for x in cols]
df['mnist-cat'] = deepcopy(df['mnist'])
df['mnist-cat'].columns = new_cols
print(new_cols[0])

('Zero-0', 'Digit: Zero')


### Make Train and Predict

In [91]:
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-test'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-test'].shape)

net.load_df(df['mnist-train'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-train-z'] = net.export_df()

net.load_df(df['mnist-test'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-test-z'] = net.export_df()

(784, 35000) (784, 35000)


In [92]:
def set_cat_colors(cat_color, axis, cat_index, cat_title=False):
    for inst_ct in cat_color:
        if cat_title != False:
            cat_name = cat_title + ': ' + inst_ct
        else:
            cat_name = inst_ct
            
        inst_color = cat_color[inst_ct]
        net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)

# Make Signatures

### Make Very Narrow Digit Categories

In [93]:
df['mnist-train'].shape

(784, 35000)

In [94]:
num_clusters = 5

In [95]:
cols = df['mnist-train'].columns.tolist()
all_digits = sorted(list(set([x[1].split(': ') [1] for x in cols])))

In [96]:
df_list = []
for inst_digit in all_digits:
    
    cols = df['mnist-train'].columns.tolist()
    keep_cols = [x for x in cols if x[1].split(': ')[1] == inst_digit]
    inst_df = df['mnist-train'][keep_cols]
    print(inst_df.shape)
    
    net.load_df(inst_df)
    ds_info = list(net.downsample(axis='col', num_samples=num_clusters, ds_type='kmeans', random_state=99))
    ds_info = [str(x) for x in ds_info]

    cols = inst_df.columns.tolist()
    new_cols = []
    for index in range(len(cols)):
        inst_col = cols[index]
        inst_cluster = ds_info[index]
        new_col = (inst_col[0], inst_col[1] + '-' + inst_cluster)
        new_cols.append(new_col)

    inst_df.columns = new_cols
    df_list.append(inst_df)
    
df['mnist-train-vn'] = pd.concat(df_list, axis=1)
df['mnist-train-vn'].shape

(784, 3445)
(784, 3152)
(784, 3442)
(784, 3412)
(784, 3907)
(784, 3601)
(784, 3471)
(784, 3595)
(784, 3501)
(784, 3474)


(784, 35000)

## Make Narrow Digit Signatures

In [97]:
pval_cutoff = 0.00001
num_top_dims = 50

df['sig'], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                     df['mnist-train-vn'],
                                                                     'Digit', num_top_dims=num_top_dims)
print(df['sig'].shape)

(438, 50)


# Predict Digit Type Using Signatures

### Predict using Narrow Signatures

In [98]:
df['mnist-test'].columns.tolist()[0]

('Nine-6292', 'Digit: Nine')

### Predict on Training Data

In [128]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-train'], df['sig'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)


y_info = {}
cols = df_pred_cat.columns.tolist()

y_info['true'] = [x[1].split(': ')[1] for x in cols]
y_info['pred'] = [x[2].split(': ')[1].split('-')[0] for x in cols]

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
ser_correct.sort_values(ascending=False)

Predict:  0.904514285714


One      0.983619
Zero     0.961716
Six      0.958513
Two      0.906598
Three    0.882893
Five     0.880076
Nine     0.878664
Seven    0.877812
Eight    0.864151
Four     0.839338
dtype: float64

### Predict on Test Data

In [131]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-test'], df['sig'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)


y_info = {}
cols = df_pred_cat.columns.tolist()
new_cols = [(x[0], x[1], x[2].split('-')[0]) for x in cols]

print(new_cols[0])

df['mnist-pred'] = deepcopy(df['mnist-test'])
df['mnist-pred'].columns = new_cols

y_info['true'] = [x[1].split(': ')[1] for x in cols]
y_info['pred'] = [x[2].split(': ')[1].split('-')[0] for x in cols]


df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
ser_correct.sort_values(ascending=False)

('Nine-6292', 'Digit: Nine', 'Pred Digit: Three')
Predict:  0.900457142857


One      0.979597
Six      0.962702
Zero     0.961505
Two      0.894239
Seven    0.888407
Nine     0.873378
Three    0.872250
Five     0.869029
Eight    0.859763
Four     0.830574
dtype: float64

In [132]:
new_cols[0]

('Nine-6292', 'Digit: Nine', 'Pred Digit: Three')

In [133]:
def set_cat_colors(cat_color, axis, cat_index, cat_title=False):
    for inst_ct in cat_color:
        if cat_title != False:
            cat_name = cat_title + ': ' + inst_ct
        else:
            cat_name = inst_ct
            
        inst_color = cat_color[inst_ct]
        net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)

In [134]:
net.load_df(df['sig'])
net.cluster()
tmp_cat_color = deepcopy(net.viz['cat_colors']['col']['cat-0'])
cat_color = {}
for inst_key in tmp_cat_color:
    cat_color[inst_key.split(': ')[1]] = tmp_cat_color[inst_key]
    
cat_color['Zero'] = 'yellow'
cat_color['Four'] = 'blue'
cat_color['Seven'] = 'red'
cat_color['Nine'] = 'grey'
cat_color['One'] = 'black'

set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Digit')
set_cat_colors(cat_color, axis='col', cat_index=2, cat_title='Pred Digit')
# cat_color  

In [136]:
net.load_df(df['mnist-pred'])
net.random_sample(axis='col', num_samples=1000, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_1-10", "ini": 603, "clust": 39, "rank": 48, "rankvar": 56,…