In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
import numpy as np
from typing import List, Union, Optional, Dict


# ============================================================================
# 1. BAG OF WORDS (BoW) - Enhanced Implementation
# ============================================================================

# Imagine you have a bag (like a shopping bag) and you throw all the words from a
# document into it:

# Word order is lost - you can't tell what came first
# You can only count how many times each word appears
# The "bag" represents the document as word frequencies

class BagOfWords:
    """
    Represents documents as word count vectors.

    Features:
    - Ignores word order completely (bag assumption)
    - Supports binary (presence/absence) or count mode
    - Vocabulary size limiting
    - Proper handling of unknown words
    - Dense or sparse output options

    Example:
        >>> bow = BagOfWords(max_features=1000, binary=False)
        >>> docs = [["hello", "world"], ["hello", "python"]]
        >>> X = bow.fit_transform(docs)
        >>> print(X.shape)  # (2, 3) - 2 docs, 3 unique words
    """

    def __init__(
        self,
        binary: bool = False,
        max_features: Optional[int] = None,
        min_df: int = 1,
        dtype: torch.dtype = torch.float32
    ):
        """
        Initialize Bag of Words vectorizer.

        Args:
            binary: If True, use 1/0 (presence/absence) instead of counts
            max_features: Maximum vocabulary size (keeps most frequent words)
            min_df: Minimum document frequency (ignore rare words)
            dtype: PyTorch dtype for output tensors
        """
        self.binary = binary
        self.max_features = max_features
        self.min_df = min_df
        self.dtype = dtype

        # Will be set during fit()
        self.vocab: Dict[str, int] = {}
        self.vocab_size_: int = 0
        self.is_fitted_: bool = False

    def fit(self, documents: List[List[str]]) -> 'BagOfWords':
        """
        Learn vocabulary from training documents.

        Args:
            documents: List of tokenized documents (each doc is list of words)

        Returns:
            self (for method chaining)

        Raises:
            ValueError: If documents is empty or invalid
        """
        # Input validation
        if not documents:
            raise ValueError("Cannot fit on empty document list")

        if not all(isinstance(doc, (list, tuple)) for doc in documents):
            raise ValueError("All documents must be lists or tuples of tokens")

        # Step 1: Count document frequencies (how many docs contain each word)
        doc_freq = Counter()
        for doc in documents:
            unique_words = set(doc)  # Only count once per document
            doc_freq.update(unique_words)

        # Step 2: Filter by minimum document frequency
        valid_words = {
            word for word, freq in doc_freq.items()
            if freq >= self.min_df
        }

        # Step 3: Limit vocabulary size if specified
        if self.max_features is not None and len(valid_words) > self.max_features:
            # Keep most frequent words
            word_counts = Counter()
            for doc in documents:
                word_counts.update(doc)

            # Get top max_features words
            most_common = word_counts.most_common(self.max_features)
            valid_words = {word for word, _ in most_common if word in valid_words}

        # Step 4: Create vocabulary mapping (sorted for reproducibility)
        self.vocab = {
            word: idx
            for idx, word in enumerate(sorted(valid_words))
        }
        self.vocab_size_ = len(self.vocab)
        self.is_fitted_ = True

        return self

    def transform(self, documents: List[List[str]]) -> torch.Tensor:
        """
        Transform documents to BoW representation.

        Args:
            documents: List of tokenized documents

        Returns:
            Tensor of shape (n_documents, vocab_size)

        Raises:
            RuntimeError: If called before fit()
            ValueError: If documents is invalid
        """
        if not self.is_fitted_:
            raise RuntimeError("BagOfWords must be fitted before transform()")

        if not documents:
            raise ValueError("Cannot transform empty document list")

        # Initialize output matrix
        bow_matrix = torch.zeros(
            len(documents),
            self.vocab_size_,
            dtype=self.dtype
        )

        # Fill matrix
        for doc_idx, doc in enumerate(documents):
            if not doc:  # Handle empty documents
                continue

            counts = Counter(doc)

            for word, count in counts.items():
                # Only process words in vocabulary (ignore unknown words)
                if word in self.vocab:
                    word_idx = self.vocab[word]
                    if self.binary:
                        bow_matrix[doc_idx, word_idx] = 1.0
                    else:
                        bow_matrix[doc_idx, word_idx] = float(count)

        return bow_matrix

    def fit_transform(self, documents: List[List[str]]) -> torch.Tensor:
        """
        Fit vocabulary and transform documents in one step.

        Args:
            documents: List of tokenized documents

        Returns:
            Tensor of shape (n_documents, vocab_size)
        """
        return self.fit(documents).transform(documents)

    def get_feature_names(self) -> List[str]:
        """
        Get list of feature names (words in vocabulary).

        Returns:
            List of words, ordered by their index

        Raises:
            RuntimeError: If called before fit()
        """
        if not self.is_fitted_:
            raise RuntimeError("BagOfWords must be fitted before getting feature names")

        # Sort by index to get correct order
        return [word for word, _ in sorted(self.vocab.items(), key=lambda x: x[1])]

    def inverse_transform(self, bow_matrix: torch.Tensor) -> List[List[str]]:
        """
        Convert BoW vectors back to approximate word lists.
        Note: Order is lost, counts are preserved.

        Args:
            bow_matrix: Tensor of shape (n_documents, vocab_size)

        Returns:
            List of word lists (repeated according to counts)
        """
        if not self.is_fitted_:
            raise RuntimeError("BagOfWords must be fitted before inverse_transform()")

        idx_to_word = {idx: word for word, idx in self.vocab.items()}
        documents = []

        for doc_vector in bow_matrix:
            words = []
            for idx, count in enumerate(doc_vector):
                if count > 0:
                    word = idx_to_word[idx]
                    # Repeat word according to count (or just once if binary)
                    repetitions = 1 if self.binary else int(count.item())
                    words.extend([word] * repetitions)
            documents.append(words)

        return documents

    def __repr__(self) -> str:
        """String representation."""
        params = f"binary={self.binary}, max_features={self.max_features}, min_df={self.min_df}"
        if self.is_fitted_:
            params += f", vocab_size={self.vocab_size_}"
        return f"BagOfWords({params})"


# ============================================================================
# DEMONSTRATION & TESTING
# ============================================================================


# 1. PREPARE DATA (already tokenized)
# train_docs = [
#     ["i", "love", "machine", "learning"],
#     ["i", "love", "deep", "learning"],
#     ["machine", "learning", "is", "fun"],
#     ["deep", "learning", "is", "powerful"]
# ]

# test_docs = [
#     ["i", "love", "learning"],
#     ["quantum", "computing"]  # New words!
# ]

# # 2. CREATE AND FIT MODEL
# bow = BagOfWords(binary=False, max_features=5)
# X_train = bow.fit_transform(train_docs)

# print("Vocabulary:", bow.get_feature_names())
# # Output: ['deep', 'i', 'is', 'learning', 'love']
# # (Only 5 words kept due to max_features=5)

# print("Training matrix:")
# print(X_train)
# # Output:
# #         deep  i  is  learning  love
# # Doc 0   [0    1   0    1        1]  "i love machine learning"
# # Doc 1   [1    1   0    1        1]  "i love deep learning"
# # Doc 2   [0    0   1    1        0]  "machine learning is fun"
# # Doc 3   [1    0   1    1        0]  "deep learning is powerful"

# # 3. TRANSFORM NEW DATA
# X_test = bow.transform(test_docs)
# print("Test matrix:")
# print(X_test)
# # Output:
# #         deep  i  is  learning  love
# # Doc 0   [0    1   0    1        1]  "i love learning" ✓
# # Doc 1   [0    0   0    0        0]  "quantum computing" (all unknown!)


# Limitations of BoW
# 1. Word Order Lost
# python"not good" vs "good not" → SAME vector!
# "dog bites man" vs "man bites dog" → SAME vector!
# 2. No Semantic Understanding
# python"car" and "automobile" → Treated as completely different words
# "happy" and "joyful" → No relationship captured
# 3. Sparse High-Dimensional Vectors
# pythonVocabulary size: 10,000 words
# Most documents use: 50-100 words
# Result: 99% of vector is zeros! (very sparse)
# 4. Context Ignored
# python"bank" in "river bank" vs "bank account" → Same representation!


# When to Use BoW
# ✅ Use BoW for:

# Quick baseline models
# Text classification (spam detection, sentiment analysis)
# Document similarity with simple metrics
# Small to medium vocabularies (<10,000 words)
# When word order doesn't matter much


if __name__ == "__main__":
    print("=" * 70)
    print("BAG OF WORDS DEMONSTRATION")
    print("=" * 70)

    # Sample documents (already tokenized)
    train_docs = [
        ["cat", "dog", "dog"],
        ["cat", "bird"],
        ["dog", "bird", "bird"],
        ["fish"]
    ]

    test_docs = [
        ["cat", "cat", "dog"],
        ["bird", "fish"],
        ["unknown", "word", "test"]  # Contains unknown words
    ]




    # Test 1: Basic count mode
    print("\n1. COUNT MODE (default)")
    print("-" * 70)
    bow_count = BagOfWords()
    X_train = bow_count.fit_transform(train_docs)
    X_test = bow_count.transform(test_docs)

    print(f"Vocabulary: {bow_count.get_feature_names()}")
    print(f"Training matrix shape: {X_train.shape}")
    print(f"Training matrix:\n{X_train}")
    print(f"\nTest matrix:\n{X_test}")

    # Test 2: Binary mode
    print("\n2. BINARY MODE (presence/absence)")
    print("-" * 70)
    bow_binary = BagOfWords(binary=True)
    X_binary = bow_binary.fit_transform(train_docs)
    print(f"Binary matrix:\n{X_binary}")

    # Test 3: Max features
    print("\n3. MAX FEATURES (vocabulary size limit)")
    print("-" * 70)
    bow_limited = BagOfWords(max_features=3)
    X_limited = bow_limited.fit_transform(train_docs)
    print(f"Limited vocabulary: {bow_limited.get_feature_names()}")
    print(f"Limited matrix:\n{X_limited}")

    # Test 4: Min document frequency
    print("\n4. MIN DOCUMENT FREQUENCY (filter rare words)")
    print("-" * 70)
    bow_filtered = BagOfWords(min_df=2)  # Word must appear in at least 2 docs
    X_filtered = bow_filtered.fit_transform(train_docs)
    print(f"Filtered vocabulary: {bow_filtered.get_feature_names()}")
    print(f"Filtered matrix:\n{X_filtered}")

    # Test 5: Inverse transform
    print("\n5. INVERSE TRANSFORM (reconstruct documents)")
    print("-" * 70)
    reconstructed = bow_count.inverse_transform(X_train)
    print("Original:      ", train_docs[0])
    print("Reconstructed: ", reconstructed[0])
    print("Note: Order is lost, but counts are preserved")

    # Test 6: Edge cases
    print("\n6. EDGE CASES")
    print("-" * 70)
    bow_edge = BagOfWords()

    # Empty document
    docs_with_empty = [["hello", "world"], [], ["test"]]
    X_edge = bow_edge.fit_transform(docs_with_empty)
    print(f"Documents with empty doc: {docs_with_empty}")
    print(f"Result:\n{X_edge}")

    print("\n" + "=" * 70)
    print("All tests completed successfully!")
    print("=" * 70)

BAG OF WORDS DEMONSTRATION

1. COUNT MODE (default)
----------------------------------------------------------------------
Vocabulary: ['bird', 'cat', 'dog', 'fish']
Training matrix shape: torch.Size([4, 4])
Training matrix:
tensor([[0., 1., 2., 0.],
        [1., 1., 0., 0.],
        [2., 0., 1., 0.],
        [0., 0., 0., 1.]])

Test matrix:
tensor([[0., 2., 1., 0.],
        [1., 0., 0., 1.],
        [0., 0., 0., 0.]])

2. BINARY MODE (presence/absence)
----------------------------------------------------------------------
Binary matrix:
tensor([[0., 1., 1., 0.],
        [1., 1., 0., 0.],
        [1., 0., 1., 0.],
        [0., 0., 0., 1.]])

3. MAX FEATURES (vocabulary size limit)
----------------------------------------------------------------------
Limited vocabulary: ['bird', 'cat', 'dog']
Limited matrix:
tensor([[0., 1., 2.],
        [1., 1., 0.],
        [2., 0., 1.],
        [0., 0., 0.]])

4. MIN DOCUMENT FREQUENCY (filter rare words)
--------------------------------------------

In [None]:
import torch
import torch.nn as nn
from collections import Counter
from typing import List, Tuple, Union, Optional, Dict, Set
import itertools

# ============================================================================
# 2. N-GRAMS - Enhanced Implementation with Local Word Order
# ============================================================================


# N-grams are sequences of N consecutive tokens (words or characters). Unlike Bag
# of Words, they capture local word order.


class NGramVectorizer:
    """
    Extract n-grams (sequences of n consecutive words/characters).
    Captures local word order unlike basic Bag of Words.

    Features:
    - Support for multiple n-gram ranges (e.g., unigrams + bigrams)
    - Word-level or character-level n-grams
    - Binary or count mode
    - Vocabulary size limiting
    - Proper boundary handling

    Examples:
        >>> # Bigrams only
        >>> ngram = NGramVectorizer(n=2)
        >>> docs = [["the", "cat", "sat"], ["the", "dog", "ran"]]
        >>> X = ngram.fit_transform(docs)

        >>> # Unigrams + Bigrams + Trigrams
        >>> ngram = NGramVectorizer(ngram_range=(1, 3))
        >>> X = ngram.fit_transform(docs)

        >>> # Character-level n-grams
        >>> ngram = NGramVectorizer(n=3, analyzer='char')
        >>> docs = ["hello", "world"]
        >>> X = ngram.fit_transform(docs)
    """

    def __init__(
        self,
        n: Optional[int] = None,
        ngram_range: Optional[Tuple[int, int]] = None,
        analyzer: str = 'word',
        binary: bool = False,
        max_features: Optional[int] = None,
        min_df: int = 1,
        token_separator: str = ' ',
        dtype: torch.dtype = torch.float32
    ):
        """
        Initialize N-Gram Vectorizer.

        Args:
            n: Size of n-grams (2=bigrams, 3=trigrams).
               Cannot be used with ngram_range.
            ngram_range: Tuple (min_n, max_n) for multiple n-gram sizes.
                        E.g., (1,2) = unigrams + bigrams.
                        Cannot be used with n.
            analyzer: 'word' for word-level or 'char' for character-level n-grams
            binary: If True, use presence/absence instead of counts
            max_features: Maximum vocabulary size (keeps most frequent n-grams)
            min_df: Minimum document frequency (ignore rare n-grams)
            token_separator: String to join tokens in n-gram representation
            dtype: PyTorch dtype for output tensors

        Raises:
            ValueError: If both n and ngram_range are specified, or neither
        """
        # Validate parameters
        if n is not None and ngram_range is not None:
            raise ValueError("Cannot specify both 'n' and 'ngram_range'")
        if n is None and ngram_range is None:
            raise ValueError("Must specify either 'n' or 'ngram_range'")

        # Set n-gram range
        if n is not None:
            self.ngram_range = (n, n)
        else:
            self.ngram_range = ngram_range

        if self.ngram_range[0] < 1:
            raise ValueError(f"min n must be >= 1, got {self.ngram_range[0]}")
        if self.ngram_range[0] > self.ngram_range[1]:
            raise ValueError(f"min n ({self.ngram_range[0]}) must be <= max n ({self.ngram_range[1]})")

        if analyzer not in ['word', 'char']:
            raise ValueError(f"analyzer must be 'word' or 'char', got '{analyzer}'")

        self.analyzer = analyzer
        self.binary = binary
        self.max_features = max_features
        self.min_df = min_df
        self.token_separator = token_separator
        self.dtype = dtype

        # Will be set during fit()
        self.vocab: Dict[str, int] = {}
        self.vocab_size_: int = 0
        self.is_fitted_: bool = False

    def _get_ngrams(self, tokens: Union[List[str], str]) -> List[str]:
        """
        Extract n-grams from token list or string.

        Args:
            tokens: List of words or string (for character n-grams)

        Returns:
            List of n-gram strings
        """
        ngrams = []
        min_n, max_n = self.ngram_range

        # Handle character-level n-grams
        if self.analyzer == 'char':
            # Convert to string if needed
            if isinstance(tokens, list):
                tokens = ''.join(tokens)

            for n in range(min_n, max_n + 1):
                for i in range(len(tokens) - n + 1):
                    ngram = tokens[i:i+n]
                    ngrams.append(ngram)

        # Handle word-level n-grams
        else:
            if isinstance(tokens, str):
                raise ValueError("For word analyzer, input must be list of tokens")

            for n in range(min_n, max_n + 1):
                # Skip if document is too short for this n
                if len(tokens) < n:
                    continue

                for i in range(len(tokens) - n + 1):
                    ngram = self.token_separator.join(tokens[i:i+n])
                    ngrams.append(ngram)

        return ngrams

    def fit(self, documents: Union[List[List[str]], List[str]]) -> 'NGramVectorizer':
        """
        Learn vocabulary from training documents.

        Args:
            documents: List of tokenized documents (word analyzer) or
                      list of strings (char analyzer)

        Returns:
            self (for method chaining)

        Raises:
            ValueError: If documents is empty or invalid
        """
        if not documents:
            raise ValueError("Cannot fit on empty document list")

        # Validate input format
        if self.analyzer == 'word':
            if not all(isinstance(doc, (list, tuple)) for doc in documents):
                raise ValueError("For word analyzer, documents must be lists of tokens")
        elif self.analyzer == 'char':
            if not all(isinstance(doc, str) for doc in documents):
                raise ValueError("For char analyzer, documents must be strings")

        # Step 1: Count document frequencies
        doc_freq = Counter()
        all_ngrams = []  # For total frequency counting

        for doc in documents:
            ngrams = self._get_ngrams(doc)
            unique_ngrams = set(ngrams)
            doc_freq.update(unique_ngrams)
            all_ngrams.extend(ngrams)

        # Step 2: Filter by minimum document frequency
        valid_ngrams = {
            ngram for ngram, freq in doc_freq.items()
            if freq >= self.min_df
        }

        # Step 3: Limit vocabulary size if specified
        if self.max_features is not None and len(valid_ngrams) > self.max_features:
            # Count total frequencies
            ngram_counts = Counter(all_ngrams)

            # Keep only valid n-grams
            valid_counts = {
                ngram: count for ngram, count in ngram_counts.items()
                if ngram in valid_ngrams
            }

            # Get top max_features
            most_common = Counter(valid_counts).most_common(self.max_features)
            valid_ngrams = {ngram for ngram, _ in most_common}

        # Step 4: Create vocabulary mapping (sorted for reproducibility)
        self.vocab = {
            ngram: idx
            for idx, ngram in enumerate(sorted(valid_ngrams))
        }
        self.vocab_size_ = len(self.vocab)
        self.is_fitted_ = True

        return self

    def transform(self, documents: Union[List[List[str]], List[str]]) -> torch.Tensor:
        """
        Transform documents to n-gram representation.

        Args:
            documents: List of tokenized documents or strings

        Returns:
            Tensor of shape (n_documents, vocab_size)

        Raises:
            RuntimeError: If called before fit()
            ValueError: If documents is invalid
        """
        if not self.is_fitted_:
            raise RuntimeError("NGramVectorizer must be fitted before transform()")

        if not documents:
            raise ValueError("Cannot transform empty document list")

        # Initialize output matrix
        ngram_matrix = torch.zeros(
            len(documents),
            self.vocab_size_,
            dtype=self.dtype
        )

        # Fill matrix
        for doc_idx, doc in enumerate(documents):
            # Handle empty documents
            if not doc:
                continue

            ngrams = self._get_ngrams(doc)
            counts = Counter(ngrams)

            for ngram, count in counts.items():
                # Only process n-grams in vocabulary
                if ngram in self.vocab:
                    ngram_idx = self.vocab[ngram]
                    if self.binary:
                        ngram_matrix[doc_idx, ngram_idx] = 1.0
                    else:
                        ngram_matrix[doc_idx, ngram_idx] = float(count)

        return ngram_matrix

    def fit_transform(
        self,
        documents: Union[List[List[str]], List[str]]
    ) -> torch.Tensor:
        """
        Fit vocabulary and transform documents in one step.

        Args:
            documents: List of tokenized documents or strings

        Returns:
            Tensor of shape (n_documents, vocab_size)
        """
        return self.fit(documents).transform(documents)

    def get_feature_names(self) -> List[str]:
        """
        Get list of feature names (n-grams in vocabulary).

        Returns:
            List of n-gram strings, ordered by their index

        Raises:
            RuntimeError: If called before fit()
        """
        if not self.is_fitted_:
            raise RuntimeError("NGramVectorizer must be fitted before getting feature names")

        return [ngram for ngram, _ in sorted(self.vocab.items(), key=lambda x: x[1])]

    def inverse_transform(self, ngram_matrix: torch.Tensor) -> List[List[str]]:
        """
        Convert n-gram vectors back to approximate token lists.
        Note: Original text cannot be perfectly reconstructed from n-grams.

        Args:
            ngram_matrix: Tensor of shape (n_documents, vocab_size)

        Returns:
            List of n-gram lists
        """
        if not self.is_fitted_:
            raise RuntimeError("NGramVectorizer must be fitted before inverse_transform()")

        idx_to_ngram = {idx: ngram for ngram, idx in self.vocab.items()}
        documents = []

        for doc_vector in ngram_matrix:
            ngrams = []
            for idx, count in enumerate(doc_vector):
                if count > 0:
                    ngram = idx_to_ngram[idx]
                    repetitions = 1 if self.binary else int(count.item())
                    ngrams.extend([ngram] * repetitions)
            documents.append(ngrams)

        return documents

    def __repr__(self) -> str:
        """String representation."""
        params = (
            f"ngram_range={self.ngram_range}, "
            f"analyzer='{self.analyzer}', "
            f"binary={self.binary}, "
            f"max_features={self.max_features}, "
            f"min_df={self.min_df}"
        )
        if self.is_fitted_:
            params += f", vocab_size={self.vocab_size_}"
        return f"NGramVectorizer({params})"



# N-grams improve upon Bag of Words by:

# ✅ Capturing local word order
# ✅ Representing phrases and idioms
# ✅ Distinguishing word sequences

# But with trade-offs:

# ❌ Much larger vocabulary (exponential growth)
# ❌ Still no long-range dependencies
# ❌ Still no semantic understanding



# ============================================================================
# DEMONSTRATION & TESTING
# ============================================================================
if __name__ == "__main__":
    print("=" * 80)
    print("N-GRAM VECTORIZER DEMONSTRATION")
    print("=" * 80)

    # Sample documents (already tokenized)
    train_docs = [
        ["the", "cat", "sat", "on", "mat"],
        ["the", "dog", "sat", "on", "log"],
        ["the", "cat", "and", "dog"],
    ]

    test_docs = [
        ["the", "cat", "sat"],
        ["dog", "on", "mat"]
    ]

    # ========================================================================
    # Test 1: Basic Bigrams
    # ========================================================================
    print("\n1. BIGRAMS ONLY (n=2)")
    print("-" * 80)
    bigram = NGramVectorizer(n=2)
    X_train = bigram.fit_transform(train_docs)
    X_test = bigram.transform(test_docs)

    print(f"Vocabulary ({bigram.vocab_size_} bigrams):")
    for i, ngram in enumerate(bigram.get_feature_names()):
        print(f"  [{i}] '{ngram}'")

    print(f"\nTraining matrix shape: {X_train.shape}")
    print("Training matrix:")
    print(X_train)

    print("\nExample: Document 'the cat sat on mat'")
    print("Bigrams: ['the cat', 'cat sat', 'sat on', 'on mat']")
    print(f"Vector: {X_train[0]}")

    # ========================================================================
    # Test 2: Unigrams + Bigrams (captures both individual words and pairs)
    # ========================================================================
    print("\n2. UNIGRAMS + BIGRAMS (ngram_range=(1,2))")
    print("-" * 80)
    combined = NGramVectorizer(ngram_range=(1, 2))
    X_combined = combined.fit_transform(train_docs)

    print(f"Vocabulary ({combined.vocab_size_} features):")
    features = combined.get_feature_names()
    unigrams = [f for f in features if ' ' not in f]
    bigrams = [f for f in features if ' ' in f]
    print(f"  Unigrams ({len(unigrams)}): {unigrams[:5]}...")
    print(f"  Bigrams ({len(bigrams)}): {bigrams[:5]}...")

    print(f"\nCombined matrix shape: {X_combined.shape}")

    # ========================================================================
    # Test 3: Trigrams (3-word sequences)
    # ========================================================================
    print("\n3. TRIGRAMS (n=3)")
    print("-" * 80)
    trigram = NGramVectorizer(n=3)
    X_trigram = trigram.fit_transform(train_docs)

    print(f"Vocabulary ({trigram.vocab_size_} trigrams):")
    for ngram in trigram.get_feature_names():
        print(f"  '{ngram}'")

    print("\nTrigram matrix:")
    print(X_trigram)

    # ========================================================================
    # Test 4: Character-level n-grams
    # ========================================================================
    print("\n4. CHARACTER-LEVEL N-GRAMS (analyzer='char', n=3)")
    print("-" * 80)
    char_docs = ["hello", "helloworld", "world"]

    char_ngram = NGramVectorizer(n=3, analyzer='char')
    X_char = char_ngram.fit_transform(char_docs)

    print(f"Vocabulary ({char_ngram.vocab_size_} character trigrams):")
    for ngram in char_ngram.get_feature_names():
        print(f"  '{ngram}'")

    print("\nCharacter n-gram matrix:")
    print(X_char)

    print("\nExample: 'hello' character trigrams:")
    print("  ['hel', 'ell', 'llo']")

    # ========================================================================
    # Test 5: Binary mode (presence/absence)
    # ========================================================================
    print("\n5. BINARY MODE (binary=True)")
    print("-" * 80)
    docs_repeated = [
        ["cat", "cat", "dog"],
        ["cat", "dog", "dog", "dog"]
    ]

    binary_ngram = NGramVectorizer(n=2, binary=True)
    X_binary = binary_ngram.fit_transform(docs_repeated)

    print("Documents with repeated n-grams:")
    for i, doc in enumerate(docs_repeated):
        print(f"  Doc {i}: {doc}")

    print(f"\nVocabulary: {binary_ngram.get_feature_names()}")
    print("Binary matrix (1 = present, 0 = absent):")
    print(X_binary)

    # ========================================================================
    # Test 6: Vocabulary limiting
    # ========================================================================
    print("\n6. VOCABULARY LIMITING (max_features=3)")
    print("-" * 80)
    limited = NGramVectorizer(n=2, max_features=3)
    X_limited = limited.fit_transform(train_docs)

    print(f"All possible bigrams from training: ~{bigram.vocab_size_}")
    print(f"Limited vocabulary (top 3): {limited.get_feature_names()}")
    print(f"Limited matrix shape: {X_limited.shape}")

    # ========================================================================
    # Test 7: Comparison with Bag of Words
    # ========================================================================
    print("\n7. COMPARISON: BAG OF WORDS vs N-GRAMS")
    print("-" * 80)

    test_doc = ["not", "good"]
    opposite_doc = ["good", "not"]  # Same words, different order

    # Unigrams (equivalent to BoW)
    unigram = NGramVectorizer(n=1)
    unigram.fit([test_doc, opposite_doc])

    vec1 = unigram.transform([test_doc])
    vec2 = unigram.transform([opposite_doc])

    print(f"Document 1: {test_doc}")
    print(f"Document 2: {opposite_doc}")
    print(f"\nUnigrams (BoW) - SAME vectors:")
    print(f"  Doc 1: {vec1[0]}")
    print(f"  Doc 2: {vec2[0]}")
    print(f"  Equal? {torch.equal(vec1, vec2)}")

    # Bigrams (captures order)
    bigram = NGramVectorizer(n=2)
    bigram.fit([test_doc, opposite_doc])

    vec1_bi = bigram.transform([test_doc])
    vec2_bi = bigram.transform([opposite_doc])

    print(f"\nBigrams - DIFFERENT vectors:")
    print(f"  Vocabulary: {bigram.get_feature_names()}")
    print(f"  Doc 1: {vec1_bi[0]}")
    print(f"  Doc 2: {vec2_bi[0]}")
    print(f"  Equal? {torch.equal(vec1_bi, vec2_bi)}")

    # ========================================================================
    # Test 8: Edge cases
    # ========================================================================
    print("\n8. EDGE CASES")
    print("-" * 80)

    # Document shorter than n
    short_docs = [["hi"], ["hello", "world"], ["a", "b", "c"]]
    safe_ngram = NGramVectorizer(n=2)
    X_safe = safe_ngram.fit_transform(short_docs)

    print(f"Documents: {short_docs}")
    print(f"Bigrams vocabulary: {safe_ngram.get_feature_names()}")
    print(f"Matrix (note: 'hi' produces no bigrams):")
    print(X_safe)

    # Empty document
    docs_with_empty = [["hello", "world"], [], ["test"]]
    ngram_empty = NGramVectorizer(n=2)
    X_empty = ngram_empty.fit_transform(docs_with_empty)
    print(f"\nDocuments with empty: {docs_with_empty}")
    print(f"Result matrix:")
    print(X_empty)

    print("\n" + "=" * 80)
    print("All tests completed successfully!")
    print("=" * 80)

N-GRAM VECTORIZER DEMONSTRATION

1. BIGRAMS ONLY (n=2)
--------------------------------------------------------------------------------
Vocabulary (9 bigrams):
  [0] 'and dog'
  [1] 'cat and'
  [2] 'cat sat'
  [3] 'dog sat'
  [4] 'on log'
  [5] 'on mat'
  [6] 'sat on'
  [7] 'the cat'
  [8] 'the dog'

Training matrix shape: torch.Size([3, 9])
Training matrix:
tensor([[0., 0., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 0., 1., 1., 0., 1., 0., 1.],
        [1., 1., 0., 0., 0., 0., 0., 1., 0.]])

Example: Document 'the cat sat on mat'
Bigrams: ['the cat', 'cat sat', 'sat on', 'on mat']
Vector: tensor([0., 0., 1., 0., 0., 1., 1., 1., 0.])

2. UNIGRAMS + BIGRAMS (ngram_range=(1,2))
--------------------------------------------------------------------------------
Vocabulary (17 features):
  Unigrams (8): ['and', 'cat', 'dog', 'log', 'mat']...
  Bigrams (9): ['and dog', 'cat and', 'cat sat', 'dog sat', 'on log']...

Combined matrix shape: torch.Size([3, 17])

3. TRIGRAMS (n=3)
--------------

In [None]:
import torch
import torch.nn as nn
from collections import Counter
from typing import List, Optional, Dict, Tuple
import math

# ============================================================================
# 4. GloVe-style Co-occurrence Matrix - Enhanced Implementation
# ============================================================================


# A co-occurrence matrix captures how often words appear near each other in text.
# It's based on the distributional hypothesis: "You shall know a word by the company it keeps" (J.R. Firth, 1957).
# Core Concept
# Words that appear in similar contexts tend to have similar meanings:

# "cat" and "dog" both appear near: "pet", "sat", "furry"
# "king" and "queen" both appear near: "throne", "crown", "rules"

# Corpus:
#   "the cat sat on mat"
#   "the dog sat on log"

# Window size = 2 (look 2 words left/right)

# For "sat":
#   Context words: [cat, on] (distance 1 each) + [the, mat] (distance 2)

# For "cat":
#   Context words: [the, sat] (distance 1 each) + [on] (distance 2)

# Co-occurrence Matrix:
#            cat  dog  log  mat  on   sat  the
#     cat  [  0    0    0    1   1    1    1  ]
#     dog  [  0    0    1    0   1    1    1  ]
#     log  [  0    1    0    0   1    1    1  ]
#     mat  [  1    0    0    0   1    1    1  ]
#     on   [  1    1    1    1   0    2    2  ]
#     sat  [  1    1    1    1   2    0    2  ]
#     the  [  1    1    1    1   2    2    0  ]

# Reading: Row = center word, Column = context word
# Example: sat[on] = 2 means "on" appears near "sat" 2 times


# Co-occurrence Matrix captures:

# ✅ Distributional semantics (context = meaning)
# ✅ Word associations
# ✅ Semantic similarity

# Improvements in this implementation:

# ✅ Fixed double-counting bug in symmetric mode
# ✅ Multiple weighting schemes
# ✅ PPMI transformation for better quality
# ✅ Vocabulary limiting for scalability
# ✅ Similarity search functionality
# ✅ Proper normalization options

# Limitations:

# ❌ Dense matrix (memory intensive for large vocab)
# ❌ No dimensionality reduction (unlike GloVe embeddings)
# ❌ Linear relationships only (no non-linear patterns)



class CooccurrenceMatrix:
    """
    Build word co-occurrence matrix (foundation of GloVe embeddings).
    Counts how often words appear near each other in a context window.

    Features:
    - Distance-based weighting (closer words = higher weight)
    - Symmetric or directed co-occurrence
    - PPMI transformation option
    - Vocabulary size limiting
    - Efficient computation

    The co-occurrence matrix captures distributional semantics:
    "You shall know a word by the company it keeps" - J.R. Firth

    Example:
        >>> docs = [["cat", "sat", "on", "mat"], ["dog", "sat", "on", "log"]]
        >>> cooc = CooccurrenceMatrix(window_size=2)
        >>> cooc.fit(docs)
        >>> matrix = cooc.get_matrix()
        >>> similar = cooc.most_similar("sat", k=3)
    """

    def __init__(
        self,
        window_size: int = 5,
        weighting: str = 'harmonic',
        symmetric: bool = True,
        max_vocab_size: Optional[int] = None,
        min_count: int = 1,
        normalize: bool = False,
        use_ppmi: bool = False,
        dtype: torch.dtype = torch.float32
    ):
        """
        Initialize Co-occurrence Matrix builder.

        Args:
            window_size: Context window size (words left/right of target)
            weighting: How to weight context words by distance:
                      - 'uniform': All context words weighted equally (1.0)
                      - 'harmonic': 1/distance (GloVe default)
                      - 'distance': 1 - (distance/window_size)
            symmetric: If True, treat (w1, w2) same as (w2, w1)
            max_vocab_size: Limit vocabulary to most frequent words
            min_count: Minimum word frequency to include in vocabulary
            normalize: If True, normalize rows to sum to 1 (probability)
            use_ppmi: If True, apply PPMI transformation (recommended)
            dtype: PyTorch dtype for matrix
        """
        if window_size < 1:
            raise ValueError(f"window_size must be >= 1, got {window_size}")

        if weighting not in ['uniform', 'harmonic', 'distance']:
            raise ValueError(f"weighting must be 'uniform', 'harmonic', or 'distance', got '{weighting}'")

        self.window_size = window_size
        self.weighting = weighting
        self.symmetric = symmetric
        self.max_vocab_size = max_vocab_size
        self.min_count = min_count
        self.normalize = normalize
        self.use_ppmi = use_ppmi
        self.dtype = dtype

        # Will be set during fit()
        self.vocab: Dict[str, int] = {}
        self.vocab_size_: int = 0
        self.word_counts: Counter = Counter()
        self.cooccur: Optional[torch.Tensor] = None
        self.is_fitted_: bool = False

    def _get_weight(self, distance: int) -> float:
        """
        Calculate weight based on distance from center word.

        Args:
            distance: Distance from center word (1, 2, 3, ...)

        Returns:
            Weight value (higher = more important)
        """
        if self.weighting == 'uniform':
            return 1.0
        elif self.weighting == 'harmonic':
            return 1.0 / distance
        elif self.weighting == 'distance':
            return 1.0 - (distance / self.window_size)
        else:
            return 1.0

    def fit(self, documents: List[List[str]]) -> 'CooccurrenceMatrix':
        """
        Build co-occurrence matrix from documents.

        Args:
            documents: List of tokenized documents

        Returns:
            self (for method chaining)
        """
        if not documents:
            raise ValueError("Cannot fit on empty document list")

        if not all(isinstance(doc, (list, tuple)) for doc in documents):
            raise ValueError("All documents must be lists or tuples of tokens")

        # Step 1: Count word frequencies
        for doc in documents:
            self.word_counts.update(doc)

        # Step 2: Filter by minimum count
        valid_words = {
            word for word, count in self.word_counts.items()
            if count >= self.min_count
        }

        # Step 3: Limit vocabulary size if needed
        if self.max_vocab_size is not None and len(valid_words) > self.max_vocab_size:
            most_common = self.word_counts.most_common(self.max_vocab_size)
            valid_words = {word for word, _ in most_common}

        # Step 4: Create vocabulary mapping
        self.vocab = {
            word: idx
            for idx, word in enumerate(sorted(valid_words))
        }
        self.vocab_size_ = len(self.vocab)

        if self.vocab_size_ == 0:
            raise ValueError("No words passed min_count threshold")

        # Step 5: Initialize co-occurrence matrix
        self.cooccur = torch.zeros(
            self.vocab_size_,
            self.vocab_size_,
            dtype=self.dtype
        )

        # Step 6: Count co-occurrences
        for doc in documents:
            # Filter document to only include vocabulary words
            filtered_doc = [word for word in doc if word in self.vocab]

            if len(filtered_doc) < 2:
                continue

            for i, center_word in enumerate(filtered_doc):
                center_idx = self.vocab[center_word]

                # Define context window
                start = max(0, i - self.window_size)
                end = min(len(filtered_doc), i + self.window_size + 1)

                for j in range(start, end):
                    if i == j:  # Skip the center word itself
                        continue

                    context_word = filtered_doc[j]
                    context_idx = self.vocab[context_word]

                    # Calculate weight based on distance
                    distance = abs(i - j)
                    weight = self._get_weight(distance)

                    # Add to co-occurrence matrix
                    # CRITICAL FIX: Only add once, not twice for symmetric
                    self.cooccur[center_idx, context_idx] += weight

                    # For symmetric, we'll symmetrize AFTER counting

        # Step 7: Symmetrize if needed (do this ONCE at the end)
        if self.symmetric:
            self.cooccur = (self.cooccur + self.cooccur.T) / 2.0

        # Step 8: Apply PPMI transformation if requested
        if self.use_ppmi:
            self.cooccur = self._compute_ppmi(self.cooccur)

        # Step 9: Normalize if requested
        if self.normalize:
            row_sums = self.cooccur.sum(dim=1, keepdim=True)
            # Avoid division by zero
            row_sums = torch.where(row_sums > 0, row_sums, torch.ones_like(row_sums))
            self.cooccur = self.cooccur / row_sums

        self.is_fitted_ = True
        return self

    def _compute_ppmi(self, cooccur_matrix: torch.Tensor) -> torch.Tensor:
        """
        Compute Positive Pointwise Mutual Information (PPMI).

        PPMI measures how much more likely two words co-occur than expected by chance.
        PMI(w1, w2) = log(P(w1, w2) / (P(w1) * P(w2)))
        PPMI = max(0, PMI)  # Only keep positive values

        Args:
            cooccur_matrix: Raw co-occurrence counts

        Returns:
            PPMI matrix
        """
        # Total number of co-occurrences
        total = cooccur_matrix.sum()

        if total == 0:
            return cooccur_matrix

        # P(w1, w2) - joint probability
        p_joint = cooccur_matrix / total

        # P(w1) and P(w2) - marginal probabilities
        p_word = cooccur_matrix.sum(dim=1) / total  # Row sums

        # P(w1) * P(w2) - expected probability under independence
        # Outer product: p_word[:, None] * p_word[None, :]
        p_expected = p_word.unsqueeze(1) * p_word.unsqueeze(0)

        # PMI = log(P(w1,w2) / (P(w1) * P(w2)))
        # Avoid log(0) by adding small epsilon
        epsilon = 1e-10
        pmi = torch.log((p_joint + epsilon) / (p_expected + epsilon))

        # PPMI = max(0, PMI)
        ppmi = torch.clamp(pmi, min=0.0)

        return ppmi

    def get_matrix(self) -> torch.Tensor:
        """
        Get the co-occurrence matrix.

        Returns:
            Tensor of shape (vocab_size, vocab_size)
        """
        if not self.is_fitted_:
            raise RuntimeError("CooccurrenceMatrix must be fitted first")
        return self.cooccur

    def get_vocabulary(self) -> List[str]:
        """Get vocabulary words in index order."""
        if not self.is_fitted_:
            raise RuntimeError("CooccurrenceMatrix must be fitted first")
        return [word for word, _ in sorted(self.vocab.items(), key=lambda x: x[1])]

    def most_similar(
        self,
        word: str,
        k: int = 10,
        metric: str = 'cosine'
    ) -> List[Tuple[str, float]]:
        """
        Find most similar words based on co-occurrence patterns.

        Args:
            word: Query word
            k: Number of similar words to return
            metric: Similarity metric ('cosine' or 'correlation')

        Returns:
            List of (word, similarity_score) tuples, sorted by similarity
        """
        if not self.is_fitted_:
            raise RuntimeError("CooccurrenceMatrix must be fitted first")

        if word not in self.vocab:
            raise ValueError(f"Word '{word}' not in vocabulary")

        word_idx = self.vocab[word]
        word_vector = self.cooccur[word_idx]

        if metric == 'cosine':
            # Cosine similarity
            similarities = self._cosine_similarity(word_vector, self.cooccur)
        elif metric == 'correlation':
            # Pearson correlation
            similarities = self._correlation(word_vector, self.cooccur)
        else:
            raise ValueError(f"Unknown metric: {metric}")

        # Get top k (excluding the word itself)
        similarities[word_idx] = -float('inf')  # Exclude self
        top_k_indices = torch.topk(similarities, k=min(k, self.vocab_size_)).indices

        # Convert to word list with scores
        idx_to_word = {idx: word for word, idx in self.vocab.items()}
        results = [
            (idx_to_word[idx.item()], similarities[idx].item())
            for idx in top_k_indices
        ]

        return results

    def _cosine_similarity(
        self,
        vector: torch.Tensor,
        matrix: torch.Tensor
    ) -> torch.Tensor:
        """Compute cosine similarity between vector and all rows of matrix."""
        # Normalize
        vector_norm = vector / (torch.norm(vector) + 1e-10)
        matrix_norm = matrix / (torch.norm(matrix, dim=1, keepdim=True) + 1e-10)

        # Dot product
        similarities = torch.matmul(matrix_norm, vector_norm)
        return similarities

    def _correlation(
        self,
        vector: torch.Tensor,
        matrix: torch.Tensor
    ) -> torch.Tensor:
        """Compute Pearson correlation between vector and all rows of matrix."""
        # Center the data (subtract mean)
        vector_centered = vector - vector.mean()
        matrix_centered = matrix - matrix.mean(dim=1, keepdim=True)

        # Compute correlation
        numerator = torch.matmul(matrix_centered, vector_centered)
        vector_std = torch.sqrt(torch.sum(vector_centered ** 2))
        matrix_std = torch.sqrt(torch.sum(matrix_centered ** 2, dim=1))

        correlations = numerator / (matrix_std * vector_std + 1e-10)
        return correlations

    def get_context_words(
        self,
        word: str,
        k: int = 10
    ) -> List[Tuple[str, float]]:
        """
        Get words that most frequently appear in context of given word.

        Args:
            word: Query word
            k: Number of context words to return

        Returns:
            List of (context_word, co-occurrence_count) tuples
        """
        if not self.is_fitted_:
            raise RuntimeError("CooccurrenceMatrix must be fitted first")

        if word not in self.vocab:
            raise ValueError(f"Word '{word}' not in vocabulary")

        word_idx = self.vocab[word]
        counts = self.cooccur[word_idx]

        # Get top k
        top_k_indices = torch.topk(counts, k=min(k, self.vocab_size_)).indices

        idx_to_word = {idx: word for word, idx in self.vocab.items()}
        results = [
            (idx_to_word[idx.item()], counts[idx].item())
            for idx in top_k_indices
        ]

        return results

    def __repr__(self) -> str:
        """String representation."""
        params = (
            f"window_size={self.window_size}, "
            f"weighting='{self.weighting}', "
            f"symmetric={self.symmetric}, "
            f"use_ppmi={self.use_ppmi}"
        )
        if self.is_fitted_:
            params += f", vocab_size={self.vocab_size_}"
        return f"CooccurrenceMatrix({params})"


# ============================================================================
# DEMONSTRATION & TESTING
# ============================================================================
if __name__ == "__main__":
    print("=" * 80)
    print("CO-OCCURRENCE MATRIX DEMONSTRATION")
    print("=" * 80)

    # Sample documents
    train_docs = [
        ["the", "cat", "sat", "on", "the", "mat"],
        ["the", "dog", "sat", "on", "the", "log"],
        ["the", "cat", "and", "dog", "are", "friends"],
        ["cats", "and", "dogs", "like", "to", "play"],
    ]

    # ========================================================================
    # Test 1: Basic Co-occurrence Matrix
    # ========================================================================
    print("\n1. BASIC CO-OCCURRENCE MATRIX")
    print("-" * 80)

    cooc = CooccurrenceMatrix(window_size=2, weighting='uniform')
    cooc.fit(train_docs)

    print(f"Vocabulary ({cooc.vocab_size_} words): {cooc.get_vocabulary()}")
    print(f"\nCo-occurrence matrix shape: {cooc.get_matrix().shape}")
    print(f"Matrix (first 5x5):")
    print(cooc.get_matrix()[:5, :5])

    # ========================================================================
    # Test 2: Weighting Schemes
    # ========================================================================
    print("\n2. DIFFERENT WEIGHTING SCHEMES")
    print("-" * 80)

    test_doc = [["cat", "sat", "on", "mat"]]

    # Uniform weighting
    cooc_uniform = CooccurrenceMatrix(window_size=2, weighting='uniform')
    cooc_uniform.fit(test_doc)

    # Harmonic weighting (GloVe default)
    cooc_harmonic = CooccurrenceMatrix(window_size=2, weighting='harmonic')
    cooc_harmonic.fit(test_doc)

    # Distance weighting
    cooc_distance = CooccurrenceMatrix(window_size=2, weighting='distance')
    cooc_distance.fit(test_doc)

    print("Document: 'cat sat on mat'")
    print("Context of 'sat': [cat (distance=1), on (distance=1), mat (distance=2)]")
    print()

    cat_idx = cooc_uniform.vocab['cat']
    sat_idx = cooc_uniform.vocab['sat']
    on_idx = cooc_uniform.vocab['on']
    mat_idx = cooc_uniform.vocab['mat']

    print(f"Uniform:  cat-sat={cooc_uniform.cooccur[sat_idx, cat_idx]:.2f}, "
          f"sat-on={cooc_uniform.cooccur[sat_idx, on_idx]:.2f}, "
          f"sat-mat={cooc_uniform.cooccur[sat_idx, mat_idx]:.2f}")

    print(f"Harmonic: cat-sat={cooc_harmonic.cooccur[sat_idx, cat_idx]:.2f}, "
          f"sat-on={cooc_harmonic.cooccur[sat_idx, on_idx]:.2f}, "
          f"sat-mat={cooc_harmonic.cooccur[sat_idx, mat_idx]:.2f}")

    print(f"Distance: cat-sat={cooc_distance.cooccur[sat_idx, cat_idx]:.2f}, "
          f"sat-on={cooc_distance.cooccur[sat_idx, on_idx]:.2f}, "
          f"sat-mat={cooc_distance.cooccur[sat_idx, mat_idx]:.2f}")

    print("\nNote: Harmonic (1/distance) gives more weight to closer words")

    # ========================================================================
    # Test 3: Symmetric vs Directed
    # ========================================================================
    print("\n3. SYMMETRIC vs DIRECTED")
    print("-" * 80)

    cooc_sym = CooccurrenceMatrix(window_size=2, symmetric=True)
    cooc_sym.fit([["cat", "dog", "bird"]])

    cooc_dir = CooccurrenceMatrix(window_size=2, symmetric=False)
    cooc_dir.fit([["cat", "dog", "bird"]])

    cat_idx = cooc_sym.vocab['cat']
    dog_idx = cooc_sym.vocab['dog']

    print("Document: 'cat dog bird'")
    print(f"\nSymmetric: cat-dog = {cooc_sym.cooccur[cat_idx, dog_idx]:.2f}, "
          f"dog-cat = {cooc_sym.cooccur[dog_idx, cat_idx]:.2f}")
    print(f"Directed:  cat-dog = {cooc_dir.cooccur[cat_idx, dog_idx]:.2f}, "
          f"dog-cat = {cooc_dir.cooccur[dog_idx, cat_idx]:.2f}")

    # ========================================================================
    # Test 4: PPMI Transformation
    # ========================================================================
    print("\n4. PPMI TRANSFORMATION")
    print("-" * 80)

    cooc_raw = CooccurrenceMatrix(window_size=2, use_ppmi=False)
    cooc_raw.fit(train_docs)

    cooc_ppmi = CooccurrenceMatrix(window_size=2, use_ppmi=True)
    cooc_ppmi.fit(train_docs)

    print("PPMI reduces effect of common word pairs")
    print(f"\nRaw counts (first 3x3):")
    print(cooc_raw.get_matrix()[:3, :3])
    print(f"\nPPMI values (first 3x3):")
    print(cooc_ppmi.get_matrix()[:3, :3])

    # ========================================================================
    # Test 5: Finding Similar Words
    # ========================================================================
    print("\n5. FINDING SIMILAR WORDS")
    print("-" * 80)

    # Larger corpus for better examples
    larger_docs = [
        ["cat", "sat", "on", "mat"],
        ["cat", "slept", "on", "mat"],
        ["dog", "sat", "on", "log"],
        ["dog", "slept", "on", "log"],
        ["cat", "and", "dog", "are", "pets"],
        ["mat", "and", "log", "are", "objects"],
    ]

    cooc_large = CooccurrenceMatrix(window_size=3, use_ppmi=True)
    cooc_large.fit(larger_docs)

    print("Query: 'cat'")
    similar_to_cat = cooc_large.most_similar("cat", k=3)
    for word, score in similar_to_cat:
        print(f"  {word}: {score:.3f}")

    print("\nQuery: 'sat'")
    similar_to_sat = cooc_large.most_similar("sat", k=3)
    for word, score in similar_to_sat:
        print(f"  {word}: {score:.3f}")

    # ========================================================================
    # Test 6: Context Words
    # ========================================================================
    print("\n6. MOST FREQUENT CONTEXT WORDS")
    print("-" * 80)

    print("Words that appear near 'cat':")
    context_cat = cooc_large.get_context_words("cat", k=5)
    for word, count in context_cat:
        print(f"  {word}: {count:.2f}")

    # ========================================================================
    # Test 7: Vocabulary Limiting
    # ========================================================================
    print("\n7. VOCABULARY LIMITING")
    print("-" * 80)

    cooc_unlimited = CooccurrenceMatrix(window_size=2)
    cooc_unlimited.fit(train_docs)

    cooc_limited = CooccurrenceMatrix(window_size=2, max_vocab_size=5)
    cooc_limited.fit(train_docs)

    print(f"Unlimited vocab: {cooc_unlimited.get_vocabulary()}")
    print(f"Limited vocab (top 5): {cooc_limited.get_vocabulary()}")

    # ========================================================================
    # Test 8: Practical Example - Distributional Semantics
    # ========================================================================
    print("\n8. DISTRIBUTIONAL SEMANTICS EXAMPLE")
    print("-" * 80)
    print("'You shall know a word by the company it keeps' - J.R. Firth")
    print()

    semantic_docs = [
        ["king", "rules", "kingdom", "crown", "throne"],
        ["queen", "rules", "kingdom", "crown", "throne"],
        ["man", "works", "in", "office"],
        ["woman", "works", "in", "office"],
        ["king", "and", "queen", "are", "royalty"],
        ["man", "and", "woman", "are", "people"],
    ]

    cooc_semantic = CooccurrenceMatrix(window_size=3, use_ppmi=True)
    cooc_semantic.fit(semantic_docs)

    print("Similar to 'king':")
    for word, score in cooc_semantic.most_similar("king", k=3):
        print(f"  {word}: {score:.3f}")

    print("\nSimilar to 'man':")
    for word, score in cooc_semantic.most_similar("man", k=3):
        print(f"  {word}: {score:.3f}")

    print("\nNote: Words with similar contexts are semantically related!")

    print("\n" + "=" * 80)
    print("All tests completed successfully!")
    print("=" * 80)

CO-OCCURRENCE MATRIX DEMONSTRATION

1. BASIC CO-OCCURRENCE MATRIX
--------------------------------------------------------------------------------
Vocabulary (15 words): ['and', 'are', 'cat', 'cats', 'dog', 'dogs', 'friends', 'like', 'log', 'mat', 'on', 'play', 'sat', 'the', 'to']

Co-occurrence matrix shape: torch.Size([15, 15])
Matrix (first 5x5):
tensor([[0., 1., 1., 1., 1.],
        [1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0.]])

2. DIFFERENT WEIGHTING SCHEMES
--------------------------------------------------------------------------------
Document: 'cat sat on mat'
Context of 'sat': [cat (distance=1), on (distance=1), mat (distance=2)]

Uniform:  cat-sat=1.00, sat-on=1.00, sat-mat=1.00
Harmonic: cat-sat=1.00, sat-on=1.00, sat-mat=0.50
Distance: cat-sat=0.50, sat-on=0.50, sat-mat=0.00

Note: Harmonic (1/distance) gives more weight to closer words

3. SYMMETRIC vs DIRECTED
---------------------------------------------

In [None]:



# ============================================================================
# 3. CHARACTER N-GRAMS - Subword information
# ============================================================================

class CharNGramVectorizer:
    """
    Extract character-level n-grams
    Useful for: morphology, typos, out-of-vocabulary words
    E.g., "cat" with n=3 → ["<ca", "cat", "at>"]
    """
    def __init__(self, n=3, word_level=True):
        """
        Args:
            n: Character n-gram size
            word_level: If True, add word boundaries <>, else do character stream
        """
        self.n = n
        self.word_level = word_level
        self.vocab = {}

    def _get_char_ngrams(self, text):
        """Extract character n-grams"""
        if self.word_level:
            # Add boundaries for each word
            words = text if isinstance(text, list) else text.split()
            text = ' '.join(f'<{word}>' for word in words)

        ngrams = []
        for i in range(len(text) - self.n + 1):
            ngrams.append(text[i:i+self.n])
        return ngrams

    def fit(self, documents):
        vocab_set = set()
        for doc in documents:
            if isinstance(doc, list):
                doc = ' '.join(doc)
            ngrams = self._get_char_ngrams(doc)
            vocab_set.update(ngrams)

        self.vocab = {ngram: idx for idx, ngram in enumerate(sorted(vocab_set))}
        return self

    def transform(self, documents):
        vocab_size = len(self.vocab)
        char_ngram_matrix = torch.zeros(len(documents), vocab_size)

        for doc_idx, doc in enumerate(documents):
            if isinstance(doc, list):
                doc = ' '.join(doc)
            ngrams = self._get_char_ngrams(doc)
            counts = Counter(ngrams)

            for ngram, count in counts.items():
                if ngram in self.vocab:
                    char_ngram_matrix[doc_idx, self.vocab[ngram]] = count

        return char_ngram_matrix

    def fit_transform(self, documents):
        return self.fit(documents).transform(documents)


# ============================================================================
# 4. GloVe-style Co-occurrence Matrix
# ============================================================================

class CooccurrenceMatrix:
    """
    Build word co-occurrence matrix (foundation of GloVe)
    Counts how often words appear near each other
    """
    def __init__(self, window_size=5, symmetric=True):
        """
        Args:
            window_size: How many words left/right to consider as context
            symmetric: If True, (w1, w2) and (w2, w1) count equally
        """
        self.window_size = window_size
        self.symmetric = symmetric
        self.vocab = {}
        self.cooccur = None

    def fit(self, documents):
        # Build vocabulary
        vocab_set = set()
        for doc in documents:
            vocab_set.update(doc)
        self.vocab = {word: idx for idx, word in enumerate(sorted(vocab_set))}
        vocab_size = len(self.vocab)

        # Initialize co-occurrence matrix
        self.cooccur = torch.zeros(vocab_size, vocab_size)

        # Count co-occurrences
        for doc in documents:
            for i, center_word in enumerate(doc):
                if center_word not in self.vocab:
                    continue
                center_idx = self.vocab[center_word]

                # Look at context window
                start = max(0, i - self.window_size)
                end = min(len(doc), i + self.window_size + 1)

                for j in range(start, end):
                    if i == j:
                        continue

                    context_word = doc[j]
                    if context_word not in self.vocab:
                        continue
                    context_idx = self.vocab[context_word]

                    # Weight by distance (optional)
                    distance = abs(i - j)
                    weight = 1.0 / distance

                    self.cooccur[center_idx, context_idx] += weight

                    if self.symmetric:
                        self.cooccur[context_idx, center_idx] += weight

        return self

    def get_matrix(self):
        """Return the co-occurrence matrix"""
        return self.cooccur


# ============================================================================
# 5. SIMPLE NEURAL BAG OF WORDS (NBoW)
# ============================================================================

class NeuralBagOfWords(nn.Module):
    """
    Learnable word embeddings averaged to represent document
    - Each word gets a learned embedding
    - Document = average of its word embeddings
    - Can be trained end-to-end for classification
    """
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x, lengths=None):
        """
        Args:
            x: [batch_size, max_seq_len] - padded word indices
            lengths: [batch_size] - actual lengths (optional, for masking)
        """
        # Get embeddings
        embedded = self.embedding(x)  # [batch, seq_len, embed_dim]

        # Average pooling (ignoring padding if lengths provided)
        if lengths is not None:
            # Create mask
            mask = torch.arange(x.size(1)).unsqueeze(0) < lengths.unsqueeze(1)
            mask = mask.unsqueeze(2).float()  # [batch, seq_len, 1]

            # Masked average
            embedded = embedded * mask
            doc_embedding = embedded.sum(dim=1) / lengths.unsqueeze(1).float()
        else:
            # Simple average
            doc_embedding = embedded.mean(dim=1)  # [batch, embed_dim]

        # Classification
        logits = self.fc(doc_embedding)  # [batch, num_classes]
        return logits


# ============================================================================
# 6. FASTTEXT-STYLE (Subword embeddings approximation)
# ============================================================================

class SubwordEmbedding:
    """
    Approximate FastText by representing words as sum of character n-grams
    Handles out-of-vocabulary words better than pure word embeddings
    """
    def __init__(self, embedding_dim=100, n=3):
        self.embedding_dim = embedding_dim
        self.n = n
        self.ngram_embeddings = {}

    def _get_char_ngrams(self, word):
        """Extract character n-grams with boundaries"""
        word = f'<{word}>'
        ngrams = [word[i:i+self.n] for i in range(len(word) - self.n + 1)]
        return ngrams

    def fit(self, words):
        """Build vocabulary of character n-grams"""
        ngram_set = set()
        for word in words:
            ngrams = self._get_char_ngrams(word)
            ngram_set.update(ngrams)

        # Initialize random embeddings for each n-gram
        for ngram in ngram_set:
            self.ngram_embeddings[ngram] = torch.randn(self.embedding_dim)

        return self

    def get_word_embedding(self, word):
        """Get word embedding as sum of its character n-gram embeddings"""
        ngrams = self._get_char_ngrams(word)

        embeddings = []
        for ngram in ngrams:
            if ngram in self.ngram_embeddings:
                embeddings.append(self.ngram_embeddings[ngram])

        if len(embeddings) == 0:
            # Fallback for unknown word
            return torch.zeros(self.embedding_dim)

        # Average n-gram embeddings
        return torch.stack(embeddings).mean(dim=0)


# ============================================================================
# EXAMPLE USAGE & COMPARISON
# ============================================================================

if __name__ == "__main__":
    # Sample documents
    documents = [
        ['the', 'cat', 'sat', 'on', 'the', 'mat'],
        ['the', 'dog', 'sat', 'on', 'the', 'log'],
        ['cats', 'and', 'dogs', 'are', 'pets'],
    ]

    print("=" * 70)
    print("1. BAG OF WORDS")
    print("=" * 70)
    bow = BagOfWords()
    bow_matrix = bow.fit_transform(documents)
    print(f"Shape: {bow_matrix.shape}")
    print(f"Doc 1: {bow_matrix[0][:10]}...")  # First 10 dimensions
    print(f"Sparsity: {(bow_matrix == 0).sum().item() / bow_matrix.numel():.2%}")

    print("\n" + "=" * 70)
    print("2. BIGRAMS (2-grams)")
    print("=" * 70)
    bigram = NGramVectorizer(n=2)
    bigram_matrix = bigram.fit_transform(documents)
    print(f"Shape: {bigram_matrix.shape}")
    print(f"Sample bigrams: {list(bigram.vocab.keys())[:5]}")

    print("\n" + "=" * 70)
    print("3. CHARACTER TRIGRAMS")
    print("=" * 70)
    char_trigram = CharNGramVectorizer(n=3)
    char_matrix = char_trigram.fit_transform(documents)
    print(f"Shape: {char_matrix.shape}")
    print(f"Sample char-3grams: {list(char_trigram.vocab.keys())[:10]}")

    print("\n" + "=" * 70)
    print("4. CO-OCCURRENCE MATRIX")
    print("=" * 70)
    cooccur = CooccurrenceMatrix(window_size=2)
    cooccur.fit(documents)
    cooccur_matrix = cooccur.get_matrix()
    print(f"Shape: {cooccur_matrix.shape}")

    # Show co-occurrence for "cat"
    if 'cat' in cooccur.vocab:
        cat_idx = cooccur.vocab['cat']
        print(f"\nWords co-occurring with 'cat':")
        vocab_reverse = {v: k for k, v in cooccur.vocab.items()}
        for idx in torch.topk(cooccur_matrix[cat_idx], k=5).indices:
            word = vocab_reverse[idx.item()]
            count = cooccur_matrix[cat_idx, idx].item()
            print(f"  {word}: {count:.2f}")

    print("\n" + "=" * 70)
    print("5. NEURAL BAG OF WORDS (Architecture only)")
    print("=" * 70)
    vocab_size = 100
    nbow = NeuralBagOfWords(vocab_size=vocab_size, embedding_dim=50, num_classes=3)

    # Dummy input
    dummy_input = torch.randint(0, vocab_size, (2, 10))  # 2 docs, max len 10
    lengths = torch.tensor([6, 8])
    output = nbow(dummy_input, lengths)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")
    print("(This would be trained end-to-end with a classification loss)")

    print("\n" + "=" * 70)
    print("6. SUBWORD EMBEDDINGS (FastText-style)")
    print("=" * 70)
    words = ['cat', 'cats', 'catlike', 'dog']
    subword = SubwordEmbedding(embedding_dim=50, n=3)
    subword.fit(words)

    cat_embed = subword.get_word_embedding('cat')
    cats_embed = subword.get_word_embedding('cats')

    # Similarity between 'cat' and 'cats'
    similarity = F.cosine_similarity(cat_embed.unsqueeze(0), cats_embed.unsqueeze(0))
    print(f"Cosine similarity between 'cat' and 'cats': {similarity.item():.4f}")
    print("(Should be high because they share many character n-grams)")

    # Out-of-vocabulary word
    unknown_embed = subword.get_word_embedding('category')
    sim_unknown = F.cosine_similarity(cat_embed.unsqueeze(0), unknown_embed.unsqueeze(0))
    print(f"Similarity between 'cat' and 'category': {sim_unknown.item():.4f}")
    print("(Non-zero even though 'category' wasn't in training!)")