In [None]:
def train_model(model, dataset, training_args):
    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    train_dataloader = DataLoader(dataset, batch_size=training_args.per_device_train_batch_size)
    optimizer = AdamW(model.parameters())
    model.train()
    gpu_utilization_printed = False
    for step, batch in enumerate(train_dataloader, start=1):
        batch = {k: v.to(model.device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss
        loss = loss / training_args.gradient_accumulation_steps
        loss.backward()

        if step % training_args.gradient_accumulation_steps == 0:
            optimizer.step()
            gradients_memory = estimate_memory_of_gradients(model)
            optimizer_memory = estimate_memory_of_optimizer(optimizer)

            if not gpu_utilization_printed:
                print_gpu_utilization()
                gpu_utilization_printed = True

            optimizer.zero_grad()

In [None]:
cleanup()
print_gpu_utilization()

gpu_memory_experiment(batch_size=4, gradient_accumulation_steps=4)
gpu_memory_experiment(batch_size=4, gradient_checkpointing=True)

torch.cuda.empty_cache()