In [None]:
# Install required packages (run this first in Colab)
!pip install --upgrade transformers torch pandas scikit-learn numpy
!pip install accelerate  # Helps with training stability

import pandas as pd
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    Trainer,
    TrainingArguments,
    EvalPrediction,
)
from torch.utils.data import WeightedRandomSampler
from google.colab import files, drive
import warnings
import os

# Disable wandb completely
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

warnings.filterwarnings("ignore")

# --- Option 1: Upload CSV file directly ---
print("Choose your data loading method:")
print("1. Upload file directly")
print("2. Load from Google Drive")
choice = input("Enter 1 or 2: ")

if choice == "1":
    # Upload file directly to Colab
    print("Please upload your CSV file:")
    uploaded = files.upload()
    csv_file_path = list(uploaded.keys())[0]
elif choice == "2":
    # Mount Google Drive
    drive.mount('/content/drive')
    # Update this path to your file location in Google Drive
    csv_file_path = "/content/drive/My Drive/reviews_with_labels_completed.csv"
    print(f"Loading from: {csv_file_path}")

# --- 1. Load and Prepare Dataset ---
df = pd.read_csv(csv_file_path)

# First, let's inspect the CSV structure
print("=== CSV INSPECTION ===")
print(f"Dataset shape: {df.shape}")
print(f"Column names: {list(df.columns)}")
print("\nFirst few rows:")
print(df.head())

# Let's identify the text column and label columns
text_column = None
potential_text_cols = ['text', 'review', 'comment', 'content', 'message']
for col in potential_text_cols:
    if col in df.columns:
        text_column = col
        break

if text_column is None:
    print(f"\nAvailable columns: {list(df.columns)}")
    text_column = input("Enter the name of the text column: ").strip()

print(f"Using '{text_column}' as text column")

# Identify label columns (exclude text column and any ID columns)
exclude_cols = [text_column, 'id', 'ID', 'index', 'Unnamed: 0']
potential_label_columns = [col for col in df.columns if col not in exclude_cols]

print(f"\nPotential label columns found: {potential_label_columns}")

# Check if these are binary columns (0/1 values)
binary_columns = []
for col in potential_label_columns:
    unique_values = df[col].dropna().unique()
    if len(unique_values) <= 2 and all(val in [0, 1, 0.0, 1.0, True, False] for val in unique_values):
        binary_columns.append(col)

if binary_columns:
    print(f"Binary label columns detected: {binary_columns}")
    label_columns = binary_columns
else:
    print("No binary columns found. Please specify label columns manually:")
    print("Available columns:", [col for col in df.columns if col != text_column])
    manual_labels = input("Enter label column names separated by commas: ").strip()
    label_columns = [col.strip() for col in manual_labels.split(',')]

print(f"\nUsing label columns: {label_columns}")

# Filter valid text entries
df = df[df[text_column].apply(lambda x: isinstance(x, str) and len(str(x).strip()) > 0)]

# Convert labels to list of floats
labels = df[label_columns].values.tolist()
df["labels"] = labels
df = df[[text_column, "labels"]]
df.rename(columns={text_column: "text"}, inplace=True)

print(f"\nFinal dataset shape: {df.shape}")
print("Class distribution:")
for i, col in enumerate(label_columns):
    count = sum([row[i] for row in labels])
    print(f"{col}: {int(count)}")

# --- 2. Calculate Class Weights Based on Actual Distribution ---
# Calculate class counts from the actual data
class_counts = {}
for i, col in enumerate(label_columns):
    count = sum([row[i] for row in labels])
    class_counts[col] = count

# Calculate inverse frequency weights
total_samples = len(df)
class_weights = []

print("\n--- Calculated Class Weights ---")
for label in label_columns:
    count = class_counts[label]
    if count > 0:
        # Weight = total_samples / (num_classes * class_count)
        weight = total_samples / (len(label_columns) * count)
    else:
        # Handle case where class has no positive samples
        weight = 1.0
    class_weights.append(weight)
    print(f"{label}: {weight:.3f} (count: {count})")

pos_weight_tensor = torch.tensor(class_weights, dtype=torch.float)

# --- 3. Split Data ---
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
train_texts, train_labels = train_df["text"].tolist(), train_df["labels"].tolist()
val_texts, val_labels = val_df["text"].tolist(), val_df["labels"].tolist()

print(f"\nTraining samples: {len(train_texts)}")
print(f"Validation samples: {len(val_texts)}")

# --- 4. Tokenizer ---
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
# SPEED OPTIMIZATION: Reduce max_length from 512 to 256
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=256)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=256)

# --- 5. Dataset Class ---
class ReviewDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = ReviewDataset(train_encodings, train_labels)
val_dataset = ReviewDataset(val_encodings, val_labels)

# --- 6. Skip Oversampling for Speed ---
# REMOVED: WeightedRandomSampler (was causing 10x more steps)
# The weighted loss function will handle class imbalance instead
print("Skipping oversampling for faster training - weighted loss will handle class imbalance")

# --- 7. Custom Trainer with Weighted Loss ---
class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor.to(logits.device))
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

# --- 8. Define Model ---
id2label = {i: label for i, label in enumerate(label_columns)}
label2id = {label: i for i, label in enumerate(label_columns)}

model = DistilBertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(label_columns),
    problem_type="multi_label_classification",
    id2label=id2label,
    label2id=label2id,
)

# --- 9. Metrics Function ---
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    probs = torch.sigmoid(torch.Tensor(preds))
    y_pred = (probs >= 0.5).int().numpy()
    y_true = p.label_ids

    print("\n--- Classification Report ---")
    print(classification_report(y_true, y_pred, target_names=label_columns, zero_division=0))

    metrics = {
        "f1_micro": f1_score(y_true, y_pred, average="micro"),
        "f1_macro": f1_score(y_true, y_pred, average="macro")
    }

    per_class_f1 = f1_score(y_true, y_pred, average=None)
    for i, lbl in enumerate(label_columns):
        metrics[f"f1_{lbl.replace(' ','_').replace('/','_')}"] = per_class_f1[i]

    return metrics

# --- 10. Training Arguments ---
# Choose training intensity level
print("Choose training intensity:")
print("1. Quick test (2 epochs, less logging)")
print("2. Balanced (3 epochs, moderate logging)")
print("3. Thorough (5 epochs, detailed logging)")

intensity = input("Enter 1, 2, or 3 (default: 2): ").strip() or "2"

if intensity == "1":
    epochs = 2
    logging_steps = 200
    warmup_steps = 100
    print("Quick training selected - good for testing!")
elif intensity == "2":
    epochs = 3
    logging_steps = 100
    warmup_steps = 200
    print("Balanced training selected - recommended for most cases!")
else:
    epochs = 5
    logging_steps = 50
    warmup_steps = 500
    print("Thorough training selected - best results but takes longer!")

training_args = TrainingArguments(
    output_dir="/content/results_multilabel_weighted_final",
    num_train_epochs=epochs,
    learning_rate=2e-5,
    per_device_train_batch_size=8,  # Adjust based on GPU memory
    per_device_eval_batch_size=16,
    warmup_steps=warmup_steps,
    weight_decay=0.01,
    logging_dir="/content/logs_multilabel_weighted_final",
    logging_steps=logging_steps,
    eval_strategy="epoch",  # Updated parameter name
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    report_to=[],  # Empty list to disable all logging
    dataloader_num_workers=0,  # Important for Colab compatibility
    disable_tqdm=False,  # Keep progress bars
)

# --- 11. Trainer Initialization ---
trainer = MultilabelTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=None,  # Use default data collator
    # NOTE: Removed train_sampler for speed - weighted loss still handles class imbalance
)

# --- 12. Start Training ---
print("\nStarting weighted multi-label training with oversampling...")
print("This may take some time depending on your dataset size...")

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

trainer.train()
print("Training complete!")

# --- 13. Save Model ---
output_model_dir = "/content/review-classification-model-multilabel-weighted-final"
trainer.save_model(output_model_dir)
tokenizer.save_pretrained(output_model_dir)
print(f"Weighted multi-label model saved to: {output_model_dir}")

# --- 14. Save to Google Drive (Optional) ---
save_to_drive = input("Do you want to save the model to Google Drive? (y/n): ")
if save_to_drive.lower() == 'y':
    if choice != "2":  # Mount drive if not already mounted
        drive.mount('/content/drive')

    import shutil
    drive_model_path = "/content/drive/My Drive/review-classification-model-multilabel-weighted-final"
    shutil.copytree(output_model_dir, drive_model_path, dirs_exist_ok=True)
    print(f"Model also saved to Google Drive: {drive_model_path}")

print("\n--- Training Summary ---")
print("Model training completed successfully!")
print("You can now use this model for inference on new review data.")

Choose your data loading method:
1. Upload file directly
2. Load from Google Drive
Enter 1 or 2: 1
Please upload your CSV file:


Saving reviews_with_labels_completed.csv to reviews_with_labels_completed.csv
=== CSV INSPECTION ===
Dataset shape: (12648, 6)
Column names: ['text', 'Advertisement', 'Irrelevant Content', 'Rant Without Visit', 'Spam', 'Useful Review']

First few rows:
                                                text  Advertisement  \
0                        Nice place and good service              0   
1  BEST hair experience i’ve ever had! my hair ca...              0   
2                           Always makes a great sub              0   
3  I love a few places in here. I love going to t...              0   
4  These guys are the best. Fair with labor and p...              0   

   Irrelevant Content  Rant Without Visit  Spam  Useful Review  
0                   0                   0     0              1  
1                   0                   0     0              1  
2                   0                   0     0              1  
3                   0                   0     0             

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Skipping oversampling for faster training - weighted loss will handle class imbalance


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Choose training intensity:
1. Quick test (2 epochs, less logging)
2. Balanced (3 epochs, moderate logging)
3. Thorough (5 epochs, detailed logging)
Enter 1, 2, or 3 (default: 2): 2
Balanced training selected - recommended for most cases!

Starting weighted multi-label training with oversampling...
This may take some time depending on your dataset size...
Using device: cuda


Epoch,Training Loss,Validation Loss,F1 Micro,F1 Macro,F1 Advertisement,F1 Irrelevant Content,F1 Rant Without Visit,F1 Spam,F1 Useful Review
1,0.1294,0.126579,0.920127,0.881772,0.986607,0.939314,0.572165,0.969849,0.940926
2,0.1137,0.132989,0.925097,0.879701,0.986607,0.936842,0.552279,0.974747,0.948031
3,0.1047,0.130993,0.922393,0.884359,0.986547,0.937824,0.567901,0.984694,0.944828



--- Classification Report ---
                    precision    recall  f1-score   support

     Advertisement       0.98      0.99      0.99       223
Irrelevant Content       0.97      0.91      0.94       196
Rant Without Visit       0.85      0.43      0.57       257
              Spam       0.96      0.98      0.97       196
     Useful Review       0.92      0.96      0.94      1840

         micro avg       0.93      0.91      0.92      2712
         macro avg       0.94      0.86      0.88      2712
      weighted avg       0.93      0.91      0.91      2712
       samples avg       0.90      0.90      0.90      2712


--- Classification Report ---
                    precision    recall  f1-score   support

     Advertisement       0.98      0.99      0.99       223
Irrelevant Content       0.97      0.91      0.94       196
Rant Without Visit       0.89      0.40      0.55       257
              Spam       0.96      0.98      0.97       196
     Useful Review       0.92     