In [1]:
#if you dont have datasets installed do so with the following line
#!pip install datasets
from datasets import load_dataset

# load AG News dataset
dataset = load_dataset("ag_news")

# check available splits
print(dataset)

# take a look at one sample
print(dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})
{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}


In [2]:
print(dataset["train"][0:20])



In [12]:
df_train1=dataset["train"].to_pandas()
df_test1=dataset["test"].to_pandas()
#get an overview of dataset and its content
print(dataset['train'].info)
print(df_train.head())

DatasetInfo(description='', citation='', homepage='', license='', features={'text': Value('string'), 'label': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'])}, post_processed=None, supervised_keys=None, builder_name='parquet', dataset_name='ag_news', config_name='default', version=0.0.0, splits={'train': SplitInfo(name='train', num_bytes=29832303, num_examples=120000, shard_lengths=None, dataset_name='ag_news'), 'test': SplitInfo(name='test', num_bytes=1880424, num_examples=7600, shard_lengths=None, dataset_name='ag_news')}, download_checksums={'hf://datasets/ag_news@eb185aade064a813bc0b7f42de02595523103ca4/data/train-00000-of-00001.parquet': {'num_bytes': 18585438, 'checksum': None}, 'hf://datasets/ag_news@eb185aade064a813bc0b7f42de02595523103ca4/data/test-00000-of-00001.parquet': {'num_bytes': 1234829, 'checksum': None}}, download_size=19820267, post_processing_size=None, dataset_size=31712727, size_in_bytes=51532994)
                                                    

In [71]:
#the classes are balanced
print(df_train['label'].value_counts())

label
2    30000
3    30000
1    30000
0    30000
Name: count, dtype: int64


In [9]:
df_train["text"]=df_train1["text"].str.lower()
df_test["text"]=df_test1["text"].str.lower()
print(df_train.head())

                                                                                                                                                                                                                                                                         text  \
0                                                                                                                            wall st. bears claw back into the black (reuters) reuters - short-sellers, wall street's dwindling\band of ultra-cynics, are seeing green again.   
1  carlyle looks toward commercial aerospace (reuters) reuters - private investment firm carlyle group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.   
2                                    oil and economy cloud stocks' outlook (reuters) reuters - soaring crude prices plus worries\about the economy and the outlook for earnings are e

In [10]:
print(df_train["text"][0:5])

0                                                                                                                              wall st. bears claw back into the black (reuters) reuters - short-sellers, wall street's dwindling\band of ultra-cynics, are seeing green again.
1    carlyle looks toward commercial aerospace (reuters) reuters - private investment firm carlyle group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.
2                                      oil and economy cloud stocks' outlook (reuters) reuters - soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.
3              iraq halts oil exports from main southern pipeline (reuters) reuters - authorities have halted oil export\flows from the main pipeline in southern iraq after\intelligenc

In [19]:
#remove all other characters and keep only alphabet
#numbers were removed due to the fact that they would represent different encodings  
#would remove also news outlets as they might affect the resultbut due to 
#the different structures of text proved more difficult so didnt do it
df_train["text"]=df_train["text"].str.replace(r'[^a-z\s]', ' ', regex=True)
df_test["text"]=df_test["text"].str.replace(r'[^a-z\s]', ' ', regex=True)
print(df_train.head())

                                                                                                                                                                                                                                                                         text  \
0                                                                                                                            wall st  bears claw back into the black  reuters  reuters   short sellers  wall street s dwindling band of ultra cynics  are seeing green again    
1  carlyle looks toward commercial aerospace  reuters  reuters   private investment firm carlyle group  which has a reputation for making well timed and occasionally controversial plays in the defense industry  has quietly placed its bets on another part of the market    
2                                    oil and economy cloud stocks  outlook  reuters  reuters   soaring crude prices plus worries about the economy and the outlook for earnings are e

In [18]:
import pandas as pd
#did this so i could see whole content
pd.set_option('display.max_colwidth', None)
#check if any are shorter than certain length since they may be invalid
empty = (df_train["text"].str.len() <= 30)
#Checked to see which ones are shorter and see if they make sense
df_sorted = df_train.sort_values(by="text", key=lambda x: x.str.len())
print(df_sorted.head(5))
print("\nEmpty strings :\n", empty.sum())
print(df_train.head())

                                                                                                        text  \
27308   grocer takes a tumble labor problems and debt retirement contribute to kroger's second-quarter slip.   
44886   apple to open 'borderless' euro music store questions company's 'one country, one store' restriction   
110719  amazon's 'morning nightmare' lasts 11 days, and counting peak-time woes bedevil merchants, customers   
103555  lakers' buss wants to make up with o'neal (ap) ap - hey, shaq, jerry buss wants to be buddies again.   
105338  sony cyber-shot dsc p150 &lt;strong&gt;review&lt;/strong&gt; fast, inexpensive and highly pocketable   

        label  
27308       2  
44886       3  
110719      3  
103555      1  
105338      3  

Empty strings :
 0
                                                                                                                                                                                                                   

In [57]:
#if you dont have the following you need to install
#!pip install nltk
!pip install spacy
!python -m spacy download en_core_web_sm
!pip install tqdm

import nltk
from nltk.corpus import wordnet, stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from tqdm.notebook import tqdm

import spacy
import re

nltk.download('stopwords')
nlp = spacy.load("en_core_web_sm")
stop_words = set(stopwords.words('english'))
#made functions so i wouldnt have to run them again each time
def remove_stopwords(sentence):
    words = sentence.split()
    filtered = [w for w in words if w not in stop_words]
    return " ".join(filtered)
    
def lemmatize_texts(texts, batch_size=50):
    lemmatized_texts = []
    for doc in tqdm(nlp.pipe(texts, batch_size=batch_size), total=len(texts)):
        lemmatized_texts.append(" ".join([token.lemma_ for token in doc]))
    return lemmatized_texts

Collecting en-core-web-sm==3.8.0
  Using cached https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[38;5;2m[+] Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


[nltk_data] Downloading package stopwords to C:\Users\MY
[nltk_data]     COMPUTER\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [58]:

df_train["text"] = df_train["text"].apply(remove_stopwords)
print(df_train.head())

                                                                                                                                                                                                                     text  \
0                                                                                                        wall st bears claw back black reuters reuters short sellers wall street dwindling band ultra cynics seeing green   
1  carlyle looks toward commercial aerospace reuters reuters private investment firm carlyle group reputation making well timed occasionally controversial plays defense industry quietly placed bets another part market   
2                                                  oil economy cloud stocks outlook reuters reuters soaring crude prices plus worries economy outlook earnings expected hang stock market next week depth summer doldrums   
3  iraq halts oil exports main southern pipeline reuters reuters authorities halted oil export flows main pipeline s

In [59]:
df_train['text_lemmatized'] = lemmatize_texts(df_train['text'].tolist())

  0%|          | 0/120000 [00:00<?, ?it/s]

In [60]:
#perform lemmatizing
print(df_train["text_lemmatized"].head())

0                                                                                                           wall st bears claw back black reuters reuters short sellers wall street dwindle band ultra cynic see green
1    carlyle look toward commercial aerospace reuters reuters private investment firm carlyle group reputation making well time occasionally controversial play defense industry quietly place bet another part market
2                                                          oil economy cloud stock outlook reuters reuters soar crude price plus worry economy outlook earning expect hang stock market next week depth summer doldrum
3           iraq halt oil export main southern pipeline reuter reuter authority halt oil export flow main pipeline southern iraq intelligence show rebel militia could strike infrastructure oil official say saturday
4                           oil price soar time record pose new menace we economy afp afp tearaway world oil price topple record strain wall

In [61]:
df_test["text"] = df_test["text"].apply(remove_stopwords)
df_test["text_lemmatized"]=lemmatize_texts(df_test['text'].tolist())

  0%|          | 0/7600 [00:00<?, ?it/s]

In [65]:
#perform tfIDF
from sklearn.feature_extraction.text import TfidfVectorizer

# Initialize TF-IDF Vectorizer
tfidf_vectorizer = TfidfVectorizer(
    # remove stopwords
    ngram_range=(1, 2)        # use unigrams and bigrams only
)

# Fit and transform the lemmatized text column
X_train = tfidf_vectorizer.fit_transform(df_train['text_lemmatized'])
X_test = tfidf_vectorizer.transform(df_test['text_lemmatized'])



In [72]:
# checks to see which are the top words acros dataset some of them are topic specific
# like oil stock world but others are meaningless as single words
import numpy as np

feature_names = tfidf_vectorizer.get_feature_names_out()
tfidf_sum = np.array(X_train.sum(axis=0)).flatten()
tfidf_df = pd.DataFrame({
    'word': feature_names,
    'tfidf_sum': tfidf_sum
})
top_words = tfidf_df.sort_values(by='tfidf_sum', ascending=False).head(20)
print(top_words)

            word    tfidf_sum
968927       say  1110.094042
736607       new   954.266423
51107         ap   904.802301
933203   reuters   739.032309
1264486     year   721.213930
478237        gt   705.275096
646249        lt   703.485487
882573      quot   592.963317
222624   company   560.019814
1174395      two   558.769236
410063     first   542.536872
763325       oil   519.973279
1255476    world   514.930897
646283     lt gt   505.181288
443567      game   487.726324
767034       one   473.020706
919154    report   471.143897
852205     price   455.607756
51224      ap ap   441.883095
1071815    stock   435.566774


In [77]:
#saveing files so we dont have to calculate again
import joblib

to_save = {
    "df_train": df_train,
    "df_test": df_test,
    "tfidf_vectorizer": tfidf_vectorizer,

}


joblib.dump(to_save, "my_vars.joblib")

['my_vars.joblib']