In [39]:
import os
import collections
from collections import defaultdict
import numpy as np
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from nltk.tokenize import sent_tokenize, word_tokenize
from tqdm import tqdm

ps = PorterStemmer()
nltk.download()

showing info https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/index.xml


True

In [61]:
class InvertedIndex:
    def __init__(self, directory, stopwords, save=True, name='ii'):
        self.directory = directory
        self.stopwords = stopwords
        self.save = save
        self.name = name
        self.id_to_file = {}
        self.index = defaultdict(lambda: {'count': [], 'words': [], 'rotations': [], 'postings': []})
        self.tgi = defaultdict(lambda: set())
        self.construct()
        self.construct_tgi()
        
    def produce_rotations(self, word):
        term = "$" + word
        res = [term]
        for i in range(len(word) - 1):
            term = term[-1] + term[:-1]
            res.append(term)
        return res
    
    def construct(self):
        for i, filename in tqdm(enumerate(os.listdir(self.directory))):
            self.id_to_file[i] = filename
            with open(os.path.join(self.directory, filename), 'rt') as original:
                sents = sent_tokenize(original.read())
                for s in sents:
                    for w in word_tokenize(s):
                        w = w.lower()
                        stemmed = ps.stem(w)
                        if stemmed not in self.stopwords:
                            if i not in self.index[stemmed]['postings']:
                                self.index[stemmed]['postings'].append(i)
                            if w not in self.index[stemmed]['words']:
                                self.index[stemmed]['words'].append(w)
        
        for t in self.index.keys():
            self.index[t]['count'] = len(self.index[t]['postings'])

            for w in self.index[t]['words']:
                self.index[t]['rotations'] += self.produce_rotations(w)
            
        if self.save:
            np.save(self.name, np.array(dict(self.index)))

    def construct_tgi(self):
        for i in self.index.keys():
            for j in self.index[i]['words']:
                for k in range(len(j) - 1):
                    self.tgi[j[k:k+2]].add(i)

In [62]:
ii = InvertedIndex(directory = 'Datasets/Shakespeare', stopwords = stopwords.words('english'))

42it [00:30,  1.39it/s]


In [64]:
ii.index

defaultdict(<function __main__.InvertedIndex.__init__.<locals>.<lambda>()>,
            {'midsumm': {'count': 4,
              'words': ['midsummer'],
              'rotations': ['$midsummer',
               'r$midsumme',
               'er$midsumm',
               'mer$midsum',
               'mmer$midsu',
               'ummer$mids',
               'summer$mid',
               'dsummer$mi',
               'idsummer$m'],
              'postings': [0, 3, 7, 40]},
             'night': {'count': 41,
              'words': ['night', 'nights', 'nighted'],
              'rotations': ['$night',
               't$nigh',
               'ht$nig',
               'ght$ni',
               'ight$n',
               '$nights',
               's$night',
               'ts$nigh',
               'hts$nig',
               'ghts$ni',
               'ights$n',
               '$nighted',
               'd$nighte',
               'ed$night',
               'ted$nigh',
               'hted$nig',
            

In [96]:
class QueryHandler:
    def __init__(self):
        self.symbols = {}
    
    def rotate(self, wildcard):
        term = '$' + wildcard
        for i, l in enumerate(term, 1):
            if l == "*":
                return term[i:] + term[:i-1], True
        else:
            return wildcard, False
        
    def union(self, p1, p2):
        i = j = 0
        res = []

        while i < len(p1) and j < len(p2):
            if p1[i] == p2[j]:
                res.append(p1[i])
                i += 1
                j += 1
            elif p1[i] < p2[j]:
                res.append(p1[i])
                i += 1
            elif p1[i] > p2[j]:
                res.append(p2[j])
                j += 1
        if i < len(p1):
            res += p1[i:]
        else:
            res += p2[j:]

        return res

    def inverse(self, p1, total):
        return [i for i in total if i not in p1]
    
    def intersection(self, p1, p2):
        i = j = 0
        res = []

        while i < len(p1) and j < len(p2):
            if p1[i] == p2[j]:
                res.append(p1[i])
                i += 1
                j += 1
            elif p1[i] < p2[j]:
                i += 1
            elif p1[i] > p2[j]:
                j += 1

        return res
    
    def and_not(self, p1, p2):
        i = j = 0
        res = []

        while i < len(p1) and j < len(p2):
            if p1[i] == p2[j]:
                i += 1
                j += 1
            elif p1[i] < p2[j]:
                res.append(p1[i])
                i += 1
            elif p1[i] > p2[j]:
                j += 1
        if i < len(p1):
            res += p1[i:]

        return res
    
    def or_not(self, p1, p2, total):    
        return self.union(p1, self.inverse(p2, total))
    
    def levenshtein_distance(self, word1, word2):
        l = max(len(word1), len(word2))
        m = np.zeros((l, l))
        for i in range(len(word1)):
            m[i, 0] = i

        for j in range(len(word2)):
            m[0, j] = j

        for i in range(1, len(word1)):
            for j in range(1, len(word2)):
                if word1[i] == word2[j]:
                    m[i, j] = min(m[i-1, j] + 1, min(m[i, j-1] + 1, m[i-1, j-1]))
                else:
                    m[i, j] = min(m[i-1, j] + 1, min(m[i, j-1] + 1, m[i-1, j-1] + 1))
        return m[len(word1)-1, len(word2)-1]
    
    def spell_correct(self, misspelled, ii):
        res = []
        for i in range(len(misspelled) - 1):
            res += ii.tgi[misspelled[i:i+2]]
        
        freqs = dict(collections.Counter(res))
        freqs = {k: v for k, v in reversed(sorted(freqs.items(), key=lambda item: item[1]))}
        
        ff = defaultdict(lambda: [])
        
        for k, v in freqs.items():
            ff[v].append(k)
            
        ed = defaultdict(lambda: [])
        
        for f in list(ff.keys())[:3]:
            for w in ff[f]:
                ed[self.levenshtein_distance(misspelled, w)].append(w)

        if not ed:
            return ""
        return max([(ii.index[x]['count'], x) for x in ed[min(list(ed.keys()))]])[1]
        
    
    def match(self, term, ii):
        if term[0] == '@':
            return self.symbols[term]
        res = []
        rotated, is_wild = self.rotate(term)
        if is_wild:
            for w in ii.index.keys():
                if len(w) >= len(term)-1:
                    for r in ii.index[w]['rotations']:
                        if r[:len(rotated)] == rotated:
                            res = self.union(res, ii.index[w]['postings'])
                            break
        else:
            rotated = ps.stem(rotated)
            for w in ii.index.keys():
                if w == rotated:
                    res = self.union(res, ii.index[w]['postings'])
                    break
                    
        if not is_wild and not res:
            corrected = self.spell_correct(term, ii)
            print(term + " is corrected to " + corrected)
            if corrected:
                return self.match(corrected, ii)
        
        return res
    
    def evaluate_expr(self, expr, i, ii, total):
        print("evaluating " + expr + " and storing as @" + str(i))
        # var or not var
        # var or var
        # var and not var
        # var and var
        # var
        # not var

        keywords = ["and", "or", "not"]
        expr = expr.split(" ")
        new_symbol = '@' + str(i)

        if expr[0] == "not":
            self.symbols[new_symbol] = self.inverse(self.match(expr[1], ii), total)
            return new_symbol

        else:
            if len(expr) == 1:
                self.symbols[new_symbol] = self.match(expr[0], ii)
                return new_symbol

            if expr[1] == 'and':
                if expr[2] == 'not':
                    self.symbols[new_symbol] = self.and_not(self.match(expr[0], ii), self.match(expr[3], ii))
                    return new_symbol

                else:
                    self.symbols[new_symbol] = self.intersection(self.match(expr[0], ii), self.match(expr[2], ii))
                    return new_symbol

            else:
                if expr[2] == 'not':
                    self.symbols[new_symbol] = self.or_not(self.match(expr[0], ii), self.match(expr[3], ii), total)
                    return new_symbol

                else:
                    self.symbols[new_symbol] = self.union(self.match(expr[0], ii), self.match(expr[2], ii))
                    return new_symbol
            
    def compute(self, query, ii, total):
        stack = []
        self.symbols = {}
        i = 0
        for c in query:
            if c != ')':
                stack.append(c)
            else:
                expr = ""
                while stack:
                    char = stack.pop()
                    if char != '(':
                        expr += char
                    else:
                        stack += list(self.evaluate_expr(expr[::-1], i, ii, total))
                        i += 1
                        break
        if stack:
            self.evaluate_expr("".join(stack), i, ii, total)
            i += 1
        return self.symbols['@' + str(i - 1)]

In [97]:
query = "whore"
qh = QueryHandler()
qh.compute(query, ii, list(ii.id_to_file.keys()))

evaluating whore and storing as @0


[2, 5, 6, 8, 9, 13, 16, 19, 20, 22, 26, 30, 33, 35, 37, 38, 39]

In [98]:
qh.match("whoring", ii)

[2, 5, 6, 8, 9, 13, 16, 19, 20, 22, 26, 30, 33, 35, 37, 38, 39]

In [93]:
ii.index["whore"]

{'count': 17,
 'words': ['whores', 'whore', 'whored', 'whoring'],
 'rotations': ['$whores',
  's$whore',
  'es$whor',
  'res$who',
  'ores$wh',
  'hores$w',
  '$whore',
  'e$whor',
  're$who',
  'ore$wh',
  'hore$w',
  '$whored',
  'd$whore',
  'ed$whor',
  'red$who',
  'ored$wh',
  'hored$w',
  '$whoring',
  'g$whorin',
  'ng$whori',
  'ing$whor',
  'ring$who',
  'oring$wh',
  'horing$w'],
 'postings': [2, 5, 6, 8, 9, 13, 16, 19, 20, 22, 26, 30, 33, 35, 37, 38, 39]}

In [36]:
qh.symbols

{'@0': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41],
 '@1': [0, 3, 7, 40],
 '@2': [0, 3, 7, 40]}