## PyTorch Training

Uses the Trainer included in Hugging Face `transformers` (backed by `accelerate`) since it mitigates a lot of annoying boilerplate.


In [1]:
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from transformers import Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = (
    pl.scan_parquet(
        "/Users/maxwoolf/Downloads/movie_data_plus_embeds_all.parquet", n_rows=10000
    )
    .select(["tconst", "averageRating", "embedding"])
    .with_columns(averageRating=pl.col("averageRating").cast(pl.Float32))
    .collect()
)

df

tconst,averageRating,embedding
str,f32,"array[f32, 768]"
"""tt0000009""",5.4,"[-0.007815, -0.022642, … 0.005391]"
"""tt0000147""",5.3,"[0.012021, 0.014255, … -0.015754]"
"""tt0000574""",6.0,"[-0.010052, -0.015825, … 0.040161]"
"""tt0000591""",5.6,"[0.00765, 0.019661, … -0.010763]"
"""tt0000630""",3.2,"[0.03492, 0.00301, … 0.027586]"
…,…,…
"""tt0035474""",6.9,"[0.007757, -0.011224, … 0.038445]"
"""tt0035475""",6.6,"[0.005858, 0.008654, … 0.039309]"
"""tt0035477""",6.0,"[0.007687, -0.020819, … 0.040466]"
"""tt0035479""",5.4,"[0.03527, 0.018015, … 0.024018]"


In [3]:
device = "cpu"

tensor_embeddings = torch.from_numpy(df["embedding"].to_numpy().copy()).to(device)
tensor_ratings = torch.from_numpy(df["averageRating"].to_numpy().copy()).to(device)

tensor_embeddings.size()

torch.Size([10000, 768])

In [4]:
train_dataset = TensorDataset(tensor_embeddings, tensor_ratings)

In [5]:
class RatingsModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(768, 1536)
        self.batchnorm_1 = nn.BatchNorm1d(1536)
        self.linear_2 = nn.Linear(1536, 768)
        self.batchnorm_2 = nn.BatchNorm1d(768)
        self.linear_3 = nn.Linear(768, 256)
        self.batchnorm_3 = nn.BatchNorm1d(256)
        self.output = nn.Linear(256, 1)

    def forward(self, x):
        x = F.gelu(self.linear_1(x))
        x = self.batchnorm_1(x)
        x = F.gelu(self.linear_2(x))
        x = self.batchnorm_2(x)
        x = F.gelu(self.linear_3(x))
        x = self.batchnorm_3(x)
        x = self.output(x)

        return x.squeeze()  # return 1D output

In [6]:
model = RatingsModel()
_ = model.to(device)
model

RatingsModel(
  (linear_1): Linear(in_features=768, out_features=1536, bias=True)
  (batchnorm_1): BatchNorm1d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear_2): Linear(in_features=1536, out_features=768, bias=True)
  (batchnorm_2): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear_3): Linear(in_features=768, out_features=256, bias=True)
  (batchnorm_3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (output): Linear(in_features=256, out_features=1, bias=True)
)

In [7]:
def collate_fn(examples):
    inputs = torch.stack([f[0] for f in examples])
    outputs = torch.stack([f[1] for f in examples])

    return (inputs, outputs)


class MAETrainer(Trainer):
    def compute_loss(self, model, inputs, num_items_in_batch):
        preds = model(inputs[0])
        loss = nn.L1Loss()(preds, inputs[1])  # L1 loss is MAE

        return loss


In [8]:
training_args = TrainingArguments(
    learning_rate=1e-2,
    lr_scheduler_type="cosine_with_restarts",
    per_device_train_batch_size=128,
    num_train_epochs=10,
    weight_decay=0.001,
    save_strategy="no",
    eval_strategy="no",
    logging_strategy="steps",
    logging_steps=0.1,
    fp16=False,
    dataloader_num_workers=0,  # since data is in memory
    dataloader_pin_memory=False,
    dataloader_persistent_workers=False,
)

# reinstantiate a clean model
model = RatingsModel()
_ = model.to(device)

trainer = MAETrainer(
    model, training_args, train_dataset=train_dataset, data_collator=collate_fn
)

In [9]:
trainer.train()

Step,Training Loss
79,1.9468
158,0.6233
237,0.6207
316,0.6157
395,0.5777
474,0.5684
553,0.5532
632,0.5363
711,0.513
790,0.5061


TrainOutput(global_step=790, training_loss=0.7061055243769778, metrics={'train_runtime': 5.8204, 'train_samples_per_second': 17181.038, 'train_steps_per_second': 135.73, 'total_flos': 0.0, 'train_loss': 0.7061055243769778, 'epoch': 10.0})