In [1]:
import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification

# prefer GPU if available, otherwise fall back to CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "vinai/bertweet-large"

# tokenizer/model setup for regression
# use_fast=False is recommended for BERTweet
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
CONFIG = AutoConfig.from_pretrained(
    MODEL_NAME,
    num_labels=1,
    problem_type="regression",
)
MODEL = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, config=CONFIG).to(DEVICE) 

  from .autonotebook import tqdm as notebook_tqdm
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at vinai/bertweet-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
# 1) Load data + compute controversy ratio (quote_count / (quote + retweet))
df = pd.read_csv("Cleaned_messi_tweets.csv")
df["engagement_count"] = df["retweet_count"] + df["quote_count"]
df = df[df["engagement_count"] > 0].copy()
df["ratio"] = df["quote_count"] / df["engagement_count"]

# limit to the first 10,000 rows per the request
MAX_ROWS = 10_000
df = df.head(MAX_ROWS).reset_index(drop=True)
print(f"Using {len(df)} rows after filtering and limiting to first {MAX_ROWS}.")

# 2) Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    df["content"], df["ratio"], test_size=0.2, random_state=42
)

class ControversyDataset(Dataset):
    def __init__(self, texts, scores, tokenizer, max_len=64):
        self.texts = list(texts)
        self.scores = list(scores)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            str(self.texts[idx]),
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.scores[idx], dtype=torch.float),
        }

train_ds = ControversyDataset(X_train, y_train, TOKENIZER)
test_ds = ControversyDataset(X_test, y_test, TOKENIZER)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False)

# 3) Brief training (keep small for CPU-only runs)
optimizer = torch.optim.AdamW(MODEL.parameters(), lr=2e-5)
NUM_EPOCHS = 1
MODEL.train()
for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    for batch in train_loader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        optimizer.zero_grad()
        outputs = MODEL(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} average loss: {avg_loss:.4f}")

# 4) Evaluate MSE on the hold-out set
MODEL.eval()
preds, targets = [], []
with torch.no_grad():
    for batch in test_loader:
        labels = batch["labels"].to(DEVICE)
        inputs = {k: v.to(DEVICE) for k, v in batch.items() if k != "labels"}
        outputs = MODEL(**inputs)
        scores = torch.sigmoid(outputs.logits.squeeze(-1))
        preds.extend(scores.cpu().numpy().tolist())
        targets.extend(labels.cpu().numpy().tolist())

mse = mean_squared_error(targets, preds)
print(f"MSE on {len(targets)} samples: {mse:.4f}")

Using 10000 rows after filtering and limiting to first 10000.
Epoch 1 average loss: 0.1107
MSE on 2000 samples: 0.2393
