In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning as L

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64),
                                nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(
            nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

In [4]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [5]:
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

In [6]:
# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# train model
trainer = L.Trainer(max_steps=20)
trainer.fit(model=autoencoder, train_dataloaders=train_loader,
            ckpt_path="lightning_logs/version_3/checkpoints/epoch=0-step=10.ckpt")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at lightning_logs/version_3/checkpoints/epoch=0-step=10.ckpt
/Users/Deependu/miniconda3/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:362: The dirpath has changed from '/Users/Deependu/code/pytorch_lightning_codes/01-basic/lightning_logs/version_3/checkpoints' to '/Users/Deependu/code/pytorch_lightning_codes/01-basic/lightning_logs/version_6/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | encoder | Encoder | 50.4 K | train
1 | decoder | Decoder | 51.2 K | train
--------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0

Training: |          | 0/? [00:00<?, ?it/s]

/Users/Deependu/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py:221: You're resuming from a checkpoint that ended before the epoch ended and your dataloader is not resumable. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing the `state_dict` / `load_state_dict` interface.
`Trainer.fit` stopped: `max_steps=20` reached.


In [7]:
len(list(autoencoder.children()))

2