## PyTorch Training

Uses the Trainer included in Hugging Face `transformers` (backed by `accelerate`) since it mitigates a lot of annoying boilerplate.


In [1]:
import os
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from transformers import Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback

In [2]:
df = (
    pl.scan_parquet(
        "movie_data_plus_embeds_all.parquet"
    )
    .select(["tconst", "averageRating", "embedding"])
    .with_columns(averageRating=pl.col("averageRating").cast(pl.Float32))
    .collect()
    .sample(fraction=1.0, shuffle=True, seed=42)
)

df

tconst,averageRating,embedding
str,f32,"array[f32, 768]"
"""tt0173052""",4.1,"[0.046187, 0.006053, … 0.011911]"
"""tt0266288""",7.4,"[-0.004875, -0.046969, … 0.017516]"
"""tt6263490""",4.3,"[0.005363, -0.018672, … 0.015112]"
"""tt10049110""",7.8,"[-0.009997, -0.029303, … 0.037793]"
"""tt5761612""",3.8,"[0.020259, -0.031869, … -0.01841]"
…,…,…
"""tt0079376""",6.2,"[0.062672, -0.009446, … 0.019441]"
"""tt1161064""",3.2,"[0.022779, 0.053063, … -0.009691]"
"""tt0179526""",5.7,"[0.001937, 0.003111, … -0.002453]"
"""tt0188233""",5.7,"[0.03125, 0.013802, … 0.009849]"


In [3]:
device = "cpu"
n_test = 20000

X_train = torch.from_numpy(df[:-n_test]["embedding"].to_numpy().copy()).to(device)
X_test = torch.from_numpy(df[-n_test:]["embedding"].to_numpy().copy()).to(device)

y_train = torch.from_numpy(df[:-n_test]["averageRating"].to_numpy().copy()).to(device)
y_test = torch.from_numpy(df[-n_test:]["averageRating"].to_numpy().copy()).to(device)

y_train

tensor([4.1000, 7.4000, 4.3000,  ..., 6.4000, 6.0000, 6.5000])

In [4]:
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

In [5]:
class RatingsModel(nn.Module):
    def __init__(self, linear_dims=256, num_layers=8):
        super().__init__()
        
        dims = [768] + [linear_dims] * num_layers
        self.mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dims[i], dims[i+1]),
                nn.GELU(),
                nn.BatchNorm1d(dims[i+1]),
                nn.Dropout(0.5)
            ) for i in range(len(dims)-1)
        ])
        
        self.output = nn.Linear(dims[-1], 1)

    def forward(self, x, targets=None):
        for layer in self.mlp:
            x = layer(x)

        return self.output(x).squeeze()  # return 1D output

In [6]:
model = RatingsModel()
_ = model.to(device)
model

RatingsModel(
  (mlp): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=768, out_features=256, bias=True)
      (1): GELU(approximate='none')
      (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.5, inplace=False)
    )
    (1-7): 7 x Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): GELU(approximate='none')
      (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.5, inplace=False)
    )
  )
  (output): Linear(in_features=256, out_features=1, bias=True)
)

In [7]:
total_params = sum(p.numel() for p in model.parameters())
total_params

661761

Validation loss doesn't play nice with the `Trainer` out of the boss, so need [some tweaks](https://discuss.huggingface.co/t/no-log-for-validation-loss-during-training-with-trainer/40094/3).


In [8]:
def collate_fn(examples):
    inputs = torch.stack([f[0] for f in examples])
    targets = torch.stack([f[1] for f in examples])

    return {"x": inputs, "targets": targets}


class MAETrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=0):
        outputs = model(**inputs)
        # loss = nn.L1Loss()(outputs, inputs["targets"])  # L1 loss is MAE
        loss = nn.MSELoss()(outputs, inputs["targets"])

        return (loss, outputs) if return_outputs else loss


In [9]:
training_args = TrainingArguments(
    learning_rate=2e-3,
    lr_scheduler_type="cosine_with_restarts",
    per_device_train_batch_size=4096,
    per_device_eval_batch_size=4096,
    num_train_epochs=600,
    weight_decay=0.001,
    save_strategy="no",
    eval_strategy="steps",
    eval_steps=0.05,
    logging_strategy="steps",
    logging_steps=0.05,
    fp16=False,  # do not need memory savings
    dataloader_num_workers=os.cpu_count(),  # since big batches
    dataloader_pin_memory=True,
    dataloader_persistent_workers=True,
)

# reinstantiate a clean model
model = RatingsModel()
_ = model.to('cuda:0')

trainer = MAETrainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
)

trainer.can_return_loss = True

In [10]:
trainer.train()

Step,Training Loss,Validation Loss
1650,2.5285,1.10619
3300,1.151,1.079443
4950,1.0449,1.079204
6600,0.9739,1.118531
8250,0.9267,1.108703
9900,0.8944,1.131433
11550,0.8696,1.123753
13200,0.8501,1.167039
14850,0.8343,1.188981
16500,0.8202,1.158032


TrainOutput(global_step=33000, training_loss=0.9325652743252841, metrics={'train_runtime': 1049.8348, 'train_samples_per_second': 127192.584, 'train_steps_per_second': 31.434, 'total_flos': 0.0, 'train_loss': 0.9325652743252841, 'epoch': 600.0})

Write logs as a parquet for later visualization. The steps logged match what's above, although the evals are a separate item and need to be consolidated (annoyingly).

In [11]:
logs = trainer.state.log_history

logs[0:4]

[{'loss': 2.5285,
  'grad_norm': 0.6826887726783752,
  'learning_rate': 0.001987703228645653,
  'epoch': 30.0,
  'step': 1650},
 {'eval_loss': 1.1061903238296509,
  'eval_runtime': 0.3411,
  'eval_samples_per_second': 58629.983,
  'eval_steps_per_second': 14.657,
  'epoch': 30.0,
  'step': 1650},
 {'loss': 1.151,
  'grad_norm': 0.6862286925315857,
  'learning_rate': 0.0019510859303344694,
  'epoch': 60.0,
  'step': 3300},
 {'eval_loss': 1.0794429779052734,
  'eval_runtime': 0.249,
  'eval_samples_per_second': 80318.032,
  'eval_steps_per_second': 20.08,
  'epoch': 60.0,
  'step': 3300}]

In [14]:
logs_consolidated = []
i = 0
while i < len(logs)-1:
    base_log = logs[i]
    base_log.update(logs[i+1])
    logs_consolidated.append(base_log)
    i += 2
    
df_logs = pl.DataFrame(logs_consolidated).sort("epoch")
df_logs.write_parquet("mlp_train_logs.parquet")
df_logs

loss,grad_norm,learning_rate,epoch,step,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second
f64,f64,f64,f64,i64,f64,f64,f64,f64
2.5285,0.682689,0.001988,30.0,1650,1.10619,0.3411,58629.983,14.657
1.151,0.686229,0.001951,60.0,3300,1.079443,0.249,80318.032,20.08
1.0449,0.602015,0.001891,90.0,4950,1.079204,0.2411,82941.87,20.735
0.9739,0.92719,0.001809,120.0,6600,1.118531,0.2498,80075.526,20.019
0.9267,0.607478,0.001707,150.0,8250,1.108703,0.2444,81844.869,20.461
…,…,…,…,…,…,…,…,…
0.7682,0.671049,0.000191,480.0,26400,1.207095,0.2515,79516.716,19.879
0.7636,0.707246,0.000109,510.0,28050,1.209251,0.3085,64821.345,16.205
0.7606,0.710131,0.000049,540.0,29700,1.21387,0.249,80315.802,20.079
0.7593,0.769773,0.000012,570.0,31350,1.212028,0.2591,77175.657,19.294


Save the model weights, which is the artifact we would use to deploy the model.

In [15]:
from safetensors.torch import save_model

save_model(model, "imdb_embeddings_mlp.safetensors")

## Test Model


In [16]:
_ = model.to("cuda:0").eval()  # to disable BatchNorm1D

eval_dataset = torch.stack([f[0] for f in test_dataset])
actual_values = torch.stack([f[1] for f in test_dataset])

with torch.no_grad():
    output = model(x=eval_dataset.to("cuda:0"))
    preds = output.detach().cpu().numpy()

test_results = (pl.DataFrame({"Predicted": preds, "Actual": actual_values.cpu().numpy()})
                .with_columns(
                    abs_diff=(pl.col("Predicted") - pl.col("Actual")).abs(),
                    square_error = ((pl.col("Actual") - pl.col("Predicted")) ** 2)
                )
               )
                
test_results

Predicted,Actual,abs_diff,square_error
f32,f32,f32,f32
6.923413,7.1,0.176587,0.031183
6.080589,6.5,0.419411,0.175906
5.10783,4.1,1.00783,1.015722
5.408305,5.5,0.091695,0.008408
7.046608,7.2,0.153392,0.023529
…,…,…,…
5.903197,6.2,0.296803,0.088092
4.862528,3.2,1.662528,2.763999
5.760687,5.7,0.060687,0.003683
5.753356,5.7,0.053356,0.002847


In [17]:
# Mean Absolute Error (MAE)
test_results["abs_diff"].mean()

0.8274261032521725

In [18]:
# Mean Square Error (MSE)
test_results["square_error"].mean()

1.2076212142897127