In [5]:
import csv
import random
from collections import defaultdict

In [10]:
reviews = []
ratings = []

with open('reviews.csv', 'r') as f:
  reader = csv.reader(f)
  next(reader)
  for row in reader:
    reviews.append(row[0])
    ratings.append(int(row[1]))


In [12]:
categories = []
for rating in ratings:
  if rating <= 2:
    categories.append('bad')
  elif rating == 3:
    categories.append('neutral')
  else:
    categories.append('good')


In [15]:
combined = list(zip(reviews, categories))
random.shuffle(combined)
n = len(combined)

In [20]:
train_data = combined[: int(0.7 * n)]
val_data = combined[int(0.7 * n): int(0.8 * n)]
test_data = combined[int(0.8 * n):]

# 4. Unpack the splits
train_reviews = [x[0] for x in train_data]
train_cats = [x[1] for x in train_data]
val_reviews = [x[0] for x in val_data]
val_cats = [x[1] for x in val_data]
test_reviews = [x[0] for x in test_data]
test_cats = [x[1] for x in test_data]

In [21]:
for i in range(10):
  print(test_data[i])

('Highly recommended. Pull bring position several everyone our return. Bill little citizen particular exist vote.', 'good')
('Highly recommended. Compare federal support sound. Of ready else.', 'good')
('Unsatisfied. Anyone since yard. Authority spend enter treatment rise medical consider age.', 'bad')
('Fine for the price. Most relationship worry general treat girl threat. Check detail draw deal.', 'neutral')
('Amazing product. Behind story discussion good in word. Bed answer read.', 'good')
('Exceeded expectations. Agreement throw job school born yes by decide. Teacher team food improve capital economic deep.', 'good')
('Worst ever. According coach product recent matter Congress once. Former conference within poor physical whether group.', 'bad')
('Totally disappointed. And since whom spend try oil require. Main project some commercial write smile black.', 'bad')
('Amazing product. Individual west plant staff quickly support. Machine product news what stock.', 'good')
('Not worth it.

In [19]:
word_counts = {'bad' : defaultdict(int),
               'neutral' : defaultdict(int),
               'good' : defaultdict(int)}
class_counts = {'bad' : 0, 'neutral' : 0, 'good' : 0}
vocab = set()

In [25]:
for review, category in zip(train_reviews, train_cats):
  class_counts[category] += 1
  words = review.lower().split()
  for word in words:
    word_counts[category][word] += 1
    vocab.add(word)

total_train = len(train_reviews)

In [26]:
 # 6. Make predictions on validation set
val_correct = 0
for review, true_cat in zip(val_reviews, val_cats):
    words = review.lower().split()
    best_category = None
    best_score = -1

    for category in ['bad', 'neutral', 'good']:
        prior = class_counts[category] / total_train
        score = prior

        total_words_in_category = sum(word_counts[category].values())
        for word in words:
            count = word_counts[category].get(word, 0) + 1
            prob = count / (total_words_in_category + len(vocab))
            score *= prob

        if score > best_score:
            best_score = score
            best_category = category

    if best_category == true_cat:
        val_correct += 1

val_accuracy = val_correct / len(val_cats) if val_cats else 0

In [27]:
print(val_accuracy)

0.998


In [28]:
# 7. Make predictions on test set
test_correct = 0
for review, true_cat in zip(test_reviews, test_cats):
    words = review.lower().split()
    best_category = None
    best_score = -1

    for category in ['bad', 'neutral', 'good']:
        prior = class_counts[category] / total_train
        score = prior

        total_words_in_category = sum(word_counts[category].values())
        for word in words:
            count = word_counts[category].get(word, 0) + 1
            prob = count / (total_words_in_category + len(vocab))
            score *= prob

        if score > best_score:
            best_score = score
            best_category = category

    if best_category == true_cat:
        test_correct += 1

test_accuracy = test_correct / len(test_cats) if test_cats else 0

In [29]:
print(test_accuracy)

0.9976
