## Adding the data-related Hook inside the lightning model

In [4]:
import os
import torch
import pytorch_lightning as pl
#import lightning as pl
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn
from torch.nn import functional as F

from torchmetrics import Accuracy
from torch.utils.data import random_split

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 128 if torch.cuda.is_available() else 32

class LitModel(pl.LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()
        
        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate
        
        #self.l1 = nn.Linear(28 * 28, 10)
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )
        
        self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.validation_step_outputs = []

    def forward(self, x):
        #return torch.relu(self.l1(x.view(x.size(0), -1)))
        logit = self.model(x)
        return logit

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        #logits = self(x)
        #loss = F.nll_loss(logits, y)
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        #preds = torch.argmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)
        #self.validation_step_outputs.append(pred)
        self.validation_step_outputs.append(loss)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True, on_step=True, on_epoch=True )
        self.log("val_acc", self.val_accuracy, prog_bar=True, on_step=True, on_epoch=True )
        return {'val_loss': loss}
   
    def on_validation_epoch_end(self): 
        #avg_loss = torch.stack([x['val_loss'] for x in validation_step_outputs]).mean()
        avg_loss = torch.stack(self.validation_step_outputs).mean()
        #print("avg_loss: ", avg_loss)
        self.log("avg_val_loss", avg_loss, prog_bar=True)
        self.validation_step_outputs.clear()
        return {'avg_val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)
    
    ####################
    # DATA RELATED HOOKS
    ####################
    
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        #MNIST(PATH_DATASETS, train=False, download=True)
        
    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, download=True, transform=transforms.ToTensor())
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        
    def train_dataloader(self):
        return DataLoader(self.mnist_train, num_workers=4, batch_size = BATCH_SIZE)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, num_workers=4, batch_size = BATCH_SIZE)

#BATCH_SIZE = 128 if torch.cuda.is_available() else 32
#train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))

#train_loader = DataLoader(
#    MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()), num_workers=4, 
#    batch_size = BATCH_SIZE
#)

#val_loader = DataLoader(
#    MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()), num_workers=4,
#    batch_size = BATCH_SIZE
#)

#mnist_full = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
#mnist_train, mnist_val = random_split(mnist_full, [55000, 5000])

#train_loader = DataLoader(
#    mnist_train, num_workers=4, batch_size = BATCH_SIZE
#)

#val_loader = DataLoader(
#    mnist_val, num_workers=4, batch_size = BATCH_SIZE
#)

trainer = pl.Trainer(
    accelerator="auto",
    max_epochs = 5
)
model = LitModel()
#trainer.fit(model, train_loader, val_loader)
trainer.fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params
----------------------------------------------------
0 | model        | Sequential         | 55.1 K
1 | val_accuracy | MulticlassAccuracy | 0     
----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
