# **PyTorch Lightning**


```bash
pip install pytorch-lightning
```

---

## **Understand the Lightning architecture**

Lightning separates your code into **three parts**:

1. **LightningModule** → model + training logic
2. **LightningDataModule (optional)** → dataloaders
3. **Trainer** → hardware, training strategy, logging, callbacks

This separation keeps the code clean.

---

## **Create LightningModule**

Inside this module you implement:

1. `__init__` → define model, loss, hyperparameters
2. `forward` → inference
3. `training_step` → compute loss
4. `validation_step` → compute val metrics
5. `configure_optimizers` → define optimizer & scheduler

Minimal example:

```python
class LitModel(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.model = nn.Linear(10, 2)
        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        # Used mainly for visualization.
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        # Used for:
        #   - model selection
        #   - early stopping
        #   - monitoring generalization

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
```

---

## **Create DataLoaders**

Three ways:

### A. Write them normally (simplest)

```python
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32)
```

### B. Or create a LightningDataModule (cleaner for big projects)

## LightningDataModule


A **LightningDataModule** is a standardized place to put:

* dataset download
* dataset preprocessing
* train/val/test splits
* augmentations
* dataloader creation
* batch size configuration
* multiple workers, pin_memory, shuffling

Instead of spreading your data logic across different files, everything is put into **one clean class**.

It replaces messy code like this scattered everywhere:

```python
train_dataset = ...
train_loader = DataLoader(...)
val_loader = DataLoader(...)
test_loader = DataLoader(...)
```

with a **well-organized object**.

---

## **What does a DataModule contain? (5 lifecycle hooks)**

LightningDataModule has **five** common methods you implement:

#### **1. `__init__`**

Store configuration like:

* paths
* batch size
* augmentations
* num_workers

#### **2. `prepare_data()`**

Used only **once** on one GPU:

* download dataset
* tokenize
* heavy preprocessing

#### **3. `setup(stage)`**

Called on **every GPU**:

* create train/val/test splits
* set transforms
* load datasets into memory if needed

#### **4. `train_dataloader()`**

Return a PyTorch DataLoader for training.

#### **5. `val_dataloader()`**

Return DataLoader for validation.

#### **6. `test_dataloader()` (optional)**

#### **7. `predict_dataloader()` (optional)**

---

## **Minimal Example**

```python
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def prepare_data(self):
        # downloads only once
        torchvision.datasets.CIFAR10(self.data_dir, train=True, download=True)
        torchvision.datasets.CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # transformations
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        if stage == "fit" or stage is None:
            full = torchvision.datasets.CIFAR10(self.data_dir, train=True, transform=transform)
            self.train_ds, self.val_ds = random_split(full, [45000, 5000])

        if stage == "test" or stage is None:
            self.test_ds = torchvision.datasets.CIFAR10(self.data_dir, train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size)
```
---

## **Configure the Trainer**

This controls training, GPUs, logging, precision, etc.

Examples:

### Single GPU

```python
trainer = pl.Trainer(accelerator="gpu", devices=1)
```

### Mixed precision (recommended)

```python
trainer = pl.Trainer(precision="16-mixed")
```

### Multi-GPU

```python
trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp")
```

### Enable checkpoints & early stopping

```python
trainer = pl.Trainer(
    callbacks=[
        ModelCheckpoint(monitor="val_loss", save_top_k=1),
        EarlyStopping(monitor="val_loss", patience=5)
    ]
)
```

---

## **Train the model**

If using dataloaders:

```python
trainer.fit(model, train_loader, val_loader)
```

If using a DataModule:

```python
dm = MyDataModule()
trainer.fit(model, dm)
```


## **Test (optional)**

Add a `test_step` and then:

```python
trainer.test(model, datamodule=dm)
```

---

## **Load a checkpoint for inference**

```python
model = LitModel.load_from_checkpoint("path/to/check.ckpt")

model.eval()
with torch.no_grad():
    preds = model(x)
```

---


## **PyTorch Lightning callbacks**
---

## **1. BatchSizeFinder**

Automatically finds the **largest batch size** that fits in GPU memory.

Usage:

```python
trainer = Trainer(auto_scale_batch_size="power")
```

It increases batch size until out-of-memory occurs, then backs off.

---

## **2. BackboneFinetuning**

For transfer learning or timm/MONAI models.

It:

* freezes backbone layers at the start
* unfreezes them after a number of epochs
* optionally sets different learning rates per layer

Use case:

* Fine-tuning ResNet
* ViT two-stage training
* MONAI UNet encoder freezing

Example:

```python
BackboneFinetuning(unfreeze_backbone_at_epoch=5)
```

---


## **3. ModelCheckpoint**

Automatically saves model checkpoints based on a metric.

Example:

```python
ModelCheckpoint(
    monitor="val_loss",
    save_top_k=3,
    mode="min"
)
```

Used for:

* saving best model
* saving last model
* resuming training
* model selection

---

## **4. DeviceStatsMonitor**

Logs:

* GPU memory
* GPU utilization
* CPU usage

Very useful with WandB or TensorBoard.

```python
trainer = Trainer(callbacks=[DeviceStatsMonitor()])
```

---

## **5. EarlyStopping**

Stops training when a metric stops improving.

Example:

```python
EarlyStopping(
    monitor="val_loss",
    patience=5,
    mode="min"
)
```

Prevents overfitting and wasted GPU hours.

---

## **6. GradientAccumulationScheduler**

Changes gradient accumulation behavior at specific epochs.

Useful for:

* warmup of effective batch size
* dynamic training strategies
* stabilizing early epochs

Example:

```python
GradientAccumulationScheduler({
    0: 1,  # epoch 0: accumulate 1 batch
    5: 4,  # epoch 5+: accumulate 4 batches
})
```

This gradually increases effective batch size.

---


## **7. LearningRateMonitor**

Logs learning rate each step or epoch.

```python
LearningRateMonitor(logging_interval="epoch")
```

Essential for debugging schedulers.

---

## **8. RichProgressBar**

Replaces default TQDM bar with a beautiful rich-based display.

```python
trainer = Trainer(callbacks=[RichProgressBar()])
```

---

## **9. StochasticWeightAveraging (SWA)**

Improves generalization by averaging weights over multiple checkpoints.

```python
StochasticWeightAveraging(swa_lrs=1e-3)
```

---

## **10. ModelPruning**

Automatically prunes weights during training.

```python
from pytorch_lightning.callbacks import ModelPruning
```

Sparsity improves model size and inference.

---

## **11. ModelSummary**

Shows a clean summary of your architecture at the start of training.

```python
ModelSummary(max_depth=2)
```

---

## **12. QuantizationAwareTraining**

Integrates PyTorch quantization with Lightning.

---

## **13. Timer**

Allows you to stop training based on total training time.

```python
Timer(duration="00:30:00")
```

---

## **14. FaultTolerantTraining**

Restart training automatically after crashes.

---


## **15. LambdaCallback**

A simple callback where you define your own functions without writing a class.

Example:

```python
callback = LambdaCallback(
    on_train_start=lambda trainer, pl_module: print("Training started"),
)
```

---

# **4. When do callbacks run?**

Callbacks have dozens of hooks:

* `on_train_start`
* `on_train_end`
* `on_train_batch_start`
* `on_train_batch_end`
* `on_validation_start`
* `on_validation_end`
* `on_validation_batch_end`
* `on_fit_end`
* `on_exception`
* `on_before_optimizer_step`
* `on_save_checkpoint`
* `on_load_checkpoint`

This makes callbacks extremely powerful for custom logic.

---

# **5. How to use callbacks**

Just pass a list to Trainer:

```python
trainer = Trainer(
    callbacks=[
        ModelCheckpoint(monitor="val_loss"),
        EarlyStopping(monitor="val_loss"),
        LearningRateMonitor(),
        DeviceStatsMonitor(),
        RichProgressBar(),
    ]
)
```

Callbacks do not require modifying your LightningModule.

---


# **Safely Combining Gradient Accumulation, Batch Size Finder, and Model Pruning in PyTorch Lightning**

When training deep learning models, it’s common to optimize:

* **Memory usage** (bigger batches, mixed precision)
* **Training stability** (gradient accumulation)
* **Model efficiency** (pruning)
* **Throughput** (optimal batch size)

However, these features interact in complex ways inside PyTorch Lightning.
This tutorial shows the **correct and safe workflow** for combining:

* `GradientAccumulationScheduler`
* `BatchSizeFinder`
* `ModelPruning`
* Standard callbacks (checkpointing, early stopping, logging)

We explain why certain combinations must be avoided and how to structure your training code properly.

---

# **1. Why Batch Size Finder Cannot Run Together with Gradient Accumulation or Model Pruning**

### **BatchSizeFinder determines the maximum batch size that fits in GPU memory.**

It does that by testing multiple forward/backward passes:

* With accumulation **disabled**
* With pruning **not yet applied**

### But gradient accumulation affects memory:

* Gradients remain in memory across several backward passes
* This increases peak memory later in training

### And model pruning affects memory too:

* Pruning changes which parameters exist
* It changes activation shapes
* It changes memory usage dynamically

Because of that:

### **BatchSizeFinder’s batch-size probe is NOT valid once gradient accumulation or pruning is active.**

This is why:

#### ❌ You must NOT include BatchSizeFinder in the same trainer that uses:

* `GradientAccumulationScheduler`
* `ModelPruning`

---

# **2. Correct Workflow (Two-Stage Training Strategy)**

## **Stage 1 — (Optional) Tune batch size in a clean environment**

Use a small, dedicated trainer:

```python
batch_size_finder = BatchSizeFinder(mode="power", max_val=max_batch_size)

tune_trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=1,
    callbacks=[batch_size_finder]
)

tune_trainer.tune(model, datamodule=data_module)
```

After this:

1. Read the suggested batch size
2. Update your `LitDataModule(batch_size=...)`
3. Remove BatchSizeFinder entirely

Now you have a safe per-step batch size.

---

## **Stage 2 — Train normally with gradient accumulation + pruning**

Once the batch size is known, you can add your full callback stack:

* Gradient accumulation scheduler
* Pruning
* Checkpoint
* Early stopping
* LR monitor
* Device stats
* Model summary

Example:

```python
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    logger=wandb_logger,
    max_epochs=100,
    log_every_n_steps=10,
    callbacks=[
        DebugCallback(),
        checkpoint,
        early_stopping,
        learning_rate_monitor,
        device_stats_monitor,
        model_summary,
        gradient_accumulation_scheduler,
        model_pruning,
    ]
)
```

### Why does this work?

Because now:

* Batch size is fixed
* Gradient accumulation is allowed
* Pruning no longer interferes with batch-size discovery
* Training is stable
* Mixed precision reduces memory and improves speed

---

# **3. How GradientAccumulationScheduler Works**

You defined:

```python
GradientAccumulationScheduler({
    0: 1,
    5: 4,
})
```

Meaning:

* Epochs 0–4 → accumulate 1 batch (normal training)
* Epochs 5+ → accumulate 4 batches before optimizer step

This increases **effective batch size**:

If per-step batch size = 32:

Epochs 0–4:
$$
32 \times 1 = 32
$$

Epochs 5+:
$$
32 \times 4 = 128
$$

This is stable because batch size was determined *before* enabling accumulation.

---

# **4. Why accumulate_grad_batches must NOT be set manually**

Lightning requires:

```python
accumulate_grad_batches = 1
```

when using `GradientAccumulationScheduler`.

The callback mutates this value **during training**, so you must not override it manually in `Trainer()`.

You correctly kept it default:

```python
# Do NOT set accumulate_grad_batches manually
```

---

# **5. Why pruning is safe only during normal training**

Your pruning callback:

```python
model_pruning = ModelPruning(
    pruning_fn="l1_unstructured",
    amount=0.5
)
```

Pruning happens during training, so the graph and memory footprint change progressively.

This is why pruning cannot run during the BatchSizeFinder stage.

---

# **6. Putting It All Together**

Your final trainer:

```python
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    logger=wandb_logger,
    max_epochs=100,
    log_every_n_steps=10,
    callbacks=[
        DebugCallback(),
        checkpoint,
        early_stopping,
        learning_rate_monitor,
        device_stats_monitor,
        model_summary,
        gradient_accumulation_scheduler,
        model_pruning
    ]
)
```

### This setup is now correct because:

* Batch size is fixed and known to be safe
* Gradient accumulation dynamically increases effective batch size
* Mixed precision reduces memory
* Pruning modifies layers safely during real training
* All logging + monitoring callbacks work together
* BatchSizeFinder is isolated in a separate tuning step

---

# **7. Final Recommended Workflow Diagram**

```
[ Stage 1: Batch Size Tuning ]
---------------------------------------
Trainer(auto_scale_batch_size=True)
     ↓
Find max batch size that fits in memory
     ↓
Update LitDataModule(batch_size=X)
     ↓
Remove BatchSizeFinder


[ Stage 2: Real Training ]
---------------------------------------
Trainer(
    callbacks = [
        ModelCheckpoint,
        EarlyStopping,
        DeviceStatsMonitor,
        GradientAccumulationScheduler,
        ModelPruning,
        ...
    ]
)
trainer.fit(...)
```

This tutorial captures the **correct best-practice workflow** for:

* batch-size tuning
* gradient accumulation
* model pruning
* safe memory usage
* stable mixed-precision training
* multi-callback integration

---

If you want, I can also write:

* A Markdown version of this tutorial
* A minimal working example (MWE) script
* A version formatted for GitHub README
* A diagram explaining memory flow with accumulation vs non-accumulation
