## PyTorch Training

Uses the Trainer included in Hugging Face `transformers` (backed by `accelerate`) since it mitigates a lot of annoying boilerplate.


In [1]:
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from transformers import Trainer, TrainingArguments, ModernBertConfig, AutoModel
from tqdm import tqdm

In [2]:
df = (
    pl.scan_parquet(
        "movie_data_plus_embeds_all.parquet",
    )
    .select(["tconst", "averageRating", "json"])
    .with_columns(averageRating=pl.col("averageRating").cast(pl.Float32))
    .collect()
    .sample(fraction=1.0, shuffle=True, seed=42)
)

df

tconst,averageRating,json
str,f32,str
"""tt0173052""",4.1,"""{  ""title"": ""The Prince and t…"
"""tt0266288""",7.4,"""{  ""title"": ""Azhakiya Ravanan…"
"""tt6263490""",4.3,"""{  ""title"": ""Getaway"",  ""gen…"
"""tt10049110""",7.8,"""{  ""title"": ""Die Wiese"",  ""g…"
"""tt5761612""",3.8,"""{  ""title"": ""Woman on the Edg…"
…,…,…
"""tt0079376""",6.2,"""{  ""title"": ""The Proud Twins""…"
"""tt1161064""",3.2,"""{  ""title"": ""Super Capers: Th…"
"""tt0179526""",5.7,"""{  ""title"": ""Who's the Caboos…"
"""tt0188233""",5.7,"""{  ""title"": ""That's Erotic"", …"


## Train a Custom Tokenizer

Use the `modernbert` tokenizer as a base, just reduce the vocabulary significantly and tailor it to this specific dataset.


In [3]:
from transformers import AutoTokenizer

json_docs = df["json"].to_list()

base_tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
print(str(base_tokenizer.vocab)[0:100])
print(len(base_tokenizer(json_docs[0])["input_ids"]))

{'Ġbanned': 20374, 'Ġexisting': 5368, 'Ġadmissions': 26120, 'lund': 45815, 'Ġ\\,\\': 28247, 'Ġfav': 
378


In [4]:
vocab_size = 5000

# don't train on all texts because it will take forever
tokenizer = base_tokenizer.train_new_from_iterator(
    iter(json_docs[:50000]), vocab_size=vocab_size,
    new_special_tokens=["  ", "    ", "      "]
)

print(str(tokenizer.vocab)[0:100])
print(len(tokenizer(json_docs[0])["input_ids"]))




{'Erik': 2006, 'ĠNguyen': 4954, 'tti': 1325, 'ifer': 1282, 'ler': 619, 'rass': 4854, 'Ã´me': 4557, '
368


Preencode all the tokens. A `max_length` of 1024 may be excessive but does not cause a proportionate reduction in model training speed over a 512 max length due to ModernBERT's unpadding + RoPE.

In order to avoid OOMs on the host system, generate in batches, then push to the GPU. (ideally we _could_ push to the GPU for each batch but that will cause GPU memory leaks)


In [5]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

In [6]:
max_length = 1024
token_batch_size = 2048
device = "cuda:0"

# input_ids = torch.empty((0, max_length)).to("cpu")
# attention_mask = torch.empty((0, max_length)).to("cpu")

input_ids = []
attention_mask = []

for docs in tqdm(batch(json_docs, token_batch_size), total=len(json_docs) // token_batch_size):
    tokens = tokenizer(docs,
                       max_length=max_length,
                       padding="max_length",
                       truncation=True,
                       return_tensors="pt").to("cpu")
    
    # input_ids = torch.vstack([input_ids, tokens["input_ids"]])
    # attention_mask = torch.vstack([attention_mask, tokens["attention_mask"]])
    
    input_ids.append(tokens["input_ids"])
    attention_mask.append(tokens["attention_mask"])
   
input_ids = torch.vstack(input_ids).to(device)
attention_mask = torch.vstack(attention_mask).to(device)

input_ids.size()

119it [03:38,  1.83s/it]                         
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


torch.Size([242552, 1024])

In [7]:
# token id 7 is the 2-space token for indents
input_ids[0][0:10]

tensor([  3, 100, 208,   7,  11, 359, 271, 267, 525, 377], device='cuda:0')

In [8]:
device = "cuda:0"
n_test = 20000

X_input_ids_train = input_ids[:-n_test].int().to(device)
X_input_ids_test = input_ids[-n_test:].int().to(device)

X_attention_train = attention_mask[:-n_test].int().to(device)
X_attention_test = attention_mask[-n_test:].int().to(device)

y_train = torch.from_numpy(df[:-n_test]["averageRating"].to_numpy().copy()).to(device)
y_test = torch.from_numpy(df[-n_test:]["averageRating"].to_numpy().copy()).to(device)

y_train

tensor([4.1000, 7.4000, 4.3000,  ..., 6.4000, 6.0000, 6.5000], device='cuda:0')

In [9]:
train_dataset = TensorDataset(X_input_ids_train, X_attention_train, y_train)
test_dataset = TensorDataset(X_input_ids_test, X_attention_test, y_test)

## Build the Model

Due to the new tokenizer, the special tokens for the fresh ModernBERT model have to be explicitly defined.


In [10]:
special_token_dict = dict(
    zip(tokenizer.special_tokens_map.keys(), tokenizer.all_special_ids)
)
special_token_dict

{'unk_token': 2,
 'sep_token': 4,
 'pad_token': 5,
 'cls_token': 3,
 'mask_token': 6,
 'additional_special_tokens': 7}

In [26]:
hidden_size = 128
dropout = 0.5

config = ModernBertConfig(
    vocab_size=vocab_size,
    max_position_embeddings=max_length,
    hidden_size=hidden_size,
    intermediate_size=512,
    num_hidden_layers=6,
    num_attention_heads=4,
    global_attn_every_n_layers=2,
    local_attention=16,
    attention_dropout=dropout,
    embeddings_dropout=dropout,
    mlp_dropout=dropout,
    unk_token_id=special_token_dict["unk_token"],
    sep_token_id=special_token_dict["sep_token"],
    pad_token_id=special_token_dict["pad_token"],
    cls_token_id=special_token_dict["cls_token"],
    mask_token_id=special_token_dict["mask_token"],
)

transformer_model = AutoModel.from_config(config)
total_params = sum(p.numel() for p in transformer_model.parameters())
total_params

2214528

In [27]:
class RatingsModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.transformer_model = model
        self.output = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask, targets=None):
        x = self.transformer_model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        x = x.last_hidden_state[:, 0]  # the [CLS] vector
        x = self.output(x)

        return x.squeeze()  # return 1D output

In [28]:
model = RatingsModel(transformer_model)
_ = model.to(device)

torch.set_float32_matmul_precision('high')  # perf increase for ModernBERT

Validation loss doesn't play nice with the `Trainer` out of the boss, so need [some tweaks](https://discuss.huggingface.co/t/no-log-for-validation-loss-during-training-with-trainer/40094/3).


In [29]:
def collate_fn(examples):
    input_ids = torch.stack([f[0] for f in examples])
    attention_masks = torch.stack([f[1] for f in examples])
    targets = torch.stack([f[2] for f in examples])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "targets": targets,
    }


class RegressionTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=0):
        outputs = model(**inputs)
        # loss = nn.L1Loss()(outputs, inputs["targets"])  # L1 loss is MAE
        loss = nn.MSELoss()(outputs, inputs["targets"])

        return (loss, outputs) if return_outputs else loss


In [30]:
training_args = TrainingArguments(
    learning_rate=5e-5,
    lr_scheduler_type="cosine",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=20,
    weight_decay=0.01,
    save_strategy="no",
    eval_strategy="steps",
    eval_steps=0.05,
    logging_strategy="steps",
    logging_steps=0.05,
    fp16=True,
    dataloader_num_workers=0,  # since data is in memory, as problem is GPU bound
    dataloader_pin_memory=False,
    dataloader_persistent_workers=False,
)

# reinstantiate a clean model
transformer_model = AutoModel.from_config(config)
model = RatingsModel(transformer_model)
_ = model.to(device)

trainer = RegressionTrainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
)

trainer.can_return_loss = True

In [31]:
trainer.train()

Step,Training Loss,Validation Loss
3478,1.6507,1.121598
6956,1.1454,1.094314
10434,1.119,1.097998
13912,1.0976,1.075459
17390,1.0755,1.076842
20868,1.0517,1.054685
24346,1.0301,1.049676
27824,1.01,1.050678
31302,0.9928,1.045853
34780,0.9741,1.051777


TrainOutput(global_step=69560, training_loss=1.0079467422459027, metrics={'train_runtime': 1870.1251, 'train_samples_per_second': 2380.076, 'train_steps_per_second': 37.195, 'total_flos': 0.0, 'train_loss': 1.0079467422459027, 'epoch': 20.0})

In [32]:
logs = trainer.state.log_history

logs[0:4]

[{'loss': 1.6507,
  'grad_norm': 19.227664947509766,
  'learning_rate': 4.96934436859867e-05,
  'epoch': 1.0,
  'step': 3478},
 {'eval_loss': 1.1215980052947998,
  'eval_runtime': 3.0562,
  'eval_samples_per_second': 6544.165,
  'eval_steps_per_second': 102.416,
  'epoch': 1.0,
  'step': 3478},
 {'loss': 1.1454,
  'grad_norm': 10.553370475769043,
  'learning_rate': 4.8778854084940764e-05,
  'epoch': 2.0,
  'step': 6956},
 {'eval_loss': 1.0943138599395752,
  'eval_runtime': 2.728,
  'eval_samples_per_second': 7331.389,
  'eval_steps_per_second': 114.736,
  'epoch': 2.0,
  'step': 6956}]

In [33]:
logs_consolidated = []
i = 0
while i < len(logs)-1:
    base_log = logs[i]
    base_log.update(logs[i+1])
    logs_consolidated.append(base_log)
    i += 2
    
df_logs = pl.DataFrame(logs_consolidated).sort("epoch")
df_logs.write_parquet("llm_scratch_train_logs.parquet")
df_logs

loss,grad_norm,learning_rate,epoch,step,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second
f64,f64,f64,f64,i64,f64,f64,f64,f64
1.6507,19.227665,0.00005,1.0,3478,1.121598,3.0562,6544.165,102.416
1.1454,10.55337,0.000049,2.0,6956,1.094314,2.728,7331.389,114.736
1.119,4.867875,0.000047,3.0,10434,1.097998,2.725,7339.382,114.861
1.0976,9.907719,0.000045,4.0,13912,1.075459,2.695,7421.251,116.143
1.0755,6.433779,0.000043,5.0,17390,1.076842,2.7045,7395.133,115.734
…,…,…,…,…,…,…,…,…
0.8877,15.479814,0.000005,16.0,55648,1.073634,2.7298,7326.416,114.658
0.8804,13.94219,0.000003,17.0,59126,1.073171,2.6996,7408.445,115.942
0.8748,13.542985,0.000001,18.0,62604,1.079629,2.7387,7302.754,114.288
0.8714,12.781465,3.1311e-7,19.0,66082,1.079943,2.6673,7498.317,117.349


In [34]:
from safetensors.torch import save_model

save_model(model, "imdb_embeddings_llm_scratch.safetensors")

## Test Model

In this case, need to evaluate the LLM in batches to avoid going OOM.

In [35]:
import numpy as np

_ = model.to("cuda:0").eval()  # to disable BatchNorm1D

dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64,
                                         shuffle=False,
                                         pin_memory=False)
preds_bucket = []

for batch in tqdm(dataloader, smoothing=0):
    with torch.no_grad():
        output = model(input_ids=batch[0],
                       attention_mask=batch[1])
        preds = output.detach().cpu().numpy()

    preds_bucket.append(preds)
        
actual_values = torch.stack([f[2] for f in test_dataset])

test_results = (pl.DataFrame({"Predicted": np.hstack(preds_bucket), "Actual": actual_values.cpu().numpy()})
                .with_columns(
                    abs_diff=(pl.col("Predicted") - pl.col("Actual")).abs(),
                    square_error = ((pl.col("Actual") - pl.col("Predicted")) ** 2)
                )
               )
                
test_results

100%|██████████| 313/313 [00:03<00:00, 103.37it/s]


Predicted,Actual,abs_diff,square_error
f32,f32,f32,f32
7.160156,7.1,0.060156,0.003619
6.417969,6.5,0.082031,0.006729
6.140625,4.1,2.040625,4.164151
6.042969,5.5,0.542969,0.294815
7.324219,7.2,0.124219,0.01543
…,…,…,…
6.484375,6.2,0.284375,0.080869
4.2421875,3.2,1.042187,1.086155
5.90625,5.7,0.20625,0.042539
6.34375,5.7,0.64375,0.414414


In [36]:
# Mean Absolute Error (MAE)
test_results["abs_diff"].mean()

0.7747781070888042

In [37]:
# Mean Square Error (MSE)
test_results["square_error"].mean()

1.0805600972121756