## Load Data

In [20]:
from google.colab import drive
import os
import pandas as pd
import numpy as np

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Define project paths
project_dir = '/content/drive/MyDrive/4awesome/'
data_dir = '/content/drive/MyDrive/4awesome/Data'

reviews = pd.read_csv(os.path.join(data_dir, 'reviews_with_policy_flags.csv'))

Mounted at /content/drive


## Import Libraries

In [21]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder

## Train/Test Split

In [22]:
train_texts, test_texts, train_labels, test_labels = train_test_split(
    reviews["cleaned_text"].tolist(),
    reviews[["rating", "rating_category_encoded", "policy_ads", "policy_short", "policy_mismatch"]],
    test_size=0.2,
    random_state=42,
    stratify=reviews["rating_category"]  # ensures balance
)

## Tokenizer

In [23]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize(batch_texts, max_len=128):
    return tokenizer(
        batch_texts,
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )

## Dataset Class

In [24]:
class ReviewsDataset(Dataset):
    def __init__(self, texts, labels):
        self.encodings = tokenize(texts)
        self.labels = labels.reset_index(drop=True)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        row = self.labels.iloc[idx]

        item["rating"] = torch.tensor(row["rating"] - 1)  # shift to 0–4
        item["rating_category"] = torch.tensor(row["rating_category_encoded"])
        item["policy_ads"] = torch.tensor(int(row["policy_ads"]), dtype=torch.long)
        item["policy_short"] = torch.tensor(int(row["policy_short"]), dtype=torch.long)
        item["policy_mismatch"] = torch.tensor(int(row["policy_mismatch"]), dtype=torch.long)


        return item

train_dataset = ReviewsDataset(train_texts, train_labels)
test_dataset = ReviewsDataset(test_texts, test_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)


## Multi-Task BERT Model

In [25]:
class MultiTaskBERT(nn.Module):
    def __init__(self, num_categories):
        super(MultiTaskBERT, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        hidden_size = self.bert.config.hidden_size

        # Heads for each task
        self.fc_rating = nn.Linear(hidden_size, 5)  # 5 star ratings
        self.fc_category = nn.Linear(hidden_size, num_categories)  # rating_category
        self.fc_ads = nn.Linear(hidden_size, 2)  # binary
        self.fc_short = nn.Linear(hidden_size, 2)
        self.fc_mismatch = nn.Linear(hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output

        return {
            "rating": self.fc_rating(pooled_output),
            "rating_category": self.fc_category(pooled_output),
            "policy_ads": self.fc_ads(pooled_output),
            "policy_short": self.fc_short(pooled_output),
            "policy_mismatch": self.fc_mismatch(pooled_output)
        }

num_categories = reviews["rating_category"].nunique()
model = MultiTaskBERT(num_categories)

## Evaluation Function

In [26]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch

def evaluate(model, dataloader, device):
    model.eval()

    # Collect predictions & labels for each task
    all_true = {"rating": [], "category": [], "ads": [], "short": [], "mismatch": []}
    all_pred = {"rating": [], "category": [], "ads": [], "short": [], "mismatch": []}

    with torch.no_grad():
        for batch in dataloader:
            # Move inputs to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # ---- Rating ----
            rating_preds = torch.argmax(outputs["rating"], dim=1).cpu()
            all_pred["rating"].extend(rating_preds.tolist())
            all_true["rating"].extend(batch["rating"].cpu().tolist())

            # ---- Category ----
            cat_preds = torch.argmax(outputs["rating_category"], dim=1).cpu()
            all_pred["category"].extend(cat_preds.tolist())
            all_true["category"].extend(batch["rating_category"].cpu().tolist())

            # ---- Policies ----
            for key, label_key in zip(["policy_ads", "policy_short", "policy_mismatch"],
                                      ["ads", "short", "mismatch"]):
                preds = torch.argmax(outputs[key], dim=1).cpu()
                all_pred[label_key].extend(preds.tolist())
                all_true[label_key].extend(batch[key].cpu().tolist())

    # ---- Compute Metrics ----
    results = {}

    # Rating accuracy
    results["rating_accuracy"] = accuracy_score(all_true["rating"], all_pred["rating"])

    # Category accuracy
    results["category_accuracy"] = accuracy_score(all_true["category"], all_pred["category"])

    # Policies: precision, recall, F1
    for key in ["ads", "short", "mismatch"]:
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_true[key], all_pred[key], average="binary", zero_division=0
        )
        results[f"{key}_precision"] = precision
        results[f"{key}_recall"] = recall
        results[f"{key}_f1"] = f1

    return results


## Training + Evaluation Loop

In [28]:
import torch.optim as optim

num_epochs = 5

# Define optimizer, criterion and device
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in train_loader:
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        # Compute multi-task losses
        loss = (
            criterion(outputs["rating"], batch["rating"].to(device)) +
            criterion(outputs["rating_category"], batch["rating_category"].to(device)) +
            criterion(outputs["policy_ads"], batch["policy_ads"].to(device)) +
            criterion(outputs["policy_short"], batch["policy_short"].to(device)) +
            criterion(outputs["policy_mismatch"], batch["policy_mismatch"].to(device))
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} | Training Loss: {avg_loss:.4f}")

    # Evaluate on validation set at the end of each epoch
    val_results = evaluate(model, test_loader, device)
    print("Validation results:", val_results)
    print("-" * 50)

Epoch 1/5 | Training Loss: 2.6922
Validation results: {'rating_accuracy': 0.45794392523364486, 'category_accuracy': 0.35046728971962615, 'ads_precision': 0.0, 'ads_recall': 0.0, 'ads_f1': 0.0, 'short_precision': 0.9130434782608695, 'short_recall': 0.84, 'short_f1': 0.875, 'mismatch_precision': 0.0, 'mismatch_recall': 0.0, 'mismatch_f1': 0.0}
--------------------------------------------------
Epoch 2/5 | Training Loss: 2.5238
Validation results: {'rating_accuracy': 0.4719626168224299, 'category_accuracy': 0.3364485981308411, 'ads_precision': 0.0, 'ads_recall': 0.0, 'ads_f1': 0.0, 'short_precision': 0.9230769230769231, 'short_recall': 0.96, 'short_f1': 0.9411764705882353, 'mismatch_precision': 0.0, 'mismatch_recall': 0.0, 'mismatch_f1': 0.0}
--------------------------------------------------
Epoch 3/5 | Training Loss: 2.3644
Validation results: {'rating_accuracy': 0.46261682242990654, 'category_accuracy': 0.3317757009345794, 'ads_precision': 0.0, 'ads_recall': 0.0, 'ads_f1': 0.0, 'short_