<a href="https://colab.research.google.com/github/bipinKrishnan/pytorch_lightning_examples/blob/main/my_pytorch_lightning/cifar10_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim
from torchvision import models
import torch.nn.functional as F

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision.transforms.transforms import ToTensor

from my_pytorch_lightning import MyLightningModule, Trainer

In [None]:
#creating dataset
train_ds = CIFAR10('./data', train=True, transform=ToTensor(), download=True)
val_ds = CIFAR10('./data', train=False, transform=ToTensor(), download=True)

#creating the dataloaders
train_dl = DataLoader(train_ds, 64, shuffle=True)
val_dl = DataLoader(val_ds, 64, shuffle=False)

In [None]:
#building the model class by subclassing our lightning module
class Model(MyLightningModule):
  def __init__(self, device):
    super().__init__()
    self.device = device
    self.clf = self.build_model()

  def build_model(self):
    self.vgg = models.vgg11(pretrained=True)
    for params in self.vgg.parameters():
      params.requires_grad = False   
    self.vgg.classifier[-1] = nn.Linear(4096, 10)

    return self.vgg.to(self.device)

  def training_step(self, batch):
    x, y = batch
    out = self.clf(x.to(self.device))
    loss = F.cross_entropy(out, y.to(self.device))
    loss.backward()

    return {"train_loss": loss.detach()}

  def validation_step(self, batch):
    x, y = batch
    out = self.clf(x.to(device))
    loss = F.cross_entropy(out, y.to(device))

    return {"val_loss": loss.detach()}

  def configure_optimizers(self):
    return optim.Adam(self.clf.parameters())

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(device)
trainer = Trainer(3)

In [None]:
#training the model 
trainer.fit(model, train_dl, val_dl)

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(1.8374, device='cuda:0')} {'val_loss': tensor(1.3314, device='cuda:0')}


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(2.0644, device='cuda:0')} {'val_loss': tensor(1.4736, device='cuda:0')}


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))

{'train_loss': tensor(1.7570, device='cuda:0')} {'val_loss': tensor(1.3854, device='cuda:0')}
