In [None]:
import math
from collections import Counter

class KNNTextClassifier:
    def __init__(self, k=3):
        self.k = k
        self.vocab = set()  # Vocabulary of all words
        self.train_data = []  # List of (vectorized_text, label)

    def preprocess(self, text):
        """Check if the input is a string, then lowercase and split."""
        if isinstance(text, str):
            # If text is a string, preprocess it
            words = text.lower().split()
        elif isinstance(text, list):
            # If it's already a list (e.g., tokenized text), no need to split
            words = [word.lower() for word in text]
        else:
            raise ValueError("Input should be a string or a list of words")
        return words

    def build_vocab(self, X_train):
        """Builds vocabulary from training texts."""
        for text in X_train:
            words = self.preprocess(text)
            self.vocab.update(words)  # Add words to vocabulary

    def vectorize(self, text):
        """Converts a text into a word frequency vector."""
        words = self.preprocess(text)
        vector = {}

        for word in self.vocab:
            vector[word] = words.count(word)  # Count occurrences of each word in vocabulary

        return vector

    def fit(self, X_train, y_train):
        """Trains the classifier by storing vectorized texts."""
        self.build_vocab(X_train)  # Step 1: Build vocabulary

        for text, label in zip(X_train, y_train):
            processed_text = self.preprocess(text)  # Step 2: Preprocess text
            vectorized_text = self.vectorize(processed_text)  # Step 3: Convert text to vector
            self.train_data.append((vectorized_text, label))  # Step 4: Store (vector, label)

    def euclidean_distance(self, vec1, vec2):
        """Computes Euclidean distance between two vectors."""
        sum_squared_diff = 0

        for word in self.vocab:
            diff = vec1[word] - vec2[word]  # Difference for this word
            squared_diff = diff ** 2  # Square it
            sum_squared_diff += squared_diff  # Accumulate

        return math.sqrt(sum_squared_diff)  # Take square root

    def predict(self, X_test):
        """Predicts labels for a list of test texts."""
        predictions = []

        for text in X_test:
            test_vector = self.vectorize(text)  # Convert test text to vector

            distances = []  # Store (distance, label) pairs

            for train_vector, label in self.train_data:
                distance = self.euclidean_distance(test_vector, train_vector)  # Compute distance
                distances.append((distance, label))  # Store (distance, label)

            distances.sort()  # Sort by distance (smallest first)

            k_nearest_labels = []

            for i in range(self.k):
                k_nearest_labels.append(distances[i][1])  # Extract labels of k-nearest neighbors

            most_common_label = Counter(k_nearest_labels).most_common(1)[0][0]  # Get most frequent label

            predictions.append(most_common_label)  # Store predicted label

        return predictions

# ---- Testing the Classifier ----
X_train = ["I love cats", "Dogs are great", "Cats are better than dogs", "I adore my dog"]
y_train = ["positive", "positive", "negative", "positive"]

X_test = ["I love my dog", "Cats are the best"]

knn = KNNTextClassifier(k=3)
knn.fit(X_train, y_train)

predictions = knn.predict(X_test)
print(predictions)  # Output: ['positive', 'negative']

['positive', 'positive']


In [5]:
knn.train_data

[({'are': 0,
   'love': 1,
   'i': 1,
   'dog': 0,
   'than': 0,
   'dogs': 0,
   'better': 0,
   'cats': 1,
   'my': 0,
   'great': 0,
   'adore': 0},
  'positive'),
 ({'are': 1,
   'love': 0,
   'i': 0,
   'dog': 0,
   'than': 0,
   'dogs': 1,
   'better': 0,
   'cats': 0,
   'my': 0,
   'great': 1,
   'adore': 0},
  'positive'),
 ({'are': 1,
   'love': 0,
   'i': 0,
   'dog': 0,
   'than': 1,
   'dogs': 1,
   'better': 1,
   'cats': 1,
   'my': 0,
   'great': 0,
   'adore': 0},
  'negative'),
 ({'are': 0,
   'love': 0,
   'i': 1,
   'dog': 1,
   'than': 0,
   'dogs': 0,
   'better': 0,
   'cats': 0,
   'my': 1,
   'great': 0,
   'adore': 1},
  'positive')]