In [1]:
# Calculates PRECISION and RECALL per CLASS in a MULTI-CLASS scenario where
# number of classes is more than 2 (It is 3 in the example below)
# It also calculates F1 Score per class.
# And it even calculates ACCURACY on the entire data set given.

# IMPORTANT NOTES:
# 1) Since the same tensor is used for accuracy/precision/recall calculation for
#    both train and test sets, you need to reset local variables in each tf.metrics
#    tensor before using it for - say - test set after using it for train set.
#    Otherwise, your - say - test results will be a combination of train and test results
# 2) You could have used the default tf local variable initializer. But it initializes all
#    local variables. It does not heart to use it in the beginning but before re-using 
#    the tf.metric function, you just need to reset the corresponding local variables
# 3) Note that tf.metrics.accuracy(..) function works fine with multi class scenarios.
#    However, precision and recall functions are not working when there are more than 2 classes.


In [2]:
import tensorflow as tf
import numpy as np

In [3]:
# ASSUMPTIONS: (Otherwise, decode_csv function needs update)
# 1) The first column is NOT a feature. (It is most probably a training example ID or similar)
# 2) The last column is always the label. And there is ONLY 1 column that represents the label.
#    If more than 1 column represents the label, decode_csv() function needs update 
# 3) The first row is assumed to include names of the data types (i.e. feature name, label, etc.) 
#    so it is skipped

# UPDATE record_default IN EACH PROJECT (depending on default values for each column)
# Determine default values for each column in case data is missing
record_defaults = [[""], [0.0], [0.0], [0.0], [0.0], [0.0], [0]]

def decode_csv(line):
    parsed_line = tf.decode_csv(line, record_defaults)
    label = parsed_line[-1:]          # last column is label
    del parsed_line[-1]               # delete the last element from the list   (label column)
    del parsed_line[0]                # even delete the first element bcz it is assumed NOT to be a feature
    features = tf.stack(parsed_line)  # Stack features so that you can later vectorize forward prop., etc.
    label = tf.stack(label)           # Needed bcz labels consist of 2 columns
    batch_to_return = features, label

    return batch_to_return


In [8]:
tf.reset_default_graph()  

# Assume that following values are our predictions from our model
train_predictions = [[[0], [0], [0], [0], [1], [2]],
                     [[1], [2], [1], [1], [0], [2]],
                     [[2], [2], [0], [1], [1], [1]]]

num_classes = 3
minibatch_size = 6
file_names = ["train1_with_3_label_classes.csv"]

precision_per_class = [0] * num_classes
update_precision_per_class = [[]] * num_classes
recall_per_class = [0] * num_classes
update_recall_per_class = [[]] * num_classes
f1_score_per_class = [0] * num_classes

with tf.name_scope("read_next_train_batch"):
    filenames = tf.placeholder(tf.string, shape=[None])
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv))
    dataset = dataset.batch(minibatch_size)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

# Placeholders to take in batches of data
tf_labels = tf.placeholder(dtype=tf.int64, shape=[minibatch_size, 1])
tf_predictions = tf.placeholder(dtype=tf.int64, shape=[minibatch_size, 1])    

with tf.name_scope("precision_per_class_scope"):
    for k in range (num_classes):
        precision_per_class[k], update_precision_per_class[k] = tf.metrics.precision(tf_labels, tf_predictions,
                                                                                     name="precision_class_"+str(k))

with tf.name_scope("recall_per_class_scope"):
    for k in range (num_classes):
        recall_per_class[k], update_recall_per_class[k] = tf.metrics.recall(tf_labels, tf_predictions,
                                                                            name="recall_class_"+str(k)) 
        
with tf.name_scope("metric_accuracy_scope"):
    # This will hold the accuracy for the entire set, not only per class
    accuracy, accuracy_update = tf.metrics.accuracy(tf_labels,
                                                    tf_predictions)

init_local = tf.local_variables_initializer() 


with tf.Session() as sess:
    sess.run(init_local)   # Local vars need to be initialized to be able to use tf.metrics functions
        
    sess.run(iterator.initializer, feed_dict={filenames: file_names})
    batch_nr = 0
    while True:
        try:
          batch_features, batch_labels = sess.run(next_element)           
          
          # Update the accuracy based on the current batch
          sess.run(accuracy_update, 
                   feed_dict={tf_labels: batch_labels, 
                              tf_predictions: train_predictions[batch_nr]})
            
          # If we have 3 classes, then they will be like class 0, class 1, and class 2.
          # When calculating precision for class 0, we will check:
          # How many CORRECT class 0 predictions do we have? (True Positive)
          # And how many WRONG class 0 predictions do we have? (False Positive)
          # And precision for class 0  =  TP_class0 / (TP_Class0 + FP_class0)
          for k in range(num_classes):
              
              # If a given batch_labels = [[0],[0],[0],[1],[1],[2]]
              # then the following code will produce: [[True],[True],[True],[False],[False],[False]] for class_0  
              labels = np.equal(batch_labels[batch_nr], np.ones(batch_labels.shape)*k)
              
              batch_predictions = np.equal(train_predictions[batch_nr], np.ones(batch_labels.shape)*k)
                
              # Update precision and recall for the class=k
              sess.run([update_precision_per_class[k], update_recall_per_class[k]], 
                       feed_dict={tf_labels: labels, 
                                  tf_predictions: batch_predictions})
                
          for k in range (num_classes):
              print("Batch nr: ", batch_nr, "class: ", k,
                    "  Precision: ", sess.run(precision_per_class[k]),
                    "  Recall:    ", sess.run(recall_per_class[k]),
                    "  Model Accuracy: ", sess.run(accuracy))
          
          batch_nr += 1
            
        except tf.errors.OutOfRangeError:
          print("All data processed.\n")
          break
    
    # CALCULATE F1 SCORE PER CLASS. All 3 operations are element-wise
    numerator = tf.multiply(2., tf.multiply(precision_per_class, recall_per_class))
    denominator= tf.add(precision_per_class, recall_per_class)
    f1_score_per_class = tf.divide(numerator, denominator)
    
    print("Summary\n")
    for k in range (num_classes):
        print("class: ", k,
              "  Precision:  ", sess.run(precision_per_class[k]),
              "  Recall: ", sess.run(recall_per_class[k]),
              "  F1 Score: ", sess.run(f1_score_per_class[k]),
              "  Model Accuracy: ", sess.run(accuracy))
        

Batch nr:  0 class:  0   Precision:  1.0   Recall:     0.6666667   Model Accuracy:  0.6666667
Batch nr:  0 class:  1   Precision:  0.0   Recall:     0.0   Model Accuracy:  0.6666667
Batch nr:  0 class:  2   Precision:  0.0   Recall:     0.0   Model Accuracy:  0.6666667
Batch nr:  1 class:  0   Precision:  0.8   Recall:     0.6666667   Model Accuracy:  0.5833333
Batch nr:  1 class:  1   Precision:  0.75   Recall:     0.5   Model Accuracy:  0.5833333
Batch nr:  1 class:  2   Precision:  0.0   Recall:     0.0   Model Accuracy:  0.5833333
Batch nr:  2 class:  0   Precision:  0.6666667   Recall:     0.6666667   Model Accuracy:  0.5
Batch nr:  2 class:  1   Precision:  0.42857143   Recall:     0.5   Model Accuracy:  0.5
Batch nr:  2 class:  2   Precision:  0.4   Recall:     0.33333334   Model Accuracy:  0.5
All data processed.

Summary

class:  0   Precision:   0.6666667   Recall:  0.6666667   F1 Score:  0.6666667   Model Accuracy:  0.5
class:  1   Precision:   0.42857143   Recall:  0.5   F1