<a href="https://colab.research.google.com/github/iemio/torch/blob/main/notebook_09.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install pytorch-lightning



In [4]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning import Trainer

In [5]:
# Hyper-parameters
input_size = 784  # 28x28
hidden_size = 500
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001

In [28]:
# Fully connected neural network with one hidden layer
class LitNeuralNet(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LitNeuralNet, self).__init__()
        self.save_hyperparameters()  # Saves hyperparameters for logging
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return out

    def training_step(self, batch, batch_idx):
        images, labels = batch
        images = images.view(-1, 28 * 28)  # Reshape

        # Forward pass
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)

        # Log training loss
        self.log("train_loss", loss, prog_bar=True)
        return loss  # Return loss directly

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        images = images.view(-1, 28 * 28)  # Reshape

        # Forward pass
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)

        # Log validation loss
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=learning_rate)

    def train_dataloader(self):
        train_dataset = torchvision.datasets.MNIST(
            root="./data", train=True, transform=transforms.ToTensor(), download=True
        )
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset, batch_size=batch_size, num_workers=4, shuffle=True
        )
        return train_loader

    def val_dataloader(self):
        val_dataset = torchvision.datasets.MNIST(
            root="./data", train=False, transform=transforms.ToTensor()
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_dataset, batch_size=batch_size, num_workers=4, shuffle=False
        )
        return val_loader

In [31]:
model = LitNeuralNet(input_size, hidden_size, num_classes)

In [32]:
# gpus=8
# fast_dev_run=True -> runs single batch through training and validation
# train_percent_check=0.1 -> train only on 10% of data
trainer = Trainer(max_epochs=num_epochs,fast_dev_run=False)
trainer.fit(model)
# advanced features
# distributed_backend
# (DDP) implements data parallelism at the module level which can run across multiple machines.
# 16 bit precision
# log_gpu_memory
# TPU support

# auto_lr_find: automatically finds a good learning rate before training
# deterministic: makes training reproducable
# gradient_clip_val: 0 default

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type   | Params | Mode 
----------------------------------------
0 | l1   | Linear | 392 K  | train
1 | relu | ReLU   | 0      | train
2 | l2   | Linear | 5.0 K  | train
----------------------------------------
397 K     Trainable params
0         Non-trainable params
397 K     Total params
1.590     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
