# CNTK 302b: Evaluation ReasoNet for Machine Comprehension with CNN Dataset



This tutorial loads a pre-trained ReasoNet model and shows how cached models can be used to perform predictions a.k.a evalation on CNN data set that was not used in the training.

## Data preparation

### Download data


In [19]:
import os
import sys

sys.path.insert(0, "../Examples/LanguageUnderstanding/")
from ReasoNet.prepare_cnn_data import file_exists,merge_files,download_cnn,download

# Check for an environment variable defined in CNTK's test infrastructure
envvar = 'CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY'
def is_test(): 
  return envvar in os.environ
    
data_root = "../Examples/LanguageUnderstanding/ReasoNet/Data"

if is_test():
  raw_train_data=os.path.join(data_root, "cnn_test/training.txt")
  raw_test_data=os.path.join(data_root, "cnn_test/test.txt")
else:
  raw_train_data=os.path.join(data_root, "cnn/training.txt")
  raw_test_data=os.path.join(data_root, "cnn/test.txt")
  if not (file_exists(raw_train_data) and file_exists(raw_test_data)):
    download_cnn(data_root)
  merge_files(os.path.join(data_root, "cnn/questions/training"), raw_train_data)
  merge_files(os.path.join(data_root, "cnn/questions/test"), raw_test_data)
print("All necessary data are downloaded to {0}".format(data_root))

All necessary data are downloaded to ../Examples/LanguageUnderstanding/ReasoNet/Data


### Convert to CNTK Text Format


In [20]:
from ReasoNet.wordvocab import *

if is_test():
  vocab_path=os.path.join(data_root, "cnn_test/cnn.vocab")
  train_ctf=os.path.join(data_root, "cnn_test/training.ctf")
  test_ctf=os.path.join(data_root, "cnn_test/test.ctf")
  test_size=379913
else:
  vocab_path=os.path.join(data_root, "cnn/cnn.vocab")
  train_ctf=os.path.join(data_root, "cnn/training.ctf")
  test_ctf=os.path.join(data_root, "cnn/test.ctf")
  test_size=2291183
vocab_size=101000
if not (file_exists(train_ctf) and file_exists(test_ctf)):
  entity_vocab, word_vocab = Vocabulary.build_vocab(raw_train_data, vocab_path, vocab_size)
  Vocabulary.build_corpus(entity_vocab, word_vocab, raw_test_data, test_ctf)
print("Data conversion finished.")

Data conversion finished.


### Download model

In [21]:
if is_test():
  model_src="http://cntk.ai/jup/models/reasonet/model_training.ctf_final.dnn.bin"
  model_path="model/model_training.ctf_final.dnn"
else:
  model_src="http://cntk.ai/jup/models/reasonet/model_cnn.epoch.00.bin"
  model_path="model/model_cnn.epoch.00.bin"
if not file_exists(model_path):
  download(model_src, model_path)
    
print("Succeeded to download model to local.")

Succeeded to download model to local.


## Basic CNTK imports

In [22]:
import sys
from datetime import datetime
import numpy as np
import cntk
from cntk import device
from cntk.ops import sequence, element_times, reshape, greater, slice, hardmax, input
from io import open

# Select the right target device when this notebook is being tested
# Currently supported only for GPU 

if 'TEST_DEVICE' in os.environ:
    if os.environ['TEST_DEVICE'] == 'cpu':
        raise ValueError('This notebook is currently not support on CPU') 
    else:
        cntk.device.set_default_device(cntk.device.gpu(0))
cntk.device.set_default_device(cntk.device.gpu(0))

True

### Predict
The original CNN data has been pre-processed by replace entities in the text with *@entityXX* and the answer is taken from the entities. Here is an example,

* The original paragraph,
>april 2 , 2015 an unstable Middle Eastern country has become a potential battlefield for a proxy war . today on CNN Student News , hear an explainer on why Yemen is the focus of global concern . we also report on the origins of April Fools ' Day , we detail how a 1,000 - year - old recipe could cure a modern - day superbug , and we feature a Character Study on a woman who 's steering kids to a better life . on this page you will find today 's show transcript and a place for you to request to be on the CNN Student News Roll Call . transcript click here to access the transcript of today 's CNN Student News program . please note that there may be a delay between the time when the video is available and when the transcript is published . CNN Student News is created by a team of journalists who consider the Common Core State Standards , national standards in different subject areas , and state standards when producing the show . ROLL CALL for a chance to be mentioned on the next CNN Student News , comment on the bottom of this page with your school name , mascot , city and state . we will be selecting schools from the comments of the previous show . you must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call ! thank you for using CNN student news !

* The original query,
>at the bottom of the page , comment for a chance to be mentioned on CNN Student News . you must be a teacher or a student age 13 or older to request a mention on the @placeholder .

* The answer
>CNN Student News Roll Call

After pre-processing, it will be looks like,

* Paragraph
>april 2 , 2015 an unstable @entity1 country has become a potential battlefield for a proxy war . today on @entity4 , hear an explainer on why @entity6 is the focus of global concern . we also report on the origins of @entity11 , we detail how a 1,000 - year - old recipe could cure a modern - day superbug , and we feature a @entity14 on a woman who 's steering kids to a better life . on this page you will find today 's show transcript and a place for you to request to be on the @entity22 . transcript click here to access the transcript of today 's @entity25 . please note that there may be a delay between the time when the video is available and when the transcript is published . @entity4 is created by a team of journalists who consider the @entity33 , national standards in different subject areas , and state standards when producing the show . @entity38 for a chance to be mentioned on the next @entity4 , comment on the bottom of this page with your school name , mascot , city and state . we will be selecting schools from the comments of the previous show . you must be a teacher or a student age 13 or older to request a mention on the @entity22 ! thank you for using @entity56 student news !

* Query
>at the bottom of the page , comment for a chance to be mentioned on @entity4 . you must be a teacher or a student age 13 or older to request a mention on the @placeholder .

* Answer
>@entity22

After we get the model, we can use it to predict answers given a paragraph and a query. The inputs to our `predict` function is the *pre-processed* paragraphs and queries. The output is a one hot vector whose dimention is the number of **unique** entities in the paragraph as we pick answer from the entities in the paragraph. And a 1 in the vector means the entity at that position(*the position index is the same as entity id*) is the predicted answer and 0 means not.

In [23]:
from ReasoNet.reasonet import *
def predict(model, params):
  """
  Compute the prediction result of the given model
  """
  model_args = {arg.name:arg for arg in model.arguments}
    
  # entity_ids_mask is a sequence of boolean with the same length 
  #  as the number of tokens in the paragraph, 
  #  where none zero means the corresponding token is an entity.
  # E.g.
  # Paragraph
  # Abc efg @entity1 xyz @entity1 @entity3
  # The corresponding entity ids mask:
  #  0  0  1  0  1  1
  entity_ids_mask = model_args['entity_ids_mask']
  
  # Normalize the input to make all none zero values to 1s
  entity_condition = greater(entity_ids_mask, 0, name='condidion')
    
  # entities_all is sequence of all 1s with the same length of the number of all the enities in the paragraph. 
  # With gather operation we will create a new dynamic sequence axes.
  # E.g. 
  # The entities in order in the paragraph is  
  #  @entity1 @entity1 @entity3
  # The output of the operation will be
  # 1 1 1  
  entities_all = sequence.gather(entity_condition, 
                                 entity_condition, 
                                 name='entities_all')

  # The model prediction is a sequence of probabilities of all the tokens in the paragraph, 
  # but we only pick answer from the entities. 
  # With gather operation, we will filter out the probabilities of none entities.
  # E.g.
  # The model prediction of the above example would be something like
  #  0.1 0.2 0.1 0.3 0.2 0.1
  # The probabilities of the entities to be the answer will be
  #  0.1 0.2 0.1
  # With scatter operation, we assign the dynamic axes of entities_all to answers.
  answers = sequence.scatter(
                             # Only get the predicted probilities of the entities in the paragraph
                             sequence.gather(model.outputs[-1], entity_condition), 
                             entities_all, 
                             name='Final_Ans')

  # entity_ids is the ids of the entities in the paragraph in order. 
  # It's a sequence of one hot encoded vector. 
  # The dimention is the maxium number of unique entities in all the paragraphs.
  # E.g. 
  # The ids for the above example will be
  #  1:1 1:1 3:1  
  entity_ids = input(shape=(params.entity_dim), 
                     is_sparse=True,   
                     # The sequence length is the same as the number of entities in the paragraph, 
                     #   so it has the same dynamic axes as entities_all, as well as answers.
                     dynamic_axes=entities_all.dynamic_axes,
                     name='entity_ids')
    
  # The global token id zero is used for unknown tokens, and entity ids start with 1. 
  # So we will trim the first column in the entity id matrix.
  # E.g. the output for the above example will be
  # [1 0 0] [1 0 0] [0 0 1]
  entity_id_matrix = slice(
                           # entity_ids is one hot encoded sparse vectors, 
                           # by reshaping it, we convert them to dense vectors. 
                           reshape(entity_ids, params.entity_dim),
                           # It's the last axis
                           axis = -1, 
                           begin_index = 1, 
                           end_index = params.entity_dim)

  # Now by multiplying answers with entity_id_matrix, we will get a probability matrix like
  # [0.1 0 0] [0.2 0 0] [0 0 0.1]
  entity_probs = element_times(answers, entity_id_matrix)
    
  # By reducing sum over the sequence dynamic axis, 
  # we will aggregate the probabilities of the same entities that 
  #  present at different positions in the paragraph. 
  # Then we get the probabilities of unique entities in the paragraph.
  # E.g. the output for the above example input will be,
  # [0.3 0 0.1]
  agg_pred = sequence.reduce_sum(entity_probs)
  
  # We pick the entities with maxium probability as the answer
  # E.g. the output for the above example will be,
  # [1 0 0]
  pred_max = hardmax(agg_pred, name='pred_max')
  return pred_max

#### Mapping the prediction to entities
The prediction result is a one hot vector that 1 means the entity at that position is the predicted answer and 0 means not. To make the predition result readable, we can convert that vector to entity id and remapping it back to the real entity.

In [24]:
import sys
import os
import cntk.device as device
import numpy as np
import math
from cntk import load_model
import time
def unroll_entities(doc, entity_dict):
  tokens = doc.split(u' ')
  for i in range(len(tokens)):
    if tokens[i] in entity_dict:
      tokens[i] = entity_dict[tokens[i]]    
  return u' '.join(tokens)  

def pred_cnn_model(model_path):
  logger.init("cnn_test")
  vocab_dim = 101585
  entity_dim = 586
  hidden_dim=256
  max_rl_steps=5
  embedding_dim=300
  att_dim = 384
  minibatch_size=1
  share_rnn = True

  test_data = create_reader(test_ctf, vocab_dim, entity_dim, False)
  embedding_init = None

  params = model_params(vocab_dim = vocab_dim, entity_dim = entity_dim, hidden_dim = hidden_dim,
                        embedding_dim = embedding_dim, attention_dim=att_dim, max_rl_steps = max_rl_steps,
                        embedding_init = embedding_init, dropout_rate = 0.2, share_rnn_param = share_rnn)

  entity_table, word_table = Vocabulary.load_vocab(vocab_path)
  model = load_model(model_path)
  predict_func = predict(model, params)
  bind = bind_data(predict_func, test_data)
  context_stream = get_context_bind_stream(bind)
  samples_sum = 0
  i = 0
  predicted_results = []
  max_num = 5
  start = time.time()
  while i<test_size:
    mbs = min(test_size - i, minibatch_size)
    mb = test_data.next_minibatch(mbs, bind)
    pred = predict_func.eval(mb)
    # Convert entity one hot vector to entity id
    ans = np.nonzero(pred)
    # Remapping entity id to real entity
    for id in ans[1]:
      predicted_results += [ entity_table.lookup_by_id(id) ]    
    i += mb[context_stream].num_samples
    samples = mb[context_stream].num_sequences
    samples_sum += samples
    sys.stdout.write('.')
    sys.stdout.flush()
    if samples_sum >= max_num:
      break
  end = time.time()
  total = end - start
  print("")
  print("Evaluated samples: {0} in {1} seconds".format(samples_sum, total))
  instance_id = 0
  with open(raw_test_data, 'r', encoding='utf-8') as raw:
    content = raw.readlines()
    for record in content:
      fields = record.strip().split(u'\t')
      query = fields[0]
      answer = fields[1]
      doc = fields[2]
      entity_dict={}
      for i in range(3,len(fields)):
        pair=fields[i].split(u':')
        entity_dict[pair[0]]=pair[1]
      print("===============")
      print("[{0}] Doc: {1}\n Query: {2}\n Answer: {3}\n Expected: {4}".format(instance_id, 
                                                                               doc, 
                                                                               query, 
                                                                               predicted_results[instance_id], 
                                                                               answer))
      print("=>Unrolled=>\n Query: {0}\n Answer: {1}".format(unroll_entities(query, entity_dict), 
                                                     unroll_entities(answer, entity_dict)))
      print()
      instance_id+=1
      if instance_id >= len(predicted_results):
        break

pred_cnn_model(model_path)

Log with log file: log/cnn_test_04-15_19.40.43.log
.....
Evaluated samples: 5 in 1.1304600238800049 seconds
[0] Doc: ( @entity0 ) -- a @entity2 court of appeals gave @entity3 a thumbs up on wednesday when it ruled that " likes " on the social network are protected as free speech under the @entity6 . in 2009 , six employees at the @entity10 in @entity11 lost their jobs after expressing support for their boss ' opponent in an upcoming election for sheriff , some by liking and commenting on the opponent 's facebook page . @entity2 circuit judge @entity23 found that " liking " something on the social network was the " internet equivalent of displaying a political sign in one 's front yard , " an act the @entity22 has already ruled as protected speech . the decision from the 4th @entity2 @entity33 in @entity34 , @entity11 , was a reversal of an earlier ruling on the case , in which district judge @entity38 said liking a @entity3 page was " insufficient speech to merit constitutional protect