In [15]:
import os
import numpy as np
from collections import Counter
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.svm import LinearSVC

In [16]:
data_root = '20_newsgroups'

folders = [f for f in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, f))]
folders = sorted(folders)
classes = folders
class_to_index = {c: i for i, c in enumerate(classes)}

print("Detected classes:", classes)
print("Number of classes:", len(classes))

if len(classes) != 20:
    raise ValueError("The number of detected classes is not 20. Please check your dataset.")

data_paths = []
labels = []
for c in classes:
    class_path = os.path.join(data_root, c)
    files = os.listdir(class_path)
    files = [f for f in files if os.path.isfile(os.path.join(class_path, f))]
    for f in files:
        data_paths.append(os.path.join(class_path, f))
        labels.append(class_to_index[c])

data_paths = np.array(data_paths)
labels = np.array(labels)

print("Total documents:", len(data_paths))

Detected classes: ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']
Number of classes: 20
Total documents: 19997


In [17]:
np.random.seed(42)
indices = np.arange(len(data_paths))
np.random.shuffle(indices)
data_paths = data_paths[indices]
labels = labels[indices]

In [18]:
#set 300 stopwords
word_counts = Counter()
for path in data_paths:
    with open(path, 'r', encoding='utf-8', errors='ignore') as f:
        text = f.read()
        words = text.strip().split()
        word_counts.update(words)

most_common_300 = [w for w, _ in word_counts.most_common(300)]
stop_words = list(most_common_300)

print("Number of custom stop words:", len(stop_words))

Number of custom stop words: 300


In [19]:
pipeline = Pipeline([
    ('tfidf', TfidfVectorizer(
        input='filename',
        encoding='utf-8',
        decode_error='ignore',
        lowercase=True,
        stop_words=stop_words,
        token_pattern=r'\b\w+\b'
    )),
    ('clf', LinearSVC(random_state=42, max_iter=2000))
])

In [20]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_scores = cross_val_score(pipeline, data_paths, labels, cv=skf, scoring='accuracy')



In [21]:
print("5-fold cross-validation scores:", cv_scores)
print("Mean accuracy:", np.mean(cv_scores))

5-fold cross-validation scores: [0.94025    0.94325    0.93873468 0.93848462 0.93473368]
Mean accuracy: 0.9390905976494125
