<a href="https://colab.research.google.com/github/constantin50/machine_learning/blob/master/Bot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install deeppavlov
!python -m deeppavlov install squad_bert
!pip install pyspellchecker
!pip install wikipedia-api

In [0]:
import re
import nltk
import wikipediaapi
from nltk import ne_chunk, pos_tag
from nltk.tree import Tree
from nltk.stem import WordNetLemmatizer
from spellchecker import SpellChecker
from nltk import ne_chunk, pos_tag
nltk.download('averaged_perceptron_tagger') 
nltk.download('words')
nltk.download('wordnet')

from deeppavlov import build_model, configs

In [0]:
test = [['what is an unit vector?','length'],
['what is a basis?','independent'],
['what is a linear span?','smallest'],
['where are vector spaces applied?','engineering'],
['is a tensor independent of any basis?','independent'],
['how can a linear map be represented?','matrices'],
['what is a imaginary unit?','solution'],
['what does a complex number mean geometrically?','plane'],
['what is a determinant geometrically speaking?','volume'],
['when is a determinant positive?','orientation'],
['what is pi','constant'],
['is an orthogonal matrix invertible?','invertible'],
['what is a real number?','continuous'],
['what is a example of a group?','integers'],
['what is a finit group?','finit'],
['what is an inverse function?','reverses'],
['what is a length of a point?','any'],
['what is a gradient?','function'],
['does the real numbers includes the rational numbers?','include'],
['how can a gradient be interpreted?','direction'],
['how can a gradient be used to maximize a function?','ascent'],
['what is a cotangent space?','smooth'],
['what is a field?','structure'],
['what is a example of a ring?','integers'],
['is a division defined on field?','defined'],
['what is a common fraction?','numeral'],
['what is a sequence?','enumerated'],
['what is a length of a sequence?','elements'],
['what are transfinite numbers?','"infinite'],
['what are a positive numbers?','greater'],
['what is a complex plane?','representation'],
['what is a gauss plane?','complex'],
['what can a rotation desribe?','body'],
['what is a zero matrix?','zero'],
['when can two matrices be added?','same'],
['when are vectors orthogonal?','0'],
['what are examples of reflection?','isometry'],
['how are angels formed?','intersection'],
['what is a vertex?','meet'],
['what does a projective geometry study?','proprities'],
['what is a norm?','function'],
['what is a vector space in geometric sense?','displacements'],
['what is a metric?','function'],
['what does a metric tensor define?','length'],
['what is a positive-definite metric tensor?','0'],
['is a symmetric tensor invariant under a permutation?','invariant'],
['what is an antisymmetric tensor?','sign'],  
['is a result of a dot product scalar?','scalar'],
['what is a variable?','arbitrary'],
['what is a example of a function?','integers']]



In [0]:
# this class handles text data: lemmatization, tagging and correction of spelling.
class Analyzer:

  """
  Attributes
  -----------
  lemmatizer : class from nltk lib
  	class turns word in lemma
  
  spell : class fron spellchecker lib
    class corrects misspelling 
  """

  def __init__(self):
    self.lemmatizer = WordNetLemmatizer()
    self.spell = SpellChecker()


  def normalize(self, exp):
    """
    Parameters
    -----------
    exp : string
  	   expression to normalize

    Returns
    --------
    result : list
  	    list of lemmatized and correctly spelled words

    """
    exp = nltk.word_tokenize(exp)
    for i in range(len(exp)): 
	    if (exp[i] != "was" or exp[i] != "does"):
	      exp[i] = self.lemmatizer.lemmatize(exp[i])
    result = self.correct_spelling(exp)
    if(result != ' '): return result;  


  def normalize_and_tag(self, exp, deter=True, aux=True):
    """
	Parameters
	-----------
	exp : string
	   expression to normalize and tag

	deter : bool    
	   if true then determiners will be removed of expression

	aux : bool 
	   if true then auxiliary verbs in questions 
	   will be marked with "AUX" tag 

	Returns
	--------
	result : list
       list of lemmatized, tagged and correctly spelled words

	"""
    aux_verbs = ["am", "is", "are", "was", "were", "will", "did", "doe", "shall"]
    math_consts = ["Ï€", "pi", "e"]
    exp = nltk.word_tokenize(exp)
    for i in range(len(exp)): 
      if (exp[i] != "was" or exp[i] != "does"):
	      exp[i] = self.lemmatizer.lemmatize(exp[i])
    exp = self.correct_spelling(exp)
    result = pos_tag(exp)
    for i in range(len(result)):
      if (result[i][0] in math_consts): 
        result[i] = list(result[i])
        result[i][1] = "NN"

    if (deter==True):
      _result = []
      for w in result:
        if(w[1] != "DT"): _result.append(w)
        result = _result;
    if (aux==True):
      if (result[0][0] in aux_verbs):
        result[0] = list(result[0])
        result[0][1] = "AUX"
    if(result != []):return result;


  def correct_spelling(self, exp):
    """
    Parameters
    -----------
    exp : string
      expression to correct

    Returns
    --------
    result : list
        list of corrected words

	  """
    misspelled = self.spell.unknown(exp)
    corrected = []
    corrected_words = [word for word in misspelled]
    for i in range(len(exp)):
      if (exp[i] not in misspelled):
	      corrected.append(exp[i])
      else:
        corrected.append(self.spell.correction(exp[i]))
        print("did you mean: ", self.spell.correction(exp[i]))
    return corrected;

In [0]:

class Bot:

  def __init__(self):
    self.knw_base = wikipediaapi.Wikipedia('en')
    self.model = build_model(configs.squad.squad_bert, download=True)
    self.analyzer = Analyzer()
    
  # takes user's input
  def take_query(self):
    """
    Attributes
    ------------

    query : string
      user's input
    
    qtns : list
      questions extracted from query

    ents : list
      entities extracted from questions

    cntxt : list
      wiki-pages 

    """  
    query = input("query: ")
    qtns = self.exract_questions(query)
    qtns = self.solve_anaphora(qtns)
    ents = self.extract_entities(qtns)
    cntxt = self.make_requests(ents)
    self.response(qtns, cntxt, ents)
  
  def response(self, qtns, contxts, ents):

    for i in range(len(contxts)):
        contxt = ' ';
        for j in range(len(contxts[i])):
          contxt += ' '+re.sub(r'\n|  |\{[^)]*\}|\\|\([^)]*\)|\S*\)|\S*\}', '',contxts[i][j].summary)
        answer = self.model([contxt], [qtns[i]])
        print("QUESTION:", qtns[i])
        if (answer[0][0] != ''):
          print("ANSWER:", answer[0][0])
        else:
          print("nothing about it")
        print("\n")
  
  # extract entities from user's query 
  def extract_entities(self, qtns):
    """
    Parameters
    -----------
    qtns : list of strings
      questions 

    Returns 
    --------
    result : list of string
      entities extracted from questions

    """
    result = []

    tagged_qtns = []
    # proprocessing
    for qtn in qtns:
      qtn = self.analyzer.normalize_and_tag(qtn)
      tagged_qtns.append(qtn);

    # find all noun phrase in questions
    result = list()
    for qt in tagged_qtns:
      ents = list()
      i = 0;
      while (i < len(qt)):
        curr = ' ';
        if((qt[i][1] == "NN" or qt[i][1] == "JJ") and 
           (qt[i-1][1] != "WRB" and qt[i-1][1] != "WP" and qt[i-1][1] != "WDT") and i < len(qt)):
          curr += qt[i][0]+" "
          i += 1
          if (i < len(qt)):
            while(qt[i][1] == "NN" or qt[i][1] == "IN" or qt[i][1] == "JJ"):
              curr += qt[i][0]+" "
              if (i == len(qt)-1): break;
              else: i+=1;
        if (curr != ' '): ents.append(curr)
        i += 1
      result.append(ents)
    
    # formate data 
    formated_result = list() 
    for a in result:
      c = list()
      for b in a:
        if(self.find_entity(b.strip(' ')) != "_" and 
           self.find_entity(b.strip(' ')) != None):
          c.append(self.find_entity(b.strip(' ')))
      formated_result.append(c)
  
    return formated_result


  def exract_questions(self, query):
    """
    Parameters
    -----------
    query : string
      query from user 

    Returns 
    --------
    result : list of string
      questions extracted from questions

    """
    aux_verbs = ["am", "is", "are", "was", "were", "will", "did", "does", "shall"]
    result = []

    # remove '!' and repetitive '?'
    query = re.sub(r'[!]','',query)
    query = re.sub(r'[?](?=\?)','',query)

    if (query.count('?') > 1):
      query = query.split("?")
      for i in range(len(query)):
        curr_qstn = ' ';
        query[i] = self.analyzer.normalize(query[i])
        for j in range(len(query[i])): 
          curr_qstn += query[i][j] + ' ';
        if (curr_qstn != ' '): result.append(curr_qstn);
      return result
    
    else:
      # preprocessing   
      query = self.analyzer.normalize_and_tag(query)
      i = 0
      while (i<len(query)):
        if (query[i][1] == "WDT" or query[i][1] == "WP" or query[i][1] == "WRB" or query[i][1] == "AUX"):
          curr_qstn= query[i][0]
          i += 1;
          while (query[i][1] != "WDT" and query[i][1] != "WP" and query[i][1] != "WRB" and query[i][1] != "AUX"):
            curr_qstn += " " + query[i][0];
            if (i == len(query) - 1): break;
            else: i += 1;
          result.append(curr_qstn)
        else: i += 1
      return result
        
  # requests wiki articles on extracted entities 
  def make_requests(self, ents):
    """
    Parameters
    -----------
    ents : list of strings
      entities extracted from user's query 

    Returns 
    --------
    result : list of custom wiki objects
      wiki pages that are related to extracted entities 
    """
    result = list()

    for i in range(len(ents)):
      curr = list()
      for j in range(len(ents[i])):
        # search only in realm of maths
        request = [ents[i][j], ents[i][j]+"_(mathematics)", ents[i][j]+"_(geometry)"]
        cntxt = [self.knw_base.page(request[0]), self.knw_base.page(request[1]), self.knw_base.page(request[2])]
        if (cntxt[1].exists()):
          curr.append(cntxt[1])
        if (cntxt[2].exists()):
          curr.append(cntxt[2])
        else:
          curr.append(cntxt[0]) 
      result.append(curr)
    return result


  def find_entity(self, np):
    """
    Parameters
    -----------
    np : string
      noun phrase that may contain entity  

    Returns 
    --------
    ents : list of string
      entities extracted from the given noun phrase

    """
    # scinario 1: page with name 'np' exists on Wikipadia in math section
    if (self.knw_base.page(np+"_(mathematics)").exists()):
      return np

    # scinario 2: page with name 'np' exists on Wikipedia and contains certain words
    elif (self.knw_base.page(np).exists()):
      if ("mathematics" in self.knw_base.page(np).text or
          "algebra" in self.knw_base.page(np).text or
          "calculus" in self.knw_base.page(np).text):
        return np

    # scinario 3: noun phrase contains preposition 'of' (e.g. result of cross product) 
    elif ('of' in np):
      new_np = np.split('of')[1]
      return self.find_entity(new_np)

    # scinario 4: noun phrase np consists of 3 words (e.g. cross product commutative)    
    elif (len(np.split(' ')) == 3):
      if (self.knw_base.page(' '.join(np.split(" ")[1:])).exists()):
        return self.find_entity(' '.join(np.split(" ")[1:]))
      elif (self.knw_base.page(' '.join(np.split(" ")[:2])).exists()):
        return self.find_entity(' '.join(np.split(" ")[:2]))

    # scinario 4: noun phrase np consists of 2 words (e.g. lines perpendicular)
    elif (len(np.split(' ')) == 2):
      if (self.knw_base.page(np.split(" ")[0]).exists()):
        return self.find_entity(np.split(" ")[0])
      elif (self.knw_base.page(np.split(" ")[1]).exists()):
        return self.find_entity(np.split(" ")[1])   
    
    else: 
      return("_")
  
  def solve_anaphora(self, qts):
    """
    takes questions, if it finds pronouns in some question
    then it tries to extract entity from previous question
    and replace the pronoun in the current one with it. 
    
    returns list of questions where all pronouns are replaced with
    certain entities
    """
    result = list()
    for i in range(len(qts)):
      curr = qts[i].split(" ")
      for j in range(len(curr)):
        if (curr[j] == 'it' or curr[j] == 'they'):
          ent = self.extract_entities([qts[i-1]])
          curr[j] = 'a ' + ent[0][0]
      result.append(" ".join(curr))
    return result
  
  def evaluate(self, data):
    errors = 0
    for k in range(len(data)):
      query = data[k][0]
      qtns = self.exract_questions(query)
      qtns = self.solve_anaphora(qtns)
      ents = self.extract_entities(qtns)
      contxts = self.make_requests(ents)
      for i in range(len(contxts)):
        contxt = ' ';
        for j in range(len(contxts[i])):
          contxt += re.sub(r'\n|  |\{[^)]*\}|\\|\([^)]*\)|\S*\)|\S*\}', '',contxts[i][j].summary)
        answer = self.model([contxt], [qtns[i]])
        if (data[k][1] not in answer[0][0]): errors += 1
    return (errors/len(data))*100
    
bot = Bot()
bot.take_query()