In [2]:
"""
This file contains code from https://github.com/lukasruff/CVDD-PyTorch/blob/aa2b033ed8216ce132ef6977da1e4fae665fb0c0/src/utils/misc.py#L94
The repo contains code from the following work:
Ruff et. al - Self-Attentive, Multi-Context One-Class Classification for Unsupervised Anomaly Detection on Text
ACL '19

as well as https://github.com/bit-ml/date
"""
import string
import re
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
nltk.download('stopwords')

def clean_text(text: str, rm_numbers=True, rm_punct=True, rm_stop_words=True, rm_short_words=True):
    """ Function to perform common NLP pre-processing tasks. """

    # make lowercase
    text = text.lower()

    # remove punctuation
    if rm_punct:
        text = text.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation)))

    # remove numbers
    if rm_numbers:
        text = re.sub(r'\d+', '', text)

    # remove whitespaces
    text = text.strip()

    # remove stopwords
    if rm_stop_words:
        stop_words = set(stopwords.words('english'))
        word_tokens = word_tokenize(text)
        text_list = [w for w in word_tokens if not w in stop_words]
        text = ' '.join(text_list)

    # remove short words
    if rm_short_words:
        text_list = [w for w in text.split() if len(w) >= 3]
        text = ' '.join(text_list)

    return text

[nltk_data] Downloading package punkt to
[nltk_data]     /home/sofiacolella/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/sofiacolella/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [8]:
#!pip install wget
import wget
import pandas as pd
import numpy as np
import os

import string

def dump_data(phase, name, text, path='./ag_od/'):
    full_path = os.path.join(path, phase)

    if not os.path.exists(full_path):
        os.makedirs(full_path)

    with open(f'{full_path}/{name}.txt', 'w') as fp:
        fp.write(text)
        print('Succesfully written', f'{full_path}/{name}.txt')

def export_ds(subsets, phase):
    raw_text = {
        "world" : '',
        "sports" : '',
        "business": '',
        "sci": ''
    }

    for data_label, subset in zip(subsets, raw_text):
        for text in subsets[data_label]:
            text = clean_text(text)
            raw_text[subset] += f'\n\n{text}'

    for subset in raw_text:
        dump_data(phase, subset, raw_text[subset])

! wget "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv"
! wget "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv"

ag_train = pd.read_csv('./train.csv', header=None)
ag_test = pd.read_csv('./test.csv', header=None)
ag_test.columns = ['label', 'title', 'description']
ag_train.columns = ['label', 'title', 'description']

subsets_test = {
    "1": [],
    "2": [],
    "3": [],
    "4": []
}

for idx, el in enumerate(np.array(ag_test)):
    label = el[0]
    text = el[1] + ' ' + el[2]
    subsets_test[f'{label}'].append(text)

print('Total samples (test)', idx+1)

export_ds(subsets_test, 'test')

subsets_train = {
    "1": [],
    "2": [],
    "3": [],
    "4": []
}

for idx, el in enumerate(np.array(ag_train)):
    label = el[0]
    text = el[1] + ' ' + el[2]
    subsets_train[f'{label}'].append(text)

print("Total samples (train):", idx+1)

export_ds(subsets_train, 'train')

--2022-07-25 04:25:59--  https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 29470338 (28M) [text/plain]
Saving to: ‘train.csv.1’


2022-07-25 04:25:59 (298 MB/s) - ‘train.csv.1’ saved [29470338/29470338]

--2022-07-25 04:25:59--  https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1857427 (1.8M) [text/plain]
Saving to: ‘test.csv’


2022-07-25 04:25:59 (139 M