In [38]:
import os
import sys
import math
import random
import pandas as pd
import regex as re

from nltk.corpus import reuters
from pprint import pprint
from sklearn.datasets import fetch_20newsgroups
from sklearn.datasets import fetch_rcv1
from sklearn.utils import shuffle
from cleantext import clean

In [39]:
# Constants
twentynews_dir = '../data/20news'

In [89]:
def load_from_path(df_path, rand=False, rand_seed=4079):
    df = pd.read_csv(df_path)
    if rand:
        df = shuffle(df, random_state=rand_seed)
    return df

In [40]:
def cust_clean(text):
    text = clean(
        text,
        lower=False,
        no_line_breaks=True,           # fully strip line breaks as opposed to only normalizing them
        no_urls=True,                  # replace all URLs with a special token
        no_emails=True,                # replace all email addresses with a special token
        no_punct=True,                 # fully remove punctuation
        )
    return re.sub(r'[#$%^&*)(|/><"\\}{]', ' ', text)
# end def

In [41]:
def save_to_path(train_df, train_path, test_df, test_path, compress=3):
    train_df.to_csv(train_path, index=False)
    test_df.to_csv(test_path, index=False)
# end def

In [87]:
# download 20news from online
def save_twentynews(save=True, default_split=True, train_test_ratio=0.7, rand_seed=4079):
    def _process_doc(doc):
        lines = doc.split('\n')
        keep = list()

        for i, l in enumerate(lines):
            if l.startswith('Subject:'):
                keep.append(l[9:])
            if l.startswith('Lines:'):
                break
        # end for

        keep += lines[i+1:]

        return ' '.join(keep)
    # end def

    train_list = list()
    test_list = list()

    if default_split:
        train_newsgroups = fetch_20newsgroups(subset='train', remove=('header', 'footers', 'quotes'))  # , remove=('headers', 'footers', 'quotes'))
        test_newsgroups = fetch_20newsgroups(subset='test', remove=('header', 'footers', 'quotes'))  # , remove=('headers', 'footers', 'quotes'))

        train_file_names = train_newsgroups.filenames
        train_target_names = train_newsgroups.target_names
        train_targets = train_newsgroups.target
        train_docs = train_newsgroups.data

        test_file_names = test_newsgroups.filenames
        test_target_names = test_newsgroups.target_names
        test_targets = test_newsgroups.target
        test_docs = test_newsgroups.data

        train_list += [dict(id=train_file_names[i].split('/')[-1], doc=_process_doc(_doc), cat=train_target_names[train_targets[i]]) for i, _doc in enumerate(train_docs)]
        pprint(x for x in train_list if x['id'] == 'C:\\Users\\User\\scikit_learn_data\\20news_home\\20news-bydate-train\\talk.politics.mideast\\75889')
        test_list += [dict(id=test_file_names[i].split('/')[-1], doc=_process_doc(_doc), cat=test_target_names[test_targets[i]]) for i, _doc in enumerate(test_docs)]
    else:
        all_newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'), shuffle=True, random_state=rand_seed)
        all_file_names = all_newsgroups.filenames
        all_file_names = [fname.split('/')[-1] for fname in all_file_names]
        all_target_names = all_newsgroups.target_names
        all_targets = all_newsgroups.target
        all_docs = all_newsgroups.data

        idx = math.floor(train_test_ratio * len(all_docs))

        train_file_names = all_file_names[:idx]
        train_targets = all_targets[:idx]
        train_docs = all_docs[:idx]

        test_file_names = all_file_names[idx:]
        test_targets = all_targets[idx:]
        test_docs = all_docs[idx:]
        del all_newsgroups, all_file_names, all_targets, all_docs

        train_list += [dict(id=train_file_names[i], doc=_doc.replace('\n', ' ').replace('\t', ' '), cat=all_target_names[train_targets[i]]) for i, _doc in enumerate(train_docs)]
        test_list += [dict(id=test_file_names[i], doc=_doc.replace('\n', ' ').replace('\t', ' '), cat=all_target_names[test_targets[i]]) for i, _doc in enumerate(test_docs)]
    # end if

    train_df = pd.DataFrame(train_list).sample(frac=1).reset_index(drop=True)
    pprint(train_df.iloc[2324]['id'])
    train_df['doc'] = train_df['doc'].apply(lambda x: cust_clean(x))
    pprint(train_df.iloc[2324])
    test_df = pd.DataFrame(test_list).sample(frac=1).reset_index(drop=True)
    test_df['doc'] = test_df['doc'].apply(lambda x: cust_clean(x))

    pprint("Number of train: {}".format(train_df.shape[0]))
    pprint("Number of test: {}".format(test_df.shape[0]))
    if save:
        train_path = os.path.join(twentynews_dir, 'train.csv')
        test_path = os.path.join(twentynews_dir, 'test.csv')
        save_to_path(train_df, train_path, test_df, test_path)
# end def

In [88]:
save_twentynews(save=True)

<generator object save_twentynews.<locals>.<genexpr> at 0x0000020FBD6C6350>
'C:\\Users\\User\\scikit_learn_data\\20news_home\\20news-bydate-train\\talk.politics.guns\\54719'
id     C:\Users\User\scikit_learn_data\20news_home\20...
doc    Re NRA address Distribution world NNTPPostingH...
cat                                   talk.politics.guns
Name: 2324, dtype: object
'Number of train: 11314'
'Number of test: 7532'


