In [1]:
from datasets import DatasetDict, Dataset
from dataclasses import dataclass
import math, re, numpy as np
from typing import Dict, Any, Optional, List

import torch
import torch.nn as nn
from torch.nn import functional as F

from transformers import (
    AutoTokenizer, AutoModel,
    PreTrainedModel, PretrainedConfig,
    TrainingArguments, Trainer, DataCollatorWithPadding
)

from datasets import Dataset, load_dataset
from sklearn.metrics import cohen_kappa_score, mean_absolute_error
import pandas as pd
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate

# ---------- Config ----------
MODEL_NAME = "google-bert/bert-base-uncased"
MAX_LEN = 384
NUM_BINS = 10                      # 0.5, 1.0, ..., 5.0  => 10 bins
BIN_VALUES = np.arange(0.5, 5.0 + 0.5, 0.5)  # [0.5, 1.0, ..., 5.0]

def rating_to_bin(r: float) -> int:
    # map 0.5→0, 1.0→1, ..., 5.0→9
    return int(round((r - 0.5) / 0.5))

def bin_to_rating(b: int) -> float:
    return 0.5 + 0.5 * b

def class_to_cumulative_targets(y: torch.Tensor, num_bins: int) -> torch.Tensor:
    # For class c, targets for thresholds k=0..K-2 are 1 if c > k else 0
    # y: (B,) long
    B = y.size(0)
    k = torch.arange(num_bins - 1, device=y.device).unsqueeze(0).expand(B, -1)
    return (y.unsqueeze(1) > k).float()  # (B, K-1)

def preprocess(ex):
    enc = tokenizer(ex["text"], truncation=True, max_length=MAX_LEN)
    # map rating -> class 0..9
    label = rating_to_bin(float(ex["rating"]))
    enc["labels"] = label
    return enc

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_set = pd.read_csv('./data/letterboxd_250movie_reviews_train.csv')
val_set  = pd.read_csv('./data/letterboxd_250movie_reviews_val.csv')
test_set  = pd.read_csv('./data/letterboxd_250movie_reviews_test.csv')

dataset = DatasetDict({
    'train': Dataset.from_pandas(train_set),
    'validation': Dataset.from_pandas(val_set),
    'test': Dataset.from_pandas(test_set)
})

dataset = {k: v.map(preprocess, remove_columns=v.column_names) for k,v in dataset.items()}
# dataset = dataset.map(preprocess, batched=True)

# ---------- CORAL Model ----------
class CoralConfig(PretrainedConfig):
    model_type = "coral"
    def __init__(self, base_model_name=MODEL_NAME, num_bins=NUM_BINS, **kwargs):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_bins = num_bins

class CoralForOrdinalRegression(PreTrainedModel):
    config_class = CoralConfig

    def __init__(self, config: CoralConfig):
        super().__init__(config)
        self.encoder = AutoModel.from_pretrained(config.base_model_name)
        hidden = self.encoder.config.hidden_size

        # Shared weight vector w (d->1), CORAL: logit_k = w^T h + b_k
        self.shared_linear = nn.Linear(hidden, 1, bias=False)
        self.thresholds = nn.Parameter(torch.zeros(config.num_bins - 1))
        self.dropout = nn.Dropout(getattr(self.encoder.config, "hidden_dropout_prob", 0.1))

        self.post_init()

    def forward(
        self,
        input_ids=None, attention_mask=None, token_type_ids=None,
        labels: Optional[torch.LongTensor] = None
    ):
        # Get [CLS]-like pooled representation (use first token)
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # Use mean pooling of last hidden state → often more stable than CLS for some models
        last = out.last_hidden_state  # (B, T, H)
        mask = attention_mask.unsqueeze(-1)  # (B, T, 1)
        pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
        pooled = self.dropout(pooled)

        s = self.shared_linear(pooled).squeeze(-1)                    # (B,)
        logits = s.unsqueeze(1) + self.thresholds.unsqueeze(0)        # (B, K-1)

        loss = None
        if labels is not None:
            targets = class_to_cumulative_targets(labels, self.config.num_bins)  # (B, K-1)
            # BCEWithLogits over all thresholds
            loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="mean")

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )

    @torch.no_grad()
    def predict_classes(self, input_ids, attention_mask, token_type_ids=None, threshold: float = 0.5):
        out = self.forward(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        probs = torch.sigmoid(out["logits"])           # (B, K-1)
        # predicted class = count of thresholds passed (p_k > 0.5)
        return (probs > threshold).sum(dim=1)          # (B,)

accuracy = evaluate.load("accuracy")

# ---------- Metrics ----------
def compute_metrics(eval_pred):
    # eval_pred.predictions is (B, K-1) logits
    logits, labels = eval_pred
    probs = 1 / (1 + np.exp(-logits))
    preds_class = (probs > 0.5).sum(axis=1)  # 0..9
    true_class = labels

    # Map to half-star ratings for MAE
    preds_rating = np.array([bin_to_rating(int(c)) for c in preds_class])
    true_rating  = np.array([bin_to_rating(int(c)) for c in true_class])

    qwk = cohen_kappa_score(true_class, preds_class, weights="quadratic")
    mae = mean_absolute_error(true_rating, preds_rating)
    acc = np.round(accuracy.compute(predictions=preds_rating, references=true_rating)['accuracy'],3)
    return {"qwk": qwk, "mae": mae, "acc": acc}

# ---------- Train ----------
collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = CoralForOrdinalRegression(CoralConfig())

# Freeze all base model params
for name, param in model.encoder.named_parameters():
    param.requires_grad = False

# Unfreeze last encoder layer (optional, for better adaptation)
# for name, param in model.encoder.named_parameters():
#     if "encoder.layer.11" in name:
#         param.requires_grad = True

for name, param in model.encoder.named_parameters():
    if "pooler" in name:
        param.requires_grad = True

# Unfreeze CORAL head and thresholds
model.shared_linear.weight.requires_grad = True
model.thresholds.requires_grad = True

# Print trainable status for all parameters
for name, param in model.named_parameters():
    print(name, param.requires_grad)

Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

thresholds True
encoder.embeddings.word_embeddings.weight False
encoder.embeddings.position_embeddings.weight False
encoder.embeddings.token_type_embeddings.weight False
encoder.embeddings.LayerNorm.weight False
encoder.embeddings.LayerNorm.bias False
encoder.encoder.layer.0.attention.self.query.weight False
encoder.encoder.layer.0.attention.self.query.bias False
encoder.encoder.layer.0.attention.self.key.weight False
encoder.encoder.layer.0.attention.self.key.bias False
encoder.encoder.layer.0.attention.self.value.weight False
encoder.encoder.layer.0.attention.self.value.bias False
encoder.encoder.layer.0.attention.output.dense.weight False
encoder.encoder.layer.0.attention.output.dense.bias False
encoder.encoder.layer.0.attention.output.LayerNorm.weight False
encoder.encoder.layer.0.attention.output.LayerNorm.bias False
encoder.encoder.layer.0.intermediate.dense.weight False
encoder.encoder.layer.0.intermediate.dense.bias False
encoder.encoder.layer.0.output.dense.weight False
encode

In [2]:
# hyperparameters
lr = 2e-4
batch_size = 16
num_epochs = 10

args = TrainingArguments(
    output_dir="./bert-letterbox-reviews-classifier_teacher",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    fp16=torch.cuda.is_available(),
    save_strategy="epoch",
    eval_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="qwk",
    greater_is_better=True,
    logging_steps=50,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=collator,
    compute_metrics=compute_metrics,
)

trainer.train()



Epoch,Training Loss,Validation Loss,Qwk,Mae,Acc
1,0.6807,0.679369,0.247753,1.76,0.155
2,0.6656,0.667801,0.35997,1.555,0.185
3,0.6535,0.656538,0.411651,1.42,0.22
4,0.6445,0.651159,0.416425,1.4225,0.195
5,0.6354,0.646139,0.426262,1.395,0.215
6,0.6362,0.6416,0.438493,1.39,0.19
7,0.6285,0.639162,0.447676,1.3575,0.22
8,0.6315,0.63623,0.450721,1.3675,0.19
9,0.6226,0.635494,0.45391,1.36,0.2
10,0.6296,0.635195,0.449754,1.3625,0.205




TrainOutput(global_step=1000, training_loss=0.6453559551239013, metrics={'train_runtime': 564.8947, 'train_samples_per_second': 28.324, 'train_steps_per_second': 1.77, 'total_flos': 3136198358378592.0, 'train_loss': 0.6453559551239013, 'epoch': 10.0})

In [3]:
# apply model to validation dataset
predictions = trainer.predict(dataset["test"])

# Extract the logits and labels from the predictions object
logits = predictions.predictions
labels = predictions.label_ids

# Use your compute_metrics function
metrics = compute_metrics((logits, labels))
print(metrics)



{'qwk': 0.5639322133133833, 'mae': 1.2325, 'acc': np.float64(0.235)}
