### Import Libraries

In [None]:
import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import pytorch_lightning as pl

In [None]:
# In the pytorch-lightning framework, certain method names in a LightningModule class
#  are reserved and have specific purposes. 
# These methods are not Python keywords,
#  but rather special methods that pytorch-lightning expects and uses to manage the 
# training, validation, and testing loops. Here's a brief overview of these methods:

# __init__: Initializes the module. T
# his is a standard Python method used to initialize the instance attributes.

# forward: Defines the forward pass of the neural network. 
# This is a standard method in PyTorch models, 
# specifying how the input data flows through the network layers.

# training_step: Defines a single step of the training loop. 
# pytorch-lightning uses this method to know what operations to perform during training 
# for each batch of data.

# validation_step: Defines a single step of the validation loop. 
# Similar to training_step, but for the validation phase.

# test_step: Defines a single step of the testing loop. 
# Similar to training_step, but for the testing phase.

# _common_step: This is a custom method you've defined to avoid code duplication between 
# training_step, validation_step, and test_step. 
# It is not a special method recognized by pytorch-lightning, 
# but rather a helper method to reduce redundancy in your code.

# predict_step: Defines a single step for prediction. 
# pytorch-lightning uses this method to perform predictions on new data.

# configure_optimizers: Specifies the optimizer(s) to use during training. 
# pytorch-lightning uses this method to configure the optimizer(s) for training.

In [None]:
class NN(pl.LightningModule):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self


    def test_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('test_loss', loss)
        return loss

    def _common_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(x.size(0), -1)
        scores = self.forward(x)
        loss = self.loss_fn(scores, y)
        return loss, scores, y

    def predict_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(x.size(0), -1)
        scores = self.forward(x)
        preds = torch.argmax(scores, dim=1)
        return preds

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)


In [None]:
data_module = MyDataModule(batch_size=64, data_dir='./data',num_workers=3)

In [None]:
# use 2 gpus and train for minimum of 3 epochs
# precision=16: This enables training with mixed precision, 
# specifically 16-bit floating point (half precision). 

trainer = pl.Trainer(accelerator="gpu",devices=2,min_epochs=3,precision=16)

### Just pass data_module object and trainer automatically recognizes what split to use

In [None]:
# fit: Trains the model using the training data and evaluates it periodically on the validation data.
# validate: Evaluates the model on the validation data after or outside of the training loop.
# test: Evaluates the model on the test data to assess its performance on new, unseen data.

trainer.fit(model,train_loader,val_loader)
trainer.validate(model,val_loader)
trainer.test(model,test_loader)