# Preliminaries 

The `all-the-news` dataset is really quite nice and has a **TON** of data, this is actually a problem. We address this and other issues in this notebook: 

* We trim the dataset down from 2.1M rows to a much more manageable 150k rows 
* We filter out the 10% longest and 10% shortest articles (These are usually either transcripts or had a scraping error)
* Also we remove Chinese (some of the articles are non English)
* We assign a train test split that is representative of each publication 
* We remove some words that would be considered cheating like publication names, author names, and urls. If we are just learning to recognize keywords from a source like "NYT" or nyt.com in the body of an article - we aren't actually learning anything beyond a rule based approach! 

#### 1.) Downsize the data

In [1]:
import polars as pl
import random
import re 
from tqdm import tqdm

In [2]:
# Load the dataset
df = pl.read_csv("all-the-news-2-1-LARGE.csv")

In [3]:
# Define target sources and their desired counts
target_counts = {
    "The New York Times": 15000,
    "The Hill": 15000,
    "Reuters": 15000,
    "People": 15000,
    "CNN": 15000,
    "Vice": 15000,
    "Politico": 15000,
    "Buzzfeed News": 15000,
    "Economist": 15000,
    "Fox News": 15000,
}

# Filter to only target sources
df = df.filter(pl.col("publication").is_in(list(target_counts.keys())))

# remove articles with length 0 or null 
df = df.filter(pl.col("article").is_not_null())

# Compute 10th and 90th percentiles of article length
article_lengths = df.select(pl.col("article").str.len_chars().alias("length"))

lower_bound = article_lengths.select(pl.col("length").quantile(0.10)).item()
upper_bound = article_lengths.select(pl.col("length").quantile(0.90)).item()

# Filter based on these quantiles
df = df.with_columns([
    pl.col("article").str.len_chars().alias("article_length")
]).filter(
    (pl.col("article_length") > lower_bound) & (pl.col("article_length") < upper_bound)
).drop("article_length")

#### 1b.) Quick fix to remove Chinese 

In [4]:
# Remove any article with Chinese characters in the title 
df = df.filter(~pl.col("title").str.contains(r"[\u4e00-\u9fff]"))

In [5]:
# Downsample by longest articles for each publication
filtered_parts = []

for pub, target_size in target_counts.items():
    pub_df = (
        df.filter(pl.col("publication") == pub)
          .head(target_size)
    )
    filtered_parts.append(pub_df)

# Combine all filtered parts
trimmed_df = pl.concat(filtered_parts)

# Group and count by publication
final_counts = (
    trimmed_df.group_by("publication")
              .count()
              .sort("count", descending=True)
)

# Print results
print(final_counts)

shape: (10, 2)
┌────────────────────┬───────┐
│ publication        ┆ count │
│ ---                ┆ ---   │
│ str                ┆ u32   │
╞════════════════════╪═══════╡
│ People             ┆ 15000 │
│ The Hill           ┆ 15000 │
│ Vice               ┆ 15000 │
│ Economist          ┆ 15000 │
│ CNN                ┆ 15000 │
│ Buzzfeed News      ┆ 15000 │
│ Politico           ┆ 15000 │
│ The New York Times ┆ 15000 │
│ Reuters            ┆ 15000 │
│ Fox News           ┆ 15000 │
└────────────────────┴───────┘


  .count()


#### 2.) Split into train and test segments

In [6]:
import polars as pl
import numpy as np

random_state = 42
np.random.seed(random_state)

def assign_splits(group: pl.DataFrame) -> pl.DataFrame:
    n = len(group)
    indices = np.random.permutation(n)  # Shuffle indices
    split = np.where(indices < 1000, "test", "train")  # Assign first 1k as test
    return group.with_columns(pl.Series("split", split))

# Apply to each group
trimmed_df = (
    trimmed_df
    .group_by("publication", maintain_order=True)
    .map_groups(assign_splits)
)

In [7]:
print(
    trimmed_df
    .group_by("publication", "split")
    .agg(pl.count())
    .sort("publication", "split")
)

shape: (20, 3)
┌────────────────────┬───────┬───────┐
│ publication        ┆ split ┆ count │
│ ---                ┆ ---   ┆ ---   │
│ str                ┆ str   ┆ u32   │
╞════════════════════╪═══════╪═══════╡
│ Buzzfeed News      ┆ test  ┆ 1000  │
│ Buzzfeed News      ┆ train ┆ 14000 │
│ CNN                ┆ test  ┆ 1000  │
│ CNN                ┆ train ┆ 14000 │
│ Economist          ┆ test  ┆ 1000  │
│ …                  ┆ …     ┆ …     │
│ The Hill           ┆ train ┆ 14000 │
│ The New York Times ┆ test  ┆ 1000  │
│ The New York Times ┆ train ┆ 14000 │
│ Vice               ┆ test  ┆ 1000  │
│ Vice               ┆ train ┆ 14000 │
└────────────────────┴───────┴───────┘


  .agg(pl.count())


In [8]:
trimmed_df.write_csv("all-the-news-2-1-SMALL.csv")

#### 3.) Clean the data

In [9]:
publications = [
    "Politico", "The Hill", "The New York Times", "Economist",
    "Reuters", "Fox News", "Vice", "CNN", "Buzzfeed News", "People"
]

# Base patterns
base_patterns = [
    r"\b" + re.escape(pub) + r"\b" for pub in publications
] + [
    r"\bNYT\b", r"\bFox\b", r"\bBF\b", r"\bCNN\.com\b", r"\bVICE\b"
]

# Domain patterns
domain_patterns = [
    r"politico\.com", r"thehill\.com", r"nytimes\.com", 
    r"economist\.com", r"reuters\.com", r"foxnews\.com",
    r"vice\.com", r"cnn\.com", r"buzzfeednews\.com", r"people\.com"
]

# Combined regex pattern (case-insensitive)
pattern = r"(?i)(" + "|".join(base_patterns + domain_patterns) + ")"

In [10]:
# Initialize counter
row_counter = 0
total_rows = len(trimmed_df)
progress = tqdm(total=total_rows, desc="Cleaning articles")

# Modified cleaning function with progress update
def clean_article_with_progress(text: str) -> str:
    global row_counter
    text = re.sub(pattern, "[PUB]", text)
    
    # Update progress every 1000 rows (reduces overhead)
    row_counter += 1
    if row_counter % 1000 == 0:
        progress.update(1000)
    
    return text

# Apply with map_elements
cleaned_df = trimmed_df.with_columns(
    pl.col("article").map_elements(clean_article_with_progress).alias("clean_article")
)

# Close progress bar
progress.close()

  cleaned_df = trimmed_df.with_columns(
Cleaning articles: 100%|██████████| 150000/150000 [01:19<00:00, 1893.70it/s]


In [11]:
# Remove URLs
cleaned_df = cleaned_df.with_columns(
    pl.col("clean_article").str.replace_all(r"https?://\S+", "[URL]"),
)

In [12]:
# Step 1: Count how many articles each author appears in
author_counts = (
    cleaned_df
    .group_by("author")
    .agg(pl.count().alias("n_articles"))
    .filter(pl.col("author").is_not_null())
)

# Step 2: Identify authors who only appear in one article
single_use_authors = (
    author_counts
    .filter(pl.col("n_articles") == 1)
    .get_column("author")
    .to_list()
)

# Step 3: Create a mapping of article to its author (for single-use authors only)
author_mapping = (
    cleaned_df
    .select(["clean_article", "author"])
    .filter(pl.col("author").is_in(single_use_authors))
)

# Step 4: Remove only the specific author from their own article
cleaned_df = cleaned_df.with_columns(
    pl.struct(["clean_article", "author"]).map_elements(
        lambda x: x["clean_article"].replace(x["author"], "[AUTHOR]") 
        if x["author"] in single_use_authors 
        else x["clean_article"],
        return_dtype=pl.String
    ).alias("clean_article")
)

  .agg(pl.count().alias("n_articles"))


In [13]:
# Re-apply publication name cleaning to catch any missed instances
cleaned_df = cleaned_df.with_columns(
    pl.col("clean_article").str.replace_all(pattern, "[PUB]").alias("clean_article")
)

# Check if cleaning was successful
sample = cleaned_df.sample(5)
for pub in publications:
    sample_with_pub = sample.filter(pl.col("clean_article").str.contains(pub, literal=True))
    if len(sample_with_pub) > 0:
        print(f"Warning: Found {pub} in sample")



#### Sanity check - lets look at average article length

In [14]:
avg_length = (
    cleaned_df
    .group_by("publication")
    .agg(
        pl.col("clean_article").str.len_chars().mean().alias("avg_length"),
        pl.col("clean_article").str.len_chars().median().alias("median_length"),
        pl.col("clean_article").str.len_chars().min().alias("min_length"),
        pl.col("clean_article").str.len_chars().max().alias("max_length")
    )
    .sort("avg_length", descending=True)
)
print(avg_length)

shape: (10, 5)
┌────────────────────┬─────────────┬───────────────┬────────────┬────────────┐
│ publication        ┆ avg_length  ┆ median_length ┆ min_length ┆ max_length │
│ ---                ┆ ---         ┆ ---           ┆ ---        ┆ ---        │
│ str                ┆ f64         ┆ f64           ┆ u32        ┆ u32        │
╞════════════════════╪═════════════╪═══════════════╪════════════╪════════════╡
│ Economist          ┆ 4126.166067 ┆ 4127.5        ┆ 382        ┆ 6528       │
│ The New York Times ┆ 3313.591133 ┆ 3194.5        ┆ 386        ┆ 6531       │
│ CNN                ┆ 3134.509933 ┆ 2850.0        ┆ 395        ┆ 6533       │
│ Politico           ┆ 3002.584467 ┆ 2645.0        ┆ 382        ┆ 6528       │
│ Vice               ┆ 2954.5134   ┆ 2641.0        ┆ 386        ┆ 6529       │
│ Buzzfeed News      ┆ 2918.1836   ┆ 2737.0        ┆ 380        ┆ 6526       │
│ The Hill           ┆ 2907.2376   ┆ 2448.0        ┆ 373        ┆ 6525       │
│ Fox News           ┆ 2479.0442   ┆ 

In [15]:
# Drop the original article column
cleaned_df = cleaned_df.drop("article")

#### Save our cleaned data! 

In [16]:
# Save the cleaned DataFrame to a CSV file
cleaned_df.write_csv("all-the-news-2-1-SMALL-CLEANED.csv")