<a href="https://colab.research.google.com/github/bipinKrishnan/pytorch_lightning_examples/blob/main/pytorch_lightning_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch-lightning

In [38]:
from torch import nn, optim
import torch.nn.functional as F
from torchvision.transforms import transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer

In [39]:
def load_dataloader(name, transform, bs):
  if name == 'cifar10':
    train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
    val_ds = datasets.CIFAR10('./data', train=False, download=True, transform=transform)

    train_dl = DataLoader(train_ds, bs, shuffle=True)
    val_dl = DataLoader(val_ds, bs, shuffle=False)

    return train_dl, val_dl    

In [71]:
class Model(LightningModule):
  def __init__(self, model='resnet18', freeze=False):
    super().__init__()
    self.model = self._build_model(model, freeze)

  def _build_model(self, name, freeze):
    if name=='resnet18':
      model = models.resnet18(pretrained=True)

      if freeze==True:
        for params in model.parameters():
          params.requires_grad = False
        model.fc = nn.Linear(512, 10)
        return model
      else:
        model.fc = nn.Linear(512, 10)
        return model

  def forward(x):
    out = self.model(x)
    return out

  def training_step(self, batch, batch_idx):
    x, y = batch
    out = self.model(x)
    loss = F.cross_entropy(out, y)
    self.log("train_loss", loss, prog_bar=True)

    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    out = self.model(x)
    loss = F.cross_entropy(out, y)
    self.log("val_loss", loss, prog_bar=True)

    return loss

  def configure_optimizers(self):
    return optim.Adam(self.model.parameters())

In [80]:
trainloader, valloader = load_dataloader('cifar10', transforms.ToTensor(), 64)
model = Model('resnet18', freeze=True)

Files already downloaded and verified
Files already downloaded and verified


In [83]:
trainer = Trainer(max_epochs=5,
                  gpus=1, 
                  progress_bar_refresh_rate=60)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Running in fast_dev_run mode: will run a full train, val and test loop using a single batch


In [84]:
trainer.fit(model, trainloader, valloader)


  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1