# CIFAR 10 Classifier: Pytorch Lightning Version

* 16 Bit Precision
* Experiment tracking with Comel.ml

---

@date: 03-Sep-2020 | @author: katnoria

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models, utils
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_lightning.loggers import CometLogger

In [2]:
pl.seed_everything(42)

42

In [3]:
def version_info(cls):
    print(f"{cls.__name__}: {cls.__version__}")

In [4]:
version_info(torch)
version_info(pl)

torch: 1.6.0
pytorch_lightning: 0.9.0


Stuff from Keras-Tuner on Colab
```
{'conv_blocks': 5,
 'dropout': 0.0,
 'filters_0': 128,
 'filters_1': 160,
 'filters_2': 160,
 'filters_3': 192,
 'filters_4': 224,
 'hidden_size': 80,
 'learning_rate': 0.0025359172395390105,
 'pooling_0': 'max',
 'pooling_1': 'max',
 'pooling_2': 'max',
 'pooling_3': 'max',
 'pooling_4': 'max',
 'tuner/bracket': 2,
 'tuner/epochs': 30,
 'tuner/initial_epoch': 10,
 'tuner/round': 2,
 'tuner/trial_id': 'b780a5e9191d55c360ad5be6040decc4'}
```

In [60]:
comet_logger = CometLogger(
    api_key='r3QI6mx4KaB3v0VMFwt6bcf33',
    workspace='katnoria',  # Optional
    save_dir='.',  # Optional
    project_name='cf10-pl',  # Optional
#     rest_api_key=os.environ.get('COMET_REST_API_KEY'),  # Optional
    experiment_name='pre-2'  # Optional
)

CometLogger will be initialized in online mode
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/katnoria/cf10-pl/3b135773fd4e41f6bc7c672abed2f02f



In [61]:
class CIFARTenLitModel(pl.LightningModule):
    """CIFAR10 Model"""
    def __init__(self, backbone, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.backbone = backbone
        self.backbone.fc = nn.Linear(2048, 10)
#         self.fc1 = nn.Linear(2048, 128)
#         self.fc2 = nn.Linear(128, 10)
#         nn.AvgPool2d()

    def forward(self, x):
        x = self.backbone(x)
#         x = F.relu(self.fc1(x))
#         out = self.fc2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.TrainResult(loss)
        result.log("train_loss", loss, prog_bar=True)
        return result
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log("val_loss", loss)
        return result
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

In [62]:
class CIFARTenLitModelV2(pl.LightningModule):
    """CIFAR10 Model"""
    def __init__(self, backbone, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.backbone = backbone
        self.backbone.fc = nn.Linear(2048, 256)
        self.fc1 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.backbone(x)
        x = F.relu(x)
        out = self.fc1(x)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.TrainResult(loss)
        result.log("train_loss", loss, prog_bar=True)
        return result
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log("val_loss", loss)
        return result
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

# Dataset

In [63]:
tfms = transforms.Compose([
    transforms.Resize(224),    
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

In [64]:
train_ds = datasets.CIFAR10(
    root="./data", 
    train=True,
    download=True,
    transform=tfms
)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=12)

Files already downloaded and verified


In [65]:
test_ds = datasets.CIFAR10(
    root="./data", 
    train=False,
    download=True,
    transform=tfms
)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=12)

Files already downloaded and verified


# Train

In [66]:
backbone = models.resnet50(pretrained=True)
for param in backbone.parameters():
    param.requires_grad = False

In [67]:
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min'
)

In [68]:
trainer = pl.Trainer(
    fast_dev_run=False, 
    gpus=1, 
    early_stop_callback=early_stop, 
    max_epochs=100,
#     auto_lr_find=True,
    logger=comet_logger
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [69]:
model = CIFARTenLitModelV2(backbone, 1e-3)

In [None]:
trainer.fit(model, train_loader, val_dataloaders=test_loader)


  | Name     | Type   | Params
------------------------------------
0 | backbone | ResNet | 24 M  
1 | fc1      | Linear | 2 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
# trainer.test(test_dataloaders=test_loader, ckpt_path="lightning_logs/version_12/checkpoints")

In [33]:
train_ds

Dataset CIFAR10
    Number of datapoints: 50000
    Split: train
    Root Location: ./data
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                         )
    Target Transforms (if any): None

In [37]:
for batch in train_loader:
    print(batch[0].size())
    break

torch.Size([128, 3, 32, 32])
