# Transfer Learning CNNs with PyTorch Lightning !

## Install Quickvision !

In [2]:
!pip install -q git+https://github.com/Quick-AI/quickvision.git

[K     |████████████████████████████████| 256kB 21.6MB/s 
[K     |████████████████████████████████| 563kB 52.7MB/s 
[K     |████████████████████████████████| 276kB 55.2MB/s 
[K     |████████████████████████████████| 829kB 49.5MB/s 
[K     |████████████████████████████████| 92kB 12.3MB/s 
[?25h  Building wheel for quickvision (setup.py) ... [?25l[?25hdone
  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... [?25l[?25hdone


In [19]:
import pytorch_lightning as pl
import torch.optim as optim
import torchvision.transforms as T
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torch
import torchvision
from quickvision.models.classification.cnn import lit_cnn

## Create Datasets and DataLoaders

In [4]:
train_transforms = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
valid_transforms = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])

In [9]:
train_dataset = torchvision.datasets.CIFAR10("./data", download=True, train=True, transform=train_transforms)
valid_dataset = torchvision.datasets.CIFAR10("./data", download=True, train=False, transform=valid_transforms)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
TRAIN_BATCH_SIZE = 512  # Training Batch Size
VALID_BATCH_SIZE = 512  # Validation Batch Size

In [11]:
train_loader = DataLoader(train_dataset, TRAIN_BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, VALID_BATCH_SIZE, shuffle=False)

## Creating Model and Training !

- Create a model with pretrained imagenet weights

In [None]:
model_imagenet = lit_cnn("resnet18", num_classes=10, pretrained="imagenet")

- You can pass all the possible Trainer Arguments.
- Quickvision Does not overwrite any !

In [None]:
trainer = pl.Trainer(max_epochs=2, gpus=1)
trainer.fit(model_imagenet, train_loader, valid_loader)

- Training without any pretrained weights.

In [14]:
model_ssl = lit_cnn("resnet18", num_classes=10, pretrained=None)

In [None]:
trainer = pl.Trainer(max_epochs=2, gpus=1)
trainer.fit(model_ssl, train_loader, valid_loader)

## Custom Training with Lightning !

- To write your own Training logic, metrics, logging. Subclass the `lit_cnn` and write your own logic !

In [16]:
class CustomTraining(lit_cnn):
    def training_step(self, batch, batch_idx):
        images, targets = batch
        outputs = self.forward(images)
        train_loss = F.cross_entropy(outputs, targets, reduction='sum')
        # Possible we can compute top-1 and top-5 accuracy here.
        return {"loss": train_loss}

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        outputs = self.forward(images)
        val_loss = F.cross_entropy(outputs, targets, reduction='sum')
        # Possible we can compute top-1 and top-5 accuracy here.
        return {"loss": val_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


- Create Model provided by Quickvision !

In [17]:
model_imagenet = CustomTraining("resnet18", num_classes=10, pretrained="imagenet")

- Train with PL Trainer !

In [None]:
trainer = pl.Trainer(max_epochs=2, gpus=1)
trainer.fit(model_imagenet, train_loader, valid_loader)