In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import transformers
from transformers import AutoModel, BertTokenizerFast

In [2]:
import matplotlib as mpl

In [3]:
from glob import glob

In [4]:
# download from https://drive.google.com/u/0/uc?id=1cGy4RNDV87ZHEXbiozABr9gsSrZpPaPz&export=download

In [5]:
import matplotlib.pyplot as plt

In [6]:
import seaborn as sns

In [None]:
!ls -l

In [None]:
!unzip -q blogs.zip

In [None]:
blogs = glob('blogs/*.xml')

In [None]:
blogs[0]

In [None]:
from tqdm.auto import tqdm

In [None]:
import re

In [None]:
blog_texts = blogs
blog_texts = [open(b, 'rb').read().decode('cp1252', errors='ignore') for b in tqdm(blog_texts)]
blog_texts = [re.sub(r'[\r\n]', ' ', b) for b in tqdm(blog_texts)]
blog_texts = [re.sub(r'</post>', '</post>\r\n', b) for b in tqdm(blog_texts)]
blog_texts = [re.findall(r'<post>(.*)</post>', b) for b in tqdm(blog_texts)]

In [None]:
blog_texts = [p for b in tqdm(blog_texts) for p in b]

In [None]:
len(blog_texts)

In [None]:
post_text = blog_texts
post_text = [re.sub(r'[^\w\s]', r'', p.lower()) for p in post_text]
post_text[0]

In [None]:
post_text[0]

In [None]:
post_words = post_text
post_words = [p.split() for p in post_words]
post_words = [[w.strip() for w in p] for p in post_words]
post_words = [[w for w in p if len(w) > 0] for p in post_words]
post_words[0]

In [None]:
from sklearn.feature_extraction.text import CountVectorizer

In [None]:
len(post_words)

In [None]:
sum([len(p) for p in post_words])

In [None]:
cleaned_texts = [' '.join(p) for p in post_words]

In [None]:
cv = CountVectorizer(min_df=1500)
cv.fit(cleaned_texts)
len(cv.vocabulary_)

In [None]:
cv.vocabulary_

In [None]:
vectorised_text = [
    [cv.vocabulary_.get(w, -1) for w in l] for l in tqdm(post_words)
]

In [None]:
import numpy as np

In [None]:
vectorised_text = [np.array(p) for p in vectorised_text]

In [None]:
vectorised_text[:5]

In [None]:
plt.hist([(p == -1).sum() / len(p) for p in vectorised_text if len(p) > 0], bins=20);

In [None]:
vectorised_text = [p[p >= 0] for p in vectorised_text]
vectorised_text = [p for p in vectorised_text if len(p > 0)]

In [None]:
plt.hist(np.clip([len(p) for p in vectorised_text], None, 1000), bins=20);

In [None]:
vectorised_text = [p for p in vectorised_text if len(p) >= 100]

In [None]:
plt.hist(np.clip([len(p) for p in vectorised_text], None, 1000), bins=20);

In [None]:
plt.figure(figsize=(12, 5))
plt.hist(np.clip([
    np.unique(p, return_counts=True)[1].max()
    for p in vectorised_text], 0, 50), bins=50);

In [None]:
cv_reverce_ix = {v: k for k, v in cv.vocabulary_.items()}

# KL text

In [None]:
min_len, max_len = 25, 150

In [None]:
seq_len_orig = np.array([len(p) for p in vectorised_text])
disjoint_split_point = np.where(
    seq_len_orig > 2 * min_len,
    (np.random.rand(len(seq_len_orig)) * (seq_len_orig - 2 * min_len)).astype(int) + min_len,
    (seq_len_orig / 2).astype(int),
)
seq_a = [p[:s] for p, s in zip(vectorised_text, disjoint_split_point)]
seq_b = [p[s:] for p, s in zip(vectorised_text, disjoint_split_point)]

In [None]:
def get_slice(seq):
    seq_len_orig = np.array([len(p) for p in seq])
    s_max_sample_len = np.clip(seq_len_orig.clip(0, max_len) - min_len, 0, None)
    s_len = (s_max_sample_len * np.random.rand(len(seq_len_orig))).astype(int) + min_len
    s_avl_pos = np.clip(seq_len_orig - s_len, 0, None)
    s_start = (s_avl_pos * np.random.rand(len(seq_len_orig))).astype(int)

    return [p[s: s + l] for p, s, l in zip(seq, s_start, s_len)]

In [None]:
' '.join([cv_reverce_ix.get(i) for i in seq_a[0]]), ' '.join([cv_reverce_ix.get(i) for i in seq_b[0]])

In [None]:
seq_a = get_slice(seq_a)
seq_b = get_slice(seq_b)

In [None]:
' '.join([cv_reverce_ix.get(i) for i in seq_a[0]]), ' '.join([cv_reverce_ix.get(i) for i in seq_b[0]])

In [None]:
def _kl(a, b, verbose=False):
    a = {k: v / len(a) for k, v in zip(*np.unique(a, return_counts=True))}
    b = {k: v / len(b) for k, v in zip(*np.unique(b, return_counts=True))}

    d = {k: 0.0 if a.get(k, 0.0) == 0.0 else a.get(k, 0.0) * np.log(a.get(k, 1e-12) / b.get(k, 1e-12))
         for k in set(a.keys()).union(b.keys())}
    if verbose:
        print(sorted([(k, v) for k, v in d.items()], key=lambda x: -x[1]))
        
    return sum(d.values())


def kl_distribution(l_a, l_b):        
    return np.array([_kl(a, b) for a, b in tqdm(zip(l_a, l_b))])

In [None]:
kl_pos = kl_distribution(seq_a, seq_b)

In [None]:
ix_shuffle = np.random.choice(len(seq_a), len(seq_a), replace=False)
kl_neg = kl_distribution(seq_a, [seq_b[i] for i in ix_shuffle])

In [None]:
plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
df = pd.concat([
        pd.DataFrame({'KL': kl_pos.clip(0, None),
                      'type': 'Same post sample'}),
        pd.DataFrame({'KL': kl_neg.clip(0, None), 
                      'type': 'Random post sample'}),
    ], axis=0).reset_index(drop=True)

In [None]:
len(df)

In [None]:
df.sample(frac=0.1).to_parquet('figures/kl_dis_text.parquet')

In [None]:
!ls -lh figures/kl_dis_text.parquet

In [None]:
with mpl.rc_context() as rc:
    mpl.rc("figure", figsize=(10, 10))
    mpl.rc('font', size=20)

    sns.histplot(pd.read_parquet('figures/kl_dis_text.parquet'), x="KL", hue="type", bins=50, element='step')

    # plt.title('Age group')
    mpl.pyplot.savefig('figures/kl_dis_text.pdf', format='pdf', bbox_inches='tight')
    # plt.show()