In [5]:
import numpy as np
import pandas as pd
from nltk.corpus import reuters, stopwords
from nltk.stem import PorterStemmer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
import math
from collections import Counter

In [7]:

import nltk
nltk.download('reuters')
nltk.download('stopwords')

[nltk_data] Downloading package reuters to
[nltk_data]     /Users/deepanshurao0001/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/deepanshurao0001/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [9]:

stemmer = PorterStemmer()
stop_words = set(stopwords.words('english'))

In [11]:
def load_reuters():
    documents = [(reuters.raw(fileid), reuters.categories(fileid)[0])
                 for fileid in reuters.fileids()]
    df = pd.DataFrame(documents, columns=['text', 'category'])
    return df


In [13]:
def preprocess_text(text):
    tokens = nltk.word_tokenize(text.lower())
    tokens = [stemmer.stem(word) for word in tokens if word.isalpha() and word not in stop_words]
    return tokens


In [15]:
def calculate_class_frequencies(df):
    class_freq = df['category'].value_counts()
    return class_freq.to_dict()



In [17]:
def afe_mert(df, class_frequencies):
    terms = Counter()
    class_term_frequencies = {}

    # Calculate term frequencies per class
    for _, row in df.iterrows():
        class_label = row['category']
        
        # Ensure text preprocessing happens here
        text_tokens = preprocess_text(" ".join(row['text']) if isinstance(row['text'], list) else row['text'])
        
        terms.update(text_tokens)
        if class_label not in class_term_frequencies:
            class_term_frequencies[class_label] = Counter()
        class_term_frequencies[class_label].update(text_tokens)

    # Calculate RIR (Relative Imbalance Ratio)
    max_class_size = max(class_frequencies.values())
    rir = {cls: max_class_size / freq for cls, freq in class_frequencies.items()}

    # Calculate weights for terms
    term_weights = {}
    for term in terms:
        for cls in class_frequencies.keys():
            n_tk_ci = class_term_frequencies[cls].get(term, 0)
            idf = math.log(len(df) / (1 + terms[term]))  # Smooth IDF
            A = class_term_frequencies[cls].get(term, 0)
            B = sum(class_term_frequencies[c].get(term, 0) for c in class_frequencies.keys() if c != cls)
            RIR_cls = rir[cls]

            # Weight formula as in AFE-MERT
            weight = (
                math.log(1 + n_tk_ci) * idf *
                math.log(1 + RIR_cls ** (1 / 2)) *  # p = 2
                math.log(1 + A / max(1, B))
            )
            term_weights[(term, cls)] = weight

    return term_weights


In [19]:
reuters_df = load_reuters()

In [49]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/deepanshurao0001/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [51]:
def preprocess_text(text):
    tokens = nltk.word_tokenize(text.lower())  # This will now work
    tokens = [stemmer.stem(word) for word in tokens if word.isalpha() and word not in stop_words]
    return tokens

In [53]:
reuters_df['text'] = reuters_df['text'].apply(preprocess_text)

In [54]:
class_freq = calculate_class_frequencies(reuters_df)

In [55]:
term_weights = afe_mert(reuters_df, class_freq)

In [56]:
print("Sample Term Weights:", f' {list(term_weights.items())[:100]}')


Sample Term Weights:  [(('asian', 'earn'), 0.03941571943799372), (('asian', 'acq'), 2.299301234071236), (('asian', 'crude'), 1.2974743115017033), (('asian', 'interest'), 0.24935569125416712), (('asian', 'money-fx'), 2.651679760693727), (('asian', 'trade'), 6.358267367335734), (('asian', 'grain'), 0.27920888184244064), (('asian', 'corn'), 0.0), (('asian', 'dlr'), 0.0), (('asian', 'money-supply'), 0.0), (('asian', 'ship'), 0.0), (('asian', 'coffee'), 0.0), (('asian', 'sugar'), 0.0), (('asian', 'gold'), 0.0), (('asian', 'bop'), 0.0), (('asian', 'gnp'), 0.0), (('asian', 'cpi'), 0.0), (('asian', 'cocoa'), 0.7374181531698444), (('asian', 'carcass'), 0.39338925408290176), (('asian', 'oilseed'), 0.0), (('asian', 'copper'), 0.0), (('asian', 'alum'), 0.0), (('asian', 'reserves'), 0.0), (('asian', 'jobs'), 0.0), (('asian', 'barley'), 0.0), (('asian', 'ipi'), 0.0), (('asian', 'iron-steel'), 0.0), (('asian', 'cotton'), 0.0), (('asian', 'rubber'), 0.43421075490700134), (('asian', 'nat-gas'), 0.0), (

In [57]:
import pandas as pd

df = pd.DataFrame(list(term_weights.items())[:100], columns=["Term", "Weight"])
print(df)


                 Term    Weight
0       (asian, earn)  0.039416
1        (asian, acq)  2.299301
2      (asian, crude)  1.297474
3   (asian, interest)  0.249356
4   (asian, money-fx)  2.651680
..                ...       ...
95      (export, cpi)  0.084300
96    (export, cocoa)  0.281241
97  (export, carcass)  0.269107
98  (export, oilseed)  0.188448
99   (export, copper)  0.068656

[100 rows x 2 columns]
