In [None]:
import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import metric

In [None]:
class CustomMSE(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        sum_squared_error = torch.sum((preds - target) ** 2)
        total = target.numel()
        self.sum_squared_error += sum_squared_error
        self.total += total

    def compute(self):
        return self.sum_squared_error / self.total

In [None]:
class NN(pl.LightningModule):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="multiclass",num_classes=num_classes)
        self.f1_score = torchmetrics.F1Score(task="multiclass",num_classes=num_classes)

        self.custom_mse = CustomMSE()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)

        accuracy = self.accuracy(scores,y)
        f1_score = self.f1_score(scores,y)

        accuracy = self.custom_mse(scores, y) # for custom accuracy

        # on_step=False: This means the metrics will not be logged at every training step.
        # on_epoch=True: This means the metrics will be logged at the end of each training epoch.

        # Epoch: A full pass through the entire dataset, encompassing all batches.
        # Step: A single update of model parameters after processing one batch of data.

        self.log_dict({'train_loss': loss, 'train_accuracy': accuracy, 'train_f1_score': f1_score},
              on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self


    def test_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('test_loss', loss)
        return loss

    def _common_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(x.size(0), -1)
        scores = self.forward(x)
        loss = self.loss_fn(scores, y)
        return loss, scores, y

    def predict_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape(x.size(0), -1)
        scores = self.forward(x)
        preds = torch.argmax(scores, dim=1)
        return preds

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

In [None]:
data_module = MyDataModule(batch_size=64, data_dir='./data',num_workers=3)

In [None]:
# use 2 gpus and train for minimum of 3 epochs
# precision=16: This enables training with mixed precision, 
# specifically 16-bit floating point (half precision). 

trainer = pl.Trainer(accelerator="gpu",devices=2,min_epochs=3,precision=16)

###  Just pass data_module object and trainer automatically recognizes what split to use

In [None]:
# fit: Trains the model using the training data and evaluates it periodically on the validation data.
# validate: Evaluates the model on the validation data after or outside of the training loop.
# test: Evaluates the model on the test data to assess its performance on new, unseen data.

trainer.fit(model,data_module)
trainer.validate(model,data_module)
trainer.test(model,data_module)