# 5.0 MNIST
We are getting ~80 correct predictions. This might be because the ten digit signatures are insufficient to capture the different variation of each digit. We might try to define the different flavors of each digit using a round of k-means clustering and labeling in the training dataset then a "narrow" signature definition, then prediction on the test dataset using this narrow set of signatures. Also, visualization of these digits as images will be useful. We can also carry over the similarity score as a value based category or use it to generate a ROC curve. 

In [1]:
from clustergrammer2 import net
df = {}

clustergrammer2 backend version 0.2.9


In [2]:
import clustergrammer_groupby as cby
from copy import deepcopy
import random
random.seed(99)

In [3]:
import pandas as pd

In [4]:
net.load_file('../data/big_data/MNIST_row_labels.txt')
df['mnist'] = net.export_df()
df['mnist'].shape

(784, 70000)

In [5]:
cols = df['mnist'].columns.tolist()
new_cols = [(x, 'Digit: ' + x.split('-')[0]) for x in cols]
df['mnist-cat'] = deepcopy(df['mnist'])
df['mnist-cat'].columns = new_cols
print(new_cols[0])

('Zero-0', 'Digit: Zero')


### Make Train and Predict

In [6]:
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-test'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-test'].shape)

net.load_df(df['mnist-train'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-train-z'] = net.export_df()

net.load_df(df['mnist-test'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-test-z'] = net.export_df()

(784, 35000) (784, 35000)


In [7]:
def set_cat_colors(cat_color, axis, cat_index, cat_title=False):
    for inst_ct in cat_color:
        if cat_title != False:
            cat_name = cat_title + ': ' + inst_ct
        else:
            cat_name = inst_ct
            
        inst_color = cat_color[inst_ct]
        net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)

# Make Signatures

### Make Very Narrow Digit Categories

In [25]:
df['mnist-train'].shape

(784, 35000)

In [27]:
num_clusters = 10

In [28]:
cols = df['mnist-train'].columns.tolist()
all_digits = sorted(list(set([x[1].split(': ')[1] for x in cols])))

In [44]:
df_list = []
for inst_digit in all_digits:
    
    cols = df['mnist-train'].columns.tolist()
    keep_cols = [x for x in cols if x[1].split(': ')[1] == inst_digit]
    inst_df = df['mnist-train'][keep_cols]
    print(inst_df.shape)
    
    net.load_df(inst_df)
    ds_info = list(net.downsample(axis='col', num_samples=num_clusters, ds_type='kmeans', random_state=99))
    ds_info = [str(x) for x in ds_info]

    cols = inst_df.columns.tolist()
    new_cols = []
    for index in range(len(cols)):
        inst_col = cols[index]
        inst_cluster = ds_info[index]
        new_col = (inst_col[0], inst_col[1] + '-' + inst_cluster)
        new_cols.append(new_col)

    inst_df.columns = new_cols
    df_list.append(inst_df)
    
df['mnist-train-vn'] = pd.concat(df_list, axis=1)
df['mnist-train-vn'].shape

(784, 3429)
(784, 3144)
(784, 3417)
(784, 3532)
(784, 3944)
(784, 3540)
(784, 3431)
(784, 3567)
(784, 3477)
(784, 3519)


(784, 35000)

In [57]:
def make_narrow_cats(df, narrow_clusters, cat_index=1):
    df_list = []
    
    cols = df.columns.tolist()
    all_cats = sorted(list(set([x[cat_index].split(': ')[1] for x in cols])))

    for inst_cat in all_cats:

        cols = df.columns.tolist()
        keep_cols = [x for x in cols if x[1].split(': ')[1] == inst_cat]
        inst_df = df[keep_cols]        

        net.load_df(inst_df)
        if narrow_clusters > inst_df.shape[1]:
            narrow_clusters = inst_df.shape[1]
            
        ds_info = list(net.downsample(axis='col', num_samples=narrow_clusters, ds_type='kmeans', random_state=99))
        ds_info = [str(x) for x in ds_info]

        cols = inst_df.columns.tolist()
        new_cols = []
        for index in range(len(cols)):
            inst_col = cols[index]
            inst_cluster = ds_info[index]
            new_col = (inst_col[0], inst_col[1] + '-C' + inst_cluster)
            new_cols.append(new_col)

        inst_df.columns = new_cols
        df_list.append(inst_df)

    df_nc = pd.concat(df_list, axis=1)

    return df_nc

In [59]:
df['mnist-train-vn'] = make_narrow_cats(df['mnist-train'], 10)

narrowed (784, 35000)


In [60]:
df['mnist-train-vn'].columns.tolist()[0]

('Eight-1719', 'Digit: Eight-C8')

In [30]:
df['mnist-train-vn'].columns.tolist()[0]

('Eight-1719', 'Digit: Eight-C8')

## Make Narrow Digit Signatures

In [61]:
pval_cutoff = 0.00001
num_top_dims = 50

df['sig'], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                     df['mnist-train-vn'],
                                                                     'Digit', num_top_dims=num_top_dims)
print(df['sig'].shape)

(471, 100)


# Predict Digit Type Using Signatures

### Predict using Narrow Signatures

In [62]:
df['mnist-test'].columns.tolist()[0]

('Four-709', 'Digit: Four')

### Predict on Training Data

In [63]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-train'], df['sig'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)


y_info = {}
cols = df_pred_cat.columns.tolist()

# broaden predicted categories
y_info['true'] = [x[1].split(': ')[1] for x in cols]
y_info['pred'] = [x[2].split(': ')[1].split('-C')[0] for x in cols]

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
ser_correct.sort_values(ascending=False)

Predict:  0.927514285714


One      0.986055
Six      0.971145
Zero     0.970446
Two      0.933276
Nine     0.906852
Seven    0.905367
Three    0.901878
Five     0.901399
Eight    0.897929
Four     0.890840
dtype: float64

### Predict on Test Data

In [65]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-test'], df['sig'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)


y_info = {}
cols = df_pred_cat.columns.tolist()
new_cols = [(x[0], x[1], x[2].split('-')[0]) for x in cols]

print(new_cols[0])

df['mnist-pred'] = deepcopy(df['mnist-test'])
df['mnist-pred'].columns = new_cols

y_info['true'] = [x[1].split(': ')[1] for x in cols]
y_info['pred'] = [x[2].split(': ')[1].split('-')[0] for x in cols]


df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
ser_correct.sort_values(ascending=False)

('Four-709', 'Digit: Four', 'Pred Digit: Nine')
Predict:  0.923628571429


One      0.983219
Zero     0.968972
Six      0.968360
Two      0.929974
Four     0.901967
Nine     0.901051
Eight    0.898999
Five     0.893973
Seven    0.893152
Three    0.889759
dtype: float64

In [35]:
new_cols[0]

('Four-709', 'Digit: Four', 'Pred Digit: Nine')

In [36]:
def set_cat_colors(cat_color, axis, cat_index, cat_title=False):
    for inst_ct in cat_color:
        if cat_title != False:
            cat_name = cat_title + ': ' + inst_ct
        else:
            cat_name = inst_ct
            
        inst_color = cat_color[inst_ct]
        net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)

In [37]:
cat_color = {}    
cat_color['Zero'] = 'white'
cat_color['One'] = '#6f4e37'
cat_color['Two'] = 'blue'
cat_color['Three'] = 'black'
cat_color['Four'] = 'red'
cat_color['Five'] = 'yellow'
cat_color['Six'] = 'purple'
cat_color['Seven'] = '#FFA500'
cat_color['Eight'] = '#1e90ff'
cat_color['Nine'] = '#ff7518'

set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Digit')
set_cat_colors(cat_color, axis='col', cat_index=2, cat_title='Pred Digit')

In [39]:
net.load_df(df['mnist-pred'])
net.random_sample(axis='col', num_samples=2000, random_state=99)
net.filter_N_top(inst_rc='row', N_top=250, rank_type='var')
net.load_df(net.export_df().round(2))
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_13-14", "ini": 250, "clust": 117, "rank": 193, "rankvar": …