In [1]:
%matplotlib inline
import re
import nltk
import string
from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk.stem.porter import PorterStemmer
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag
import joblib
from joblib import Parallel, delayed
from tqdm.notebook import tqdm as tqdm
from tqdm.notebook import trange
import contextlib
import numpy as np
import pandas as pd
from scipy.special import digamma
from scipy.stats import norm as normal
from pickle import dump, load
from scipy.sparse import csr_matrix
from gensim.corpora.dictionary import Dictionary
from gensim.models import CoherenceModel
from gensim.parsing import strip_tags, strip_numeric, strip_multiple_whitespaces, stem_text, strip_punctuation, remove_stopwords
from gensim.parsing import preprocess_string
import os
import sys
import time






In [2]:
tqdm.pandas()

@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar given as argument"""
    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)

    old_batch_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_batch_callback
        tqdm_object.close()  

In [3]:
nltk.download('wordnet')
nltk.download('stopwords')
nltk.download('punkt')

[nltk_data] Downloading package wordnet to /home/iron/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /home/iron/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/iron/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
n = 100

seed = 42
data_train = pd.read_csv("data/papers.csv")["paper_text"]
data_train.head()

0    767\n\nSELF-ORGANIZATION OF ASSOCIATIVE DATABA...
1    683\n\nA MEAN FIELD THEORY OF LAYER IV OF VISU...
2    394\n\nSTORING COVARIANCE BY THE ASSOCIATIVE\n...
3    Bayesian Query Construction for Neural\nNetwor...
4    Neural Network Ensembles, Cross\nValidation, a...
Name: paper_text, dtype: object

In [5]:
data_train

0       767\n\nSELF-ORGANIZATION OF ASSOCIATIVE DATABA...
1       683\n\nA MEAN FIELD THEORY OF LAYER IV OF VISU...
2       394\n\nSTORING COVARIANCE BY THE ASSOCIATIVE\n...
3       Bayesian Query Construction for Neural\nNetwor...
4       Neural Network Ensembles, Cross\nValidation, a...
                              ...                        
7236    Single Transistor Learning Synapses\n\nPaul Ha...
7237    Bias, Variance and the Combination of\nLeast S...
7238    A Real Time Clustering CMOS\nNeural Engine\nT....
7239    Learning direction in global motion: two\nclas...
7240    Correlation and Interpolation Networks for\nRe...
Name: paper_text, Length: 7241, dtype: object

In [6]:
data = pd.Series(data_train).sample(n).copy()
data.reset_index(drop=True, inplace=True)
data.head()

0    Learning Monotonic Transformations for\nClassi...
1    Selective Attention for Handwritten\nDigit Rec...
2    Predicting User Activity Level In Point Proces...
3    Learning Deep Parsimonious Representations\n\n...
4    Sketch-Based Linear Value Function Approximati...
Name: paper_text, dtype: object

In [7]:
clean_stem_filters = [strip_tags,
                        strip_numeric,
                        strip_punctuation, 
                        lambda x: x.lower(),
                        lambda s: re.sub(r'\b\w{1,2}\b', ' ', s),
                        strip_multiple_whitespaces,
                        remove_stopwords
                     ]

def text_processing(document):
    lemmatizer = WordNetLemmatizer()
    return [lemmatizer.lemmatize(i) for i in preprocess_string(document, clean_stem_filters)]

In [8]:
def proc_func(data):
    with tqdm_joblib(tqdm(desc="Preprocessing", total=len(data))) as progress_bar:
        data_proc = Parallel(n_jobs=1)(delayed(text_processing)(text) for text in data)
        data_proc = pd.Series(data_proc, index=data.index, name='data')
    return data_proc

In [9]:
data_proc = proc_func(data)
data_proc.head()

Preprocessing:   0%|          | 0/100 [00:00<?, ?it/s]

0    [learning, monotonic, transformation, classifi...
1    [selective, attention, handwritten, digit, rec...
2    [predicting, user, activity, level, point, pro...
3    [learning, deep, parsimonious, representation,...
4    [sketch, based, linear, value, function, appro...
Name: data, dtype: object

In [10]:
def encode2(text, word_dict):
    return np.asarray(word_dict.doc2idx(text))

In [11]:
word_dict = Dictionary(data_proc)
data_enc = data_proc.progress_apply(lambda x: encode2(x, word_dict))

  0%|          | 0/100 [00:00<?, ?it/s]

In [12]:
D = len(data_enc)
Ns = data_enc.apply(lambda x: len(x)).to_numpy().astype(int)
N = Ns.sum()
V = len(word_dict)
Gs = [None] * D

In [13]:
print(D, Ns, N, V, len(Gs))

100 [1604  592 2862 2611 2446 2589 2907 2131 1608 1799 2259 2395 1047 1321
 2196 1310 2696 1282  801 2132 3419 2015 1826 3547 2407 1780 2038 1516
 1193  777 2846 1828 1572 2541 2847 2528 2827 1773 1419 2530  700 2764
  660 1116 2436  808  559 1652 2088 1526 1293 1869 1726 1342 2565 2642
 2086 2649 1453 2138 2054 1521 2602 2845 3020 2746  994 1920 1290 2400
 2016 2622 1118 1811 2468 2720 2158 2841 3351 1255 1848 2768 2135 2384
 1469 1850  936 1339 2267 1547 1985 3425 1183 2510 2061 2477 1228 2846
 2941 2844] 201684 14001 100


In [None]:
# https://people.eecs.berkeley.edu/~jordan/papers/hdp.pdf

In [119]:
class DP0:
    def __init__(self, a, H):
        self.a = a
        self.H = H
        self.K = 0 # Dishes
        self.phi_k = [] # Dish ID
        self.t_k = [] # Number of tables serving dish K

    def sample(self):
        denom = 1./(len(self.t_k)+self.a)
        alphas = np.asarray(self.t_k + [self.a]) * denom
        probs = np.random.dirichlet(alphas)
        k = np.random.choice(range(self.K+1), p=probs)
        if k+1 > self.K:
            self.K += 1
            self.phi_k.append(H.rvs())

        return self.phi_k[k], k
    
    def increase_dish_count(self, dish):
        if len(self.t_k) <= dish:
            self.t_k += [0] * ((dish+1) - len(self.t_k))
        self.t_k[dish] += 1

    def decrease_dish_count(self, dish):
        if len(self.t_k) > dish:
            self.t_k[dish] -= 1

class DP1:
    def __init__(self, a, DP, words):
        self.a = a
        self.G = DP
        self.t = 0
        self.psi_t = []
        self.m = []
        self.n_tk = []
        self.k_t = []
        self.theta_i = [None] * len(words)
        self.words = words
        
        self._init_sample()
        
    def _init_sample(self):
        for i in range(len(self.words)):
            denom = 1./(i+self.a)
            if len(self.n_tk) == 0:
                alphas = [self.a * denom]
            else:
                alphas = np.asarray([np.sum(t) for t in self.n_tk] + [self.a]) * denom
            probs = np.random.dirichlet(alphas)
            t = np.random.choice(range(self.t+1), p=probs)
            if t+1 > self.t:
                self.t += 1
                phi_k, k = self.G.sample()
                self.psi_t.append(phi_k)
                self.k_t.append(k)
                if len(self.m) <= k:
                    self.m.append(1)
                else:
                    self.m[k] += 1
                if len(self.n_tk) <= self.t:
                    self.n_tk.append([])
            else:
                self.psi_t.append(phi_k)
                self.k_t.append(k)

            if len(self.n_tk[t]) <= k:
                    self.n_tk[t] += [0] * ((k+1)-len(self.n_tk[t]))
                    
            self.n_tk[t][k] += 1
            self.G.increase_dish_count(k)
            self.theta_i[i] = t

In [131]:
H = normal(0, 1/2)
a = 5
a, H

(5, <scipy.stats._distn_infrastructure.rv_frozen at 0x7ff7c7beeee0>)

In [132]:
dp = DP0(a, H)
gjs = []

In [None]:
for d in data_enc:
    gjs.append(DP1(5, dp, list(d)))

In [None]:
dp.K

In [None]:
gjs[0].t