# COMP 345: Assignment 1 (104 points)

This assignment will help you practice regular expressions for text pattern matching, working with dictionaries to analyze word frequencies and compute statistics, and implementing a Naive Bayes text classifier from scratch including tokenization, probability computation, and model evaluation.


## Instructions

For each exercise in this notebook:
- Read the problem description carefully
- Write your solution **only** between the designated code markers:
  ```python
  ### WRITE YOUR CODE BELOW
  # Your code here
  ### END CODE HERE
  ```
- Do **not** modify any code outside these markers
- Do **not** change function signatures (function name, parameters, return type)
- Do **not** change test cells or example outputs
- Run each cell to verify your solution works correctly

## How to Submit the Assignment?

1. **Create a copy of the assignment**: Before starting, create a copy of this notebook in your Google Drive by clicking `File > Save a copy in Drive`. This ensures your progress is saved as you work.

2. **Complete all exercises**: Work through each exercise in your copied notebook, writing your solutions between the designated code markers.

3. **Download the notebook**: Once you have completed all exercises, download the notebook as an `.ipynb` file by clicking `File > Download > Download .ipynb` as shown below:

<p align="center">
  <img src="https://drive.google.com/thumbnail?id=1SZc-bK8PzBmMsgI5iUY4g9AvKk128qO4&sz=w800" alt="Download notebook from Colab" width="800">
</p>

4. **Rename the file**: Rename the downloaded file to `<student_id>_A1.ipynb` (e.g., `9284827_A1.ipynb`).

5. **Submit on Gradescope**: Upload the renamed notebook file to Gradescope.
https://www.gradescope.ca/courses/35049/assignments/178742

## Questions?

If you have any questions about the assignment, please reach out to the TA:
- Slack: `#assignment-1`
- Email: `jay.gala@mila.quebec` (**Note:** Please include `[COMP 345]` in the subject)
- Office Hours: Tuesdays and Thursdays, 2:45 pm – 3:45 pm in McConnell Engineering Building Room 110 (Jan 19, 22, 27, 29)



## AI Usage Disclosure

If you used any AI tools while completing this assignment, you must disclose this below. Please specify:

1. **Which AI tool(s) did you use?**
2. **For which exercises or tasks did you use them?**
3. **How did you use them?**

**Note:** Please refer to the course website for details on the [Generative AI Policy](https://mcgill-nlp.github.io/teaching/comp345-ling345-W26/#generative-ai-policy).

---

**Your disclosure (edit this cell):**

- AI tool(s) used: <Claude.ai>
- Exercises: <for each exercise, I used it to check my work>
- How used: <I provided the questions and my code, and asked if it was correct or it needed improvements. I improved my answers based on the feedback when necessary.>

---


In [None]:
from typing import List, Dict, Tuple, Optional

import re
import math
import random
import pandas as pd

## 1. Regular Expressions (35 points)

This section tests your understanding of regular expressions (regex) for pattern matching in text. You'll work with Python's `re` module to find patterns, extract information, validate formats, and process text data using various regex features including character classes, quantifiers, anchors, and groups.

**Note on `re` module methods:** While the lecture primarily covered `re.findall()`, this assignment will also use other methods from the `re` module:

- **`re.findall(pattern, string)`**: Returns a list of all matches in the string. Use this when you need to extract all occurrences of a pattern.
- **`re.search(pattern, string)`**: Searches for the first occurrence of the pattern anywhere in the string. Returns a match object if found, `None` otherwise. Use this when you need to check if a pattern exists.
- **`re.match(pattern, string)`**: Checks if the pattern matches at the beginning of the string. Returns a match object if it matches, `None` otherwise. Use this for validation when the entire string should match a pattern (often with anchors `^` and `$`).

For more details and examples, refer to the [Python `re` module documentation](https://docs.python.org/3/library/re.html) and [Python Regular Expressions Google blogpost](https://developers.google.com/edu/python/regular-expressions).

### 1.1: Email Validator (9 points)

You're building a simple form validator that needs to check if user input contains valid email addresses. While real email validation is complex, you'll implement a basic version using regex patterns to match common email formats.

#### 1.1.1: Extract All Emails (3 points)

This function extracts all email addresses from a given text. An email should match the pattern:
- Local part (before @): one or more characters including letters (a-z, A-Z), digits (0-9), and punctuation characters (. _ % + -)
- @ symbol
- Domain name (after @): one or more characters including letters, digits, and punctuation characters (. -)
- A dot (.) followed by 2 or more letters for the domain extension

**Note:** The first character should be always start as alphanumeric. So the string that starts with `%abc123@gmail.com` will be invalid email.

Arguments:
- `text (str)`: The input text to search for email addresses.

Returns:
- `emails (List[str])`: A list of all email addresses found in the text (in order of appearance).

Examples:
```python
>>> text = "Contact us at support@example.com or sales@company.org for help."
>>> extract_emails(text)
['support@example.com', 'sales@company.org']
>>> text = "No emails here!"
>>> extract_emails(text)
[]
```

In [None]:
def extract_emails(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # first char: alphanumeric, rest: can have letters, digits, . _ % + -
    # @ + (domain: letters, digits, . - ) + "."  +  (2+ letters) 
    
    pattern = r'[a-zA-Z0-9][a-zA-Z0-9._%+-]*@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
    emails = re.findall(pattern, text)
    
    ### END CODE HERE
    return emails


#### 1.1.2: Validate Email Format (4 points)

This function checks if a single string is a valid email address. The entire string must match the email pattern exactly with no extra characters before or after. A valid email should have:
- Local part (before @): one or more characters including letters (a-z, A-Z), digits (0-9), and punctuation characters (. _ % + -)
- @ symbol
- Domain name (after @): one or more characters including letters, digits, and punctuation characters (. -)
- A dot (.) followed by 2 or more letters for the domain extension

**Note:** The first character should be always start as alphanumeric. So the string that starts with `%abc123@gmail.com` will be invalid email.

Arguments:
- `email (str)`: The string to validate.

Returns:
- `is_valid (bool)`: True if the string is a valid email format, False otherwise.

Examples:
```python
>>> validate_email("user@example.com")
True
>>> validate_email("invalid.email@")
False
>>> validate_email("also@invalid")
False
>>> validate_email("This is not an email")
False
```

In [None]:
def validate_email(email: str) -> bool:
    ### WRITE YOUR CODE BELOW
    # same as extract_emails + anchors to match entire string
    
    pattern = r'^[a-zA-Z0-9][a-zA-Z0-9._%+-]*@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    is_valid = re.match(pattern, email) is not None
    
    ### END CODE HERE
    return is_valid


#### 1.1.3: Extract Email Domains (2 points)

This function extracts unique domain names (the part after @) from all email addresses found in a text, returning them sorted alphabetically.

Arguments:
- `text (str)`: The input text containing email addresses.

Returns:
- `domains (List[str])`: A sorted list of unique domain names (e.g., "example.com").

Examples:
```python
>>> text = "Email us at support@example.com or info@example.com and admin@test.org"
>>> extract_domains(text)
['example.com', 'test.org']
>>> text = "No emails here"
>>> extract_domains(text)
[]
```

In [None]:
def extract_domains(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    
    emails = extract_emails(text)
    domains = []  # after @
    for email in emails:
        domain = email.split('@')[1]
        domains.append(domain)
    domains = sorted(list(set(domains)))
    
    ### END CODE HERE
    return domains


### 1.2: Phone Number Extractor (5 points)

You're working on a contact management system that needs to extract and standardize phone numbers from various text sources. Phone numbers can appear in different formats, and you need to identify and process them consistently.

#### 1.2.1: Find Simple Phone Numbers (2 points)

This function finds all phone numbers in the format XXX-XXXX (3 digits, hyphen, 4 digits) in a text.

Arguments:
- `text (str)`: The input text to search.

Returns:
- `phone_numbers (List[str])`: A list of all phone numbers found in XXX-XXXX format.

Examples:
```python
>>> text = "Call me at 555-1234 or 555-5678 for more info."
>>> find_simple_phones(text)
['555-1234', '555-5678']
>>> text = "My number is 12345678"
>>> find_simple_phones(text)
[]
```

In [None]:
def find_simple_phones(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # XXX-XXXX: 3 digits, hyphen, 4 digits
    pattern = r'\d{3}-\d{4}'
    phone_numbers = re.findall(pattern, text)
    ### END CODE HERE
    return phone_numbers


#### 1.2.2: Find US Phone Numbers (3 points)

This function finds phone numbers in the format (XXX) XXX-XXXX, where X is a digit. This represents the common US phone number format with area code in parentheses.

Arguments:
- `text (str)`: The input text to search.

Returns:
- `phone_numbers (List[str])`: A list of all US-format phone numbers found.

Examples:
```python
>>> text = "Contact: (555) 123-4567 or (800) 555-0199"
>>> find_us_phones(text)
['(555) 123-4567', '(800) 555-0199']
>>> text = "Call 555-1234"
>>> find_us_phones(text)
[]
```

In [None]:
def find_us_phones(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # (XXX) XXX-XXXX - escape parentheses
    pattern = r'\(\d{3}\) \d{3}-\d{4}'
    phone_numbers = re.findall(pattern, text)
    ### END CODE HERE
    return phone_numbers


### 1.3: Date and Time Parser (6 points)

You're building a log analyzer that needs to extract dates and timestamps from various log files. Different systems use different date formats, so you need flexible regex patterns to handle them.

#### 1.3.1: Extract Dates in MM/DD/YYYY Format (3 points)

This function finds all dates in the format MM/DD/YYYY where MM and DD can be 1 or 2 digits, and YYYY is exactly 4 digits.

Arguments:
- `text (str)`: The input text containing dates.

Returns:
- `dates (List[str])`: A list of all dates found in MM/DD/YYYY format.

Examples:
```python
>>> text = "Important dates: 12/25/2023 and 1/1/2024"
>>> extract_dates_mdy(text)
['12/25/2023', '1/1/2024']
>>> text = "No dates here"
>>> extract_dates_mdy(text)
[]
```

In [None]:
def extract_dates_mdy(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # MM/DD/YYYY: MM and DD --> 1-2 digits, YYYY --> 4 digits
    pattern = r'\d{1,2}/\d{1,2}/\d{4}'
    dates = re.findall(pattern, text)
    ### END CODE HERE
    return dates


#### 1.3.2: Extract Timestamps (3 points)

This function extracts timestamps in the format HH:MM:SS where each component is exactly 2 digits.

Arguments:
- `text (str)`: The input text containing timestamps.

Returns:
- `timestamps (List[str])`: A list of all timestamps found in HH:MM:SS format.

Examples:
```python
>>> text = "Events at 14:30:00 and 09:15:45"
>>> extract_timestamps(text)
['14:30:00', '09:15:45']
>>> text = "Time is 9:5:3"
>>> extract_timestamps(text)
[]
```

In [None]:
def extract_timestamps(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # HH:MM:SS --> each component: 2 digits
    pattern = r'\d{2}:\d{2}:\d{2}'
    timestamps = re.findall(pattern, text)
    ### END CODE HERE
    return timestamps


### 1.4: Social Media Text Processor (8 points)

You're developing a social media analytics tool that needs to extract and analyze various elements from posts and tweets, including hashtags, mentions, and URLs.

#### 1.4.1: Extract Hashtags (2 points)

This function extracts all hashtags from social media text. A valid hashtag must:
- Start with a # symbol that is NOT immediately preceded by a word character (letter, digit, or underscore)
- Be followed by one or more word characters (letters, digits, or underscores)

This means:
- `#Python` is a valid hashtag
- `#AI_ML` is a valid hashtag (underscores are allowed)
- `#123` is a valid hashtag (digits are allowed)
- `test#nothashtag` is NOT a valid hashtag (# is preceded by a word character)
- `#` alone is NOT a valid hashtag (no word characters after #)

Arguments:
- `text (str)`: The social media post text.

Returns:
- `hashtags (List[str])`: A list of all hashtags found (including the # symbol).

Examples:
```python
>>> text = "Check out #Python and #DataScience for #AI_ML projects!"
>>> extract_hashtags(text)
['#Python', '#DataScience', '#AI_ML']
>>> text = "No hashtags here"
>>> extract_hashtags(text)
[]
```

In [None]:
def extract_hashtags(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # '#' is not preceded by \w
    pattern = r'(?<!\w)#\w+'
    hashtags = re.findall(pattern, text)
    ### END CODE HERE
    return hashtags


#### 1.4.2: Extract Mentions (2 points)

This function extracts all user mentions from social media text. A valid mention must:
- Start with an @ symbol that is NOT immediately preceded by a word character (letter, digit, or underscore)
- Be followed by one or more word characters (letters, digits, or underscores)

This means:
- `@user123` is a valid mention
- `@admin` is a valid mention
- `@user_name` is a valid mention (underscores are allowed)
- `info@example.com` is NOT a valid mention (@ is preceded by a word character, so it's recognized as an email)
- `@` alone is NOT a valid mention (no word characters after @)

Arguments:
- `text (str)`: The social media post text.

Returns:
- `mentions (List[str])`: A list of all mentions found (including the @ symbol).

Examples:
```python
>>> text = "Thanks @user123 and @admin for the help! You can also reach out to info@example.com"
>>> extract_mentions(text)
['@user123', '@admin']
>>> text = "No mentions in this post"
>>> extract_mentions(text)
[]
```

In [None]:
def extract_mentions(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # @ not preceded by word char  +    1+ word characters
    pattern = r'(?<!\w)@\w+'
    mentions = re.findall(pattern, text)
    ### END CODE HERE
    return mentions


#### 1.4.3: Extract URLs (4 points)

This function extracts all URLs from text. A URL starts with http:// or https:// followed by one or more non-whitespace characters.

Arguments:
- `text (str)`: The text containing URLs.

Returns:
- `urls (List[str])`: A list of all URLs found.

Examples:
```python
>>> text = "Visit https://example.com and http://test.org for more info"
>>> extract_urls(text)
['https://example.com', 'http://test.org']
>>> text = "No links here"
>>> extract_urls(text)
[]
```

In [None]:
def extract_urls(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # http:// or https://      +      1+ non-whitespace chars
    pattern = r'https?://\S+'
    urls = re.findall(pattern, text)
    ### END CODE HERE
    return urls


### 1.5: Advanced Pattern Matching (7 points)

These exercises involve more complex regex patterns using groups, alternation, and advanced pattern matching techniques.

#### 1.5.1: Extract Price Information (3 points)

This function extracts monetary amounts from text. Prices can be in various formats: \$X, \$X.XX. The function should handle optional cents (two digits after decimal point) and return all unique prices found, sorted in ascending order by their numeric value.

Arguments:
- `text (str)`: The input text containing prices.

Returns:
- `prices (List[str])`: A sorted list of unique price strings found (keeping their original format).

Examples:
```python
>>> text = "Items cost $19.99, $5, and $125.50 respectively"
>>> extract_prices(text)
['$5', '$19.99', '$125.50']
>>> text = "The book is $29 and the pen is $2.50"
>>> extract_prices(text)
['$2.50', '$29']
>>> text = "No prices here"
>>> extract_prices(text)
[]
```

In [None]:
def extract_prices(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    # $ + digits + (optional) .XX
    pattern = r'\$\d+(?:\.\d{2})?'
    prices = re.findall(pattern, text)
    prices = list(set(prices))
    prices.sort(key=lambda x: float(x.replace('$', '')))
    ### END CODE HERE
    return prices


#### 1.5.2: Validate Password Strength (4 points)

This function validates whether a password meets security requirements using regex. A strong password must:
1. Be at least 8 characters long
2. Contain at least one uppercase letter
3. Contain at least one lowercase letter
4. Contain at least one digit
5. Contain at least one special character from: !@#$%^&*()

Arguments:
- `password (str)`: The password string to validate.

Returns:
- `is_valid (bool)`: True if the password meets all requirements, False otherwise.

Examples:
```python
>>> validate_password("Abc123!@")
True
>>> validate_password("weak")
False
>>> validate_password("NoDigits!")
False
>>> validate_password("nouppercas3!")
False
```

In [None]:
def validate_password(password: str) -> bool:
    ### WRITE YOUR CODE BELOW
    if len(password) < 8:
        return False
    if not re.search(r'[A-Z]', password):  # 1+ uppercase
        return False
    if not re.search(r'[a-z]', password):  # 1+ lowercase
        return False
    if not re.search(r'\d', password):  # 1+ digit
        return False
    if not re.search(r'[!@#$%^&*()]', password):  # 1+ special char
        return False
    is_valid = True
    ### END CODE HERE
    return is_valid


## 2. Tokenization (16 points)

This section tests your understanding of tokenization. You'll build a simple Byte-Pair Encoding (BPE) tokenizer from scratch, learning how modern language models break text into tokens. BPE is used by models like GPT to handle rare words and reduce vocabulary size by learning common character sequences.

### Background

BPE works by iteratively merging the most frequent pairs of characters (or character sequences) in a corpus. Starting with individual characters, it builds up a vocabulary of subwords that efficiently represent the text.

In this implementation, we use a special start-of-word marker `_` (underscore) to distinguish word boundaries. Following the convention used by GPT-2 (which uses `Ġ`) and SentencePiece (which uses `▁`), the marker is **attached to the first character** of each word during the initial tokenization. For example, the text "hello world" would be tokenized as `['_h', 'e', 'l', 'l', 'o', ' ', '_w', 'o', 'r', 'l', 'd']`. This way, tokens like `_t` (start of "the", "to", etc.) can be learned directly through BPE merges.

You can read more about BPE tokenization on [Hugging Face](https://huggingface.co/learn/llm-course/en/chapter6/5).

In [None]:
### DO NOT MODIFY THIS CELL ###
CORPUS = """To be or not to be that is the question.
Whether tis nobler in the mind to suffer the slings and arrows of outrageous fortune.
Or to take arms against a sea of troubles and by opposing end them.
To die to sleep no more and by a sleep to say we end the heartache.
And the thousand natural shocks that flesh is heir to."""

print("Training corpus:")
print(CORPUS)
print(f"\nCorpus size: {len(CORPUS)} characters")

### 2.1: Tokenize Text into Characters with Word Markers (2 points)

Before we can perform BPE merges, we need to break our text into individual character tokens. This function should split a text string into a list of individual characters, with the start-of-word marker `_` **attached to the first character** of each word. Words are separated by spaces. This follows the standard convention used by tokenizers like GPT-2 (which uses `Ġ`) and SentencePiece (which uses `▁`).

Arguments:
- `text (str)`: The input text to tokenize.

Returns:
- `tokens (List[str])`: A list of character tokens where the first character of each word has `_` prefixed to it.

Examples:
```python
>>> tokenize_characters("hello")
['_h', 'e', 'l', 'l', 'o']
>>> tokenize_characters("hi there")
['_h', 'i', ' ', '_t', 'h', 'e', 'r', 'e']
>>> tokenize_characters("a b c")
['_a', ' ', '_b', ' ', '_c']
```

In [None]:
def tokenize_characters(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW
    tokens = []
    at_word_start = True
    
    for char in text:
        if char == ' ':
            tokens.append(' ')
            at_word_start = True
        else:
            if at_word_start:
                tokens.append('_' + char)
                at_word_start = False
            else:
                tokens.append(char)
    ### END CODE HERE
    return tokens


### 2.2: Count Consecutive Pairs (3 points)

A key part of BPE is identifying which pairs of tokens appear most frequently. This function takes a list of tokens and counts how many times each consecutive pair appears.

Arguments:
- `tokens (List[str])`: A list of tokens (can be characters or merged tokens).

Returns:
- `pair_counts (Dict[Tuple[str, str], int])`: A dictionary mapping each pair (as a tuple) to its frequency count.

Examples:
```python
>>> count_pairs(['_h', 'e', 'l', 'l', 'o'])
{('_h', 'e'): 1, ('e', 'l'): 1, ('l', 'l'): 1, ('l', 'o'): 1}
>>> count_pairs(['_a', 'b', ' ', '_a', 'b'])
{('_a', 'b'): 2, ('b', ' '): 1, (' ', '_a'): 1}
```

In [None]:
def count_pairs(tokens: List[str]) -> Dict[Tuple[str, str], int]:
    ### WRITE YOUR CODE BELOW
    pair_counts = {}
    for i in range(len(tokens) - 1):
        pair = (tokens[i], tokens[i + 1])
        pair_counts[pair] = pair_counts.get(pair, 0) + 1
    ### END CODE HERE
    return pair_counts


### 2.3: Find Most Frequent Pair (3 points)

After counting pairs, we need to identify which pair occurs most frequently. This is the pair we'll merge in BPE. If there are no pairs (empty or single-token list), return None.

Arguments:
- `pair_counts (Dict[Tuple[str, str], int])`: A dictionary of pair counts from the previous function.

Returns:
- `most_frequent (Optional[Tuple[str, str]])`: The pair with the highest count, or None if no pairs exist.

Examples:
```python
>>> pairs = {('_h', 'e'): 1, ('e', 'l'): 1, ('l', 'l'): 2, ('l', 'o'): 1}
>>> find_most_frequent_pair(pairs)
('l', 'l')
>>> find_most_frequent_pair({})
None
```

In [None]:
def find_most_frequent_pair(pair_counts: Dict[Tuple[str, str], int]) -> Optional[Tuple[str, str]]:
    ### WRITE YOUR CODE BELOW
    if not pair_counts:
        return None
    most_frequent = max(pair_counts.items(), key=lambda x: x[1])[0]
    ### END CODE HERE
    return most_frequent


### 2.4: Merge Token Pair (3 points)

Once we've identified the most frequent pair, we need to merge all occurrences of it in our token list. This function takes a list of tokens and a target pair, and replaces every consecutive occurrence of that pair with a single merged token (by concatenating the pair elements).

Arguments:
- `tokens (List[str])`: The current list of tokens.
- `pair (Tuple[str, str])`: The pair to merge (e.g., ('e', 'l')).

Returns:
- `merged_tokens (List[str])`: A new list with all occurrences of the pair merged into single tokens.

Examples:
```python
>>> merge_pair(['_h', 'e', 'l', 'l', 'o'], ('l', 'l'))
['_h', 'e', 'll', 'o']
>>> merge_pair(['_h', 'e', 'l', 'l', 'o'], ('_h', 'e'))
['_he', 'l', 'l', 'o']
>>> merge_pair(['_h', 'e', 'l', 'l', 'o'], ('x', 'y'))
['_h', 'e', 'l', 'l', 'o']
```

In [None]:
def merge_pair(tokens: List[str], pair: Tuple[str, str]) -> List[str]:
    ### WRITE YOUR CODE BELOW
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
            merged_tokens.append(tokens[i] + tokens[i + 1])  # merge pair
            i += 2
        else:
            merged_tokens.append(tokens[i])
            i += 1
    ### END CODE HERE
    return merged_tokens


### 2.5: Build BPE Vocabulary (5 points)

Now we'll put it all together! This function implements the complete BPE algorithm by iteratively finding and merging the most frequent pairs. It should:
1. Start with character-level tokens (with word markers using `tokenize_characters`)
2. Repeat for `num_merges` iterations:
   - Count all pairs
   - Find the most frequent pair
   - Merge that pair in the token list
3. Return the final vocabulary (unique tokens after all merges)

Use the helper functions you implemented above.

Arguments:
- `text (str)`: The input text corpus to train on.
- `num_merges (int)`: The number of merge operations to perform.

Returns:
- `vocab (List[str])`: A list of unique tokens in the final vocabulary, ordered by their frequency (most frequent first).

Examples:
```python
>>> build_bpe_vocab("hello", num_merges=1)
['l', '_he', 'o', '_h', 'e']  # ('_h', 'e') merged into '_he'
>>> build_bpe_vocab("hi hi", num_merges=1)
['_hi', ' ', '_h', 'i']  # ('_h', 'i') merged into '_hi'
```

In [None]:
def build_bpe_vocab(text: str, num_merges: int) -> List[str]:
    ### WRITE YOUR CODE BELOW
    tokens = tokenize_characters(text)
    
    for _ in range(num_merges):
        pair_counts = count_pairs(tokens)
        most_frequent_pair = find_most_frequent_pair(pair_counts)
        if most_frequent_pair is None:
            break
        tokens = merge_pair(tokens, most_frequent_pair)
    
    token_freq = {}
    for token in tokens:
        token_freq[token] = token_freq.get(token, 0) + 1
    
    unique_tokens = list(set(tokens))
    vocab = sorted(unique_tokens, key=lambda x: (-token_freq[x], x))  # freq then alpha
    ### END CODE HERE
    return vocab


### 2.6: Test Your Tokenizer on the Corpus [ungraded]

Now let's test your BPE implementation on the Shakespeare corpus! Run the cell below to build a vocabulary with 10 merges and see what tokens are learned. The vocabulary will be ordered by token frequency (most common tokens first).

In [None]:
### DO NOT MODIFY THIS CELL ###
# Build BPE vocabulary with 10 merges
bpe_vocab = build_bpe_vocab(CORPUS, num_merges=10)

print(f"Vocabulary size: {len(bpe_vocab)}")
print("\nTop 20 most frequent tokens:")
print(bpe_vocab[:20])
print("\nLeast 20 frequent tokens:")
print(bpe_vocab[-20:])

# Show some interesting multi-character tokens
multi_char = [token for token in bpe_vocab if len(token) > 1]
print(f"\nNumber of multi-character tokens: {len(multi_char)}")
print(f"Most common multi-character tokens: {[t for t in bpe_vocab if len(t) > 1][:10]}")

## 3. Calculating Relative Frequencies (15 points)

This section tests your understanding of text frequency analysis, including computing word frequencies, relative frequencies, and comparing word usage across different texts. These concepts are fundamental to computational linguistics and NLP, helping us understand what makes texts distinctive and how to compare language use across different corpora.

**Note on Tokenization:** For this section, we'll use a very simple tokenization approach (lowercase, split on whitespace, keep only alphabetic words). The tokenization is already handled in the provided code. **Do not modify the tokenization** - focus on implementing the frequency analysis functions.

### 3.1: Baby Corpus Setup

We'll work with a small "baby corpus" to understand frequency calculations. The corpus contains three short texts about different topics: cooking, technology, and nature. Run the cell below to load the corpus.

In [None]:
### DO NOT MODIFY THIS CELL ###
# Baby Corpus: Three short texts about different topics

COOKING_TEXT = """
The chef prepared a delicious meal in the kitchen. She chopped vegetables and
added fresh herbs to the soup. The aroma filled the kitchen as the soup simmered
on the stove. Fresh bread was baking in the oven while the chef stirred the pot.
The meal was ready and the chef served the delicious soup with fresh bread.
"""

TECH_TEXT = """
The programmer wrote code on the computer all day. She debugged the software and
fixed several bugs in the code. The computer processed the data quickly and the
software ran smoothly. The programmer tested the code again and the software
worked perfectly on the computer. She saved the code and shut down the computer.
"""

NATURE_TEXT = """
The birds sang in the forest as the sun rose over the mountains. Flowers bloomed
in the meadow and butterflies danced in the breeze. The river flowed through the
forest and deer drank from the clear water. The sun warmed the forest and the
birds continued their beautiful song in the trees.
"""

# Simple tokenizer: lowercase and split on whitespace, keeping only alphabetic tokens
def simple_tokenize(text: str) -> List[str]:
    """Tokenize text into lowercase alphabetic words."""
    return [word.lower() for word in text.split() if word.isalpha()]

# Pre-tokenized texts
cooking_tokens = simple_tokenize(COOKING_TEXT)
tech_tokens = simple_tokenize(TECH_TEXT)
nature_tokens = simple_tokenize(NATURE_TEXT)

print(f"Cooking text: {len(cooking_tokens)} tokens")
print(f"Tech text: {len(tech_tokens)} tokens")
print(f"Nature text: {len(nature_tokens)} tokens")

#### 3.1.1: Count Word Frequencies (2 points)

This function counts how many times each word appears in a list of tokens and returns a dictionary mapping words to their counts (raw frequencies).

Arguments:
- `tokens (List[str])`: A list of word tokens.

Returns:
- `freq_dict (Dict[str, int])`: A dictionary where keys are words and values are their counts.

Examples:
```python
>>> count_frequencies(["the", "cat", "sat", "the"])
{'the': 2, 'cat': 1, 'sat': 1}
>>> count_frequencies(["hello", "hello", "world"])
{'hello': 2, 'world': 1}
```

In [None]:
def count_frequencies(tokens: List[str]) -> Dict[str, int]:
    ### WRITE YOUR CODE BELOW
    freq_dict = {}
    for token in tokens:
        freq_dict[token] = freq_dict.get(token, 0) + 1
    ### END CODE HERE
    return freq_dict


#### 3.1.2: Compute Relative Frequencies (2 points)

Raw counts don't allow fair comparison between texts of different lengths. **Relative frequency** (or normalized frequency) tells us what proportion of the text is made up of each word.

**Relative Frequency** = (Word Count) / (Total Words)

This function takes a list of tokens, computes their frequencies (by calling `count_frequencies`), and then converts those counts to relative frequencies.

Arguments:
- `tokens (List[str])`: A list of word tokens.

Returns:
- `rel_freq_dict (Dict[str, float])`: A dictionary where keys are words and values are their relative frequencies (between 0 and 1).

Examples:
```python
>>> compute_relative_frequencies(['the', 'cat', 'sat', 'on', 'the', 'mat'])
{'the': 0.3333333333333333, 'cat': 0.16666666666666666, 'sat': 0.16666666666666666, 'on': 0.16666666666666666, 'mat': 0.16666666666666666}
>>> compute_relative_frequencies(['hello', 'world'])
{'hello': 0.5, 'world': 0.5}
```

In [None]:
def compute_relative_frequencies(tokens: List[str]) -> Dict[str, float]:
    ### WRITE YOUR CODE BELOW
    freq_dict = count_frequencies(tokens)
    total = len(tokens)
    rel_freq_dict = {}
    for word, count in freq_dict.items():
        rel_freq_dict[word] = count / total
    ### END CODE HERE
    return rel_freq_dict


#### 3.1.3: Get Top N Words (3 points)

This function returns the top N most frequent words from a frequency dictionary, sorted by frequency in descending order. If there are ties, sort alphabetically as a tiebreaker.

Arguments:
- `freq_dict (Dict[str, int])`: A dictionary mapping words to their counts.
- `n (int)`: The number of top words to return.

Returns:
- `top_words (List[Tuple[str, int]])`: A list of (word, count) tuples for the top N words, sorted by count (descending), then alphabetically for ties.

Examples:
```python
>>> get_top_n_words({'the': 5, 'cat': 3, 'sat': 3, 'on': 1}, 2)
[('the', 5), ('cat', 3)]
>>> get_top_n_words({'a': 2, 'b': 2, 'c': 2}, 2)
[('a', 2), ('b', 2)]
```

In [None]:
def get_top_n_words(freq_dict: Dict[str, int], n: int) -> List[Tuple[str, int]]:
    ### WRITE YOUR CODE BELOW
    # Sort by count (descending), then alphabetically for ties
    sorted_words = sorted(freq_dict.items(), key=lambda x: (-x[1], x[0]))
    top_words = sorted_words[:n]
    ### END CODE HERE
    return top_words


#### 3.1.4: Compare Word Frequency (2 points)

This function compares how a specific word is used across two texts by computing its relative frequency in each. This helps us understand which text uses the word more prominently.

Arguments:
- `word (str)`: The word to compare.
- `freq_dict1 (Dict[str, int])`: Frequency dictionary for the first text.
- `total1 (int)`: Total tokens in the first text.
- `freq_dict2 (Dict[str, int])`: Frequency dictionary for the second text.
- `total2 (int)`: Total tokens in the second text.

Returns:
- `comparison (Tuple[float, float])`: A tuple of (relative_freq_in_text1, relative_freq_in_text2). If a word doesn't appear in a text, its relative frequency should be 0.0.

Examples:
```python
>>> compare_word_frequency("the", freq_dict1={'the': 10}, total1=100, freq_dict2={'the': 5}, total2=50)
(0.1, 0.1)
>>> compare_word_frequency("cat", freq_dict1={'cat': 5}, total1=100, freq_dict2={'dog': 5}, total2=100)
(0.05, 0.0)
```

In [None]:
def compare_word_frequency(word: str, freq_dict1: Dict[str, int], total1: int,
                           freq_dict2: Dict[str, int], total2: int) -> Tuple[float, float]:
    ### WRITE YOUR CODE BELOW
    # relative frequency in text1
    count1 = freq_dict1.get(word, 0)
    rel_freq1 = count1 / total1 if total1 > 0 else 0.0
    
    # relative frequency in text2
    count2 = freq_dict2.get(word, 0)
    rel_freq2 = count2 / total2 if total2 > 0 else 0.0
    
    comparison = (rel_freq1, rel_freq2)
    ### END CODE HERE
    return comparison


#### 3.1.5: Calculate Frequency Ratio (3 points)

A **frequency ratio** tells us how many times more frequent a word is in one text compared to another. This is useful for identifying **keywords** - words that are characteristic of a particular text.

**Frequency Ratio** = (Relative Frequency in Target) / (Relative Frequency in Reference)

To avoid division by zero when a word doesn't appear in the reference text, use smoothing: if the reference frequency is 0, use a small value (0.5 / total_reference_tokens) instead.

Arguments:
- `word (str)`: The word to analyze.
- `target_freq (Dict[str, int])`: Frequency dictionary for the target text.
- `target_total (int)`: Total tokens in the target text.
- `reference_freq (Dict[str, int])`: Frequency dictionary for the reference text.
- `reference_total (int)`: Total tokens in the reference text.

Returns:
- `ratio (float)`: The frequency ratio. Returns 0.0 if the word doesn't appear in the target text.

Examples:
```python
>>> calculate_frequency_ratio("whale", target_freq={'whale': 10}, target_total=100, reference_freq={'whale': 1}, reference_total=100)
10.0
>>> calculate_frequency_ratio("computer", target_freq={'computer': 5}, target_total=100, reference_freq={}, reference_total=100)
10.0  # Uses smoothed reference frequency
```

In [None]:
def calculate_frequency_ratio(word: str, target_freq: Dict[str, int], target_total: int,
                              reference_freq: Dict[str, int], reference_total: int) -> float:
    ### WRITE YOUR CODE BELOW
    # word not in target -> return 0.0
    if word not in target_freq:
        return 0.0
    
    target_rel_freq = target_freq[word] / target_total if target_total > 0 else 0.0
    reference_count = reference_freq.get(word, 0)
    reference_rel_freq = reference_count / reference_total if reference_total > 0 else 0.0
    
    # reference frequency is 0 -> use smoothing
    if reference_rel_freq == 0.0:
        reference_rel_freq = 0.5 / reference_total if reference_total > 0 else 0.5
    
    ratio = target_rel_freq / reference_rel_freq if reference_rel_freq > 0 else 0.0
    ### END CODE HERE
    return ratio


#### 3.1.6: Find Keywords (3 points)

Keywords are words that appear significantly more often in a target text compared to a reference text. This function finds all words in the target text that have a frequency ratio above a given threshold.

Arguments:
- `target_freq (Dict[str, int])`: Frequency dictionary for the target text.
- `target_total (int)`: Total tokens in the target text.
- `reference_freq (Dict[str, int])`: Frequency dictionary for the reference text.
- `reference_total (int)`: Total tokens in the reference text.
- `min_ratio (float)`: Minimum frequency ratio to consider a word as a keyword.

Returns:
- `keywords (List[Tuple[str, float]])`: A list of (word, ratio) tuples for words with ratio >= min_ratio, sorted by ratio in descending order.

Examples:
```python
>>> find_keywords(target_freq={'whale': 10, 'the': 5}, target_total=100, reference_freq={'whale': 1, 'the': 10}, reference_total=100, min_ratio=2.0)
[('whale', 10.0)]
>>> find_keywords(target_freq={'code': 8, 'and': 4}, target_total=100, reference_freq={'code': 2, 'and': 4}, reference_total=100, min_ratio=3.0)
[('code', 4.0)]
```

In [None]:
def find_keywords(target_freq: Dict[str, int], target_total: int,
                  reference_freq: Dict[str, int], reference_total: int,
                  min_ratio: float) -> List[Tuple[str, float]]:
    ### WRITE YOUR CODE BELOW
    keywords = []
    for word in target_freq:
        ratio = calculate_frequency_ratio(word, target_freq, target_total,
                                         reference_freq, reference_total)
        if ratio >= min_ratio:
            keywords.append((word, ratio))
    # descending sort by ratio
    keywords.sort(key=lambda x: -x[1])
    ### END CODE HERE
    return keywords


#### 3.1.7: Test Your Frequency Functions [ungraded]

Run the cell below to test your implementations on the baby corpus. This will show you the frequency analysis results for each text.

In [None]:
### DO NOT MODIFY THIS CELL ###
# Test your frequency functions on the baby corpus

# Count frequencies for each text
cooking_freq = count_frequencies(cooking_tokens)
tech_freq = count_frequencies(tech_tokens)
nature_freq = count_frequencies(nature_tokens)

print("=" * 60)
print("FREQUENCY ANALYSIS OF BABY CORPUS")
print("=" * 60)

# Show top 5 words from each text
print("\n--- Top 5 Words in Each Text ---")
print(f"Cooking: {get_top_n_words(cooking_freq, 5)}")
print(f"Tech: {get_top_n_words(tech_freq, 5)}")
print(f"Nature: {get_top_n_words(nature_freq, 5)}")

# Compare 'the' across texts
print("\n--- Comparing 'the' across texts ---")
cooking_the, tech_the = compare_word_frequency("the", cooking_freq, len(cooking_tokens),
                                                tech_freq, len(tech_tokens))
print(f"'the' in Cooking: {cooking_the:.4f} ({cooking_the*100:.2f}%)")
print(f"'the' in Tech: {tech_the:.4f} ({tech_the*100:.2f}%)")

# Find keywords in cooking text vs tech text
print("\n--- Keywords in Cooking (vs Tech) with ratio >= 2.0 ---")
cooking_keywords = find_keywords(cooking_freq, len(cooking_tokens),
                                  tech_freq, len(tech_tokens), 2.0)
for word, ratio in cooking_keywords[:5]:
    print(f"  {word}: {ratio:.2f}x more frequent")


## 4: Collocations and PMI Analysis (17 points)

**Collocations** are word pairs that appear together more frequently than we'd expect by chance. While frequency analysis (Section 3) looks at individual words, collocation analysis examines word associations.

**Pointwise Mutual Information (PMI)** measures the strength of association between two words by comparing their co-occurrence probability with their independent probabilities.

**PMI Formula**: PMI(x, y) = log₂(P(x,y) / (P(x) × P(y)))

- **High PMI**: Words appear together much more than expected → strong collocation
- **PMI ≈ 0**: Words appear together about as expected by chance
- **Negative PMI**: Words appear together less than expected

In this section, you'll implement functions to find and analyze collocations using PMI.

### 4.1: PMI Collocation Functions

Implement the following functions to analyze collocations using PMI.

#### 4.1.1: Generate Bigrams (2 points)

A bigram is a sequence of two consecutive words. This function generates all bigrams from a list of tokens.

Arguments:
- `tokens (List[str])`: A list of word tokens.

Returns:
- `bigrams (List[Tuple[str, str]])`: A list of bigrams, where each bigram is a tuple of two consecutive words.

Examples:
```python
>>> generate_bigrams(["the", "cat", "sat", "on", "the", "mat"])
[('the', 'cat'), ('cat', 'sat'), ('sat', 'on'), ('on', 'the'), ('the', 'mat')]
>>> generate_bigrams(["hello", "world"])
[('hello', 'world')]
>>> generate_bigrams(["single"])
[]
```

In [None]:
def generate_bigrams(tokens: List[str]) -> List[Tuple[str, str]]:
    ### WRITE YOUR CODE BELOW
    bigrams = []
    # tuples of consecutive word pairs
    for i in range(len(tokens) - 1):
        bigrams.append((tokens[i], tokens[i + 1]))
    ### END CODE HERE
    return bigrams


#### 4.1.2: Count Bigram Frequencies (2 points)

This function counts how many times each bigram appears in a list of bigrams and returns a dictionary mapping bigrams to their counts.

Arguments:
- `bigrams (List[Tuple[str, str]])`: A list of bigrams.

Returns:
- `bigram_freq (Dict[Tuple[str, str], int])`: A dictionary where keys are bigrams (tuples) and values are their counts.

Examples:
```python
>>> count_bigram_frequencies([('the', 'cat'), ('the', 'dog'), ('the', 'cat')])
{('the', 'cat'): 2, ('the', 'dog'): 1}
>>> count_bigram_frequencies([('hello', 'world'), ('hello', 'world')])
{('hello', 'world'): 2}
```

In [None]:
def count_bigram_frequencies(bigrams: List[Tuple[str, str]]) -> Dict[Tuple[str, str], int]:
    ### WRITE YOUR CODE BELOW
    bigram_freq = {}
    for bigram in bigrams:
        bigram_freq[bigram] = bigram_freq.get(bigram, 0) + 1
    ### END CODE HERE
    return bigram_freq


#### 4.1.3: Calculate PMI Score (4 points)

This function calculates the PMI (Pointwise Mutual Information) score for a single bigram.

**PMI Formula**: PMI(x, y) = log₂(P(x,y) / (P(x) × P(y)))

Where:
- P(x,y) = bigram_count / total_bigrams (probability of bigram occurring)
- P(x) = word1_count / total_words (probability of first word)
- P(y) = word2_count / total_words (probability of second word)

Arguments:
- `bigram (Tuple[str, str])`: The bigram to analyze.
- `bigram_count (int)`: How many times this bigram appears.
- `word1_count (int)`: How many times the first word appears.
- `word2_count (int)`: How many times the second word appears.
- `total_bigrams (int)`: Total number of bigrams in the corpus.
- `total_words (int)`: Total number of words in the corpus.

Returns:
- `pmi (float)`: The PMI score for this bigram.

Examples:
```python
>>> calculate_pmi(('strong', 'coffee'), 10, 50, 100, 1000, 1001)
# P(strong,coffee) = 10/1000 = 0.01
# P(strong) = 50/1001 ≈ 0.05, P(coffee) = 100/1001 ≈ 0.1
# PMI = log2(0.01 / (0.05 * 0.1)) ≈ 0.997
0.9965784284662086
```

**Note**: Use `math.log2()` for the logarithm. `math` is already at the imported at the top.

In [None]:
def calculate_pmi(bigram: Tuple[str, str], bigram_count: int, word1_count: int,
                  word2_count: int, total_bigrams: int, total_words: int) -> float:
    ### WRITE YOUR CODE BELOW
    # P(x,y) = bigram_count / total_bigrams
    p_xy = bigram_count / total_bigrams if total_bigrams > 0 else 0.0
    
    # P(x) = word1_count / total_words
    p_x = word1_count / total_words if total_words > 0 else 0.0
    
    # P(y) = word2_count / total_words
    p_y = word2_count / total_words if total_words > 0 else 0.0
    
    # PMI = log2(P(x,y) / (P(x) * P(y)))
    if p_x * p_y == 0:
        pmi = 0.0  # avoid division by zero
    else:
        pmi = math.log2(p_xy / (p_x * p_y))
    ### END CODE HERE
    return pmi


#### 4.1.4: Calculate PMI for All Bigrams (3 points)

This function calculates PMI scores for all bigrams in a corpus and returns them sorted by PMI score in descending order.

Arguments:
- `bigram_freq (Dict[Tuple[str, str], int])`: Dictionary of bigram frequencies.
- `word_freq (Dict[str, int])`: Dictionary of word (unigram) frequencies.
- `total_bigrams (int)`: Total number of bigrams.
- `total_words (int)`: Total number of words.

Returns:
- `pmi_scores (List[Tuple[Tuple[str, str], float, int]])`: A list of tuples, where each tuple contains (bigram, pmi_score, frequency), sorted by PMI score in descending order.

Example:
```python
>>> bigram_freq = {('strong', 'coffee'): 10, ('the', 'cat'): 5}
>>> word_freq = {'strong': 50, 'coffee': 100, 'the': 200, 'cat': 30}
>>> calculate_all_pmi(bigram_freq, word_freq, 1000, 1001)
# Returns list with (bigram, pmi, count) tuples sorted by PMI
```

In [None]:
def calculate_all_pmi(bigram_freq: Dict[Tuple[str, str], int], word_freq: Dict[str, int],
                     total_bigrams: int, total_words: int) -> List[Tuple[Tuple[str, str], float, int]]:
    ### WRITE YOUR CODE BELOW
    pmi_scores = []
    for bigram, bigram_count in bigram_freq.items():
        word1, word2 = bigram
        word1_count = word_freq.get(word1, 0)
        word2_count = word_freq.get(word2, 0)
        pmi = calculate_pmi(bigram, bigram_count, word1_count, word2_count, total_bigrams, total_words)
        pmi_scores.append((bigram, pmi, bigram_count))
    # sort by PMI score - descending
    pmi_scores.sort(key=lambda x: -x[1])
    ### END CODE HERE
    return pmi_scores


#### 4.1.5: Filter Collocations by Minimum Frequency (2 points)

PMI can give very high scores to rare bigrams that appear only once or twice. To find meaningful collocations, we need to filter by minimum frequency.

This function filters a list of PMI scores to only include bigrams that appear at least `min_freq` times.

Arguments:
- `pmi_scores (List[Tuple[Tuple[str, str], float, int]])`: List of (bigram, pmi, frequency) tuples.
- `min_freq (int)`: Minimum frequency threshold.

Returns:
- `filtered (List[Tuple[Tuple[str, str], float, int]])`: Filtered list containing only bigrams with frequency >= min_freq.

Examples:
```python
>>> pmi_scores = [(('white', 'whale'), 8.5, 50), (('rare', 'word'), 12.0, 1), (('the', 'cat'), 3.2, 100)]
>>> filter_by_frequency(pmi_scores, 10)
[(('white', 'whale'), 8.5, 50), (('the', 'cat'), 3.2, 100)]
```

In [None]:
def filter_by_frequency(pmi_scores: List[Tuple[Tuple[str, str], float, int]],
                       min_freq: int) -> List[Tuple[Tuple[str, str], float, int]]:
    ### WRITE YOUR CODE BELOW
    # keep only bigrams with frequency >= min_freq
    filtered = [item for item in pmi_scores if item[2] >= min_freq]
    ### END CODE HERE
    return filtered


#### 4.1.6: Find Collocations for a Specific Word (4 points)

This function finds all collocations containing a specific target word, showing which words commonly appear with it and whether they appear to the left or right.

Arguments:
- `target_word (str)`: The word to find collocations for.
- `pmi_scores (List[Tuple[Tuple[str, str], float, int]])`: List of (bigram, pmi, frequency) tuples.
- `min_freq (int)`: Minimum frequency threshold (default 1).

Returns:
- `collocations (List[Tuple[str, float, int, str]])`: A list of tuples (collocate_word, pmi, frequency, position), where position is either 'left' or 'right', sorted by PMI in descending order.

Example:
```python
>>> pmi_scores = [(('white', 'whale'), 8.5, 50), (('whale', 'ship'), 7.2, 30), (('the', 'whale'), 2.1, 100)]
>>> find_word_collocations('whale', pmi_scores, min_freq=10)
[('white', 8.5, 50, 'left'), ('ship', 7.2, 30, 'right'), ('the', 2.1, 100, 'left')]
```

In [None]:
def find_word_collocations(target_word: str, pmi_scores: List[Tuple[Tuple[str, str], float, int]],
                          min_freq: int = 1) -> List[Tuple[str, float, int, str]]:
    ### WRITE YOUR CODE BELOW
    collocations = []
    for bigram, pmi, frequency in pmi_scores:
        # check if frequency meets minimum threshold
        if frequency < min_freq:
            continue
        
        word1, word2 = bigram
        # check if target_word is in the bigram
        if word1 == target_word:
            # target_word: left, collocate: right
            collocations.append((word2, pmi, frequency, 'right'))
        elif word2 == target_word:
            # target_word: right, collocate: left
            collocations.append((word1, pmi, frequency, 'left'))
    
    # sort by PMI - descending
    collocations.sort(key=lambda x: -x[1])
    ### END CODE HERE
    return collocations

## 5. Building a Naive Bayes Text Classifier from Scratch (21 points)

In this section, you will implement a **Naive Bayes text classifier** from scratch for news article classification. Naive Bayes is a probabilistic classifier based on Bayes' theorem with the "naive" assumption that features (words) are conditionally independent given the class.

**Bayes' Theorem:**
$$P(class|document) = \frac{P(document|class) \times P(class)}{P(document)}$$

Since $P(document)$ is the same for all classes, we can simplify to:
$$P(class|document) \propto P(document|class) \times P(class)$$

**Naive Assumption:** We assume words are independent given the class:
$$P(document|class) = \prod_{word \in document} P(word|class)$$

To avoid numerical underflow with many word probabilities, we use **log probabilities**:
$$\log P(class|document) = \log P(class) + \sum_{word \in document} \log P(word|class)$$

**Laplace Smoothing:** To handle words not seen in training, we add a small constant (usually 1) to all word counts:
$$P(word|class) = \frac{count(word, class) + 1}{total\_words\_in\_class + vocabulary\_size}$$

You will build a classifier to categorize news articles into categories like: **tech**, **sports**, **business**, **health**, and **politics**.

### 5.1: Load and Explore the Dataset

First, let's load the news classification dataset and explore its structure. The dataset contains news article text snippets and their corresponding categories.

In [None]:
### DO NOT MODIFY THIS CELL ###

def load_news_dataset():
    df = pd.read_csv("https://huggingface.co/datasets/okite97/news-data/raw/main/train.csv")
    df = df.rename(columns={'Title': 'title', 'Excerpt': 'text', 'Category': 'category'})
    # stratified sampling to work with smaller dataset
    n_samples = 500
    # Sample proportionally from each category
    sampled_indices = df.groupby('category').apply(
        lambda x: x.sample(n=int(n_samples * len(x) / len(df)), random_state=42).index,
        include_groups=False
    )
    # flatten the indices and select rows
    all_indices = [idx for indices in sampled_indices for idx in indices]
    df = df.loc[all_indices]
    df = df.sample(frac=1, random_state=42)
    return df.reset_index(drop=True)


news_df = load_news_dataset()
print("Dataset Shape:", news_df.shape)
print("\nColumn Names:", news_df.columns.tolist())
print("\nCategory Distribution:")
print(news_df['category'].value_counts())
print("\nSample articles:")
print(news_df[['title', 'category']].head(2))

Below we perform a crucial step in machine learning: splitting our dataset into training and test sets using an 80/20 ratio.

**Why is this necessary?** To properly evaluate our classifier, we need to test it on data it has never seen before. If we trained and tested on the same data, we couldn't tell if the model truly learned patterns or just memorized the training examples. The test set simulates real-world usage where the model encounters new, unseen text.

**The approach:** We use a simple split where the first 80% of documents become the training set (used to learn word probabilities and patterns) and the remaining 20% become the test set (used to evaluate performance). The texts and labels are extracted into separate lists for convenience in later processing.

In [None]:
### DO NOT MODIFY THIS CELL ###
split_index = int(0.8 * len(news_df))
train_df = news_df[:split_index].reset_index(drop=True)
test_df = news_df[split_index:].reset_index(drop=True)

train_texts = train_df['text'].tolist()
train_labels = train_df['category'].tolist()
test_texts = test_df['text'].tolist()
test_labels = test_df['category'].tolist()

print(f"\nTraining set size: {len(train_texts)}")
print(f"Testing set size: {len(test_texts)}")

### 5.2: Preprocess Text (2 points)

Before building our classifier, we need to tokenize the text (convert to lowercase and split into words).

#### 5.2.1: Implement Text Preprocessing (2 points)

Implement a function that takes raw text and returns a list of lowercase word tokens. Convert the text to lowercase, replace all non-alphabetic characters (except spaces) with empty strings, split on whitespace, and filter out empty strings.

Arguments:
- `text (str)`: The raw text to preprocess.

Returns:
- `tokens (List[str])`: A list of lowercase word tokens.

Examples:
```python
>>> preprocess_text("Hello World! This is NLP.")
['hello', 'world', 'this', 'is', 'nlp']
>>> preprocess_text("AI and ML are cool!")
['ai', 'and', 'ml', 'are', 'cool']
```

In [None]:
def preprocess_text(text: str) -> List[str]:
    ### WRITE YOUR CODE BELOW

    text = text.lower()
    # all non-alphabetic characters except spaces -> empty strings
    text = re.sub(r'[^a-z ]', '', text)
    
    tokens = [word for word in text.split() if word]
    ### END CODE HERE
    return tokens


### 5.3: Train the Naive Bayes Classifier (11 points)

Now we'll implement the training phase of our Naive Bayes classifier. Training involves computing:

1. **Prior probabilities**: $P(class)$ - the probability of each class based on training data
2. **Likelihood probabilities**: $P(word|class)$ - the probability of each word given a class
3. **Vocabulary**: the set of all unique words in the training data

#### 5.3.1: Compute Prior Probabilities (3 points)

Implement a function that computes the prior probability for each class, which represents the fraction of training documents belonging to each category.

**The approach:** For each unique class label in the training data, count how many documents belong to that class, then divide by the total number of training documents. This gives us $P(class)$, which represents our baseline expectation of seeing each category before looking at the document's content.

Arguments:
- `labels (List[str])`: List of class labels from training data.

Returns:
- `priors (Dict[str, float])`: Dictionary mapping each class to its prior probability.

Examples:
```python
>>> labels = ["tech", "tech", "sports", "sports", "tech"]
>>> compute_priors(labels)
{'tech': 0.6, 'sports': 0.4}
```

In [None]:
def compute_priors(labels: List[str]) -> Dict[str, float]:
    ### WRITE YOUR CODE BELOW
    priors = {}
    total = len(labels)
    # documents per class
    for label in labels:
        priors[label] = priors.get(label, 0) + 1
    
    # counts --> probabilities conversion
    for label in priors:
        priors[label] = priors[label] / total
    ### END CODE HERE
    return priors


#### 5.3.2: Compute Word Counts per Class (3 points)

Implement a function that builds the word frequency statistics needed for computing likelihood probabilities in Naive Bayes classification.

**The approach:** This function processes all training documents to create a comprehensive statistical profile for each class. It tracks three key pieces of information:
1. How many times each word appears in documents of each class (word counts per class)
2. The total number of words in all documents of each class (for normalization)
3. The complete vocabulary of unique words across all training documents

Arguments:
- `texts (List[str])`: List of text documents.
- `labels (List[str])`: Corresponding list of class labels.

Returns:
- `word_counts (Dict[str, Dict[str, int]])`: Nested dictionary where `word_counts[class][word]` gives the count of `word` in documents of `class`.
- `class_total_words (Dict[str, int])`: Dictionary mapping each class to its total word count.
- `vocabulary (set)`: Set of all unique words across all documents.

Examples:
```python
>>> texts = ["hello world", "hello there"]
>>> labels = ["a", "b"]
>>> word_counts, class_totals, vocab = compute_word_counts(texts, labels)
>>> word_counts['a']['hello']
1
>>> class_totals['a']
2
>>> 'world' in vocab
True
```

In [None]:
def compute_word_counts(texts: List[str], labels: List[str]) -> Tuple[Dict[str, Dict[str, int]], Dict[str, int], set]:
    ### WRITE YOUR CODE BELOW
    word_counts = {}
    class_total_words = {}
    vocabulary = set()
    
    for text, label in zip(texts, labels):
        tokens = preprocess_text(text)
        
        if label not in word_counts:
            word_counts[label] = {}
        if label not in class_total_words:
            class_total_words[label] = 0
        
        for word in tokens:
            word_counts[label][word] = word_counts[label].get(word, 0) + 1
            class_total_words[label] += 1
            vocabulary.add(word)
    ### END CODE HERE
    return word_counts, class_total_words, vocabulary


#### 5.3.3: Compute Word Likelihood with Laplace Smoothing (5 points)

Implement a function that computes the log probability of a word given a class, using Laplace (add-1) smoothing to handle unseen words.

**Formula with Laplace Smoothing:**
$$P(word|class) = \frac{count(word, class) + 1}{total\_words\_in\_class + vocabulary\_size}$$

We return the **log probability** to avoid underflow issues when multiplying many small probabilities.

Arguments:
- `word (str)`: The word to compute probability for.
- `class_label (str)`: The class to compute probability for.
- `word_counts (Dict[str, Dict[str, int]])`: Word counts per class.
- `class_total_words (Dict[str, int])`: Total words per class.
- `vocab_size (int)`: Size of the vocabulary.

Returns:
- `log_prob (float)`: The log probability of the word given the class.

Examples:
```python
>>> word_counts = {'a': {'hello': 5, 'world': 3}, 'b': {'hello': 2}}
>>> class_totals = {'a': 8, 'b': 2}
>>> vocab_size = 3
>>> compute_word_log_likelihood('hello', 'a', word_counts, class_totals, vocab_size)
-0.6061...  # log((5+1)/(8+3))
```

**Note:** Use `math.log` for the logarithm.


In [None]:
def compute_word_log_likelihood(word: str, class_label: str, word_counts: Dict[str, Dict[str, int]],
                                  class_total_words: Dict[str, int], vocab_size: int) -> float:
    ### WRITE YOUR CODE BELOW
    # word count for class (0 if word not seen)
    word_count = word_counts.get(class_label, {}).get(word, 0)
    # total words for class
    total_words_in_class = class_total_words.get(class_label, 0)
    
    # Laplace smoothing: P(word|class) = (count + 1) / (total_words + vocab_size)
    prob = (word_count + 1) / (total_words_in_class + vocab_size)
    
    log_prob = math.log(prob)
    ### END CODE HERE
    return log_prob


#### 5.3.4: Train the Classifier [ungraded]

Run the cell below to train the Naive Bayes classifier on the training data. This computes all the parameters needed for classification.

In [None]:
### DO NOT MODIFY THIS CELL ###

# Compute prior probabilities
priors = compute_priors(train_labels)
print("Prior probabilities:")
for cls, prob in sorted(priors.items()):
    print(f"  P({cls}) = {prob:.4f}")

# Compute word counts and vocabulary
word_counts, class_total_words, vocabulary = compute_word_counts(train_texts, train_labels)
vocab_size = len(vocabulary)

print(f"\nVocabulary size: {vocab_size} unique words")
print("\nTotal words per class:")
for cls, count in sorted(class_total_words.items()):
    print(f"  {cls}: {count} words")


### 5.4: Classify New Documents (5 points)

Now implement the prediction function that uses the trained model to classify new documents.

#### 5.4.1: Predict Class for a Single Document (5 points)

Implement a function that predicts the most likely class for a given document. For each class, compute:
$$\log P(class|document) = \log P(class) + \sum_{word \in document} \log P(word|class)$$

Return the class with the highest log probability.

Arguments:
- `text (str)`: The document to classify.
- `priors (Dict[str, float])`: Prior probabilities for each class.
- `word_counts (Dict[str, Dict[str, int]])`: Word counts per class.
- `class_total_words (Dict[str, int])`: Total words per class.
- `vocab_size (int)`: Size of the vocabulary.

Returns:
- `predicted_class (str)`: The predicted class label.
- `log_probs (Dict[str, float])`: Dictionary mapping each class to its log probability for this document.

Examples:
```python
>>> text = "new smartphone app released for mobile devices"
>>> predicted, log_probs = predict_class(text, priors, word_counts, class_total_words, vocab_size)
>>> predicted
'tech'
```

**Note:** Use `math.log` for the logarithm.

In [None]:
def predict_class(text: str, priors: Dict[str, float], word_counts: Dict[str, Dict[str, int]],
                  class_total_words: Dict[str, int], vocab_size: int) -> Tuple[str, Dict[str, float]]:
    ### WRITE YOUR CODE BELOW
    # Preprocess the text
    tokens = preprocess_text(text)
    
    log_probs = {}
    # log P(class|document) = log P(class) + sum(log P(word|class))
    for class_label in priors:
        # log prior
        log_prob = math.log(priors[class_label])
        # log likelihood for each word
        for word in tokens:
            log_prob += compute_word_log_likelihood(word, class_label, word_counts, 
                                                   class_total_words, vocab_size)
        log_probs[class_label] = log_prob
    
    # class with highest log probability
    predicted_class = max(log_probs.items(), key=lambda x: x[1])[0]
    ### END CODE HERE
    return predicted_class, log_probs


#### 5.4.2: Test on Sample Documents [ungraded]

Run the cell below to test your classifier on some sample news headlines.

In [None]:
### DO NOT MODIFY THIS CELL ###
# Test the classifier on sample documents
sample_texts = [
    "New artificial intelligence model can write code and solve complex problems",
    "Championship game goes into overtime as teams battle for victory",
    "Stock prices surge as company reports record quarterly profits",
    "Study finds exercise reduces risk of heart disease significantly",
    "Senator proposes new bill to reform healthcare policy nationwide"
]

print("Sample Predictions:")
print("=" * 70)
for text in sample_texts:
    predicted, log_probs = predict_class(text, priors, word_counts, class_total_words, vocab_size)
    print(f"\nText: {text[:60]}...")
    print(f"Predicted: {predicted}")
    print("Log probabilities:", {k: f"{v:.2f}" for k, v in sorted(log_probs.items())})

### 5.5: Evaluate the Classifier (3 points)

Finally, let's evaluate how well our classifier performs on the test set.

#### 5.5.1: Compute Accuracy (3 points)

Implement a function that computes the classification accuracy on a test set.

Arguments:
- `test_texts (List[str])`: List of test documents.
- `test_labels (List[str])`: True labels for test documents.
- `priors (Dict[str, float])`: Prior probabilities.
- `word_counts (Dict[str, Dict[str, int]])`: Word counts per class.
- `class_total_words (Dict[str, int])`: Total words per class.
- `vocab_size (int)`: Vocabulary size.

Returns:
- `accuracy (float)`: Proportion of correctly classified documents (between 0 and 1).
- `predictions (List[str])`: List of predicted labels for each test document.

Examples:
```python
>>> accuracy, predictions = compute_accuracy(test_texts, test_labels, priors, word_counts, class_total_words, vocab_size)
>>> print(f"Accuracy: {accuracy:.2%}")
Accuracy: 85.00%
```

In [None]:
def compute_accuracy(test_texts: List[str], test_labels: List[str], priors: Dict[str, float],
                     word_counts: Dict[str, Dict[str, int]], class_total_words: Dict[str, int],
                     vocab_size: int) -> Tuple[float, List[str]]:
    ### WRITE YOUR CODE BELOW
    predictions = []
    correct = 0
    
    # predict
    for text in test_texts:
        predicted, _ = predict_class(text, priors, word_counts, class_total_words, vocab_size)
        predictions.append(predicted)
    
    # correct
    for i in range(len(test_labels)):
        if predictions[i] == test_labels[i]:
            correct += 1
    
    # accuracy
    accuracy = correct / len(test_labels) if len(test_labels) > 0 else 0.0
    ### END CODE HERE
    return accuracy, predictions


#### 5.5.2: Evaluate on Test Set [ungraded]

Run the cell below to evaluate your Naive Bayes classifier on the test set and see detailed results.

In [None]:
### DO NOT MODIFY THIS CELL ###
accuracy, predictions = compute_accuracy(test_texts, test_labels, priors, word_counts, class_total_words, vocab_size)

print("=" * 60)
print("NAIVE BAYES CLASSIFIER EVALUATION")
print("=" * 60)
print(f"\nOverall Accuracy: {accuracy:.2%} ({int(accuracy * len(test_labels))}/{len(test_labels)} correct)")

# Compute per-class accuracy
print("\nPer-class Results:")
classes = sorted(set(test_labels))
for cls in classes:
    # Get indices for this class
    cls_indices = [i for i, label in enumerate(test_labels) if label == cls]
    cls_correct = sum(1 for i in cls_indices if predictions[i] == cls)
    cls_total = len(cls_indices)
    cls_accuracy = cls_correct / cls_total if cls_total > 0 else 0
    print(f"  {cls}: {cls_accuracy:.2%} ({cls_correct}/{cls_total})")

# Show some misclassified examples
print("\nMisclassified Examples:")
misclassified = [(text, true, pred) for text, true, pred in zip(test_texts, test_labels, predictions) if true != pred]
for text, true, pred in misclassified[:3]:
    print(f"  Text: {text[:50]}...")
    print(f"  True: {true}, Predicted: {pred}\n")