<a href="https://colab.research.google.com/github/janbanot/msc-project/blob/main/test_notebooks/msc_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade transformers datasets captum quantus accelerate

In [None]:
from google.colab import drive
drive.mount('/drive')

In [None]:
import pandas as pd

csv_path = '/drive/MyDrive/msc-project/jigsaw-toxic-comment/train.csv'
try:
    df = pd.read_csv(csv_path)
    print("CSV file loaded successfully!")
    display(df.head())
except FileNotFoundError:
    print(f"Error: The file was not found at {csv_path}")
except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
import re

def clean_text(example):
    """Applies all cleaning steps to the 'comment_text' field."""

    # 1. Get the text
    text = example['comment_text']

    # 2. Lowercasing
    # This is crucial for "uncased" BERT models
    text = text.lower()

    # 3. Remove URLs
    # re.sub finds a pattern and replaces it
    # r'http\S+' finds 'http' followed by any non-space characters
    text = re.sub(r'http\S+|www\S+', '', text)

    # 4. Remove IP Addresses
    # \d{1,3} means "a digit, 1-to-3 times". \. means "a literal dot".
    text = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', '', text)

    # 5. Remove Wikipedia metadata like (talk), timestamps, etc.
    # This is a simple regex to find things like (talk)
    # You could make this more complex, but this is a good start.
    text = re.sub(r'\(talk\)', '', text)
    text = re.sub(r'\d{2}:\d{2}, \w+ \d{1,2}, \d{4} \(utc\)', '', text)

    # 6. Remove newlines and other special characters
    text = text.replace('\n', ' ')
    text = text.replace('\xa0', ' ')

    # 7. Remove any text inside double quotes at the start/end
    # This removes things like '"\n\n ' from the beginning
    text = text.strip(' "')

    # 8. Clean up whitespace
    # \s+ means "one or more space characters"
    # We replace any group of spaces with a single space
    text = re.sub(r'\s+', ' ', text).strip()

    # 9. Update the example
    example['comment_text'] = text
    return example

In [None]:
import datasets

train_df = df.head(2000)
data = datasets.Dataset.from_pandas(train_df)

In [None]:
print("\nCleaning data...")
cleaned_data = data.map(clean_text)
print("Data cleaned!")

In [None]:
print("\n--- BEFORE CLEANING ---")
print(data[1]['comment_text'])
print("\n" + data[6]['comment_text'])
print("\n" + data[0]['comment_text'])

print("\n\n--- AFTER CLEANING ---")
print(cleaned_data[1]['comment_text'])
print("\n" + cleaned_data[6]['comment_text'])
print("\n" + cleaned_data[0]['comment_text'])

In [None]:
from transformers import AutoTokenizer

# This is the "model card" for the model
# 'uncased' matches the .lower() step we did earlier.
model_checkpoint = "distilbert-base-uncased"

try:
    # This downloads and caches the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    print("Tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading tokenizer: {e}")

In [None]:
def tokenize_function(examples):
    """Applies the tokenizer to a batch of text."""

    # This is the main tokenization step.
    # padding="max_length" fills short comments with [PAD] tokens.
    # truncation=True cuts off comments that are too long.
    # max_length=256 is a good balance of speed and context for comments.
    # You could use 512 (DistilBERT's max) but it's slower.
    return tokenizer(
        examples["comment_text"],
        padding="max_length",
        truncation=True,
        max_length=256
    )

# Apply the function with .map()
# batched=True makes it MUCH faster by tokenizing many texts at once.
print("\nTokenizing data...")
tokenized_data = cleaned_data.map(tokenize_function, batched=True)
print("Data tokenized!")

In [None]:
print("\n--- Example of a Tokenized Entry ---")
print(tokenized_data[0])

In [None]:
import numpy as np

# 1. Define your label columns in the correct order
label_columns = [
    'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'
]

def create_labels_column(example):
    """
    Creates a new 'labels' column by combining the 6 label columns.
    We convert them to float32, which is what ML models expect.
    """
    # For each example, build a list of its label values
    labels_list = [float(example[col]) for col in label_columns]
    example['labels'] = labels_list
    return example

# 2. Apply the function
print("\nConsolidating labels...")
final_data = tokenized_data.map(create_labels_column)
print("Labels consolidated!")

# 3. Let's see the result for a toxic comment
print("\n--- Example of a Processed Entry ---")
print(final_data[6])

In [None]:
# 1. List all columns we want to remove
columns_to_remove = [
    'id', 'comment_text', 'toxic', 'severe_toxic',
    'obscene', 'threat', 'insult', 'identity_hate'
]

# The '_' at the end means it modifies the dataset "in-place"
print(f"\nOriginal columns: {final_data.column_names}")
final_data = final_data.remove_columns(columns_to_remove)
print(f"Cleaned columns: {final_data.column_names}")

# 2. Set the dataset format to "torch" (for PyTorch)
try:
    final_data.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    print("\nDataset format set to 'torch'!")
except ImportError:
    print("\nPyTorch not installed. Skipping .set_format('torch').")
    print("Please install with: pip install torch")

# Let's check the final, final output
print("\n--- Final, Model-Ready Item ---")
print(final_data[6])

In [None]:
from transformers import AutoModelForSequenceClassification

num_labels = 6 # 6 toxic categories

# Load the model, configuring it for multi-label classification
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

print("Model loaded successfully!")
print("Model configured for multi-label classification.")

In [None]:
data_splits = final_data.train_test_split(test_size=0.2, seed=42)

train_dataset = data_splits['train']
eval_dataset = data_splits['test']

print(f"\nData split complete:")
print(f"Training samples: {len(train_dataset)}")
print(f"Evaluation samples: {len(eval_dataset)}")

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch

def compute_metrics(p: EvalPrediction):
    # p.predictions are the raw logit outputs
    # p.label_ids are the true labels

    # First, apply sigmoid to logits to get probabilities
    logits = p.predictions
    probs = 1 / (1 + np.exp(-logits)) # Sigmoid function

    # Next, set a threshold (0.5) to get binary predictions
    threshold = 0.5
    predictions = (probs > threshold).astype(int)

    # Now, compute the metrics
    labels = p.label_ids

    # We'll use 'micro' averaging, which is good for imbalanced labels
    f1_micro = f1_score(labels, predictions, average='micro')

    # This measures how many individual labels (out of 6*num_samples) were correct
    overall_accuracy = accuracy_score(labels.flatten(), predictions.flatten())

    # Return metrics as a dictionary
    return {
        'f1_micro': f1_micro,
        'accuracy': overall_accuracy
    }

In [None]:
from transformers import TrainingArguments

model_output_dir = "/drive/MyDrive/msc-project/models/distilbert-jigsaw-finetuned"


training_args = TrainingArguments(
    output_dir=model_output_dir,
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    # helps prevent overfitting
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model="f1_micro",
    # DISABLE WANDB
    report_to="none",
)

In [None]:
from transformers import Trainer

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # We pass this so it can create batches correctly
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

print("\n--- Starting Training ---")
# This one line does all the work!
trainer.train()
print("--- Training Complete ---")