In [20]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [27]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64,3),
        )

    def forward(self, x):
        return self.l1(x)
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28)
        )

    def forward(self, x):
        return self.l1(x)
    
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log('Validation loss', val_loss)
        return val_loss
    
    def test_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log('Validation loss', test_loss)
        return test_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [4]:
# Data
transform = transforms.ToTensor()
train_set = MNIST(
    root=os.getcwd(),
    download=True,
    train=True,
    transform=transform
)
test_set = MNIST(
    root=os.getcwd(),
    download=True,
    train=False,
    transform=transform
)

In [5]:
# Model
autoencoder = LitAutoEncoder(Encoder(), Decoder())

In [None]:
# Early Stopping
class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # override this to disable early stopping at the end of val loop
        pass

    def on_train_end(self, trainer, pl_module):
        # instead, do it at the end of training loop
        self._run_early_stopping_check(trainer)

In [33]:
# Debug the model before training
trainer_debug = L.Trainer(
        fast_dev_run=3,
    limit_train_batches=0.1, # use only 10 of training data
    limit_val_batches=0.01, # use only 1% of validation data
)
trainer_debug.fit(
    model=autoencoder,
    train_dataloaders=DataLoader(train_set)
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | encoder | Encoder | 50.4 K | train
1 | decoder | Decoder | 51.2 K | train
--------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode
/Users/khueluu/.local/share/virtualenvs/lightning-playground-veOcVop1/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 104.80it/s]

`Trainer.fit` stopped: `max_steps=3` reached.


Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 102.40it/s]


In [40]:
# Run sanity check and print weights summary
from lightning.pytorch.callbacks import ModelSummary

summary = ModelSummary(autoencoder)
print(summary)

trainer_debug = L.Trainer(
    profiler='simple',
    num_sanity_val_steps=2,
    callbacks=[
        ModelSummary(max_depth=-1)
    ]
)
trainer_debug.fit(
    model=autoencoder,
    train_dataloaders=DataLoader(train_set)
)

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/khueluu/.local/share/virtualenvs/lightning-playground-veOcVop1/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name         | Type       | Params | Mode 
----------------------------------------------------
0 | encoder      | Encoder    | 50.4 K | train
1 | encoder.l1   | Sequential | 50.4 K | train
2 | encoder.l1.0 | Linear     | 50.2 K | train
3 | encoder.l1.1 | ReLU       | 0      | train
4 | encoder.l1.2 | Linear     | 195    | train
5 | decoder      | Decoder    | 51.2 K | train
6 | decoder.l1   | Sequential | 51.2 K | train
7 | decoder.l1.0 | Linear     |

<lightning.pytorch.callbacks.model_summary.ModelSummary object at 0x14730be60>
Epoch 2:  82%|████████▏ | 48974/60000 [04:06<00:55, 199.07it/s, v_num=15]

In [34]:
# Trainer
path_to_ckpt = '/Users/khueluu/Documents/Projects/lightning-playground/lightning_logs/version_6/checkpoints/epoch=14-step=900000.ckpt'
trainer = L.Trainer(
    callbacks=[
        EarlyStopping(monitor='val_loss', mode='min')
    ]
)
trainer.fit(
    model=autoencoder,
    train_dataloaders=DataLoader(train_set),
    # ckpt_path=path_to_ckpt
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/khueluu/.local/share/virtualenvs/lightning-playground-veOcVop1/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | encoder | Encoder | 50.4 K | train
1 | decoder | Decoder | 51.2 K | train
--------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Epoch 0:   1%|          | 688/60000 [00:03<04:59, 197.71it/s, v_num=10]


Detected KeyboardInterrupt, attempting graceful shutdown ...


Epoch 0:   1%|          | 689/60000 [00:03<05:00, 197.69it/s, v_num=10]

NameError: name 'exit' is not defined

In [28]:
# Test
path_to_ckpt = '/Users/khueluu/Documents/Projects/lightning-playground/lightning_logs/version_6/checkpoints/epoch=14-step=900000.ckpt'
model = LitAutoEncoder.load_from_checkpoint(path_to_ckpt, encoder=Encoder(), decoder=Decoder())
model.eval()


/Users/khueluu/.local/share/virtualenvs/lightning-playground-veOcVop1/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.
/Users/khueluu/.local/share/virtualenvs/lightning-playground-veOcVop1/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['decoder'])`.


LitAutoEncoder(
  (encoder): Encoder(
    (l1): Sequential(
      (0): Linear(in_features=784, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=3, bias=True)
    )
  )
  (decoder): Decoder(
    (l1): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=784, bias=True)
    )
  )
)

In [29]:
trainer.test(model=model, dataloaders=DataLoader(test_set))

/Users/khueluu/.local/share/virtualenvs/lightning-playground-veOcVop1/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [01:47<?, ?it/s]
Testing: |          | 10000/? [00:18<00:00, 542.28it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validation loss        0.04106734320521355
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'Validation loss': 0.04106734320521355}]