In [1]:
pip install torch==2.2.0 torchvision


Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simpleNote: you may need to restart the kernel to use updated packages.

Collecting torch==2.2.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/58/b8/51b956c2da9729390a3080397cd2f31171394543af7746681466e372f69a/torch-2.2.0-cp311-cp311-win_amd64.whl (198.6 MB)
     ---------------------------------------- 0.0/198.6 MB ? eta -:--:--
     ---------------------------------------- 0.1/198.6 MB 2.6 MB/s eta 0:01:16
     ---------------------------------------- 0.2/198.6 MB 2.1 MB/s eta 0:01:35
     ---------------------------------------- 0.3/198.6 MB 2.0 MB/s eta 0:01:41
     ---------------------------------------- 0.4/198.6 MB 2.1 MB/s eta 0:01:36
     ---------------------------------------- 0.5/198.6 MB 2.0 MB/s eta 0:01:39
     ---------------------------------------- 0.5/198.6 MB 2.1 MB/s eta 0:01:35
     ---------------------------------------- 0.6/198.6 MB 2.1 MB/s eta 0:01:34
     --------------------------------------

In [1]:
# %% [markdown]
# ## Introduction
#
# First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`

# %%
import os

import lightning as L
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import CIFAR10, MNIST


In [5]:
PATH_DATASETS = os.environ.get("PATH_DATASETS", "./MINST")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

# %% [markdown]
# ## Using DataModules
#
# DataModules are a way of decoupling data-related hooks from the `LightningModule
# ` so you can develop dataset agnostic models.

# %% [markdown]
# ### Defining The MNISTDataModule
#
# Let's go over each function in the class below and talk about what they're doing:
#
# 1. ```__init__```
#     - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.
#     - Defines a transform that will be applied across train, val, and test dataset splits.
#     - Defines default `self.dims`.
#
#
# 2. ```prepare_data```
#     - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.
#     - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)
#
# 3. ```setup```
#     - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).
#     - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.
#     - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.
#     - **Note this runs across all GPUs and it *is* safe to make state assignments here**
#
#
# 4. ```x_dataloader```
#     - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`


# %%
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


# %% [markdown]
# ### Defining the dataset agnostic `LitModel`
#
# Below, we define the same model as the `LitMNIST` model we made earlier.
#
# However, this time our model has the freedom to use any input data that we'd like 🔥.


# %%
class LitModel(L.LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # We take in input dimensions as parameters and use those to dynamically build model.
        self.channels = channels
        self.width = width
        self.height = height
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task="multiclass", num_classes=10)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


# %% [markdown]
# ### Training the `LitModel` using the `MNISTDataModule`
#
# Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders.

# %%
# Init DataModule



In [6]:
dm = MNISTDataModule()

In [7]:
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = L.Trainer(
    max_epochs=3,
    accelerator="auto",
    devices=1,
)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
c:\Users\jinmeng\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\logger_connector\logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Missing logger folder: d:\Coding\learndl\learndl\lightning_logs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220   

Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 62.39it/s]

c:\Users\jinmeng\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

c:\Users\jinmeng\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 860/860 [00:19<00:00, 43.58it/s, v_num=0, val_loss=0.188, val_acc=0.943]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 860/860 [00:19<00:00, 43.55it/s, v_num=0, val_loss=0.188, val_acc=0.943]


In [None]:

# %% [markdown]
# ### Defining the CIFAR10 DataModule
#
# Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset.


# %%
class CIFAR10DataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)


# %% [markdown]
# ### Training the `LitModel` using the `CIFAR10DataModule`
#
# Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.
#
# The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data.

# %%
dm = CIFAR10DataModule()
model = LitModel(*dm.dims, dm.num_classes, hidden_size=256)
trainer = L.Trainer(
    max_epochs=5,
    accelerator="auto",
    devices=1,
)
trainer.fit(model, dm)