<a href="https://colab.research.google.com/github/bipinKrishnan/pytorch_lightning_examples/blob/main/my_lightning_module.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn, optim
from torchvision import models
import torch.nn.functional as F

from tqdm.notebook import tqdm

# Lightning Module

In [2]:
class MyLightningModule(nn.Module):
  def __init__(self):
    super().__init__()

  def training_step(self, batch):
    return

  def validation_step(self, batch):
    return

  def configure_optimizers(self):
    return

# Trainer class

In [3]:
class Trainer:
  def __init__(self, max_epochs):
    self.max_epochs = max_epochs

  def fit(self, model, trainloader, valloader=None):
    opt = model.configure_optimizers()
    epoch_tqdm = tqdm(range(self.max_epochs), total=self.max_epochs, leave=False)

    for epoch in epoch_tqdm:

      for train_batch in tqdm(trainloader, total=len(trainloader), leave=False):
        opt.zero_grad()
        train_out = model.training_step(train_batch)
        opt.step()
      if valloader:
        with torch.no_grad():
          for val_batch in valloader:
            val_out = model.validation_step(val_batch)

      print(train_out, val_out)

# Testing our library(Lightning module & Trainer)

### Loading the dataset and building the model component

In [4]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision.transforms.transforms import ToTensor

In [5]:
train_ds = CIFAR10('./data', train=True, transform=ToTensor(), download=True)
val_ds = CIFAR10('./data', train=False, transform=ToTensor(), download=True)

train_dl = DataLoader(train_ds, 64, shuffle=True)
val_dl = DataLoader(val_ds, 64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class Model(MyLightningModule):
  def __init__(self, device):
    super().__init__()
    self.device = device
    self.clf = self.build_model()

  def build_model(self):
    self.vgg = models.vgg11(pretrained=True)
    for params in self.vgg.parameters():
      params.requires_grad = False   
    self.vgg.classifier[-1] = nn.Linear(4096, 10)

    return self.vgg.to(self.device)

  def training_step(self, batch):
    x, y = batch
    out = self.clf(x.to(self.device))
    loss = F.cross_entropy(out, y.to(self.device))
    loss.backward()

    return {"train_loss": loss.detach()}

  def validation_step(self, batch):
    x, y = batch
    out = self.clf(x.to(device))
    loss = F.cross_entropy(out, y.to(device))

    return {"val_loss": loss.detach()}

  def configure_optimizers(self):
    return optim.Adam(self.clf.parameters())

### Training 

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(device)
trainer = Trainer(5)

In [8]:
trainer.fit(model, train_dl, val_dl)

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(0.9787, device='cuda:0')} {'val_loss': tensor(2.0336, device='cuda:0')}


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(2.1279, device='cuda:0')} {'val_loss': tensor(1.4343, device='cuda:0')}


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(1.3419, device='cuda:0')} {'val_loss': tensor(1.4505, device='cuda:0')}


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(1.1404, device='cuda:0')} {'val_loss': tensor(1.5227, device='cuda:0')}


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(1.5020, device='cuda:0')} {'val_loss': tensor(2.0477, device='cuda:0')}
