# 5.0 MNIST
We are getting ~80 correct predictions. This might be because the ten digit signatures are insufficient to capture the different variation of each digit. We might try to define the different flavors of each digit using a round of k-means clustering and labeling in the training dataset then a "narrow" signature definition, then prediction on the test dataset using this narrow set of signatures. Also, visualization of these digits as images will be useful. We can also carry over the similarity score as a value based category or use it to generate a ROC curve. 

In [1]:
from clustergrammer2 import net
df = {}

clustergrammer2 backend version 0.2.9


In [2]:
import clustergrammer_groupby as cby
from copy import deepcopy
import random
random.seed(99)

In [3]:
net.load_file('../data/big_data/MNIST_row_labels.txt')
df['mnist'] = net.export_df()
df['mnist'].shape

(784, 70000)

In [4]:
cols = df['mnist'].columns.tolist()
new_cols = [(x, 'Digit: ' + x.split('-')[0]) for x in cols]
df['mnist-cat'] = deepcopy(df['mnist'])
df['mnist-cat'].columns = new_cols
print(new_cols[0])

('Zero-0', 'Digit: Zero')


### Make Train and Predict

In [9]:
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)

net.load_df(df['mnist-train'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-train-z'] = net.export_df()

net.load_df(df['mnist-pred'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-pred-z'] = net.export_df()

(784, 35000) (784, 35000)


In [10]:
def set_cat_colors(cat_color, axis, cat_index, cat_title=False):
    for inst_ct in cat_color:
        if cat_title != False:
            cat_name = cat_title + ': ' + inst_ct
        else:
            cat_name = inst_ct
            
        inst_color = cat_color[inst_ct]
        net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)

# Make Signatures
## Make Narrow Digit Signatures

In [11]:
pval_cutoff = 0.00001
num_top_dims = 50
for inst_norm in ['', '-z']:
    df['sig' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                         df['mnist-train' + inst_norm],
                                                                         'Digit', num_top_dims=num_top_dims)
    print(inst_norm, df['sig' + inst_norm].shape)

 (285, 10)


  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


-z (270, 10)


In [20]:
net.load_df(df['sig'])
net.cluster()
tmp_cat_color = deepcopy(net.viz['cat_colors']['col']['cat-0'])
cat_color = {}
for inst_key in tmp_cat_color:
    cat_color[inst_key.split(': ')[1]] = tmp_cat_color[inst_key]
    
cat_color['Zero'] = 'yellow'
cat_color['Four'] = 'blue'
cat_color['Seven'] = 'red'
cat_color['Nine'] = 'grey'
cat_color['One'] = 'black'

set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Digit')
cat_color  

{'Eight': '#393b79',
 'Five': '#ff7f0e',
 'Four': 'blue',
 'Nine': 'grey',
 'One': 'black',
 'Seven': 'red',
 'Six': '#FFDB58',
 'Three': '#e377c2',
 'Two': '#2ca02c',
 'Zero': 'yellow',
 'Four-Seven-Nine': '#393b79',
 'Three-Five-Eight': '#98df8a',
 'Zero-Two-Six': '#404040'}

In [14]:
net.load_df(df['sig'])
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_10-10", "ini": 285, "clust": 91, "rank": 224, "rankvar": 2…

# Predict Digit Type Using Signatures

### Predict using Narrow Signatures

In [17]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

Predict:  0.8058

broad cell type:  0.8058 

One      0.925441
Zero     0.906678
Six      0.881938
Seven    0.828007
Three    0.802312
Four     0.778238
Two      0.773001
Nine     0.753807
Eight    0.741716
Five     0.634609
dtype: float64


### Artifically Broadening the Narrow Digits Improves Performance
Will test running narrow prediction on broad digits.

In [21]:
y_broad = {}

inst_true = []
for inst_cat in y_info['true']:
    if inst_cat in merge_358:
        inst_cat = 'Three-Five-Eight'

    if inst_cat in merge_479:
        inst_cat = 'Four-Seven-Nine' 

    if inst_cat in merge_026:
        inst_cat = 'Zero-Two-Six'
        
    inst_true.append(inst_cat)

inst_pred = []
for inst_cat in y_info['pred']:
    if inst_cat in merge_358:
        inst_cat = 'Three-Five-Eight'

    if inst_cat in merge_479:
        inst_cat = 'Four-Seven-Nine' 

    if inst_cat in merge_026:
        inst_cat = 'Zero-Two-Six'
        
    inst_pred.append(inst_cat)

y_broad['true'] = inst_true
y_broad['pred'] = inst_pred

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_broad)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_broad)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

Predict:  0.896114285714

broad cell type:  0.896114285714 

One                 0.925441
Four-Seven-Nine     0.910169
Zero-Two-Six        0.904001
Three-Five-Eight    0.861703
dtype: float64


### Predict using Broad then Narrow

In [32]:
coarse_digits = {}
coarse_digits['Three-Five-Eight'] = ['Three', 'Five', 'Eight']
coarse_digits['Four-Seven-Nine'] = ['Four', 'Seven', 'Nine']
coarse_digits['Zero-Two-Six'] = ['Zero', 'Two', 'Six']

In [34]:
coarse_digits

{'Three-Five-Eight': ['Three', 'Five', 'Eight'],
 'Four-Seven-Nine': ['Four', 'Seven', 'Nine'],
 'Zero-Two-Six': ['Zero', 'Two', 'Six']}

In [43]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,
                                                                   predict_level='Broad Digit', unknown_thresh=0.0)

ini_broad_cols = df_pred_cat.columns.tolist()

broad_cols = []
for inst_col in ini_broad_cols:
    
    inst_cat = inst_col[2].split(': ')[1]

    broad_predict = inst_col[2]
    for inst_group in coarse_digits:
        if inst_cat in coarse_digits[inst_group]:
            broad_predict = 'Broad Digits: ' + inst_group
        
    new_col = (inst_col[0], inst_col[1], broad_predict)
    

# # artifically broaden prediction
# for inst_group in coarse_digits:
#     keep_cols = [x for x in broad_cols if x[2].split(': ')[1] in coarse_digits[inst_group]]
#     df[inst_group] = df['mnist-train'][keep_cols]
#     print(inst_group, df[inst_group].shape)

# # df['pred-broad'] = deepcopy(df['mnist-pred'])

# # df['pred-broad'].columns.tolist()[:2]

In [None]:
df_pred_cat.shape

In [None]:
net.load_df(df_pred_cat)
set_cat_colors(cat_color, axis='col', cat_index=2, cat_title='Pred Digit')

In [None]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=100)
net.widget()

### Z-scored Data

In [None]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-z'], df['sig-z'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_sig_sim.columns = df_pred_cat.columns.tolist()
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

In [None]:
df_conf.shape

In [None]:
# net.load_df(df_conf)
# net.widget()

In [None]:
df_sig_sim.shape

# Coarse Grained Digits
At the coarse grained level we appear to be able to distinguish
* Three-Five-Eight
* One
* Zero-Two-Six
* Four-Seven-Nine

In [None]:
net.load_df(df_sig_sim)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

In [None]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

# Experiment Broad and Narrow Digits
## Make Broad Signature

### Merge Categories

In [None]:
cols = df['mnist-cat'].columns.tolist()
merge_358 = ['Three', 'Five', 'Eight']
merge_479 = ['Four', 'Seven', 'Nine']
merge_026 = ['Zero', 'Two', 'Six']

random.shuffle(cols)

df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)

for inst_data in ['mnist-train', 'mnist-pred']:
    cols = df[inst_data]
    new_cols = []
    for inst_col in cols:
        inst_cat = inst_col[1].split(': ')[1]
        
        if inst_cat in merge_358:
            inst_cat = 'Three-Five-Eight'
            
        if inst_cat in merge_479:
            inst_cat = 'Four-Seven-Nine' 
            
        if inst_cat in merge_026:
            inst_cat = 'Zero-Two-Six'
            
        new_col = (inst_col[0], 'Coarse: ' + inst_cat, inst_col[1])
        new_cols.append(new_col)
    
    df[inst_data + '-coarse'] = deepcopy(df[inst_data])
    df[inst_data + '-coarse'].columns = new_cols
    print(df[inst_data + '-coarse'].shape)
    
    net.load_df(df[inst_data + '-coarse'])
    net.normalize(axis='row', norm_type='zscore')
    df[inst_data + '-coarse-z'] = net.export_df()

### Make Broad Signature

In [None]:
pval_cutoff = 0.00001
num_top_dims = 50
for inst_norm in ['', '-z']:
    df['sig-broad' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                         df['mnist-train-coarse' + inst_norm],
                                                                         'Coarse', num_top_dims=num_top_dims)
    print(inst_norm, df['sig-broad' + inst_norm].shape)

## Make Narrow Signatures

### Make Group DataFrames

In [None]:
cols = df['mnist-cat'].columns.tolist()


random.shuffle(cols)

df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)

for inst_group in coarse_digits:
    cols = df['mnist-train']
    keep_cols = [x for x in cols if x[1].split(': ')[1] in coarse_digits[inst_group]]
    df[inst_group] = df['mnist-train'][keep_cols]
    print(inst_group, df[inst_group].shape)
    
    net.load_df(df[inst_group])
    net.normalize(axis='row', norm_type='zscore')
    df[inst_group + '-z'] = net.export_df()

### Make Narrow Signatures

In [None]:
# Generate Signatures
pval_cutoff = 1e-10
num_top_dims=50

for inst_group in coarse_digits:

    for inst_norm in ['', '-z']:
        df['sig-' + inst_group + '-' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                             df[inst_group + inst_norm],
                                                                             'Digit', pval_cutoff=pval_cutoff, 
                                                                              num_top_dims=num_top_dims)
        print(inst_group + inst_norm, df['sig-' + inst_group + '-' + inst_norm].shape)    

# Predict Broad then Narrow
Need to predict broad digits, then separate each of the broad categories and predict using the narrow signature.

### Predict Broad Digits

In [None]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-coarse'], df['sig-broad'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

In [None]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-358-z'], df['sig-z'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print(ser_correct.sort_values(ascending=False))

In [None]:
df_pred_cat.shape

In [None]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

In [None]:
cat_color['Three-Five-Eight'] = 'red'
cat_color['Four-Seven-Nine'] = 'blue'
cat_color['Zero-Two-Six'] = 'yellow'

In [None]:
net.load_df(df_pred_cat)
set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Coarse')
set_cat_colors(cat_color, axis='col', cat_index=3, cat_title='Pred Digit')

In [None]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()