In [None]:
from multiprocessing import Pool
from typing import List, Dict, Tuple, Callable, Iterable
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from collections import Counter, defaultdict
import regex as re
import nltk
from string import punctuation

In [None]:
BOS = '<BOS>'
EOS = '<EOS>'

In [None]:
class TrieNode:
    
    def __init__(self, char: str) -> None:
        self.char: str = char
        self.is_end: bool = False
        self.count: int = 0
        self.children: Dict = {}

class Trie:

    def __init__(self, words: Iterable[str] = None):
        self.root = TrieNode("")
        
        if words:
            for word in words:
                self.insert(word)
        
    def insert(self, word: str) -> None:
        
        node: TrieNode = self.root
        
        for char in word:
            if char in node.children:
                node = node.children[char]
            else:
                new_node: TrieNode = TrieNode(char)
                node.children[char] = new_node
                node = new_node
        node.is_end = True
        node.count += 1
        
    def dfs(self, node, prefix):
        
        if node.is_end:
            self.output.append((prefix + node.char, node.count))
            
        for child in node.children.values():
            self.dfs(child, prefix + node.char)
    
    def query(self, prefix: str) -> List[Tuple[str, int]]:
        
        self.output = []
        node = self.root
        
        for char in prefix:
            if char in node.children:
                node = node.children[char]
            else:
                return []
            
        self.dfs(node, prefix[:-1])
        
        return sorted(self.output, key=lambda x: -x[1])


Uncomment if you are using colab

In [None]:
!mkdir ./data
!wget https://raw.githubusercontent.com/vadim0912/MLIntro2021/main/lecture08/data/train.csv.zip -O ./data/train.csv.zip
!wget https://raw.githubusercontent.com/vadim0912/MLIntro2021/main/lecture08/data/test.csv.zip -O ./data/test.csv.zip

In [None]:
train_df = pd.read_csv("./data/train.csv.zip")
test_df = pd.read_csv("./data/test.csv.zip")
train_df.head()

In [None]:
def ngrams_count(token_text, n_ngrams):

    counts = defaultdict(Counter)
    for tokens in token_text:
        for ngram in nltk.ngrams(tokens, n=n_ngrams, pad_left=True, left_pad_symbol=BOS, pad_right=True, right_pad_symbol=EOS):
            counts[' '.join(ngram[:-1])][ngram[-1]] += 1

    return counts

In [None]:
class WordPredict:
    
    def __init__(self, tokens, n_ngrams):
        
        self.n_ngrams = n_ngrams
        self.tokens = tokens
    
    def compute_count(self):

        for prev, dist in ngrams_count(self.tokens, self.n_ngrams).items():
            self.proba[prev] = Counter({
                token : count / sum(dist.values()) for token, count in dist.items()})

In [None]:
def remove_bad_sym(text):
    return re.sub('[0-9^\p,.\-?!–«»"":+]', ' ', text)

def get_token(text):
    return nltk.wordpunct_tokenize(text)

In [None]:
train_df.sentence = train_df.sentence.apply(lambda x : remove_bad_sym(x))
df_sentence = train_df.sentence.apply(lambda x : get_token(x))

In [None]:
tokens = df_sentence.values.tolist()

In [None]:
def trie_predict(sentence: str, trie: Trie) -> str:

    model = WordPredict(tokens, 2)
    model.compute_count()

    what_pred = sentence.split()[-1]
    pred = model.proba[sentence.split()[-2:-1][0]]

    if pred:

        max_prob = 0
        pred_word = ''

        for word, prob in pred.items():
            if word.startswith(what_pred) and max_prob < prob:
                max_prob = prob
                pred_word = word

        if max_prob != 0:
            return pred_word
            
    pred_trie = trie.query(what_pred)

    return pred_trie[0][0] if pred_trie else what_pred

def pd_func(df) -> pd.DataFrame:
    df['token'] = df['prefix'].apply(lambda x: trie_predict(x, trie))
    return df

def parallelize_dataframe(df: pd.DataFrame, func: Callable, n_cores: int) -> pd.DataFrame:
    with Pool(n_cores) as pool:
        results = pool.map(func, np.array_split(df, n_cores))
    return pd.concat(results)

In [None]:
trie = Trie(
     word for sentence in tokens for word in sentence
)

In [None]:
pred = parallelize_dataframe(test_df, pd_func, n_cores=4)

In [None]:
pred

In [None]:
pred[['index', 'token']].to_csv("simple_baseline.csv", index=False)