[<< Previous: Data Exploration](02_data_explore.ipynb) &nbsp; | &nbsp; [Next: Fit Analysis >>](04_fit_analysis.ipynb)

In [1]:
import numpy as np
import yaml
import multiprocessing
import gc
import warnings

import torch
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import RichProgressBar, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
from transformers import AutoModel, AutoTokenizer
from datasets import load_from_disk

from myutilpy.models.text_regressor import TextRegressor, LitTextRegressor
from myutilpy.models.pooling import pooling_fns, pool_cls, pool_mean

In [2]:
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*Detected KeyboardInterrupt, attempting graceful shutdown.*")

In [3]:
num_cores_avail = max(1, multiprocessing.cpu_count() - 1)

In [4]:
config_id = "mlml6_rate_pred_clsp"

In [5]:
with open(f"../experiments/configs/{config_id}/main.yaml", 'r') as f:
    main_config = yaml.safe_load(f)

with open(f"../experiments/configs/{config_id}/model.yaml", 'r') as f:
    model_config = yaml.safe_load(f)

In [6]:
dataset_checkpoint = main_config["dataset_checkpoint"]
dataset_checkpoint_revision = main_config["dataset_checkpoint_revision"]
pt_model_checkpoint = main_config["pt_model_checkpoint"]
pt_model_checkpoint_revision = main_config["pt_model_checkpoint_revision"]
dataset_id = main_config["dataset_id"]
frozen_model_checkpoint_path = model_config["frozen_model_checkpoint_path"]
finetune_model_checkpoint_path = model_config["finetune_model_checkpoint_path"]

In [7]:
embedding_model = AutoModel.from_pretrained(
    pt_model_checkpoint,
    revision=pt_model_checkpoint_revision
)

tokenizer = AutoTokenizer.from_pretrained(
    pt_model_checkpoint,
    revision=pt_model_checkpoint_revision
)

datasets = load_from_disk(f"../data/pitchfork/{dataset_id}/dataset")

In [8]:
keeper_cols = ["artist", "album", "year_released", "rating", "input_ids", "attention_mask"]
drop_cols = set(datasets["train"].column_names).difference(set(keeper_cols))

In [9]:
tokenized_datasets = (
    datasets
        .map(lambda examples: tokenizer(examples["review"], padding=True, truncation=True), batched=True, num_proc=num_cores_avail)
        .remove_columns(drop_cols)
)

In [10]:
frozen_epochs = 20
finetune_epochs = 10
batch_size = 16
accelerator = "gpu" if torch.cuda.is_available() else "cpu"

In [11]:
def collate_reviews(batch):
    # Extract input_ids and labels from the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    ratings = [item['rating'] for item in batch]

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    ratings = torch.tensor(ratings)

    return input_ids, attention_masks, ratings

In [12]:
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=batch_size, collate_fn=collate_reviews, shuffle=True)
valid_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=batch_size, collate_fn=collate_reviews)
test_dataloader = DataLoader(tokenized_datasets["test"], batch_size=batch_size, collate_fn=collate_reviews)

# # Random subsets for quick development
# train_dataloader = DataLoader(tokenized_datasets["train"].shuffle(seed=42).select(range(1500)), batch_size=batch_size, collate_fn=collate_reviews, shuffle=True)
# valid_dataloader = DataLoader(tokenized_datasets["validation"].shuffle(seed=42).select(range(1500)), batch_size=batch_size, collate_fn=collate_reviews)
# test_dataloader = DataLoader(tokenized_datasets["test"].shuffle(seed=42).select(range(1500)), batch_size=batch_size, collate_fn=collate_reviews)

# Quick checks

In [13]:
for batch_idx, batch in enumerate(valid_dataloader):
    input_ids, attention_masks, ratings = batch
    break

In [14]:
with torch.no_grad():
    embedding = embedding_model(input_ids=input_ids, attention_mask=attention_masks).last_hidden_state
mp_embedding = pool_mean(embedding, attention_masks)
cp_embedding = pool_cls(embedding, attention_masks)

In [15]:
print(mp_embedding.shape)
print(cp_embedding.shape)

torch.Size([16, 384])
torch.Size([16, 384])


# Setup

In [16]:
results_base = f"../experiments/results/{config_id}"

csv_logger = CSVLogger(results_base, "frozen_lightning_logs")
tb_logger = TensorBoardLogger(results_base, name="frozen_tb_logs")
frozen_model_checkpointer = ModelCheckpoint(
    f"{results_base}/frozen_checkpoints/version_{csv_logger.version}",
    filename="checkpoint",
    monitor="avg_val_loss",
    mode="min",
    save_top_k=1
)

loggers = [csv_logger, tb_logger]
callbacks = [frozen_model_checkpointer, RichProgressBar()]

if frozen_model_checkpoint_path is not None:
    checkpoint = torch.load(f"../{frozen_model_checkpoint_path}")
    # Account for zero indexing
    frozen_epochs = checkpoint["epoch"] + 1

frozen_trainer = pl.Trainer(
    max_epochs=frozen_epochs,
    accelerator=accelerator,
    callbacks=callbacks,
    precision="16-mixed",
    logger=loggers
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
tr_model = TextRegressor(
    embedding_model,
    embed_dim=embedding_model.config.hidden_size,
    pooling_fn=pooling_fns[model_config["pooling"]]
)

# Frozen training

In [18]:
lit_model = LitTextRegressor(tr_model)
lit_model.freeze_pretrained_model()

In [19]:
ModelSummary(lit_model)

  | Name           | Type          | Params
-------------------------------------------------
0 | text_regressor | TextRegressor | 22.7 M
-------------------------------------------------
385       Trainable params
22.7 M    Non-trainable params
22.7 M    Total params
90.854    Total estimated model params size (MB)

In [20]:
if frozen_model_checkpoint_path is not None:
    print(f"Loading checkpoint from: {frozen_model_checkpoint_path}")
    frozen_trainer.fit(
        model=lit_model,
        ckpt_path=f"../{frozen_model_checkpoint_path}",
        train_dataloaders=train_dataloader,
        val_dataloaders=valid_dataloader
    )
else:
    print(f"Training from scratch")
    frozen_trainer.fit(
        model=lit_model,
        train_dataloaders=train_dataloader,
        val_dataloaders=valid_dataloader
    )

Restoring states from the checkpoint path at ../experiments/results/mlml6_rate_pred_clsp/frozen_checkpoints/version_0/checkpoint.ckpt


Loading checkpoint from: experiments/results/mlml6_rate_pred_clsp/frozen_checkpoints/version_0/checkpoint.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Restored all states from the checkpoint at ../experiments/results/mlml6_rate_pred_clsp/frozen_checkpoints/version_0/checkpoint.ckpt


Output()

`Trainer.fit` stopped: `max_epochs=20` reached.


# Unfrozen fine-tuning

In [21]:
# Free GPU memory
del lit_model
gc.collect()
torch.cuda.empty_cache()

In [22]:
frozen_trainer.checkpoint_callback.best_model_path

'/home/carcook/dev/nlp-projects/experiments/results/mlml6_rate_pred_clsp/frozen_checkpoints/version_0/checkpoint.ckpt'

In [23]:
lit_model = LitTextRegressor.load_from_checkpoint(
    frozen_trainer.checkpoint_callback.best_model_path,
    # "../experiments/results/pitchfork_cls/frozen_checkpoints/version_0/checkpoint.ckpt",
    text_regressor = tr_model
)

In [24]:
lit_model.unfreeze_pretrained_model()

In [25]:
ModelSummary(lit_model)

  | Name           | Type          | Params
-------------------------------------------------
0 | text_regressor | TextRegressor | 22.7 M
-------------------------------------------------
22.7 M    Trainable params
0         Non-trainable params
22.7 M    Total params
90.854    Total estimated model params size (MB)

In [26]:
config_id

'mlml6_rate_pred_clsp'

In [27]:
results_base = f"../experiments/results/{config_id}"
csv_logger = CSVLogger(results_base, "finetune_lightning_logs")
tb_logger = TensorBoardLogger(results_base, name="finetune_tb_logs")
finetune_model_checkpointer = ModelCheckpoint(
    f"{results_base}/finetune_checkpoints/version_{csv_logger.version}",
    filename="finetune_checkpoint",
    monitor="avg_val_loss",
    mode="min",
    save_top_k=1
)

loggers = [csv_logger, tb_logger]
callbacks = [finetune_model_checkpointer, RichProgressBar()]

if finetune_model_checkpoint_path is not None:
    checkpoint = torch.load(f"../{finetune_model_checkpoint_path}")
    # Account for zero indexing
    finetune_epochs = checkpoint["epoch"] + 1

finetune_trainer = pl.Trainer(
    max_epochs=finetune_epochs,
    accelerator=accelerator,
    callbacks=callbacks,
    precision="16-mixed",
    logger=loggers
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
if finetune_model_checkpoint_path is not None:
    print(f"Loading checkpoint from: {frozen_model_checkpoint_path}")
    finetune_trainer.fit(
        model=lit_model,
        ckpt_path=f"../{finetune_model_checkpoint_path}",
        train_dataloaders=train_dataloader,
        val_dataloaders=valid_dataloader
    )
else:
    print(f"Training from scratch")
    finetune_trainer.fit(
        model=lit_model,
        train_dataloaders=train_dataloader,
        val_dataloaders=valid_dataloader
    )

Restoring states from the checkpoint path at ../experiments/results/mlml6_rate_pred_clsp/finetune_checkpoints/version_0/finetune_checkpoint.ckpt


Loading checkpoint from: experiments/results/mlml6_rate_pred_clsp/frozen_checkpoints/version_0/checkpoint.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Restored all states from the checkpoint at ../experiments/results/mlml6_rate_pred_clsp/finetune_checkpoints/version_0/finetune_checkpoint.ckpt


Output()

`Trainer.fit` stopped: `max_epochs=3` reached.


In [29]:
# Using test set instead of val set since val data were (somewhat)
# used for model selection by helping make some design choices (e.g.,
# pooling strategy) and in determining when overfitting occurred (e.g.,
# deciding when to overwrite checkpoints).
best_ft_checkpoint_path = finetune_trainer.checkpoint_callback.best_model_path
finetune_trainer.test(lit_model, dataloaders=test_dataloader, ckpt_path="best")

Restoring states from the checkpoint path at /home/carcook/dev/nlp-projects/experiments/results/mlml6_rate_pred_clsp/finetune_checkpoints/version_0/finetune_checkpoint.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/carcook/dev/nlp-projects/experiments/results/mlml6_rate_pred_clsp/finetune_checkpoints/version_0/finetune_checkpoint.ckpt


Output()

[{'avg_test_loss': 0.8739522099494934}]

In [30]:
print(lit_model.test_epoch_metrics)

{'mse': 0.8739522, 'rmse': 0.9348541}


[<< Previous: Data Exploration](02_data_explore.ipynb) &nbsp; | &nbsp; [Next: Fit Analysis >>](04_fit_analysis.ipynb)