# Episode Generation for DeepQA using NLP

In [6]:
from keras.models import Sequential, Model
from keras.layers.embeddings import Embedding
from keras.layers import Input, Activation, Dense, Permute, Dropout, add, dot, concatenate
from keras.layers import LSTM
from keras.utils.data_utils import get_file
from keras.preprocessing.sequence import pad_sequences
from pymongo import MongoClient
from functools import reduce
import nltk
import tarfile
import numpy as np
import re

NATMED_CN = "mongodb://localhost:27017"
NATMED_DB = "natmed"
NATMED_COL = "foods"

DEEPQA_DIR = "../Dumps/natmed_dqa"

DIALOGUE_SIZE = 15

EPISODE_TYPES = [
    'SINGLE_FACT',
    'TWO_FACT',
    'THREE_FACT',
    'TWO_ARGS',
    'THREE_ARGS',
    'YES_NO',
    'LISTS',
    'NEGATION',
    'INDEFINITE',
    'BASIC_COFERENCE',
    'CONJUNCTION',
    'COMPOUND',
    'BASIC_DEDUCTION',
    'BASIC_INDUCTION',
    'PATH_FIDING',
    'AGENT_MOTIVATION']

In [7]:
client = MongoClient(NATMED_CN)
db = client[NATMED_DB][NATMED_COL]

In [8]:
class Episode(object):
    def __init__(self, _type):
        self.type = _type
        self.lines = []
        self.counter = 0
    
    def fact(self, fact):
        self.counter += 1
        self.lines.append((self.counter, fact))
        return self.counter
    
    def dialoge(self, question, answer, fact):
        self.counter += 1
        self.lines.append((self.counter, question, answer, fact))
        return self.counter
    
    def __str__(self):
        lines = ["\t".join(map(str,line)) for line in self.lines]
        return "\n".join(lines)

class Narrative(object):
    def __init__(self, name):
        self.name = name
        self.episodes = {}
        
        for t in EPISODE_TYPES:
            self.episodes[t] = []
    
    def episode(self, episode):
        self.episodes[episode.type].append(episode)
    
    def dump(self):
        print("Narrative", self.name)
        for k in self.episodes.keys():
            if len(self.episodes[k]) > 0:
                print("Episode", k)
                [print(e) for e in self.episodes[k]]

In [9]:
nr = Narrative("teste")

ep = Episode("YES_NO")

fid = ep.fact("Teste is cool!")
ep.dialoge("Is Teste cool?", "yes", fid)

nr.episode(ep)

nr.dump()

Narrative teste
Episode YES_NO
1	Teste is cool!
2	Is Teste cool?	yes	1


In [10]:
def family_names(limit):
    return db.aggregate([
        { "$project": { "name": 1, "familyName": 1 } },
        { "$limit": limit }
    ])

def family_name_narrative():
    nr = Narrative("family_name")
    single = Episode("SINGLE_FACT")
    yes_no = Episode("YES_NO")
    
    for doc in family_names(15):
        if doc.get('familyName'):
            f_name = " ".join(doc.get('familyName').split("/"))
            
            fid = single.fact("{} is the family name of {}.".format(f_name, doc['name']))
            single.dialoge("What is the family name of {}?".format(doc['name']), f_name, fid)
            
            
    nr.episode(single)
    
    return nr

In [11]:
nr = family_name_narrative()
nr.dump()

Narrative family_name
Episode SINGLE_FACT
1	Polemoniaceae is the family name of Abscess Root.
2	What is the family name of Abscess Root?	Polemoniaceae	1
3	Menispermaceae is the family name of Abuta.
4	What is the family name of Abuta?	Menispermaceae	3
5	Fabaceae Leguminosae is the family name of Acacia.
6	What is the family name of Acacia?	Fabaceae Leguminosae	5
7	Fabaceae is the family name of Acacia rigidula.
8	What is the family name of Acacia rigidula?	Fabaceae	7
9	Arecaceae Palmae is the family name of Acai.
10	What is the family name of Acai?	Arecaceae Palmae	9
11	Malpighiaceae is the family name of Acerola.
12	What is the family name of Acerola?	Malpighiaceae	11
13	Sapindaceae is the family name of Ackee.
14	What is the family name of Ackee?	Sapindaceae	13
15	Ranunculaceae is the family name of Aconite.
16	What is the family name of Aconite?	Ranunculaceae	15
17	Cyperaceae is the family name of Adrue.
18	What is the family name of Adrue?	Cyperaceae	17
19	Hypoxidaceae or Liliaceae

In [12]:
def tokenize(sent):
    '''Return the tokens of a sentence including punctuation.
    >>> tokenize('Bob dropped the apple. Where is the apple?')
    ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
    '''
    return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()]

In [13]:
tokenize("Polemoniaceae is the family name of Abscess Root")

  return _compile(pattern, flags).split(string, maxsplit)


['Polemoniaceae', 'is', 'the', 'family', 'name', 'of', 'Abscess', 'Root']

In [37]:
def parse_stories(lines, only_supporting=False):
    '''Parse stories provided in the bAbi tasks format
    If only_supporting is true, only the sentences
    that support the answer are kept.
    '''
    data = []
    story = []
    for line in lines:
        #line = line.decode('utf-8').strip()
        nid, line = line.split('\t', 1)
        nid = int(nid)
        if nid == 1:
            story = []
        if '\t' in line:
            q, a, supporting = line.split('\t')
            q = tokenize(q)
            substory = None
            if only_supporting:
                # Only select the related substory
                supporting = map(int, supporting.split())
                substory = [story[i - 1] for i in supporting]
            else:
                # Provide all the substories
                substory = [x for x in story if x]
            data.append((substory, q, a))
            story.append('')
        else:
            sent = tokenize(line)
            story.append(sent)
    return data

In [22]:
stories = """1	Polemoniaceae is the family name of Abscess Root.
2	What is the family name of Abscess Root?	Polemoniaceae	1
3	Menispermaceae is the family name of Abuta.
4	What is the family name of Abuta?	Menispermaceae	3
5	Fabaceae Leguminosae is the family name of Acacia.
6	What is the family name of Acacia?	Fabaceae Leguminosae	5
7	Fabaceae is the family name of Acacia rigidula.
8	What is the family name of Acacia rigidula?	Fabaceae	7
9	Arecaceae Palmae is the family name of Acai.
10	What is the family name of Acai?	Arecaceae Palmae	9
11	Malpighiaceae is the family name of Acerola.
12	What is the family name of Acerola?	Malpighiaceae	11
13	Sapindaceae is the family name of Ackee.
14	What is the family name of Ackee?	Sapindaceae	13
15	Ranunculaceae is the family name of Aconite.
16	What is the family name of Aconite?	Ranunculaceae	15
17	Cyperaceae is the family name of Adrue.
18	What is the family name of Adrue?	Cyperaceae	17
19	Hypoxidaceae or Liliaceae is the family name of African Wild Potato.
20	What is the family name of African Wild Potato?	Hypoxidaceae or Liliaceae	19
21	Amanitaceae is the family name of Aga.
22	What is the family name of Aga?	Amanitaceae	21"""

In [39]:
data_stories = parse_stories(stories.split("\n"), only_supporting=True)

  return _compile(pattern, flags).split(string, maxsplit)


In [40]:
data_stories[0]

([['Polemoniaceae',
   'is',
   'the',
   'family',
   'name',
   'of',
   'Abscess',
   'Root',
   '.']],
 ['What', 'is', 'the', 'family', 'name', 'of', 'Abscess', 'Root', '?'],
 'Polemoniaceae')

In [41]:
def get_stories(stories, only_supporting=False, max_length=None):
    '''Given a file name, read the file,
    retrieve the stories,
    and then convert the sentences into a single story.
    If max_length is supplied,
    any stories longer than max_length tokens will be discarded.
    '''
    data = parse_stories(stories.split("\n"), only_supporting=only_supporting)
    flatten = lambda data: reduce(lambda x, y: x + y, data)
    data = [(flatten(story), q, answer) for story, q, answer in data if not max_length or len(flatten(story)) < max_length]
    return data

In [43]:
data_stories = get_stories(stories)

  return _compile(pattern, flags).split(string, maxsplit)


In [44]:
def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
    X = []
    Xq = []
    Y = []
    for story, query, answer in data:
        x = [word_idx[w] for w in story]
        xq = [word_idx[w] for w in query]
        # let's not forget that index 0 is reserved
        y = np.zeros(len(word_idx) + 1)
        y[word_idx[answer]] = 1
        X.append(x)
        Xq.append(xq)
        Y.append(y)
    return (pad_sequences(X, maxlen=story_maxlen),
            pad_sequences(Xq, maxlen=query_maxlen), np.array(Y))