In [249]:
import os
import pickle
import random
from collections import defaultdict
import numpy as np
import scipy.io
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.corpus import stopwords

In [116]:
twenty = fetch_20newsgroups(subset='all', shuffle=False, remove=('headers', 'footers'))

In [117]:
print('Number of articles: ' + str(len(twenty.data)))
print('Number of different categories: ' + str(len(twenty.target_names)))
twenty.target_names

Number of articles: 18846
Number of different categories: 20


['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [119]:
twenty_grouped = defaultdict(list)

for i, article in enumerate(twenty.data):
    group_num = twenty.target[i]
    twenty_grouped[group_num].append((group_num, article))

In [120]:
# Split equally by group; returns (group index, data) pair
def tr_va_ts_split(grouped, tr_prop, va_prop, ts_prop):
    assert tr_prop + va_prop + ts_prop == 1.
    train, valid, test = list(), list(), list()
    for i in range(len(grouped.keys())):
        num_tr = int(tr_prop * len(grouped[i]))
        num_va = int(va_prop * len(grouped[i]))
        train.extend(grouped[i][: num_tr])
        valid.extend(grouped[i][num_tr : (num_tr + num_va)])
        test.extend(grouped[i][(num_tr + num_va) :])
    random.Random(5).shuffle(train)
    random.Random(5).shuffle(valid)
    random.Random(5).shuffle(test)
    return train, valid, test

In [121]:
train, valid, test = tr_va_ts_split(twenty_grouped, 0.6, 0.1, 0.3)

In [122]:
print(len(train))
print(len(valid))
print(len(test))

11301
1876
5669


In [345]:
tf_vect = TfidfVectorizer(stop_words=stopwords.words('english'),
                          use_idf=False,
                          norm=None,
                          token_pattern=r"(?u)\b[a-zA-Z][a-zA-Z]+\b")

# drop docs that don't have at least min_cnt words (can only check after tfidf transform)
def split_and_drop(mat, labels, min_cnt=10, drop=True, verbose=True):
    counts = np.asarray(np.split(mat.data.astype(np.uint8), mat.indptr[1:-1]))
    tokens = np.asarray(np.split(mat.indices.astype(np.uint16), mat.indptr[1:-1]))
    small_idxs = []
    if drop:
        for i in range(len(counts)):
            if counts[i].sum() < min_cnt:
                small_idxs.append(i)
        if verbose:
            print(f'Deleted {len(small_idxs)} docs with <{min_cnt} words')
    return np.delete(counts, small_idxs), np.delete(tokens, small_idxs), np.delete(labels, small_idxs), small_idxs

def split_and_drop_mult(mats, labels, min_cnt=10, verbose=True):
    counts_list, tokens_list = [], []
    small_idxs = set()
    for j, mat in enumerate(mats):
        if j > 0:
            min_cnt = 1
        counts = np.asarray(np.split(mat.data.astype(np.uint8), mat.indptr[1:-1]))
        tokens = np.asarray(np.split(mat.indices.astype(np.uint16), mat.indptr[1:-1]))
        counts_list.append(counts)
        tokens_list.append(tokens)
        for i in range(len(counts)):
            if counts[i].sum() < min_cnt:
                small_idxs.add(i)
    if verbose:
        print(f'Deleted {len(small_idxs)} docs with <{min_cnt} words')
    small_idxs = list(small_idxs)
    for i in range(len(mats)):
        counts_list[i] = np.delete(counts_list[i], small_idxs)
        tokens_list[i] = np.delete(tokens_list[i], small_idxs)
    labels = np.delete(labels, small_idxs)
    return counts_list, tokens_list, labels, small_idxs

def process(train, valid, test):
    tr_labels, tr_data = [list(t) for t in zip(*train)]
    va_labels, va_data = [list(t) for t in zip(*valid)]
    ts_labels, ts_data = [list(t) for t in zip(*test)]
    
    tf_vect.set_params(min_df=50, vocabulary=None)
    tr_mat = tf_vect.fit_transform(tr_data).sorted_indices()
    vocab = tf_vect.get_feature_names()
    
    tf_vect.set_params(min_df=1, vocabulary=vocab)
    vocab2 = tf_vect.get_feature_names()
    va_mat = tf_vect.fit_transform(va_data).sorted_indices()
    ts_mat = tf_vect.fit_transform(ts_data).sorted_indices()
    
    tr_counts, tr_tokens, tr_labels, _ = split_and_drop(tr_mat, tr_labels)
    va_counts, va_tokens, va_labels, _ = split_and_drop(va_mat, va_labels)
    
    ts_clean_data = ts_data
    ts_h1_data = [article[: len(article) // 2] for article in ts_clean_data]
    ts_h2_data = [article[len(article) // 2 :] for article in ts_clean_data]
    ts_h1_mat = tf_vect.fit_transform(ts_h1_data).sorted_indices()
    ts_h2_mat = tf_vect.fit_transform(ts_h2_data).sorted_indices()
    ts_counts, ts_tokens, ts_labels, _ = split_and_drop_mult([ts_mat, ts_h1_mat, ts_h2_mat], ts_labels)
    counts = [tr_counts, va_counts] + ts_counts
    tokens = [tr_tokens, va_tokens] + ts_tokens
    return counts, tokens, [tr_labels, va_labels, ts_labels], vocab
    
def save(counts, tokens, labels, vocab, path, prefix):
    with open(os.path.join(path, 'vocab.pkl'), 'wb') as f:
        pickle.dump(vocab, f)
    with open(os.path.join(path, 'labels.pkl'), 'wb') as f:
        pickle.dump({'train': labels[0], 'valid': labels[1], 'test': labels[2]}, f)
    for i, name in enumerate(['tr', 'va', 'ts', 'ts_h1', 'ts_h2']):
        scipy.io.savemat(os.path.join(path, f'{prefix}_{name}_counts.mat'), {'counts': counts[i]})
        scipy.io.savemat(os.path.join(path, f'{prefix}_{name}_tokens.mat'), {'tokens': tokens[i]})
    print('Saved!')

In [344]:
counts, tokens, labels, vocab = process(train, valid, test)
save(counts, tokens, labels, vocab, './../../../ETM/data/my_20ng', 'bow')

Deleted 544 docs with <10 words
Deleted 91 docs with <10 words
Deleted 290 docs with <1 words


In [348]:
print(f'Num train articles: {len(counts[0])}')
print(f'Num valid articles: {len(counts[1])}')
print(f'Num test articles:  {len(counts[2])}')

Num train articles: 10757
Num valid articles: 1785
Num test articles:  5379
