In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
from sklearn.utils import shuffle
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

################################################################################################

def create_dataframe(doc_tf, doc_targets):
    docs = []
    for i, bow in enumerate(doc_tf):
        d = {'doc_id': i, 'bow': bow, 'label': doc_targets[i]}
        docs.append(d)
    return pd.DataFrame.from_dict(docs)

train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes')) 
test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

################################################################################################
count_vect = CountVectorizer(stop_words='english', max_features=10000, max_df=0.8, min_df=3)
train_tf = count_vect.fit_transform(train.data)
test_tf = count_vect.transform(test.data)

train_df = create_dataframe(train_tf, train.target)
test_df = create_dataframe(test_tf, test.target)

def get_doc_length(doc_bow):
    return doc_bow.sum()

# remove an empty document
train_df = train_df[train_df.bow.apply(get_doc_length) > 0]
test_df = test_df[test_df.bow.apply(get_doc_length) > 0]

# split test and cv
num_train = len(train_df)
num_test = len(test_df) // 2
num_cv = len(test_df) - num_test

test_df = shuffle(test_df)
cv_df = test_df.iloc[:num_cv]
test_df = test_df.iloc[num_cv:]

# set doc_id as an index
train_df.set_index('doc_id', inplace=True)
test_df.set_index('doc_id', inplace=True)
cv_df.set_index('doc_id', inplace=True)

In [None]:
# save the dataframes
train_df.to_pickle('../dataset/ng20/train.df.pkl')
test_df.to_pickle('../dataset/ng20/test.df.pkl')
cv_df.to_pickle('../dataset/ng20/cv.df.pkl')

# save vocab
with open('../dataset/ng20/vocab.pkl', 'wb') as handle:
    pickle.dump(count_vect.vocabulary_, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [9]:
import os
from os.path import join
import numpy as np
import pandas as pd
import pickle
import torch
from torch.utils.data import Dataset

class Newsgroups20Dataset(Dataset):
    """Newsgroups20 dataset."""

    def __init__(self, data_dir, download=False, subset='train', bow_format='tf'):
        """
        Args:
            data_dir (string): Directory for loading and saving train, test, and cv dataframes.
            download (boolean): Download newsgroups20 dataset from sklearn if necessary.
            subset (string): Specify subset of the datasets. The choices are: train, test, cv.
            bow_format (string): A weight scheme of a bag-of-words document. The choices are:
                tf (term frequency), tfidf (term freq with inverse document frequency), bm25.
        """
        self.data_dir = data_dir
        self.subset = subset
        self.bow_format = bow_format
        self.df = self.load_df('{}.df.pkl'.format(subset))
        
    def load_df(self, df_file):
        df_file = os.path.join(self.data_dir, df_file)
        return pd.read_pickle(df_file)
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        doc_bow = self.df.iloc[idx].bow
        doc_bow = torch.from_numpy(doc_bow.toarray().squeeze().astype(np.float32))
        label = self.df.iloc[idx].label
        return (doc_bow, label)

In [10]:
train_set = Newsgroups20Dataset('../dataset/ng20', subset='train', download=True, bow_format='tf')
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=64, shuffle=True)

In [12]:
for xb, yb in train_loader:
    break

In [13]:
xb.size()

torch.Size([64, 10000])

In [14]:
yb

tensor([  5,  10,   2,  14,   8,   4,   5,   3,   1,   6,  16,  15,
          3,  12,   4,  14,   4,   8,   9,  15,  13,  15,   9,   0,
         15,   1,   9,  17,  13,   2,  10,   2,  13,  12,   1,   9,
         12,   9,  10,  17,  13,  10,   6,   9,   2,   1,  10,  17,
         13,  14,  11,   9,   1,   6,   9,  12,   5,   6,   5,  15,
          1,   2,   0,  11])