<a href="https://colab.research.google.com/github/bipinKrishnan/pytorch_lightning_examples/blob/main/my_lightning_module.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim
from torchvision import models
import torch.nn.functional as F

from tqdm.notebook import tqdm

# Lightning Module

In [None]:
class MyLightningModule(nn.Module):
  def __init__(self):
    super().__init__()

  def training_step(self, batch):
    return

  def configure_optimizer(self):
    return

# Trainer class

In [None]:
class Trainer:
  def __init__(self, model):
    self.model = model

  def fit(self, trainloader, num_epoch):
    opt = self.model.configure_optimizer()
    epoch_tqdm = tqdm(range(num_epoch), total=num_epoch, leave=False)
    batch_tqdm = tqdm(trainloader, total=len(trainloader), leave=False)

    for epoch in epoch_tqdm:
      for batch in batch_tqdm:
        opt.zero_grad()
        y_ = self.model.training_step(batch)
        opt.step()
      print(y_)

# Testing our library(Lightning module & Trainer)

### Loading the dataset and building the model component

In [None]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision.transforms.transforms import ToTensor

In [None]:
ds = CIFAR10('./data', train=True, transform=ToTensor(), download=True)
dl = DataLoader(ds, 4, shuffle=True)

In [None]:
class Model(MyLightningModule):
  def __init__(self, device):
    super().__init__()
    self.device = device
    self.clf = self.build_model()

  def build_model(self):
    self.vgg = models.vgg11(pretrained=True)
    for params in self.vgg.parameters():
      params.requires_grad = False   
    self.vgg.classifier[-1] = nn.Linear(4096, 10)

    return self.vgg.to(self.device)

  def training_step(self, batch):
    x, y = batch
    out = self.clf(x.to(self.device))
    loss = F.cross_entropy(out, y.to(self.device))
    loss.backward()

    return {"Loss": loss}

  def configure_optimizer(self):
    return optim.Adam(self.clf.parameters())

### Training 

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(device)
trainer = Trainer(model)

In [None]:
trainer.fit(dl, 3)

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))

{'Loss': tensor(7.5864, device='cuda:0', grad_fn=<NllLossBackward>)}
{'Loss': tensor(5.3257, device='cuda:0', grad_fn=<NllLossBackward>)}
{'Loss': tensor(1.3992, device='cuda:0', grad_fn=<NllLossBackward>)}
