## PyTorch Training

Uses the Trainer included in Hugging Face `transformers` (backed by `accelerate`) since it mitigates a lot of annoying boilerplate.


In [1]:
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from transformers import Trainer, TrainingArguments, ModernBertConfig, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = (
    pl.scan_parquet(
        "/Users/maxwoolf/Downloads/movie_data_plus_embeds_all.parquet", n_rows=10000
    )
    .select(["tconst", "averageRating", "json"])
    .with_columns(averageRating=pl.col("averageRating").cast(pl.Float32))
    .collect()
)

df

tconst,averageRating,json
str,f32,str
"""tt0000009""",5.4,"""{  ""title"": ""Miss Jerry"",  ""…"
"""tt0000147""",5.3,"""{  ""title"": ""The Corbett-Fitz…"
"""tt0000574""",6.0,"""{  ""title"": ""The Story of the…"
"""tt0000591""",5.6,"""{  ""title"": ""The Prodigal Son…"
"""tt0000630""",3.2,"""{  ""title"": ""Hamlet"",  ""genr…"
…,…,…
"""tt0035474""",6.9,"""{  ""title"": ""True to the Army…"
"""tt0035475""",6.6,"""{  ""title"": ""Trysil-Knut"",  …"
"""tt0035477""",6.0,"""{  ""title"": ""The Tuttles of T…"
"""tt0035479""",5.4,"""{  ""title"": ""Twin Beds"",  ""g…"


## Train a Custom Tokenizer

Use uses the `modernbert` tokenizer as a base (since it has useful special tokens), just reduce the vocabulary significantly and tailor it to this specific dataset.


In [3]:
from transformers import AutoTokenizer

json_docs = df["json"].to_list()

base_tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
print(str(base_tokenizer.vocab)[0:100])
print(len(base_tokenizer(json_docs[0])["input_ids"]))

{'kets': 43846, 'áĢº': 33160, 'ĠDowntown': 46827, 're': 250, 'ĠTan': 22188, 'Ġsinus': 22749, 'Ð¸Ñħ':
169


In [4]:
vocab_size = 5000

tokenizer = base_tokenizer.train_new_from_iterator(
    iter(json_docs), vocab_size=vocab_size
)

print(str(tokenizer.vocab)[0:100])
print(len(tokenizer(json_docs[0])["input_ids"]))




{'ĠPearson': 4183, 'ĠEm': 1047, 'ĠKendall': 3814, 'oll': 910, 'Isabel': 2855, 'ĠRawlinson': 3588, 'b
139


Preencode all the tokens. Set max length to `1024` to be safe.


In [5]:
max_length = 1024

tokens = tokenizer(json_docs, max_length=max_length, padding="max_length")
len(tokens["input_ids"])

10000

In [6]:
input_lengths = [len(x) for x in tokens["input_ids"]]
max(input_lengths)

1024

In [7]:
device = "cpu"

tensor_input_ids = torch.Tensor(tokens["input_ids"]).int().to(device)
tensor_attention_mask = torch.Tensor(tokens["attention_mask"]).int().to(device)
tensor_ratings = torch.from_numpy(df["averageRating"].to_numpy().copy()).to(device)
tensor_dataset = TensorDataset(tensor_input_ids, tensor_attention_mask, tensor_ratings)

In [8]:
test_proportion = 0.05

train_dataset, test_dataset = torch.utils.data.random_split(
    tensor_dataset, [1 - test_proportion, test_proportion]
)

## Build the Model

Due to the new tokenizer, the special tokens for the fresh ModernBERT model have to be explicitly defined.


In [9]:
special_token_dict = dict(
    zip(tokenizer.special_tokens_map.keys(), tokenizer.all_special_ids)
)
special_token_dict

{'unk_token': 2,
 'sep_token': 4,
 'pad_token': 5,
 'cls_token': 3,
 'mask_token': 6}

In [10]:
config = ModernBertConfig(
    vocab_size=vocab_size,
    max_position_embeddings=max_length,
    embedding_size=256,
    hidden_size=256,
    intermediate_size=768,
    num_hidden_layers=8,
    num_attention_heads=4,
    unk_token_id=special_token_dict["unk_token"],
    sep_token_id=special_token_dict["sep_token"],
    pad_token_id=special_token_dict["pad_token"],
    cls_token_id=special_token_dict["cls_token"],
    mask_token_id=special_token_dict["mask_token"],
)

transformer_model = AutoModel.from_config(config)
total_params = sum(p.numel() for p in transformer_model.parameters())
total_params

8100096

In [11]:
class RatingsModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.transformer_model = model
        self.output = nn.Linear(256, 1)

    def forward(self, input_ids, attention_mask, targets=None):
        x = self.transformer_model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        x = x.last_hidden_state[:, 0]  # the [CLS] vector
        x = self.output(x)

        return x.squeeze()  # return 1D output

In [12]:
model = RatingsModel(transformer_model)
_ = model.to(device)
model

RatingsModel(
  (transformer_model): ModernBertModel(
    (embeddings): ModernBertEmbeddings(
      (tok_embeddings): Embedding(5000, 256, padding_idx=5)
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (layers): ModuleList(
      (0): ModernBertEncoderLayer(
        (attn_norm): Identity()
        (attn): ModernBertAttention(
          (Wqkv): Linear(in_features=256, out_features=768, bias=False)
          (rotary_emb): ModernBertRotaryEmbedding()
          (Wo): Linear(in_features=256, out_features=256, bias=False)
          (out_drop): Identity()
        )
        (mlp_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): ModernBertMLP(
          (Wi): Linear(in_features=256, out_features=1536, bias=False)
          (act): GELUActivation()
          (drop): Dropout(p=0.0, inplace=False)
          (Wo): Linear(in_features=768, out_features=256, bias=False)
        )
      )
      (1-7): 7 x

Validation loss doesn't play nice with the `Trainer` out of the boss, so need [some tweaks](https://discuss.huggingface.co/t/no-log-for-validation-loss-during-training-with-trainer/40094/3).


In [13]:
def collate_fn(examples):
    input_ids = torch.stack([f[0] for f in examples])
    attention_masks = torch.stack([f[1] for f in examples])
    targets = torch.stack([f[2] for f in examples])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "targets": targets,
    }


class MAETrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=0):
        outputs = model(**inputs)
        loss = nn.L1Loss()(outputs, inputs["targets"])  # L1 loss is MAE

        return (loss, outputs) if return_outputs else loss


In [14]:
training_args = TrainingArguments(
    learning_rate=1e-3,
    lr_scheduler_type="cosine_with_restarts",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    weight_decay=0.001,
    save_strategy="no",
    eval_strategy="steps",
    eval_steps=0.05,
    logging_strategy="steps",
    logging_steps=0.05,
    fp16=False,
    dataloader_num_workers=0,  # since data is in memory
    dataloader_pin_memory=False,
    dataloader_persistent_workers=False,
)

# reinstantiate a clean model
transformer_model = AutoModel.from_config(config)
model = RatingsModel(transformer_model)
_ = model.to(device)

trainer = MAETrainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
)

trainer.can_return_loss = True

In [15]:
trainer.train()

Step,Training Loss,Validation Loss
119,0.7315,0.652724
238,0.6437,0.658383
357,0.6821,0.694192
476,0.6517,0.687195
595,0.6234,0.66377
714,0.6913,0.657295
833,0.563,0.65648
952,0.6197,0.650622
1071,0.6783,0.653892
1190,0.6171,0.66526


TrainOutput(global_step=2375, training_loss=0.6285669041683799, metrics={'train_runtime': 597.1282, 'train_samples_per_second': 15.909, 'train_steps_per_second': 3.977, 'total_flos': 0.0, 'train_loss': 0.6285669041683799, 'epoch': 1.0})

In [16]:
trainer.evaluate(test_dataset)

{'eval_loss': 0.6347675919532776,
 'eval_runtime': 7.2476,
 'eval_samples_per_second': 68.988,
 'eval_steps_per_second': 17.247,
 'epoch': 1.0}

## Test Model


In [20]:
_ = model.to(device)
eval_dataset = test_dataset[0:10]

with torch.no_grad():
    output = model(input_ids=eval_dataset[0], attention_mask=eval_dataset[1])
    preds = output.detach().cpu()

pl.DataFrame({"Predicted": preds, "Actual": eval_dataset[2]}).with_columns(
    abs_diff=(pl.col("Predicted") - pl.col("Actual")).abs().round(2)
)

Predicted,Actual,abs_diff
f32,f32,f32
6.495729,6.7,0.2
6.369858,6.7,0.33
6.074883,7.1,1.03
6.30686,8.1,1.79
6.118845,6.0,0.12
6.196302,5.1,1.1
5.905068,5.6,0.31
6.30401,5.4,0.9
6.25199,5.7,0.55
6.122697,5.8,0.32
