# 5.0 MNIST
We are getting ~80 correct predictions. This might be because the ten digit signatures are insufficient to capture the different variation of each digit. We might try to define the different flavors of each digit using a round of k-means clustering and labeling in the training dataset then a "narrow" signature definition, then prediction on the test dataset using this narrow set of signatures. Also, visualization of these digits as images will be useful. We can also carry over the similarity score as a value based category or use it to generate a ROC curve. 

In [1]:
from clustergrammer2 import net
df = {}

clustergrammer2 backend version 0.2.9


In [2]:
import clustergrammer_groupby as cby
from copy import deepcopy
import random
random.seed(99)

In [3]:
net.load_file('../data/big_data/MNIST_row_labels.txt')
df['mnist'] = net.export_df()
df['mnist'].shape

(784, 70000)

In [4]:
cols = df['mnist'].columns.tolist()
new_cols = [(x, 'Digit: ' + x.split('-')[0]) for x in cols]
df['mnist-cat'] = deepcopy(df['mnist'])
df['mnist-cat'].columns = new_cols
print(new_cols[0])

('Zero-0', 'Digit: Zero')


### Make Train and Predict

In [5]:
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)

net.load_df(df['mnist-train'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-train-z'] = net.export_df()

net.load_df(df['mnist-pred'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-pred-z'] = net.export_df()

(784, 35000) (784, 35000)


### Make Signatures

In [7]:
pval_cutoff = 0.00001
num_top_dims = 50
for inst_norm in ['', '-z']:
    df['sig' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                         df['mnist-train' + inst_norm],
                                                                         'Digit', num_top_dims=num_top_dims)
    print(inst_norm, df['sig' + inst_norm].shape)

 (276, 10)


  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


-z (276, 10)


In [8]:
def set_cat_colors(cat_color, axis, cat_index, cat_title=False):
    for inst_ct in cat_color:
        if cat_title != False:
            cat_name = cat_title + ': ' + inst_ct
        else:
            cat_name = inst_ct
            
        inst_color = cat_color[inst_ct]
        net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)

In [9]:
net.load_df(df['sig'])
net.cluster()
tmp_cat_color = deepcopy(net.viz['cat_colors']['col']['cat-0'])
cat_color = {}
for inst_key in tmp_cat_color:
    cat_color[inst_key.split(': ')[1]] = tmp_cat_color[inst_key]
    
cat_color['Zero'] = 'yellow'
cat_color['Four'] = 'blue'
cat_color['Seven'] = 'red'
cat_color['Nine'] = 'grey'
cat_color['One'] = 'black'
cat_color  

{'Eight': '#393b79',
 'Five': '#ff7f0e',
 'Four': 'blue',
 'Nine': 'grey',
 'One': 'black',
 'Seven': 'red',
 'Six': '#FFDB58',
 'Three': '#e377c2',
 'Two': '#2ca02c',
 'Zero': 'yellow'}

In [10]:
set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Digit')

In [11]:
net.load_df(df['sig'])
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_10-10", "ini": 276, "clust": 234, "rank": 218, "rankvar": …

In [12]:
net.load_df(df['sig-z'])
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_10-10", "ini": 276, "clust": 60, "rank": 264, "rankvar": 1…

# Predict Digit Type Using Signatures

### Raw Data

In [13]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

Predict:  0.813285714286

broad cell type:  0.813285714286 

One      0.929062
Zero     0.901891
Six      0.884470
Seven    0.819078
Four     0.806868
Three    0.795187
Two      0.789069
Nine     0.784588
Eight    0.749117
Five     0.644683
dtype: float64


In [14]:
df_pred_cat.shape

(276, 35000)

In [15]:
net.load_df(df_pred_cat)
set_cat_colors(cat_color, axis='col', cat_index=2, cat_title='Pred Digit')

In [16]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=100)
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_4-12", "ini": 276, "clust": 23, "rank": 56, "rankvar": 54,…

### Z-scored Data

In [17]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-z'], df['sig-z'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_sig_sim.columns = df_pred_cat.columns.tolist()
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

Predict:  0.790514285714

broad cell type:  0.790514285714 

One      0.964404
Zero     0.904255
Six      0.837446
Seven    0.835332
Four     0.793367
Three    0.773923
Two      0.750071
Nine     0.739054
Eight    0.662544
Five     0.602398
dtype: float64


In [18]:
df_conf.shape

(10, 10)

In [19]:
# net.load_df(df_conf)
# net.widget()

In [20]:
df_sig_sim.shape

(10, 35000)

# Coarse Grained Digits
At the coarse grained level we appear to be able to distinguish
* Three-Five-Eight
* One
* Zero-Two-Six
* Four-Seven-Nine

In [21]:
net.load_df(df_sig_sim)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "Eight", "ini": 10, "clust": 7, "rank": 5, "rankvar": 1, "cat-0…

In [22]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_4-12", "ini": 276, "clust": 92, "rank": 45, "rankvar": 26,…

# Experiment Broad and Narrow Digits
## Make Broad Signature

### Merge Categories

In [82]:
cols = df['mnist-cat'].columns.tolist()
merge_358 = ['Three', 'Five', 'Eight']
merge_479 = ['Four', 'Seven', 'Nine']
merge_026 = ['Zero', 'Two', 'Six']

random.shuffle(cols)

df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)

for inst_data in ['mnist-train', 'mnist-pred']:
    cols = df[inst_data]
    new_cols = []
    for inst_col in cols:
        inst_cat = inst_col[1].split(': ')[1]
        
        if inst_cat in merge_358:
            inst_cat = 'Three-Five-Eight'
            
        if inst_cat in merge_479:
            inst_cat = 'Four-Seven-Nine' 
            
        if inst_cat in merge_026:
            inst_cat = 'Zero-Two-Six'
            
        new_col = (inst_col[0], 'Coarse: ' + inst_cat, inst_col[1])
        new_cols.append(new_col)
    
    df[inst_data + '-coarse'] = deepcopy(df[inst_data])
    df[inst_data + '-coarse'].columns = new_cols
    print(df[inst_data + '-coarse'].shape)
    
    net.load_df(df[inst_data + '-coarse'])
    net.normalize(axis='row', norm_type='zscore')
    df[inst_data + '-coarse-z'] = net.export_df()

(784, 35000) (784, 35000)
(784, 35000)
(784, 35000)


### Make Broad Signature

In [85]:
pval_cutoff = 0.00001
num_top_dims = 50
for inst_norm in ['', '-z']:
    df['sig-broad' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                         df['mnist-train-coarse' + inst_norm],
                                                                         'Coarse', num_top_dims=num_top_dims)
    print(inst_norm, df['sig-broad' + inst_norm].shape)

 (144, 4)


  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


-z (144, 4)


## Make Narrow Signatures

### Make Group DataFrames

In [77]:
cols = df['mnist-cat'].columns.tolist()

keep_digits = {}
keep_digits['Three-Five-Eight'] = ['Three', 'Five', 'Eight']
keep_digits['Four-Seven-Nine'] = ['Four', 'Seven', 'Nine']
keep_digits['Zero-Two-Six'] = ['Zero', 'Two', 'Six']

random.shuffle(cols)

df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)

for inst_group in keep_digits:
    cols = df['mnist-train']
    keep_cols = [x for x in cols if x[1].split(': ')[1] in keep_digits[inst_group]]
    df[inst_group] = df['mnist-train'][keep_cols]
    print(inst_group, df[inst_group].shape)
    
    net.load_df(df[inst_group])
    net.normalize(axis='row', norm_type='zscore')
    df[inst_group + '-z'] = net.export_df()

(784, 35000) (784, 35000)
Three-Five-Eight (784, 10196)
Four-Seven-Nine (784, 10495)
Zero-Two-Six (784, 10389)


### Make Narrow Signatures

In [78]:
# Generate Signatures
pval_cutoff = 1e-10
num_top_dims=50

for inst_group in keep_digits:

    for inst_norm in ['', '-z']:
        df['sig-' + inst_group + '-' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
                                                                             df[inst_group + inst_norm],
                                                                             'Digit', pval_cutoff=pval_cutoff, 
                                                                              num_top_dims=num_top_dims)
        print(inst_group + inst_norm, df['sig-' + inst_group + '-' + inst_norm].shape)    

Three-Five-Eight (104, 3)


  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


Three-Five-Eight-z (104, 3)
Four-Seven-Nine (99, 3)
Four-Seven-Nine-z (99, 3)
Zero-Two-Six (116, 3)
Zero-Two-Six-z (115, 3)


# Predict Broad then Narrow
Need to predict broad digits, then separate each of the broad categories and predict using the narrow signature.

### Predict Broad Digits

In [86]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-coarse'], df['sig-broad'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))

Predict:  0.832914285714

broad cell type:  0.832914285714 

One                 0.967963
Four-Seven-Nine     0.869330
Three-Five-Eight    0.809754
Zero-Two-Six        0.767641
dtype: float64


In [38]:
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-358-z'], df['sig-z'], truth_level=1,
                                                                   predict_level='Pred Digit', unknown_thresh=0.0)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)

df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print(ser_correct.sort_values(ascending=False))

Predict:  0.909090909091
Zero    0.934527
Six     0.923233
Two     0.870508
dtype: float64


In [39]:
df_pred_cat.shape

(124, 10351)

In [40]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_21-5", "ini": 124, "clust": 105, "rank": 45, "rankvar": 11…

In [61]:
cat_color['Three-Five-Eight'] = 'red'
cat_color['Four-Seven-Nine'] = 'blue'
cat_color['Zero-Two-Six'] = 'yellow'

In [64]:
net.load_df(df_pred_cat)
set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Coarse')
set_cat_colors(cat_color, axis='col', cat_index=3, cat_title='Pred Digit')

In [65]:
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()

ExampleWidget(network='{"row_nodes": [{"name": "pos_16-20", "ini": 152, "clust": 111, "rank": 48, "rankvar": 5…