In [1]:
import yaml
import multiprocessing

import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning.pytorch as pl
from torch.optim import AdamW
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from transformers import AutoModel, AutoTokenizer
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk

In [2]:
num_cores_avail = max(1, multiprocessing.cpu_count() - 1)

In [3]:
with open("../experiments/configs/pitchfork_cls/main.yaml", 'r') as f:
    main_config = yaml.safe_load(f)

In [4]:
dataset_checkpoint = main_config["dataset_checkpoint"]
dataset_checkpoint_revision = main_config["dataset_checkpoint_revision"]
model_checkpoint = main_config["model_checkpoint"]
model_checkpoint_revision = main_config["model_checkpoint_revision"]

In [5]:
embedding_model = AutoModel.from_pretrained(
    model_checkpoint,
    revision=model_checkpoint_revision
)

tokenizer = AutoTokenizer.from_pretrained(
    model_checkpoint,
    revision=model_checkpoint_revision
)

datasets = load_from_disk("../data/pitchfork/dataset_dbr/")

In [6]:
keeper_cols = ["artist", "album", "year_released", "rating", "input_ids", "attention_mask"]
drop_cols = set(datasets["train"].column_names).difference(set(keeper_cols))

In [7]:
tokenized_datasets = (
    datasets
        .map(lambda examples: tokenizer(examples["review"], padding=True, truncation=True), batched=True, num_proc=num_cores_avail)
        .remove_columns(drop_cols)
)

In [8]:
def collate_reviews(batch):
    # Extract input_ids and labels from the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    ratings = [item['rating'] for item in batch]

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    ratings = torch.tensor(ratings)

    return input_ids, attention_masks, ratings

In [9]:
class TextRegressor(nn.Module):
    def __init__(self, embedder, embed_dim, output_dim=1):
        super().__init__()
        
        # Initialize the encoder (e.g., DistilBERT, BERT, etc.)
        self.embedder = embedder
        
        # Regression head
        self.regression_head = nn.Linear(embed_dim, output_dim)
        
    def forward(self, input_ids, attention_mask):
        # Forward pass through encoder
        embedding = self.embedder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Extract the [CLS] embedding
        embedding = embedding.last_hidden_state[:, 0, :]
        
        # Forward pass through regression head
        yhat = self.regression_head(embedding)
        
        return yhat


class LitTextRegressor(pl.LightningModule):
    def __init__(self, text_regressor):
        super().__init__()
        self.text_regressor = text_regressor
        # Loss
        self.criterion = F.mse_loss

    def forward(self, input_ids, attention_mask):
        yhat = self.text_regressor(input_ids=input_ids, attention_mask=attention_mask)
        return yhat

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, ratings = batch
        yhat = self.text_regressor(input_ids=input_ids, attention_mask=attention_mask)

        loss = self.criterion(yhat, ratings.unsqueeze(1))
        self.log("avg_train_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, ratings = batch
        yhat = self.text_regressor(input_ids=input_ids, attention_mask=attention_mask)

        loss = self.criterion(yhat, ratings.unsqueeze(1))
        self.log("avg_val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)

    def configure_optimizers(self):
        no_wd_parameters = ["word_embeddings", "position_embeddings"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.text_regressor.named_parameters() if any(excl in n for excl in no_wd_parameters)],
                "weight_decay": 0.0,
            },
            {
                "params": [p for n, p in self.text_regressor.named_parameters() if all(excl not in n for excl in no_wd_parameters)],
                "weight_decay": 0.01,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=1e-3)
        return optimizer
        
    def freeze_pretrained_model(self):
        # TODO: re-build optimizers after freezing/un-freezing parameters
        for param in self.text_regressor.embedder.parameters():
            param.requires_grad = False

    def unfreeze_pretrained_model(self):
        # TODO: re-build optimizers after freezing/un-freezing parameters
        for param in self.text_regressor.embedder.parameters():
            param.requires_grad = True

In [10]:
epochs = 10
batch_size = 64
accelerator = "gpu" if torch.cuda.is_available() else "cpu"

In [11]:
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=batch_size, collate_fn=collate_reviews, shuffle=True)
valid_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=batch_size, collate_fn=collate_reviews)

In [12]:
for batch_idx, batch in enumerate(valid_dataloader):
    break

In [13]:
input_ids, attention_masks, ratings = batch

In [14]:
input_ids

tensor([[  101, 27166, 13146,  ...,   112,   187,   102],
        [  101, 10657,   117,  ..., 13028,   112,   102],
        [  101,   107, 11065,  ..., 10111, 18850,   102],
        ...,
        [  101, 14600, 21213,  ..., 12592, 33944,   102],
        [  101, 12242, 10151,  ..., 56445,   119,   102],
        [  101, 12613, 10105,  ..., 92153, 10146,   102]])

In [15]:
tokenizer.decode([101])

'[CLS]'

In [16]:
tr_model = TextRegressor(
    embedding_model,
    embed_dim=embedding_model.config.dim
)
lit_model = LitTextRegressor(tr_model)
# TODO: fine-tune, then un-freeze
lit_model.freeze_pretrained_model()

In [17]:
with torch.no_grad():
    lit_model.eval()
    yhat = lit_model(input_ids=input_ids, attention_mask=attention_masks)

In [18]:
callbacks = [RichProgressBar()]
loggers = [CSVLogger(".", name="lightning_logs"), TensorBoardLogger(".", name="tb_logs")]
trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator=accelerator,
    callbacks=callbacks,
    precision=16,
    logger=loggers
)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [19]:
trainer.fit(
    model=lit_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

`Trainer.fit` stopped: `max_epochs=10` reached.
